From a1f8f65addf03348210e3fd1f34d1bd14604cb92 Mon Sep 17 00:00:00 2001 From: Ryan Petrello Date: Thu, 30 Nov 2017 12:49:54 -0500 Subject: [PATCH] support specifying multiple vault IDs for a playbook run see: https://github.com/ansible/awx/issues/352 --- awx/main/fields.py | 7 ++ .../migrations/0009_v330_multi_credential.py | 2 + awx/main/migrations/_credentialtypes.py | 5 + awx/main/models/credential.py | 10 ++ awx/main/models/jobs.py | 4 + awx/main/tasks.py | 70 +++++++++----- awx/main/tests/functional/test_credential.py | 25 +++++ awx/main/tests/unit/test_tasks.py | 91 +++++++++++++++++++ 8 files changed, 190 insertions(+), 24 deletions(-) diff --git a/awx/main/fields.py b/awx/main/fields.py index 19a1f1b78e..63e8743709 100644 --- a/awx/main/fields.py +++ b/awx/main/fields.py @@ -415,6 +415,13 @@ class JSONSchemaField(JSONBField): return value +@JSONSchemaField.format_checker.checks('vault_id') +def format_vault_id(value): + if '@' in value: + raise jsonschema.exceptions.FormatError('@ is not an allowed character') + return True + + @JSONSchemaField.format_checker.checks('ssh_private_key') def format_ssh_private_key(value): # Sanity check: GCE, in particular, provides JSON-encoded private diff --git a/awx/main/migrations/0009_v330_multi_credential.py b/awx/main/migrations/0009_v330_multi_credential.py index 5eef9126db..69c602a3e0 100644 --- a/awx/main/migrations/0009_v330_multi_credential.py +++ b/awx/main/migrations/0009_v330_multi_credential.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from django.db import migrations, models from awx.main.migrations import _migration_utils as migration_utils +from awx.main.migrations import _credentialtypes as credentialtypes from awx.main.migrations._multi_cred import migrate_to_multi_cred @@ -50,4 +51,5 @@ class Migration(migrations.Migration): model_name='jobtemplate', name='vault_credential', ), + migrations.RunPython(credentialtypes.add_vault_id_field) ] diff --git a/awx/main/migrations/_credentialtypes.py b/awx/main/migrations/_credentialtypes.py index f34d08903a..f9f0c8eab0 100644 --- a/awx/main/migrations/_credentialtypes.py +++ b/awx/main/migrations/_credentialtypes.py @@ -173,3 +173,8 @@ def migrate_job_credentials(apps, schema_editor): finally: utils.get_current_apps = orig_current_apps + +def add_vault_id_field(apps, schema_editor): + vault_credtype = CredentialType.objects.get(kind='vault') + vault_credtype.inputs = CredentialType.defaults.get('vault')().inputs + vault_credtype.save() diff --git a/awx/main/models/credential.py b/awx/main/models/credential.py index 180085fbc0..7e6bfc238d 100644 --- a/awx/main/models/credential.py +++ b/awx/main/models/credential.py @@ -689,6 +689,16 @@ def vault(cls): 'type': 'string', 'secret': True, 'ask_at_runtime': True + }, { + 'id': 'vault_id', + 'label': 'Vault Identifier', + 'type': 'string', + 'format': 'vault_id', + 'help_text': ('Specify an (optional) Vault ID. This is ' + 'equivalent to specifying the --vault-id ' + 'Ansible parameter for providing multiple Vault ' + 'passwords. Note: this feature only works in ' + 'Ansible 2.4+.') }], 'required': ['vault_password'], } diff --git a/awx/main/models/jobs.py b/awx/main/models/jobs.py index 9807942925..a8dfa9388e 100644 --- a/awx/main/models/jobs.py +++ b/awx/main/models/jobs.py @@ -172,6 +172,10 @@ class JobOptions(BaseModel): def cloud_credentials(self): return list(self.credentials.filter(credential_type__kind='cloud')) + @property + def vault_credentials(self): + return list(self.credentials.filter(credential_type__kind='vault')) + @property def credential(self): cred = self.get_deprecated_credential('ssh') diff --git a/awx/main/tasks.py b/awx/main/tasks.py index a32c9dca5d..b3cc37c752 100644 --- a/awx/main/tasks.py +++ b/awx/main/tasks.py @@ -712,7 +712,7 @@ class BaseTask(LogErrorsTask): job_timeout = 0 return job_timeout - def get_password_prompts(self): + def get_password_prompts(self, **kwargs): ''' Return a dictionary where keys are strings or regular expressions for prompts, and values are password lookup keys (keys that are returned @@ -833,7 +833,7 @@ class BaseTask(LogErrorsTask): job_cwd=cwd, job_env=safe_env, result_stdout_file=stdout_handle.name) expect_passwords = {} - for k, v in self.get_password_prompts().items(): + for k, v in self.get_password_prompts(**kwargs).items(): expect_passwords[k] = kwargs['passwords'].get(v, '') or '' _kw = dict( expect_passwords=expect_passwords, @@ -961,19 +961,30 @@ class RunJob(BaseTask): and ansible-vault. ''' passwords = super(RunJob, self).build_passwords(job, **kwargs) - for kind, fields in { - 'ssh': ('ssh_key_unlock', 'ssh_password', 'become_password'), - 'vault': ('vault_password',) - }.items(): - cred = job.get_deprecated_credential(kind) - if cred: - for field in fields: - if field == 'ssh_password': - value = kwargs.get(field, decrypt_field(cred, 'password')) - else: - value = kwargs.get(field, decrypt_field(cred, field)) - if value not in ('', 'ASK'): - passwords[field] = value + cred = job.get_deprecated_credential('ssh') + if cred: + for field in ('ssh_key_unlock', 'ssh_password', 'become_password'): + value = kwargs.get( + field, + decrypt_field(cred, 'password' if field == 'ssh_password' else field) + ) + if value not in ('', 'ASK'): + passwords[field] = value + + for cred in job.vault_credentials: + field = 'vault_password' + if cred.inputs.get('vault_id'): + field = 'vault_password.{}'.format(cred.inputs['vault_id']) + if field in passwords: + raise RuntimeError( + 'multiple vault credentials were specified with --vault-id {}@prompt'.format( + cred.inputs['vault_id'] + ) + ) + value = kwargs.get(field, decrypt_field(cred, 'vault_password')) + if value not in ('', 'ASK'): + passwords[field] = value + return passwords def build_env(self, job, **kwargs): @@ -1107,9 +1118,16 @@ class RunJob(BaseTask): args.extend(['--become-user', become_username]) if 'become_password' in kwargs.get('passwords', {}): args.append('--ask-become-pass') - # Support prompting for a vault password. - if 'vault_password' in kwargs.get('passwords', {}): - args.append('--ask-vault-pass') + + # Support prompting for multiple vault passwords + for k, v in kwargs.get('passwords', {}).items(): + if k.startswith('vault_password'): + if k == 'vault_password': + args.append('--ask-vault-pass') + else: + vault_id = k.split('.')[1] + args.append('--vault-id') + args.append('{}@prompt'.format(vault_id)) if job.forks: # FIXME: Max limit? args.append('--forks=%d' % job.forks) @@ -1177,8 +1195,8 @@ class RunJob(BaseTask): def get_idle_timeout(self): return getattr(settings, 'JOB_RUN_IDLE_TIMEOUT', None) - def get_password_prompts(self): - d = super(RunJob, self).get_password_prompts() + def get_password_prompts(self, **kwargs): + d = super(RunJob, self).get_password_prompts(**kwargs) d[re.compile(r'Enter passphrase for .*:\s*?$', re.M)] = 'ssh_key_unlock' d[re.compile(r'Bad passphrase, try again for .*:\s*?$', re.M)] = '' for method in PRIVILEGE_ESCALATION_METHODS: @@ -1187,6 +1205,10 @@ class RunJob(BaseTask): d[re.compile(r'SSH password:\s*?$', re.M)] = 'ssh_password' d[re.compile(r'Password:\s*?$', re.M)] = 'ssh_password' d[re.compile(r'Vault password:\s*?$', re.M)] = 'vault_password' + for k, v in kwargs.get('passwords', {}).items(): + if k.startswith('vault_password.'): + vault_id = k.split('.')[1] + d[re.compile(r'Vault password \({}\):\s*?$'.format(vault_id), re.M)] = k return d def get_stdout_handle(self, instance): @@ -1442,8 +1464,8 @@ class RunProjectUpdate(BaseTask): output_replacements.append((pattern2 % d_before, pattern2 % d_after)) return output_replacements - def get_password_prompts(self): - d = super(RunProjectUpdate, self).get_password_prompts() + def get_password_prompts(self, **kwargs): + d = super(RunProjectUpdate, self).get_password_prompts(**kwargs) d[re.compile(r'Username for.*:\s*?$', re.M)] = 'scm_username' d[re.compile(r'Password for.*:\s*?$', re.M)] = 'scm_password' d[re.compile(r'Password:\s*?$', re.M)] = 'scm_password' @@ -2142,8 +2164,8 @@ class RunAdHocCommand(BaseTask): def get_idle_timeout(self): return getattr(settings, 'JOB_RUN_IDLE_TIMEOUT', None) - def get_password_prompts(self): - d = super(RunAdHocCommand, self).get_password_prompts() + def get_password_prompts(self, **kwargs): + d = super(RunAdHocCommand, self).get_password_prompts(**kwargs) d[re.compile(r'Enter passphrase for .*:\s*?$', re.M)] = 'ssh_key_unlock' d[re.compile(r'Bad passphrase, try again for .*:\s*?$', re.M)] = '' for method in PRIVILEGE_ESCALATION_METHODS: diff --git a/awx/main/tests/functional/test_credential.py b/awx/main/tests/functional/test_credential.py index b4dd0cb0e4..5bae4cc065 100644 --- a/awx/main/tests/functional/test_credential.py +++ b/awx/main/tests/functional/test_credential.py @@ -247,6 +247,31 @@ def test_ssh_key_data_validation(organization, kind, ssh_key_data, ssh_key_unloc assert e.type in (ValidationError, serializers.ValidationError) +@pytest.mark.django_db +@pytest.mark.parametrize('inputs, valid', [ + ({'vault_password': 'some-pass'}, True), + ({}, False), + ({'vault_password': 'dev-pass', 'vault_id': 'dev'}, True), + ({'vault_password': 'dev-pass', 'vault_id': 'dev@prompt'}, False), # @ not allowed +]) +def test_vault_validation(organization, inputs, valid): + cred_type = CredentialType.defaults['vault']() + cred_type.save() + cred = Credential( + credential_type=cred_type, + name="Best credential ever", + inputs=inputs, + organization=organization + ) + cred.save() + if valid: + cred.full_clean() + else: + with pytest.raises(Exception) as e: + cred.full_clean() + assert e.type in (ValidationError, serializers.ValidationError) + + @pytest.mark.django_db @pytest.mark.parametrize('become_method, valid', zip( dict(V1Credential.FIELDS['become_method'].choices).keys(), diff --git a/awx/main/tests/unit/test_tasks.py b/awx/main/tests/unit/test_tasks.py index 75b1ab0207..066cfaf684 100644 --- a/awx/main/tests/unit/test_tasks.py +++ b/awx/main/tests/unit/test_tasks.py @@ -450,6 +450,97 @@ class TestJobCredentials(TestJobExecution): ] == 'vault-me' assert '--ask-vault-pass' in ' '.join(args) + def test_vault_password_ask(self): + vault = CredentialType.defaults['vault']() + credential = Credential( + pk=1, + credential_type=vault, + inputs={'vault_password': 'ASK'} + ) + credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password') + self.instance.credentials.add(credential) + self.task.run(self.pk, vault_password='provided-at-launch') + + assert self.run_pexpect.call_count == 1 + call_args, call_kwargs = self.run_pexpect.call_args_list[0] + args, cwd, env, stdout = call_args + + assert call_kwargs.get('expect_passwords')[ + re.compile(r'Vault password:\s*?$', re.M) + ] == 'provided-at-launch' + assert '--ask-vault-pass' in ' '.join(args) + + def test_multi_vault_password(self): + vault = CredentialType.defaults['vault']() + for i, label in enumerate(['dev', 'prod']): + credential = Credential( + pk=i, + credential_type=vault, + inputs={'vault_password': 'pass@{}'.format(label), 'vault_id': label} + ) + credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password') + self.instance.credentials.add(credential) + self.task.run(self.pk) + + assert self.run_pexpect.call_count == 1 + call_args, call_kwargs = self.run_pexpect.call_args_list[0] + args, cwd, env, stdout = call_args + + vault_passwords = dict( + (k.pattern, v) for k, v in call_kwargs['expect_passwords'].items() + if 'Vault' in k.pattern + ) + assert vault_passwords['Vault password \(prod\):\\s*?$'] == 'pass@prod' + assert vault_passwords['Vault password \(dev\):\\s*?$'] == 'pass@dev' + assert vault_passwords['Vault password:\\s*?$'] == '' + assert '--ask-vault-pass' not in ' '.join(args) + assert '--vault-id dev@prompt' in ' '.join(args) + assert '--vault-id prod@prompt' in ' '.join(args) + + def test_multi_vault_id_conflict(self): + vault = CredentialType.defaults['vault']() + for i in range(2): + credential = Credential( + pk=i, + credential_type=vault, + inputs={'vault_password': 'some-pass', 'vault_id': 'conflict'} + ) + credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password') + self.instance.credentials.add(credential) + + with pytest.raises(Exception): + self.task.run(self.pk) + + def test_multi_vault_password_ask(self): + vault = CredentialType.defaults['vault']() + for i, label in enumerate(['dev', 'prod']): + credential = Credential( + pk=i, + credential_type=vault, + inputs={'vault_password': 'ASK', 'vault_id': label} + ) + credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password') + self.instance.credentials.add(credential) + self.task.run(self.pk, **{ + 'vault_password.dev': 'provided-at-launch@dev', + 'vault_password.prod': 'provided-at-launch@prod' + }) + + assert self.run_pexpect.call_count == 1 + call_args, call_kwargs = self.run_pexpect.call_args_list[0] + args, cwd, env, stdout = call_args + + vault_passwords = dict( + (k.pattern, v) for k, v in call_kwargs['expect_passwords'].items() + if 'Vault' in k.pattern + ) + assert vault_passwords['Vault password \(prod\):\\s*?$'] == 'provided-at-launch@prod' + assert vault_passwords['Vault password \(dev\):\\s*?$'] == 'provided-at-launch@dev' + assert vault_passwords['Vault password:\\s*?$'] == '' + assert '--ask-vault-pass' not in ' '.join(args) + assert '--vault-id dev@prompt' in ' '.join(args) + assert '--vault-id prod@prompt' in ' '.join(args) + def test_ssh_key_with_agent(self): ssh = CredentialType.defaults['ssh']() credential = Credential(