support specifying multiple vault IDs for a playbook run

see: https://github.com/ansible/awx/issues/352
This commit is contained in:
Ryan Petrello
2017-11-30 12:49:54 -05:00
parent fde5a8850d
commit a1f8f65add
8 changed files with 190 additions and 24 deletions

View File

@@ -415,6 +415,13 @@ class JSONSchemaField(JSONBField):
return value return value
@JSONSchemaField.format_checker.checks('vault_id')
def format_vault_id(value):
if '@' in value:
raise jsonschema.exceptions.FormatError('@ is not an allowed character')
return True
@JSONSchemaField.format_checker.checks('ssh_private_key') @JSONSchemaField.format_checker.checks('ssh_private_key')
def format_ssh_private_key(value): def format_ssh_private_key(value):
# Sanity check: GCE, in particular, provides JSON-encoded private # Sanity check: GCE, in particular, provides JSON-encoded private

View File

@@ -4,6 +4,7 @@ from __future__ import unicode_literals
from django.db import migrations, models from django.db import migrations, models
from awx.main.migrations import _migration_utils as migration_utils from awx.main.migrations import _migration_utils as migration_utils
from awx.main.migrations import _credentialtypes as credentialtypes
from awx.main.migrations._multi_cred import migrate_to_multi_cred from awx.main.migrations._multi_cred import migrate_to_multi_cred
@@ -50,4 +51,5 @@ class Migration(migrations.Migration):
model_name='jobtemplate', model_name='jobtemplate',
name='vault_credential', name='vault_credential',
), ),
migrations.RunPython(credentialtypes.add_vault_id_field)
] ]

View File

@@ -173,3 +173,8 @@ def migrate_job_credentials(apps, schema_editor):
finally: finally:
utils.get_current_apps = orig_current_apps utils.get_current_apps = orig_current_apps
def add_vault_id_field(apps, schema_editor):
vault_credtype = CredentialType.objects.get(kind='vault')
vault_credtype.inputs = CredentialType.defaults.get('vault')().inputs
vault_credtype.save()

View File

@@ -689,6 +689,16 @@ def vault(cls):
'type': 'string', 'type': 'string',
'secret': True, 'secret': True,
'ask_at_runtime': True 'ask_at_runtime': True
}, {
'id': 'vault_id',
'label': 'Vault Identifier',
'type': 'string',
'format': 'vault_id',
'help_text': ('Specify an (optional) Vault ID. This is '
'equivalent to specifying the --vault-id '
'Ansible parameter for providing multiple Vault '
'passwords. Note: this feature only works in '
'Ansible 2.4+.')
}], }],
'required': ['vault_password'], 'required': ['vault_password'],
} }

View File

@@ -172,6 +172,10 @@ class JobOptions(BaseModel):
def cloud_credentials(self): def cloud_credentials(self):
return list(self.credentials.filter(credential_type__kind='cloud')) return list(self.credentials.filter(credential_type__kind='cloud'))
@property
def vault_credentials(self):
return list(self.credentials.filter(credential_type__kind='vault'))
@property @property
def credential(self): def credential(self):
cred = self.get_deprecated_credential('ssh') cred = self.get_deprecated_credential('ssh')

View File

