Merge pull request #3021 from jakemcdermott/credential_input_access_methods

add input access methods to credentials

Reviewed-by: https://github.com/softwarefactory-project-zuul[bot]
This commit is contained in:
softwarefactory-project-zuul[bot]
2019-01-21 21:10:12 +00:00
committed by GitHub
8 changed files with 178 additions and 87 deletions

View File

@@ -4334,7 +4334,7 @@ class JobLaunchSerializer(BaseSerializer):
passwords_needed=cred.passwords_needed passwords_needed=cred.passwords_needed
) )
if cred.credential_type.managed_by_tower and 'vault_id' in cred.credential_type.defined_fields: if cred.credential_type.managed_by_tower and 'vault_id' in cred.credential_type.defined_fields:
cred_dict['vault_id'] = cred.inputs.get('vault_id') or None cred_dict['vault_id'] = cred.get_input('vault_id', default=None)
defaults_dict.setdefault(field_name, []).append(cred_dict) defaults_dict.setdefault(field_name, []).append(cred_dict)
else: else:
defaults_dict[field_name] = getattr(obj, field_name) defaults_dict[field_name] = getattr(obj, field_name)

View File

@@ -70,7 +70,6 @@ from awx.main.models import * # noqa
from awx.main.utils import * # noqa from awx.main.utils import * # noqa
from awx.main.utils import ( from awx.main.utils import (
extract_ansible_vars, extract_ansible_vars,
decrypt_field,
) )
from awx.main.utils.encryption import encrypt_value from awx.main.utils.encryption import encrypt_value
from awx.main.utils.filters import SmartFilter from awx.main.utils.filters import SmartFilter
@@ -1592,7 +1591,7 @@ class HostInsights(GenericAPIView):
serializer_class = EmptySerializer serializer_class = EmptySerializer
def _extract_insights_creds(self, credential): def _extract_insights_creds(self, credential):
return (credential.inputs['username'], decrypt_field(credential, 'password')) return (credential.get_input('username', default=''), credential.get_input('password', default=''))
def _get_insights(self, url, username, password): def _get_insights(self, url, username, password):
session = requests.Session() session = requests.Session()

View File

@@ -385,6 +385,7 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique, ResourceMixin):
def encrypt_field(self, field, ask): def encrypt_field(self, field, ask):
if not hasattr(self, field): if not hasattr(self, field):
return None return None
encrypted = encrypt_field(self, field, ask=ask) encrypted = encrypt_field(self, field, ask=ask)
if encrypted: if encrypted:
self.inputs[field] = encrypted self.inputs[field] = encrypted
@@ -415,12 +416,12 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique, ResourceMixin):
type_alias = self.credential_type.name type_alias = self.credential_type.name
else: else:
type_alias = self.credential_type_id type_alias = self.credential_type_id
if self.kind == 'vault' and self.inputs.get('vault_id', None): if self.kind == 'vault' and self.has_input('vault_id'):
if display: if display:
fmt_str = six.text_type('{} (id={})') fmt_str = six.text_type('{} (id={})')
else: else:
fmt_str = six.text_type('{}_{}') fmt_str = six.text_type('{}_{}')
return fmt_str.format(type_alias, self.inputs.get('vault_id')) return fmt_str.format(type_alias, self.get_input('vault_id'))
return six.text_type(type_alias) return six.text_type(type_alias)
@staticmethod @staticmethod
@@ -430,6 +431,29 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique, ResourceMixin):
ret[cred.unique_hash()] = cred ret[cred.unique_hash()] = cred
return ret return ret
def get_input(self, field_name, **kwargs):
"""
Get an injectable and decrypted value for an input field.
Retrieves the value for a given credential input field name. Return
values for secret input fields are decrypted. If the credential doesn't
have an input value defined for the given field name, an AttributeError
is raised unless a default value is provided.
:param field_name(str): The name of the input field.
:param default(optional[str]): A default return value to use.
"""
if field_name in self.credential_type.secret_fields:
return decrypt_field(self, field_name)
if field_name in self.inputs:
return self.inputs[field_name]
if 'default' in kwargs:
return kwargs['default']
raise AttributeError(field_name)
def has_input(self, field_name):
return field_name in self.inputs and self.inputs[field_name] not in ('', None)
class CredentialType(CommonModelNameNotUnique): class CredentialType(CommonModelNameNotUnique):
''' '''
@@ -611,8 +635,9 @@ class CredentialType(CommonModelNameNotUnique):
safe_namespace[field_name] = namespace[field_name] = value safe_namespace[field_name] = namespace[field_name] = value
continue continue
value = credential.get_input(field_name)
if field_name in self.secret_fields: if field_name in self.secret_fields:
value = decrypt_field(credential, field_name)
safe_namespace[field_name] = '**********' safe_namespace[field_name] = '**********'
elif len(value): elif len(value):
safe_namespace[field_name] = value safe_namespace[field_name] = value

