From b4a446dba040993b72dc98489b683f755816d26d Mon Sep 17 00:00:00 2001 From: AlanCoding Date: Thu, 26 Oct 2017 11:25:40 -0400 Subject: [PATCH] raise error for invalid type lookup --- .../tests/functional/utils/test_common.py | 21 +++++++++++++++++++ awx/main/utils/common.py | 3 +++ 2 files changed, 24 insertions(+) diff --git a/awx/main/tests/functional/utils/test_common.py b/awx/main/tests/functional/utils/test_common.py index f9dcc3769b..5ef89b0d06 100644 --- a/awx/main/tests/functional/utils/test_common.py +++ b/awx/main/tests/functional/utils/test_common.py @@ -3,10 +3,14 @@ import pytest import copy import json +from django.db import DatabaseError + from awx.main.utils.common import ( model_instance_diff, model_to_dict, + get_model_for_type ) +from awx.main import models @pytest.mark.django_db @@ -58,3 +62,20 @@ def test_model_instance_diff(alice, bob): assert hasattr(alice, 'is_superuser') assert hasattr(bob, 'is_superuser') assert 'is_superuser' not in output_dict + + +@pytest.mark.django_db +def test_get_model_for_invalid_type(): + with pytest.raises(DatabaseError) as exc: + get_model_for_type('foobar') + assert 'not a valid AWX model' in str(exc) + + +@pytest.mark.django_db +@pytest.mark.parametrize("model_type,model_class", [ + ('inventory', models.Inventory), + ('job_template', models.JobTemplate), + ('unified_job_template', models.UnifiedJobTemplate) +]) +def test_get_model_for_valid_type(model_type, model_class): + assert get_model_for_type(model_type) == model_class diff --git a/awx/main/utils/common.py b/awx/main/utils/common.py index ccd4322b5d..449f32bcf8 100644 --- a/awx/main/utils/common.py +++ b/awx/main/utils/common.py @@ -24,6 +24,7 @@ from decorator import decorator # Django from django.core.exceptions import ObjectDoesNotExist +from django.db import DatabaseError from django.utils.translation import ugettext_lazy as _ from django.db.models.fields.related import ForeignObjectRel, ManyToManyField @@ -506,6 +507,8 @@ def get_model_for_type(type): ct_type = get_type_for_model(ct_model) if type == ct_type: return ct_model + else: + raise DatabaseError('"{}" is not a valid AWX model.'.format(type)) def cache_list_capabilities(page, prefetch_list, model, user):