diff --git a/awx/main/migrations/_credentialtypes.py b/awx/main/migrations/_credentialtypes.py index 1c90822bab..104caa334a 100644 --- a/awx/main/migrations/_credentialtypes.py +++ b/awx/main/migrations/_credentialtypes.py @@ -175,4 +175,4 @@ def migrate_job_credentials(apps, schema_editor): def create_ovirt4_credtype(apps, schema_editor): - CredentialType.defaults['ovirt4']().save() + CredentialType.setup_tower_managed_defaults() diff --git a/awx/main/models/credential.py b/awx/main/models/credential.py index 5a1832f323..7db8db04e9 100644 --- a/awx/main/models/credential.py +++ b/awx/main/models/credential.py @@ -3,6 +3,7 @@ from collections import OrderedDict import functools import json +import logging import operator import os import stat @@ -35,6 +36,8 @@ from awx.main.utils import encrypt_field __all__ = ['Credential', 'CredentialType', 'V1Credential'] +logger = logging.getLogger('awx.main.models.credential') + class V1Credential(object): @@ -468,6 +471,11 @@ class CredentialType(CommonModelNameNotUnique): for default in cls.defaults.values(): default_ = default() if persisted: + if CredentialType.objects.filter(name=default_.name, kind=default_.kind).count(): + continue + logger.debug(_( + "adding %s credential type" % default_.name + )) default_.save() @classmethod diff --git a/awx/main/tests/functional/api/test_credential.py b/awx/main/tests/functional/api/test_credential.py index b60d215e14..ff29f9ed3f 100644 --- a/awx/main/tests/functional/api/test_credential.py +++ b/awx/main/tests/functional/api/test_credential.py @@ -14,6 +14,17 @@ EXAMPLE_PRIVATE_KEY = '-----BEGIN PRIVATE KEY-----\nxyz==\n-----END PRIVATE KEY- EXAMPLE_ENCRYPTED_PRIVATE_KEY = '-----BEGIN PRIVATE KEY-----\nProc-Type: 4,ENCRYPTED\nxyz==\n-----END PRIVATE KEY-----' +@pytest.mark.django_db +def test_idempotent_credential_type_setup(): + assert CredentialType.objects.count() == 0 + CredentialType.setup_tower_managed_defaults() + total = CredentialType.objects.count() + assert total > 0 + + CredentialType.setup_tower_managed_defaults() + assert CredentialType.objects.count() == total + + @pytest.mark.django_db @pytest.mark.parametrize('kind, total', [ ('ssh', 1), ('net', 0)