@@ -712,7 +712,7 @@ class BaseTask(LogErrorsTask):
job_timeout = 0 job_timeout = 0
return job_timeout return job_timeout
def get_password_prompts(self): def get_password_prompts(self, **kwargs):
''' '''
Return a dictionary where keys are strings or regular expressions for Return a dictionary where keys are strings or regular expressions for
prompts, and values are password lookup keys (keys that are returned prompts, and values are password lookup keys (keys that are returned
@@ -833,7 +833,7 @@ class BaseTask(LogErrorsTask):
job_cwd=cwd, job_env=safe_env, result_stdout_file=stdout_handle.name) job_cwd=cwd, job_env=safe_env, result_stdout_file=stdout_handle.name)
expect_passwords = {} expect_passwords = {}
for k, v in self.get_password_prompts().items(): for k, v in self.get_password_prompts(**kwargs).items():
expect_passwords[k] = kwargs['passwords'].get(v, '') or '' expect_passwords[k] = kwargs['passwords'].get(v, '') or ''
_kw = dict( _kw = dict(
expect_passwords=expect_passwords, expect_passwords=expect_passwords,
@@ -961,19 +961,30 @@ class RunJob(BaseTask):
and ansible-vault. and ansible-vault.
''' '''
passwords = super(RunJob, self).build_passwords(job, **kwargs) passwords = super(RunJob, self).build_passwords(job, **kwargs)
for kind, fields in { cred = job.get_deprecated_credential('ssh')
'ssh': ('ssh_key_unlock', 'ssh_password', 'become_password'), if cred:
'vault': ('vault_password',) for field in ('ssh_key_unlock', 'ssh_password', 'become_password'):
}.items(): value = kwargs.get(
cred = job.get_deprecated_credential(kind) field,
if cred: decrypt_field(cred, 'password' if field == 'ssh_password' else field)
for field in fields: )
if field == 'ssh_password': if value not in ('', 'ASK'):
value = kwargs.get(field, decrypt_field(cred, 'password')) passwords[field] = value
else:
value = kwargs.get(field, decrypt_field(cred, field)) for cred in job.vault_credentials:
if value not in ('', 'ASK'): field = 'vault_password'
passwords[field] = value if cred.inputs.get('vault_id'):
field = 'vault_password.{}'.format(cred.inputs['vault_id'])
if field in passwords:
raise RuntimeError(
'multiple vault credentials were specified with --vault-id {}@prompt'.format(
cred.inputs['vault_id']
)
)
value = kwargs.get(field, decrypt_field(cred, 'vault_password'))
if value not in ('', 'ASK'):
passwords[field] = value
return passwords return passwords
def build_env(self, job, **kwargs): def build_env(self, job, **kwargs):
@@ -1107,9 +1118,16 @@ class RunJob(BaseTask):
args.extend(['--become-user', become_username]) args.extend(['--become-user', become_username])
if 'become_password' in kwargs.get('passwords', {}): if 'become_password' in kwargs.get('passwords', {}):
args.append('--ask-become-pass') args.append('--ask-become-pass')
# Support prompting for a vault password.
if 'vault_password' in kwargs.get('passwords', {}): # Support prompting for multiple vault passwords
args.append('--ask-vault-pass') for k, v in kwargs.get('passwords', {}).items():
if k.startswith('vault_password'):
if k == 'vault_password':
args.append('--ask-vault-pass')
else:
vault_id = k.split('.')[1]
args.append('--vault-id')
args.append('{}@prompt'.format(vault_id))
if job.forks: # FIXME: Max limit? if job.forks: # FIXME: Max limit?
args.append('--forks=%d' % job.forks) args.append('--forks=%d' % job.forks)
@@ -1177,8 +1195,8 @@ class RunJob(BaseTask):
def get_idle_timeout(self): def get_idle_timeout(self):
return getattr(settings, 'JOB_RUN_IDLE_TIMEOUT', None) return getattr(settings, 'JOB_RUN_IDLE_TIMEOUT', None)
def get_password_prompts(self): def get_password_prompts(self, **kwargs):
d = super(RunJob, self).get_password_prompts() d = super(RunJob, self).get_password_prompts(**kwargs)
d[re.compile(r'Enter passphrase for .*:\s*?$', re.M)] = 'ssh_key_unlock' d[re.compile(r'Enter passphrase for .*:\s*?$', re.M)] = 'ssh_key_unlock'
d[re.compile(r'Bad passphrase, try again for .*:\s*?$', re.M)] = '' d[re.compile(r'Bad passphrase, try again for .*:\s*?$', re.M)] = ''
for method in PRIVILEGE_ESCALATION_METHODS: for method in PRIVILEGE_ESCALATION_METHODS:
@@ -1187,6 +1205,10 @@ class RunJob(BaseTask):
d[re.compile(r'SSH password:\s*?$', re.M)] = 'ssh_password' d[re.compile(r'SSH password:\s*?$', re.M)] = 'ssh_password'
d[re.compile(r'Password:\s*?$', re.M)] = 'ssh_password' d[re.compile(r'Password:\s*?$', re.M)] = 'ssh_password'
d[re.compile(r'Vault password:\s*?$', re.M)] = 'vault_password' d[re.compile(r'Vault password:\s*?$', re.M)] = 'vault_password'
for k, v in kwargs.get('passwords', {}).items():
if k.startswith('vault_password.'):
vault_id = k.split('.')[1]
d[re.compile(r'Vault password \({}\):\s*?$'.format(vault_id), re.M)] = k
return d return d
def get_stdout_handle(self, instance): def get_stdout_handle(self, instance):
@@ -1442,8 +1464,8 @@ class RunProjectUpdate(BaseTask):
output_replacements.append((pattern2 % d_before, pattern2 % d_after)) output_replacements.append((pattern2 % d_before, pattern2 % d_after))
return output_replacements return output_replacements
def get_password_prompts(self): def get_password_prompts(self, **kwargs):
d = super(RunProjectUpdate, self).get_password_prompts() d = super(RunProjectUpdate, self).get_password_prompts(**kwargs)
d[re.compile(r'Username for.*:\s*?$', re.M)] = 'scm_username' d[re.compile(r'Username for.*:\s*?$', re.M)] = 'scm_username'
d[re.compile(r'Password for.*:\s*?$', re.M)] = 'scm_password' d[re.compile(r'Password for.*:\s*?$', re.M)] = 'scm_password'
d[re.compile(r'Password:\s*?$', re.M)] = 'scm_password' d[re.compile(r'Password:\s*?$', re.M)] = 'scm_password'
@@ -2142,8 +2164,8 @@ class RunAdHocCommand(BaseTask):
def get_idle_timeout(self): def get_idle_timeout(self):
return getattr(settings, 'JOB_RUN_IDLE_TIMEOUT', None) return getattr(settings, 'JOB_RUN_IDLE_TIMEOUT', None)
def get_password_prompts(self): def get_password_prompts(self, **kwargs):
d = super(RunAdHocCommand, self).get_password_prompts() d = super(RunAdHocCommand, self).get_password_prompts(**kwargs)
d[re.compile(r'Enter passphrase for .*:\s*?$', re.M)] = 'ssh_key_unlock' d[re.compile(r'Enter passphrase for .*:\s*?$', re.M)] = 'ssh_key_unlock'
d[re.compile(r'Bad passphrase, try again for .*:\s*?$', re.M)] = '' d[re.compile(r'Bad passphrase, try again for .*:\s*?$', re.M)] = ''
for method in PRIVILEGE_ESCALATION_METHODS: for method in PRIVILEGE_ESCALATION_METHODS:

View File

@@ -247,6 +247,31 @@ def test_ssh_key_data_validation(organization, kind, ssh_key_data, ssh_key_unloc
assert e.type in (ValidationError, serializers.ValidationError) assert e.type in (ValidationError, serializers.ValidationError)
@pytest.mark.django_db
@pytest.mark.parametrize('inputs, valid', [
({'vault_password': 'some-pass'}, True),
({}, False),
({'vault_password': 'dev-pass', 'vault_id': 'dev'}, True),
({'vault_password': 'dev-pass', 'vault_id': 'dev@prompt'}, False), # @ not allowed
])
def test_vault_validation(organization, inputs, valid):
cred_type = CredentialType.defaults['vault']()
cred_type.save()
cred = Credential(
credential_type=cred_type,
name="Best credential ever",
inputs=inputs,
organization=organization
)
cred.save()
if valid:
cred.full_clean()
else:
with pytest.raises(Exception) as e:
cred.full_clean()
assert e.type in (ValidationError, serializers.ValidationError)
@pytest.mark.django_db @pytest.mark.django_db
@pytest.mark.parametrize('become_method, valid', zip( @pytest.mark.parametrize('become_method, valid', zip(
dict(V1Credential.FIELDS['become_method'].choices).keys(), dict(V1Credential.FIELDS['become_method'].choices).keys(),

View File

@@ -450,6 +450,97 @@ class TestJobCredentials(TestJobExecution):
] == 'vault-me' ] == 'vault-me'
assert '--ask-vault-pass' in ' '.join(args) assert '--ask-vault-pass' in ' '.join(args)
def test_vault_password_ask(self):
vault = CredentialType.defaults['vault']()
credential = Credential(
pk=1,
credential_type=vault,
inputs={'vault_password': 'ASK'}
)
credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password')
self.instance.credentials.add(credential)
self.task.run(self.pk, vault_password='provided-at-launch')
assert self.run_pexpect.call_count == 1
call_args, call_kwargs = self.run_pexpect.call_args_list[0]
args, cwd, env, stdout = call_args
assert call_kwargs.get('expect_passwords')[
re.compile(r'Vault password:\s*?$', re.M)
] == 'provided-at-launch'
assert '--ask-vault-pass' in ' '.join(args)
def test_multi_vault_password(self):
vault = CredentialType.defaults['vault']()
for i, label in enumerate(['dev', 'prod']):
credential = Credential(
pk=i,
credential_type=vault,
inputs={'vault_password': 'pass@{}'.format(label), 'vault_id': label}
)
credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password')
self.instance.credentials.add(credential)
self.task.run(self.pk)
assert self.run_pexpect.call_count == 1
call_args, call_kwargs = self.run_pexpect.call_args_list[0]
args, cwd, env, stdout = call_args
vault_passwords = dict(
(k.pattern, v) for k, v in call_kwargs['expect_passwords'].items()
if 'Vault' in k.pattern
)
assert vault_passwords['Vault password \(prod\):\\s*?$'] == 'pass@prod'
assert vault_passwords['Vault password \(dev\):\\s*?$'] == 'pass@dev'
assert vault_passwords['Vault password:\\s*?$'] == ''
assert '--ask-vault-pass' not in ' '.join(args)
assert '--vault-id dev@prompt' in ' '.join(args)
assert '--vault-id prod@prompt' in ' '.join(args)
def test_multi_vault_id_conflict(self):
vault = CredentialType.defaults['vault']()
for i in range(2):
credential = Credential(
pk=i,
credential_type=vault,
inputs={'vault_password': 'some-pass', 'vault_id': 'conflict'}
)
credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password')
self.instance.credentials.add(credential)
with pytest.raises(Exception):
self.task.run(self.pk)
def test_multi_vault_password_ask(self):
vault = CredentialType.defaults['vault']()
for i, label in enumerate(['dev', 'prod']):
credential = Credential(
pk=i,
credential_type=vault,
inputs={'vault_password': 'ASK', 'vault_id': label}
)
credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password')
self.instance.credentials.add(credential)
self.task.run(self.pk, **{
'vault_password.dev': 'provided-at-launch@dev',
'vault_password.prod': 'provided-at-launch@prod'
})
assert self.run_pexpect.call_count == 1
call_args, call_kwargs = self.run_pexpect.call_args_list[0]
args, cwd, env, stdout = call_args
vault_passwords = dict(
(k.pattern, v) for k, v in call_kwargs['expect_passwords'].items()
if 'Vault' in k.pattern
)
assert vault_passwords['Vault password \(prod\):\\s*?$'] == 'provided-at-launch@prod'
assert vault_passwords['Vault password \(dev\):\\s*?$'] == 'provided-at-launch@dev'
assert vault_passwords['Vault password:\\s*?$'] == ''
assert '--ask-vault-pass' not in ' '.join(args)
assert '--vault-id dev@prompt' in ' '.join(args)
assert '--vault-id prod@prompt' in ' '.join(args)
def test_ssh_key_with_agent(self): def test_ssh_key_with_agent(self):
ssh = CredentialType.defaults['ssh']() ssh = CredentialType.defaults['ssh']()
credential = Credential( credential = Credential(