View File

@@ -3,25 +3,28 @@ import os
import stat import stat
import tempfile import tempfile
from awx.main.utils import decrypt_field
from django.conf import settings from django.conf import settings
def aws(cred, env, private_data_dir): def aws(cred, env, private_data_dir):
env['AWS_ACCESS_KEY_ID'] = cred.username env['AWS_ACCESS_KEY_ID'] = cred.get_input('username', default='')
env['AWS_SECRET_ACCESS_KEY'] = decrypt_field(cred, 'password') env['AWS_SECRET_ACCESS_KEY'] = cred.get_input('password', default='')
if len(cred.security_token) > 0:
env['AWS_SECURITY_TOKEN'] = decrypt_field(cred, 'security_token') if cred.has_input('security_token'):
env['AWS_SECURITY_TOKEN'] = cred.get_input('security_token', default='')
def gce(cred, env, private_data_dir): def gce(cred, env, private_data_dir):
env['GCE_EMAIL'] = cred.username project = cred.get_input('project', default='')
env['GCE_PROJECT'] = cred.project username = cred.get_input('username', default='')
env['GCE_EMAIL'] = username
env['GCE_PROJECT'] = project
json_cred = { json_cred = {
'type': 'service_account', 'type': 'service_account',
'private_key': decrypt_field(cred, 'ssh_key_data'), 'private_key': cred.get_input('ssh_key_data', default=''),
'client_email': cred.username, 'client_email': username,
'project_id': cred.project 'project_id': project
} }
handle, path = tempfile.mkstemp(dir=private_data_dir) handle, path = tempfile.mkstemp(dir=private_data_dir)
f = os.fdopen(handle, 'w') f = os.fdopen(handle, 'w')
@@ -32,21 +35,25 @@ def gce(cred, env, private_data_dir):
def azure_rm(cred, env, private_data_dir): def azure_rm(cred, env, private_data_dir):
if len(cred.client) and len(cred.tenant): client = cred.get_input('client', default='')
env['AZURE_CLIENT_ID'] = cred.client tenant = cred.get_input('tenant', default='')
env['AZURE_SECRET'] = decrypt_field(cred, 'secret')
env['AZURE_TENANT'] = cred.tenant if len(client) and len(tenant):
env['AZURE_SUBSCRIPTION_ID'] = cred.subscription env['AZURE_CLIENT_ID'] = client
env['AZURE_TENANT'] = tenant
env['AZURE_SECRET'] = cred.get_input('secret', default='')
env['AZURE_SUBSCRIPTION_ID'] = cred.get_input('subscription', default='')
else: else:
env['AZURE_SUBSCRIPTION_ID'] = cred.subscription env['AZURE_SUBSCRIPTION_ID'] = cred.get_input('subscription', default='')
env['AZURE_AD_USER'] = cred.username env['AZURE_AD_USER'] = cred.get_input('username', default='')
env['AZURE_PASSWORD'] = decrypt_field(cred, 'password') env['AZURE_PASSWORD'] = cred.get_input('password', default='')
if cred.inputs.get('cloud_environment', None):
env['AZURE_CLOUD_ENVIRONMENT'] = cred.inputs['cloud_environment'] if cred.has_input('cloud_environment'):
env['AZURE_CLOUD_ENVIRONMENT'] = cred.get_input('cloud_environment')
def vmware(cred, env, private_data_dir): def vmware(cred, env, private_data_dir):
env['VMWARE_USER'] = cred.username env['VMWARE_USER'] = cred.get_input('username', default='')
env['VMWARE_PASSWORD'] = decrypt_field(cred, 'password') env['VMWARE_PASSWORD'] = cred.get_input('password', default='')
env['VMWARE_HOST'] = cred.host env['VMWARE_HOST'] = cred.get_input('host', default='')
env['VMWARE_VALIDATE_CERTS'] = str(settings.VMWARE_VALIDATE_CERTS) env['VMWARE_VALIDATE_CERTS'] = str(settings.VMWARE_VALIDATE_CERTS)

