diff --git a/awx/api/permissions.py b/awx/api/permissions.py index a655360dc8..8ec26a2cc8 100644 --- a/awx/api/permissions.py +++ b/awx/api/permissions.py @@ -4,9 +4,6 @@ # Python import logging -# Django -from django.http import Http404 - # Django REST Framework from rest_framework.exceptions import MethodNotAllowed, PermissionDenied from rest_framework import permissions @@ -19,7 +16,7 @@ from awx.main.utils import get_object_or_400 logger = logging.getLogger('awx.api.permissions') __all__ = ['ModelAccessPermission', 'JobTemplateCallbackPermission', - 'TaskPermission', 'ProjectUpdatePermission', 'UserPermission'] + 'TaskPermission', 'ProjectUpdatePermission', 'UserPermission',] class ModelAccessPermission(permissions.BasePermission): @@ -96,13 +93,6 @@ class ModelAccessPermission(permissions.BasePermission): method based on the request method. ''' - # Check that obj (if given) is active, otherwise raise a 404. - active = getattr(obj, 'active', getattr(obj, 'is_active', True)) - if callable(active): - active = active() - if not active: - raise Http404() - # Don't allow anonymous users. 401, not 403, hence no raised exception. if not request.user or request.user.is_anonymous(): return False @@ -216,3 +206,5 @@ class UserPermission(ModelAccessPermission): elif request.user.is_superuser: return True raise PermissionDenied() + + diff --git a/awx/api/views.py b/awx/api/views.py index bf59089cdf..f36404d710 100644 --- a/awx/api/views.py +++ b/awx/api/views.py @@ -1685,6 +1685,7 @@ class HostList(ListCreateAPIView): class HostDetail(RetrieveUpdateDestroyAPIView): + always_allow_superuser = False model = Host serializer_class = HostSerializer diff --git a/awx/main/access.py b/awx/main/access.py index 196d58b6c6..9417563f75 100644 --- a/awx/main/access.py +++ b/awx/main/access.py @@ -285,7 +285,7 @@ class BaseAccess(object): return True # User has access to both, permission check passed - def check_license(self, add_host=False, feature=None, check_expiration=True): + def check_license(self, add_host_name=None, feature=None, check_expiration=True): validation_info = TaskEnhancer().validate_enhancements() if ('test' in sys.argv or 'py.test' in sys.argv[0] or 'jenkins' in sys.argv) and not os.environ.get('SKIP_LICENSE_FIXUP_FOR_TEST', ''): validation_info['free_instances'] = 99999999 @@ -299,11 +299,14 @@ class BaseAccess(object): free_instances = validation_info.get('free_instances', 0) available_instances = validation_info.get('available_instances', 0) - if add_host and free_instances == 0: - raise PermissionDenied(_("License count of %s instances has been reached.") % available_instances) - elif add_host and free_instances < 0: - raise PermissionDenied(_("License count of %s instances has been exceeded.") % available_instances) - elif not add_host and free_instances < 0: + + if add_host_name: + host_exists = Host.objects.filter(name=add_host_name).exists() + if not host_exists and free_instances == 0: + raise PermissionDenied(_("License count of %s instances has been reached.") % available_instances) + elif not host_exists and free_instances < 0: + raise PermissionDenied(_("License count of %s instances has been exceeded.") % available_instances) + elif not add_host_name and free_instances < 0: raise PermissionDenied(_("Host count exceeds available instances.")) if feature is not None: @@ -612,7 +615,7 @@ class HostAccess(BaseAccess): return False # Check to see if we have enough licenses - self.check_license(add_host=True) + self.check_license(add_host_name=data.get('name', None)) return True def can_change(self, obj, data): @@ -620,6 +623,11 @@ class HostAccess(BaseAccess): inventory_pk = get_pk_from_dict(data, 'inventory') if obj and inventory_pk and obj.inventory.pk != inventory_pk: raise PermissionDenied(_('Unable to change inventory on a host.')) + + # Prevent renaming a host that might exceed license count + if 'name' in data: + self.check_license(add_host_name=data['name']) + # Checks for admin or change permission on inventory, controls whether # the user can edit variable data. return obj and self.user in obj.inventory.admin_role diff --git a/awx/main/tests/unit/test_access.py b/awx/main/tests/unit/test_access.py index 8a6687ba2f..05199fd5e3 100644 --- a/awx/main/tests/unit/test_access.py +++ b/awx/main/tests/unit/test_access.py @@ -1,9 +1,11 @@ import pytest import mock +import os from django.contrib.auth.models import User from django.forms.models import model_to_dict from rest_framework.exceptions import ParseError +from rest_framework.exceptions import PermissionDenied from awx.main.access import ( BaseAccess, @@ -14,7 +16,14 @@ from awx.main.access import ( ) from awx.conf.license import LicenseForbids -from awx.main.models import Credential, Inventory, Project, Role, Organization, Instance +from awx.main.models import ( + Credential, + Inventory, + Project, + Role, + Organization, + Instance, +) @pytest.fixture @@ -247,6 +256,41 @@ class TestWorkflowAccessMethods: assert access.can_add({'organization': 1}) +class TestCheckLicense: + @pytest.fixture + def validate_enhancements_mocker(self, mocker): + os.environ['SKIP_LICENSE_FIXUP_FOR_TEST'] = '1' + + def fn(available_instances=1, free_instances=0, host_exists=False): + + class MockFilter: + def exists(self): + return host_exists + + mocker.patch('awx.main.tasks.TaskEnhancer.validate_enhancements', return_value={'free_instances': free_instances, 'available_instances': available_instances, 'date_warning': True}) + + mock_filter = MockFilter() + mocker.patch('awx.main.models.Host.objects.filter', return_value=mock_filter) + + return fn + + def test_check_license_add_host_duplicate(self, validate_enhancements_mocker, user_unit): + validate_enhancements_mocker(available_instances=1, free_instances=0, host_exists=True) + + BaseAccess(None).check_license(add_host_name='blah', check_expiration=False) + + def test_check_license_add_host_new_exceed_licence(self, validate_enhancements_mocker, user_unit, mocker): + validate_enhancements_mocker(available_instances=1, free_instances=0, host_exists=False) + exception = None + + try: + BaseAccess(None).check_license(add_host_name='blah', check_expiration=False) + except PermissionDenied as e: + exception = e + + assert "License count of 1 instances has been reached." == str(exception) + + def test_user_capabilities_method(): """Unit test to verify that the user_capabilities method will defer to the appropriate sub-class methods of the access classes.