diff --git a/awx/api/permissions.py b/awx/api/permissions.py index cf29f3454a..e01ab4d9fb 100644 --- a/awx/api/permissions.py +++ b/awx/api/permissions.py @@ -34,7 +34,7 @@ class ModelAccessPermission(permissions.BasePermission): def check_get_permissions(self, request, view, obj=None): if hasattr(view, 'parent_model'): - parent_obj = get_object_or_400(view.parent_model, pk=view.kwargs['pk']) + parent_obj = view.get_parent_object() if not check_user_access(request.user, view.parent_model, 'read', parent_obj): return False @@ -44,12 +44,12 @@ class ModelAccessPermission(permissions.BasePermission): def check_post_permissions(self, request, view, obj=None): if hasattr(view, 'parent_model'): - parent_obj = get_object_or_400(view.parent_model, pk=view.kwargs['pk']) + parent_obj = view.get_parent_object() if not check_user_access(request.user, view.parent_model, 'read', parent_obj): return False if hasattr(view, 'parent_key'): - if not check_user_access(request.user, view.model, 'add', {view.parent_key: parent_obj.pk}): + if not check_user_access(request.user, view.model, 'add', {view.parent_key: parent_obj}): return False return True elif getattr(view, 'is_job_start', False): diff --git a/awx/api/renderers.py b/awx/api/renderers.py index 006057a09b..f6fb089d47 100644 --- a/awx/api/renderers.py +++ b/awx/api/renderers.py @@ -48,7 +48,8 @@ class BrowsableAPIRenderer(renderers.BrowsableAPIRenderer): obj = getattr(view, 'object', None) if obj is None and hasattr(view, 'get_object') and hasattr(view, 'retrieve'): try: - obj = view.get_object() + view.object = view.get_object() + obj = view.object except Exception: obj = None with override_method(view, request, method) as request: diff --git a/awx/main/access.py b/awx/main/access.py index d73b877389..bf72dd3d34 100644 --- a/awx/main/access.py +++ b/awx/main/access.py @@ -17,7 +17,12 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework.exceptions import ParseError, PermissionDenied, ValidationError # AWX -from awx.main.utils import * # noqa +from awx.main.utils import ( + get_object_or_400, + get_pk_from_dict, + to_python_boolean, + get_licenser, +) from awx.main.models import * # noqa from awx.main.models.unified_jobs import ACTIVE_STATES from awx.main.models.mixins import ResourceMixin @@ -36,6 +41,36 @@ access_registry = { } +def get_object_from_data(field, Model, data, obj=None): + """ + Utility method to obtain related object in data according to fallbacks: + - if data contains key with pointer to Django object, return that + - if contains integer, get object from database + - if this does not work, raise exception + """ + try: + raw_value = data[field] + except KeyError: + # Calling method needs to deal with non-existence of key + raise ParseError(_("Required related field %s for permission check." % field)) + + if isinstance(raw_value, Model): + return raw_value + elif raw_value is None: + return None + else: + try: + new_pk = int(raw_value) + # Avoid database query by comparing pk to model for similarity + if obj and new_pk == getattr(obj, '%s_id' % field, None): + return getattr(obj, field) + else: + # Get the new resource from the database + return get_object_or_400(Model, pk=new_pk) + except (TypeError, ValueError): + raise ParseError(_("Bad data found in related field %s." % field)) + + class StateConflict(ValidationError): status_code = 409 @@ -205,24 +240,8 @@ class BaseAccess(object): # Use reference object's related fields, if given new = getattr(data['reference_obj'], field) elif data and field in data: - # Obtain the resource specified in `data` - raw_value = data[field] - if isinstance(raw_value, Model): - new = raw_value - elif raw_value is None: - new = None - else: - try: - new_pk = int(raw_value) - # Avoid database query by comparing pk to model for similarity - if obj and new_pk == getattr(obj, '%s_id' % field, None): - changed = False - else: - # Get the new resource from the database - new = get_object_or_400(Model, pk=new_pk) - except (TypeError, ValueError): - raise ParseError(_("Bad data found in related field %s." % field)) - elif data is None or field not in data: + new = get_object_from_data(field, Model, data, obj=obj) + else: changed = False # Obtain existing related resource @@ -940,17 +959,14 @@ class CredentialAccess(BaseAccess): def can_add(self, data): if not data: # So the browseable API will work return True - user_pk = get_pk_from_dict(data, 'user') - if user_pk: - user_obj = get_object_or_400(User, pk=user_pk) + if data and data.get('user', None): + user_obj = get_object_from_data('user', User, data) return check_user_access(self.user, User, 'change', user_obj, None) - team_pk = get_pk_from_dict(data, 'team') - if team_pk: - team_obj = get_object_or_400(Team, pk=team_pk) + if data and data.get('team', None): + team_obj = get_object_from_data('team', Team, data) return check_user_access(self.user, Team, 'change', team_obj, None) - organization_pk = get_pk_from_dict(data, 'organization') - if organization_pk: - organization_obj = get_object_or_400(Organization, pk=organization_pk) + if data and data.get('organization', None): + organization_obj = get_object_from_data('organization', Organization, data) return check_user_access(self.user, Organization, 'change', organization_obj, None) return False @@ -1173,9 +1189,8 @@ class JobTemplateAccess(BaseAccess): if reference_obj: return getattr(reference_obj, field, None) else: - pk = get_pk_from_dict(data, field) - if pk: - return get_object_or_400(Class, pk=pk) + if data and data.get(field, None): + return get_object_from_data(field, Class, data) else: return None @@ -1261,23 +1276,6 @@ class JobTemplateAccess(BaseAccess): return False return True - def can_update_sensitive_fields(self, obj, data): - project_id = data.get('project', obj.project.id if obj.project else None) - inventory_id = data.get('inventory', obj.inventory.id if obj.inventory else None) - credential_id = data.get('credential', obj.credential.id if obj.credential else None) - vault_credential_id = data.get('credential', obj.vault_credential.id if obj.vault_credential else None) - - if project_id and self.user not in Project.objects.get(pk=project_id).use_role: - return False - if inventory_id and self.user not in Inventory.objects.get(pk=inventory_id).use_role: - return False - if credential_id and self.user not in Credential.objects.get(pk=credential_id).use_role: - return False - if vault_credential_id and self.user not in Credential.objects.get(pk=vault_credential_id).use_role: - return False - - return True - def can_delete(self, obj): is_delete_allowed = self.user.is_superuser or self.user in obj.admin_role if not is_delete_allowed: @@ -1387,9 +1385,8 @@ class JobAccess(BaseAccess): add_data = dict(data.items()) # If a job template is provided, the user should have read access to it. - job_template_pk = get_pk_from_dict(data, 'job_template') - if job_template_pk: - job_template = get_object_or_400(JobTemplate, pk=job_template_pk) + if data and data.get('job_template', None): + job_template = get_object_from_data('job_template', JobTemplate, data) add_data.setdefault('inventory', job_template.inventory.pk) add_data.setdefault('project', job_template.project.pk) add_data.setdefault('job_type', job_template.job_type) diff --git a/awx/main/tests/unit/api/test_generics.py b/awx/main/tests/unit/api/test_generics.py index 10baf7eab1..62eac9d99c 100644 --- a/awx/main/tests/unit/api/test_generics.py +++ b/awx/main/tests/unit/api/test_generics.py @@ -242,28 +242,28 @@ class TestResourceAccessList: ), method='GET') - def mock_view(self): + def mock_view(self, parent=None): view = ResourceAccessList() view.parent_model = Organization view.kwargs = {'pk': 4} + if parent: + view.get_parent_object = lambda: parent return view def test_parent_access_check_failed(self, mocker, mock_organization): - with mocker.patch('awx.api.permissions.get_object_or_400', return_value=mock_organization): - mock_access = mocker.MagicMock(__name__='for logger', return_value=False) - with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access): - with pytest.raises(PermissionDenied): - self.mock_view().check_permissions(self.mock_request()) - mock_access.assert_called_once_with(mock_organization) + mock_access = mocker.MagicMock(__name__='for logger', return_value=False) + with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access): + with pytest.raises(PermissionDenied): + self.mock_view(parent=mock_organization).check_permissions(self.mock_request()) + mock_access.assert_called_once_with(mock_organization) def test_parent_access_check_worked(self, mocker, mock_organization): - with mocker.patch('awx.api.permissions.get_object_or_400', return_value=mock_organization): - mock_access = mocker.MagicMock(__name__='for logger', return_value=True) - with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access): - self.mock_view().check_permissions(self.mock_request()) - mock_access.assert_called_once_with(mock_organization) + mock_access = mocker.MagicMock(__name__='for logger', return_value=True) + with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access): + self.mock_view(parent=mock_organization).check_permissions(self.mock_request()) + mock_access.assert_called_once_with(mock_organization) def test_related_search_reverse_FK_field(): diff --git a/awx/main/utils/common.py b/awx/main/utils/common.py index 58d795567f..291ff0722e 100644 --- a/awx/main/utils/common.py +++ b/awx/main/utils/common.py @@ -724,7 +724,10 @@ def get_pk_from_dict(_dict, key): Helper for obtaining a pk from user data dict or None if not present. ''' try: - return int(_dict[key]) + val = _dict[key] + if isinstance(val, object) and hasattr(val, 'id'): + return val.id # return id if given model object + return int(val) except (TypeError, KeyError, ValueError): return None