diff --git a/awx/main/tests/unit/test_access.py b/awx/main/tests/unit/test_access.py index 8a6687ba2f..05199fd5e3 100644 --- a/awx/main/tests/unit/test_access.py +++ b/awx/main/tests/unit/test_access.py @@ -1,9 +1,11 @@ import pytest import mock +import os from django.contrib.auth.models import User from django.forms.models import model_to_dict from rest_framework.exceptions import ParseError +from rest_framework.exceptions import PermissionDenied from awx.main.access import ( BaseAccess, @@ -14,7 +16,14 @@ from awx.main.access import ( ) from awx.conf.license import LicenseForbids -from awx.main.models import Credential, Inventory, Project, Role, Organization, Instance +from awx.main.models import ( + Credential, + Inventory, + Project, + Role, + Organization, + Instance, +) @pytest.fixture @@ -247,6 +256,41 @@ class TestWorkflowAccessMethods: assert access.can_add({'organization': 1}) +class TestCheckLicense: + @pytest.fixture + def validate_enhancements_mocker(self, mocker): + os.environ['SKIP_LICENSE_FIXUP_FOR_TEST'] = '1' + + def fn(available_instances=1, free_instances=0, host_exists=False): + + class MockFilter: + def exists(self): + return host_exists + + mocker.patch('awx.main.tasks.TaskEnhancer.validate_enhancements', return_value={'free_instances': free_instances, 'available_instances': available_instances, 'date_warning': True}) + + mock_filter = MockFilter() + mocker.patch('awx.main.models.Host.objects.filter', return_value=mock_filter) + + return fn + + def test_check_license_add_host_duplicate(self, validate_enhancements_mocker, user_unit): + validate_enhancements_mocker(available_instances=1, free_instances=0, host_exists=True) + + BaseAccess(None).check_license(add_host_name='blah', check_expiration=False) + + def test_check_license_add_host_new_exceed_licence(self, validate_enhancements_mocker, user_unit, mocker): + validate_enhancements_mocker(available_instances=1, free_instances=0, host_exists=False) + exception = None + + try: + BaseAccess(None).check_license(add_host_name='blah', check_expiration=False) + except PermissionDenied as e: + exception = e + + assert "License count of 1 instances has been reached." == str(exception) + + def test_user_capabilities_method(): """Unit test to verify that the user_capabilities method will defer to the appropriate sub-class methods of the access classes.