From b7b9fb531e30e8e6b2a9709eb727e938f2ec70fa Mon Sep 17 00:00:00 2001 From: Ryan Petrello Date: Mon, 8 May 2017 12:56:57 -0400 Subject: [PATCH] properly support `(cloud|network)_credential` for JT update *and* create fix a bug which caused `POST /api/v1/job_templates/` to not properly set `JobTemplate.extra_credentials`. see: #5807 --- awx/api/serializers.py | 70 +++++++++++++------ .../tests/functional/api/test_job_template.py | 48 +++++++++++++ 2 files changed, 95 insertions(+), 23 deletions(-) diff --git a/awx/api/serializers.py b/awx/api/serializers.py index a353f67aca..3c76bcbb0a 100644 --- a/awx/api/serializers.py +++ b/awx/api/serializers.py @@ -2187,30 +2187,52 @@ class JobOptionsSerializer(LabelsListMixin, BaseSerializer): ret['network_credential'] = obj.network_credential return ret + def create(self, validated_data): + deprecated_fields = {} + for key in ('cloud_credential', 'network_credential'): + if key in validated_data: + deprecated_fields[key] = validated_data.pop(key) + obj = super(JobOptionsSerializer, self).create(validated_data) + if self.version == 1 and deprecated_fields: # TODO: remove in 3.3 + self._update_deprecated_fields(deprecated_fields, obj) + return obj + + def update(self, obj, validated_data): + deprecated_fields = {} + for key in ('cloud_credential', 'network_credential'): + if key in validated_data: + deprecated_fields[key] = validated_data.pop(key) + obj = super(JobOptionsSerializer, self).update(obj, validated_data) + if self.version == 1 and deprecated_fields: # TODO: remove in 3.3 + self._update_deprecated_fields(deprecated_fields, obj) + return obj + + def _update_deprecated_fields(self, fields, obj): + for key, existing in ( + ('cloud_credential', obj.cloud_credentials), + ('network_credential', obj.network_credentials), + ): + if key in fields: + for cred in existing: + obj.extra_credentials.remove(cred) + if fields[key]: + obj.extra_credentials.add(fields[key]) + obj.save() + def validate(self, attrs): + v1_credentials = {} if self.version == 1: # TODO: remove in 3.3 - if 'cloud_credential' in attrs: - pk = attrs.pop('cloud_credential') - for cred in self.instance.cloud_credentials: - self.instance.extra_credentials.remove(cred) - if pk: - cred = Credential.objects.get(pk=pk) - if cred.credential_type.kind != 'cloud': - raise serializers.ValidationError({ - 'cloud_credential': _('You must provide a cloud credential.'), - }) - self.instance.extra_credentials.add(cred) - if 'network_credential' in attrs: - pk = attrs.pop('network_credential') - for cred in self.instance.network_credentials: - self.instance.extra_credentials.remove(cred) - if pk: - cred = Credential.objects.get(pk=pk) - if cred.credential_type.kind != 'net': - raise serializers.ValidationError({ - 'network_credential': _('You must provide a network credential.'), - }) - self.instance.extra_credentials.add(cred) + for attr, kind, error in ( + ('cloud_credential', 'cloud', _('You must provide a cloud credential.')), + ('network_credential', 'net', _('You must provide a network credential.')) + ): + if attr in attrs: + v1_credentials[attr] = None + pk = attrs.pop(attr) + if pk: + cred = v1_credentials[attr] = Credential.objects.get(pk=pk) + if cred.credential_type.kind != kind: + raise serializers.ValidationError({attr: error}) if 'project' in self.fields and 'playbook' in self.fields: project = attrs.get('project', self.instance and self.instance.project or None) @@ -2225,7 +2247,9 @@ class JobOptionsSerializer(LabelsListMixin, BaseSerializer): if project and not playbook: raise serializers.ValidationError({'playbook': _('Must select playbook for project.')}) - return super(JobOptionsSerializer, self).validate(attrs) + ret = super(JobOptionsSerializer, self).validate(attrs) + ret.update(v1_credentials) + return ret class JobTemplateMixin(object): diff --git a/awx/main/tests/functional/api/test_job_template.py b/awx/main/tests/functional/api/test_job_template.py index 1289e0f8ee..ab40ab1066 100644 --- a/awx/main/tests/functional/api/test_job_template.py +++ b/awx/main/tests/functional/api/test_job_template.py @@ -36,6 +36,52 @@ def test_create(post, project, machine_credential, inventory, alice, grant_proje }, alice, expect=expect) +# TODO: remove in 3.3 +@pytest.mark.django_db +def test_create_with_v1_deprecated_credentials(get, post, project, machine_credential, credential, net_credential, inventory, alice): + project.use_role.members.add(alice) + machine_credential.use_role.members.add(alice) + inventory.use_role.members.add(alice) + + pk = post(reverse('api:job_template_list', kwargs={'version': 'v1'}), { + 'name': 'Some name', + 'project': project.id, + 'credential': machine_credential.id, + 'cloud_credential': credential.id, + 'network_credential': net_credential.id, + 'inventory': inventory.id, + 'playbook': 'helloworld.yml', + }, alice, expect=201).data['id'] + + url = reverse('api:job_template_detail', kwargs={'version': 'v1', 'pk': pk}) + response = get(url, alice) + assert response.data.get('cloud_credential') == credential.pk + assert response.data.get('network_credential') == net_credential.pk + + +# TODO: remove in 3.3 +@pytest.mark.django_db +def test_create_with_empty_v1_deprecated_credentials(get, post, project, machine_credential, inventory, alice): + project.use_role.members.add(alice) + machine_credential.use_role.members.add(alice) + inventory.use_role.members.add(alice) + + pk = post(reverse('api:job_template_list', kwargs={'version': 'v1'}), { + 'name': 'Some name', + 'project': project.id, + 'credential': machine_credential.id, + 'cloud_credential': None, + 'network_credential': None, + 'inventory': inventory.id, + 'playbook': 'helloworld.yml', + }, alice, expect=201).data['id'] + + url = reverse('api:job_template_detail', kwargs={'version': 'v1', 'pk': pk}) + response = get(url, alice) + assert response.data.get('cloud_credential') is None + assert response.data.get('network_credential') is None + + # TODO: test this with RBAC and lower-priveleged users @pytest.mark.django_db def test_extra_credential_creation(get, post, organization_factory, job_template_factory, credentialtype_aws): @@ -150,6 +196,7 @@ def test_attach_extra_credential_wrong_kind_xfail(get, post, organization_factor assert response.data.get('count') == 0 +# TODO: remove in 3.3 @pytest.mark.django_db def test_v1_extra_credentials_detail(get, organization_factory, job_template_factory, credential, net_credential): objs = organization_factory("org", superusers=['admin']) @@ -165,6 +212,7 @@ def test_v1_extra_credentials_detail(get, organization_factory, job_template_fac assert response.data.get('network_credential') == net_credential.pk +# TODO: remove in 3.3 @pytest.mark.django_db def test_v1_set_extra_credentials_assignment(get, patch, organization_factory, job_template_factory, credential, net_credential): objs = organization_factory("org", superusers=['admin'])