View File

@@ -166,8 +166,8 @@ class ProjectOptions(models.Model):
check_special_cases=False) check_special_cases=False)
scm_url_parts = urlparse.urlsplit(scm_url) scm_url_parts = urlparse.urlsplit(scm_url)
# Prefer the username/password in the URL, if provided. # Prefer the username/password in the URL, if provided.
scm_username = scm_url_parts.username or cred.username or '' scm_username = scm_url_parts.username or cred.get_input('username', default='')
if scm_url_parts.password or cred.password: if scm_url_parts.password or cred.has_input('password'):
scm_password = '********' scm_password = '********'
else: else:
scm_password = '' scm_password = ''

View File

@@ -54,7 +54,7 @@ from awx.main.queue import CallbackQueueDispatcher
from awx.main.expect import run, isolated_manager from awx.main.expect import run, isolated_manager
from awx.main.dispatch.publish import task from awx.main.dispatch.publish import task
from awx.main.dispatch import get_local_queuename, reaper from awx.main.dispatch import get_local_queuename, reaper
from awx.main.utils import (get_ansible_version, get_ssh_version, decrypt_field, update_scm_url, from awx.main.utils import (get_ansible_version, get_ssh_version, update_scm_url,
check_proot_installed, build_proot_temp_dir, get_licenser, check_proot_installed, build_proot_temp_dir, get_licenser,
wrap_args_with_proot, OutputEventFilter, OutputVerboseFilter, ignore_inventory_computed_fields, wrap_args_with_proot, OutputEventFilter, OutputVerboseFilter, ignore_inventory_computed_fields,
ignore_inventory_group_removal, extract_ansible_vars, schedule_task_manager) ignore_inventory_group_removal, extract_ansible_vars, schedule_task_manager)
@@ -1124,16 +1124,16 @@ class RunJob(BaseTask):
for credential in job.credentials.all(): for credential in job.credentials.all():
# If we were sent SSH credentials, decrypt them and send them # If we were sent SSH credentials, decrypt them and send them
# back (they will be written to a temporary file). # back (they will be written to a temporary file).
if credential.ssh_key_data not in (None, ''): if credential.has_input('ssh_key_data'):
private_data['credentials'][credential] = decrypt_field(credential, 'ssh_key_data') or '' private_data['credentials'][credential] = credential.get_input('ssh_key_data', default='')
if credential.kind == 'openstack': if credential.kind == 'openstack':
openstack_auth = dict(auth_url=credential.host, openstack_auth = dict(auth_url=credential.get_input('host', default=''),
username=credential.username, username=credential.get_input('username', default=''),
password=decrypt_field(credential, "password"), password=credential.get_input('password', default=''),
project_name=credential.project) project_name=credential.get_input('project', default=''))
if credential.domain not in (None, ''): if credential.has_input('domain'):
openstack_auth['domain_name'] = credential.domain openstack_auth['domain_name'] = credential.get_input('domain', default='')
openstack_data = { openstack_data = {
'clouds': { 'clouds': {
'devstack': { 'devstack': {
@@ -1156,22 +1156,27 @@ class RunJob(BaseTask):
for field in ('ssh_key_unlock', 'ssh_password', 'become_password'): for field in ('ssh_key_unlock', 'ssh_password', 'become_password'):
value = kwargs.get( value = kwargs.get(
field, field,
decrypt_field(cred, 'password' if field == 'ssh_password' else field) cred.get_input('password' if field == 'ssh_password' else field, default='')
) )
if value not in ('', 'ASK'): if value not in ('', 'ASK'):
passwords[field] = value passwords[field] = value
for cred in job.vault_credentials: for cred in job.vault_credentials:
field = 'vault_password' field = 'vault_password'
if cred.inputs.get('vault_id'): vault_id = cred.get_input('vault_id', default=None)
field = 'vault_password.{}'.format(cred.inputs['vault_id']) if vault_id:
field = 'vault_password.{}'.format(vault_id)
if field in passwords: if field in passwords:
raise RuntimeError( raise RuntimeError(
'multiple vault credentials were specified with --vault-id {}@prompt'.format( 'multiple vault credentials were specified with --vault-id {}@prompt'.format(
cred.inputs['vault_id'] vault_id
) )
) )
value = kwargs.get(field, decrypt_field(cred, 'vault_password'))
value = kwargs.get(field, None)
if value is None:
value = cred.get_input('vault_password', default='')
if value not in ('', 'ASK'): if value not in ('', 'ASK'):
passwords[field] = value passwords[field] = value
@@ -1181,10 +1186,10 @@ class RunJob(BaseTask):
''' '''
if 'ssh_key_unlock' not in passwords: if 'ssh_key_unlock' not in passwords:
for cred in job.network_credentials: for cred in job.network_credentials:
if cred.inputs.get('ssh_key_unlock'): if cred.has_input('ssh_key_unlock'):
passwords['ssh_key_unlock'] = kwargs.get( passwords['ssh_key_unlock'] = kwargs.get(
'ssh_key_unlock', 'ssh_key_unlock',
decrypt_field(cred, 'ssh_key_unlock') cred.get_input('ssh_key_unlock', default='')
) )
break break
@@ -1240,17 +1245,17 @@ class RunJob(BaseTask):
env['OS_CLIENT_CONFIG_FILE'] = cred_files.get(cloud_cred, '') env['OS_CLIENT_CONFIG_FILE'] = cred_files.get(cloud_cred, '')
for network_cred in job.network_credentials: for network_cred in job.network_credentials:
env['ANSIBLE_NET_USERNAME'] = network_cred.username env['ANSIBLE_NET_USERNAME'] = network_cred.get_input('username', default='')
env['ANSIBLE_NET_PASSWORD'] = decrypt_field(network_cred, 'password') env['ANSIBLE_NET_PASSWORD'] = network_cred.get_input('password', default='')
ssh_keyfile = cred_files.get(network_cred, '') ssh_keyfile = cred_files.get(network_cred, '')
if ssh_keyfile: if ssh_keyfile:
env['ANSIBLE_NET_SSH_KEYFILE'] = ssh_keyfile env['ANSIBLE_NET_SSH_KEYFILE'] = ssh_keyfile
authorize = network_cred.authorize authorize = network_cred.get_input('authorize', default=False)
env['ANSIBLE_NET_AUTHORIZE'] = six.text_type(int(authorize)) env['ANSIBLE_NET_AUTHORIZE'] = six.text_type(int(authorize))
if authorize: if authorize:
env['ANSIBLE_NET_AUTH_PASS'] = decrypt_field(network_cred, 'authorize_password') env['ANSIBLE_NET_AUTH_PASS'] = network_cred.get_input('authorize_password', default='')
return env return env
@@ -1263,9 +1268,9 @@ class RunJob(BaseTask):
ssh_username, become_username, become_method = '', '', '' ssh_username, become_username, become_method = '', '', ''
if creds: if creds:
ssh_username = kwargs.get('username', creds.username) ssh_username = kwargs.get('username', creds.get_input('username', default=''))
become_method = kwargs.get('become_method', creds.become_method) become_method = kwargs.get('become_method', creds.get_input('become_method', default=''))
become_username = kwargs.get('become_username', creds.become_username) become_username = kwargs.get('become_username', creds.get_input('become_username', default=''))
else: else:
become_method = None become_method = None
become_username = "" become_username = ""
@@ -1490,8 +1495,8 @@ class RunProjectUpdate(BaseTask):
private_data = {'credentials': {}} private_data = {'credentials': {}}
if project_update.credential: if project_update.credential:
credential = project_update.credential credential = project_update.credential
if credential.ssh_key_data not in (None, ''): if credential.has_input('ssh_key_data'):
private_data['credentials'][credential] = decrypt_field(credential, 'ssh_key_data') private_data['credentials'][credential] = credential.get_input('ssh_key_data', default='')
return private_data return private_data
def build_passwords(self, project_update, **kwargs): def build_passwords(self, project_update, **kwargs):
@@ -1502,9 +1507,9 @@ class RunProjectUpdate(BaseTask):
passwords = super(RunProjectUpdate, self).build_passwords(project_update, passwords = super(RunProjectUpdate, self).build_passwords(project_update,
**kwargs) **kwargs)
if project_update.credential: if project_update.credential:
passwords['scm_key_unlock'] = decrypt_field(project_update.credential, 'ssh_key_unlock') passwords['scm_key_unlock'] = project_update.credential.get_input('ssh_key_unlock', default='')
passwords['scm_username'] = project_update.credential.username passwords['scm_username'] = project_update.credential.get_input('username', default='')
passwords['scm_password'] = decrypt_field(project_update.credential, 'password') passwords['scm_password'] = project_update.credential.get_input('password', default='')
return passwords return passwords
def build_env(self, project_update, **kwargs): def build_env(self, project_update, **kwargs):
@@ -1828,12 +1833,13 @@ class RunInventoryUpdate(BaseTask):
credential = inventory_update.get_cloud_credential() credential = inventory_update.get_cloud_credential()
if inventory_update.source == 'openstack': if inventory_update.source == 'openstack':
openstack_auth = dict(auth_url=credential.host, openstack_auth = dict(auth_url=credential.get_input('host', default=''),
username=credential.username, username=credential.get_input('username', default=''),
password=decrypt_field(credential, "password"), password=credential.get_input('password', default=''),
project_name=credential.project) project_name=credential.get_input('project', default=''))
if credential.domain not in (None, ''): if credential.has_input('domain'):
openstack_auth['domain_name'] = credential.domain openstack_auth['domain_name'] = credential.get_input('domain', default='')
private_state = inventory_update.source_vars_dict.get('private', True) private_state = inventory_update.source_vars_dict.get('private', True)
# Retrieve cache path from inventory update vars if available, # Retrieve cache path from inventory update vars if available,
# otherwise create a temporary cache path only for this update. # otherwise create a temporary cache path only for this update.
@@ -1909,9 +1915,9 @@ class RunInventoryUpdate(BaseTask):
cp.add_section(section) cp.add_section(section)
cp.set('vmware', 'cache_max_age', '0') cp.set('vmware', 'cache_max_age', '0')
cp.set('vmware', 'validate_certs', str(settings.VMWARE_VALIDATE_CERTS)) cp.set('vmware', 'validate_certs', str(settings.VMWARE_VALIDATE_CERTS))
cp.set('vmware', 'username', credential.username) cp.set('vmware', 'username', credential.get_input('username', default=''))
cp.set('vmware', 'password', decrypt_field(credential, 'password')) cp.set('vmware', 'password', credential.get_input('password', default=''))
cp.set('vmware', 'server', credential.host) cp.set('vmware', 'server', credential.get_input('host', default=''))
vmware_opts = dict(inventory_update.source_vars_dict.items()) vmware_opts = dict(inventory_update.source_vars_dict.items())
if inventory_update.instance_filters: if inventory_update.instance_filters:
@@ -1942,9 +1948,9 @@ class RunInventoryUpdate(BaseTask):
cp.set(section, k, six.text_type(v)) cp.set(section, k, six.text_type(v))
if credential: if credential:
cp.set(section, 'url', credential.host) cp.set(section, 'url', credential.get_input('host', default=''))
cp.set(section, 'user', credential.username) cp.set(section, 'user', credential.get_input('username', default=''))
cp.set(section, 'password', decrypt_field(credential, 'password')) cp.set(section, 'password', credential.get_input('password', default=''))
section = 'ansible' section = 'ansible'
cp.add_section(section) cp.add_section(section)
@@ -1963,9 +1969,9 @@ class RunInventoryUpdate(BaseTask):
cp.add_section(section) cp.add_section(section)
if credential: if credential:
cp.set(section, 'url', credential.host) cp.set(section, 'url', credential.get_input('host', default=''))
cp.set(section, 'username', credential.username) cp.set(section, 'username', credential.get_input('username', default=''))
cp.set(section, 'password', decrypt_field(credential, 'password')) cp.set(section, 'password', credential.get_input('password', default=''))
cp.set(section, 'ssl_verify', "false") cp.set(section, 'ssl_verify', "false")
cloudforms_opts = dict(inventory_update.source_vars_dict.items()) cloudforms_opts = dict(inventory_update.source_vars_dict.items())
@@ -2021,10 +2027,10 @@ class RunInventoryUpdate(BaseTask):
credential = inventory_update.get_cloud_credential() credential = inventory_update.get_cloud_credential()
if credential: if credential:
for subkey in ('username', 'host', 'project', 'client', 'tenant', 'subscription'): for subkey in ('username', 'host', 'project', 'client', 'tenant', 'subscription'):
passwords['source_%s' % subkey] = getattr(credential, subkey) passwords['source_%s' % subkey] = credential.get_input(subkey, default='')
for passkey in ('password', 'ssh_key_data', 'security_token', 'secret'): for passkey in ('password', 'ssh_key_data', 'security_token', 'secret'):
k = 'source_%s' % passkey k = 'source_%s' % passkey
passwords[k] = decrypt_field(credential, passkey) passwords[k] = credential.get_input(passkey, default='')
return passwords return passwords
def build_env(self, inventory_update, **kwargs): def build_env(self, inventory_update, **kwargs):
@@ -2229,8 +2235,8 @@ class RunAdHocCommand(BaseTask):
# back (they will be written to a temporary file). # back (they will be written to a temporary file).
creds = ad_hoc_command.credential creds = ad_hoc_command.credential
private_data = {'credentials': {}} private_data = {'credentials': {}}
if creds and creds.ssh_key_data not in (None, ''): if creds and creds.has_input('ssh_key_data'):
private_data['credentials'][creds] = decrypt_field(creds, 'ssh_key_data') or '' private_data['credentials'][creds] = creds.get_input('ssh_key_data', default='')
return private_data return private_data
def build_passwords(self, ad_hoc_command, **kwargs): def build_passwords(self, ad_hoc_command, **kwargs):
@@ -2243,9 +2249,9 @@ class RunAdHocCommand(BaseTask):
if creds: if creds:
for field in ('ssh_key_unlock', 'ssh_password', 'become_password'): for field in ('ssh_key_unlock', 'ssh_password', 'become_password'):
if field == 'ssh_password': if field == 'ssh_password':
value = kwargs.get(field, decrypt_field(creds, 'password')) value = kwargs.get(field, creds.get_input('password', default=''))
else: else:
value = kwargs.get(field, decrypt_field(creds, field)) value = kwargs.get(field, creds.get_input(field, default=''))
if value not in ('', 'ASK'): if value not in ('', 'ASK'):
passwords[field] = value passwords[field] = value
return passwords return passwords
@@ -2282,9 +2288,9 @@ class RunAdHocCommand(BaseTask):
creds = ad_hoc_command.credential creds = ad_hoc_command.credential
ssh_username, become_username, become_method = '', '', '' ssh_username, become_username, become_method = '', '', ''
if creds: if creds:
ssh_username = kwargs.get('username', creds.username) ssh_username = kwargs.get('username', creds.get_input('username', default=''))
become_method = kwargs.get('become_method', creds.become_method) become_method = kwargs.get('become_method', creds.get_input('become_method', default=''))
become_username = kwargs.get('become_username', creds.become_username) become_username = kwargs.get('become_username', creds.get_input('become_username', default=''))
else: else:
become_method = None become_method = None
become_username = "" become_username = ""

View File

@@ -327,3 +327,51 @@ def test_credential_update_with_prior(organization_factory, credentialtype_ssh):
assert cred.inputs['username'] == 'joe' assert cred.inputs['username'] == 'joe'
assert cred.inputs['password'].startswith('$encrypted$') assert cred.inputs['password'].startswith('$encrypted$')
assert decrypt_field(cred, 'password') == 'testing123' assert decrypt_field(cred, 'password') == 'testing123'
@pytest.mark.django_db
def test_credential_get_input(organization_factory):
organization = organization_factory('test').organization
type_ = CredentialType(
kind='vault',
name='somevault',
managed_by_tower=True,
inputs={
'fields': [{
'id': 'vault_password',
'type': 'string',
'secret': True,
}, {
'id': 'vault_id',
'type': 'string',
'secret': False
}]
}
)
type_.save()
cred = Credential(
organization=organization,
credential_type=type_,
name="Bob's Credential",
inputs={'vault_password': 'testing321'}
)
cred.save()
cred.full_clean()
assert isinstance(cred, Credential)
# verify expected exception is raised when attempting to access an unset
# input without providing a default
with pytest.raises(AttributeError):
cred.get_input('vault_id')
# verify that the provided default is used for unset inputs
assert cred.get_input('vault_id', default='foo') == 'foo'
# verify expected exception is raised when attempting to access an undefined
# input without providing a default
with pytest.raises(AttributeError):
cred.get_input('field_not_on_credential_type')
# verify that the provided default is used for undefined inputs
assert cred.get_input('field_not_on_credential_type', default='bar') == 'bar'
# verify return values for encrypted secret fields are decrypted
assert cred.inputs['vault_password'].startswith('$encrypted$')
assert cred.get_input('vault_password') == 'testing321'

View File

@@ -108,13 +108,16 @@ def test_safe_env_returns_new_copy():
def test_openstack_client_config_generation(mocker): def test_openstack_client_config_generation(mocker):
update = tasks.RunInventoryUpdate() update = tasks.RunInventoryUpdate()
credential = mocker.Mock(**{ credential_type = CredentialType.defaults['openstack']()
inputs = {
'host': 'https://keystone.openstack.example.org', 'host': 'https://keystone.openstack.example.org',
'username': 'demo', 'username': 'demo',
'password': 'secrete', 'password': 'secrete',
'project': 'demo-project', 'project': 'demo-project',
'domain': 'my-demo-domain', 'domain': 'my-demo-domain',
}) }
credential = Credential(pk=1, credential_type=credential_type, inputs=inputs)
cred_method = mocker.Mock(return_value=credential) cred_method = mocker.Mock(return_value=credential)
inventory_update = mocker.Mock(**{ inventory_update = mocker.Mock(**{
'source': 'openstack', 'source': 'openstack',
@@ -144,13 +147,16 @@ def test_openstack_client_config_generation(mocker):
]) ])
def test_openstack_client_config_generation_with_private_source_vars(mocker, source, expected): def test_openstack_client_config_generation_with_private_source_vars(mocker, source, expected):
update = tasks.RunInventoryUpdate() update = tasks.RunInventoryUpdate()
credential = mocker.Mock(**{ credential_type = CredentialType.defaults['openstack']()
inputs = {
'host': 'https://keystone.openstack.example.org', 'host': 'https://keystone.openstack.example.org',
'username': 'demo', 'username': 'demo',
'password': 'secrete', 'password': 'secrete',
'project': 'demo-project', 'project': 'demo-project',
'domain': None, 'domain': None,
}) }
credential = Credential(pk=1, credential_type=credential_type, inputs=inputs)
cred_method = mocker.Mock(return_value=credential) cred_method = mocker.Mock(return_value=credential)
inventory_update = mocker.Mock(**{ inventory_update = mocker.Mock(**{
'source': 'openstack', 'source': 'openstack',