diff --git a/awx/api/serializers.py b/awx/api/serializers.py index a5d6f9a4e1..38ca992021 100644 --- a/awx/api/serializers.py +++ b/awx/api/serializers.py @@ -2036,6 +2036,7 @@ class CredentialSerializer(BaseSerializer): return ret def to_internal_value(self, data): + # TODO: remove when API v1 is removed if 'credential_type' not in data: # If `credential_type` is not provided, assume the payload is a # v1 credential payload that specifies a `kind` and a flat list @@ -2050,6 +2051,21 @@ class CredentialSerializer(BaseSerializer): {'credential_type': credential_type}.items() + super(CredentialSerializer, self).to_internal_value(data).items() ) + + # Make a set of the keys in the POST/PUT payload + # - Subtract real fields (name, organization, inputs) + # - Subtract virtual v1 fields defined on the determined credential + # type (username, password, etc...) + # - Any leftovers are invalid for the determined credential type + valid_fields = set(super(CredentialSerializer, self).get_fields().keys()) + valid_fields.update(V2CredentialFields().get_fields().keys()) + valid_fields.update(['kind', 'cloud']) + + for field in set(data.keys()) - valid_fields - set(credential_type.defined_fields): + if data.get(field): + raise serializers.ValidationError( + {"detail": _("'%s' is not a valid field for %s") % (field, credential_type.name)} + ) value.pop('kind', None) return value return super(CredentialSerializer, self).to_internal_value(data) diff --git a/awx/main/models/credential.py b/awx/main/models/credential.py index ccded60c38..806c681cff 100644 --- a/awx/main/models/credential.py +++ b/awx/main/models/credential.py @@ -473,7 +473,7 @@ class CredentialType(CommonModelNameNotUnique): kind_choices = dict(V1Credential.KIND_CHOICES) requirements = {} if kind == 'ssh': - if 'vault_password' in data: + if data.get('vault_password'): requirements['kind'] = 'vault' else: requirements['kind'] = 'ssh' diff --git a/awx/main/tests/functional/api/test_credential.py b/awx/main/tests/functional/api/test_credential.py index 45096c2433..e837594866 100644 --- a/awx/main/tests/functional/api/test_credential.py +++ b/awx/main/tests/functional/api/test_credential.py @@ -681,6 +681,61 @@ def test_scm_create_ok(post, organization, admin, version, params): assert decrypt_field(cred, 'ssh_key_unlock') == 'some_key_unlock' +@pytest.mark.django_db +@pytest.mark.parametrize('version, params', [ + ['v1', { + 'kind': 'ssh', + 'name': 'Best credential ever', + 'password': 'secret', + 'vault_password': '', + }], + ['v2', { + 'credential_type': 1, + 'name': 'Best credential ever', + 'inputs': { + 'password': 'secret', + } + }] +]) +def test_ssh_create_ok(post, organization, admin, version, params): + ssh = CredentialType.defaults['ssh']() + ssh.save() + params['organization'] = organization.id + response = post( + reverse('api:credential_list', kwargs={'version': version}), + params, + admin + ) + assert response.status_code == 201 + + assert Credential.objects.count() == 1 + cred = Credential.objects.all()[:1].get() + assert cred.credential_type == ssh + assert decrypt_field(cred, 'password') == 'secret' + + +@pytest.mark.django_db +def test_v1_ssh_vault_ambiguity(post, organization, admin): + vault = CredentialType.defaults['vault']() + vault.save() + params = { + 'organization': organization.id, + 'kind': 'ssh', + 'name': 'Best credential ever', + 'username': 'joe', + 'password': 'secret', + 'ssh_key_data': 'some_key_data', + 'ssh_key_unlock': 'some_key_unlock', + 'vault_password': 'vault_password', + } + response = post( + reverse('api:credential_list', kwargs={'version': 'v1'}), + params, + admin + ) + assert response.status_code == 400 + + # # Vault Credentials #