Pass existing object references within access methods

This avoids re-loading objects from the database in our
chain of permission checking, wherever possible.
access.py is equiped to handle object references instead
of pk ints, and permissions.py is changed to pass those refs.
This commit is contained in:
AlanCoding
2017-08-30 16:05:02 -04:00
parent bfea00f6dc
commit 41940687f1
5 changed files with 69 additions and 68 deletions

View File

@@ -34,7 +34,7 @@ class ModelAccessPermission(permissions.BasePermission):
def check_get_permissions(self, request, view, obj=None): def check_get_permissions(self, request, view, obj=None):
if hasattr(view, 'parent_model'): if hasattr(view, 'parent_model'):
parent_obj = get_object_or_400(view.parent_model, pk=view.kwargs['pk']) parent_obj = view.get_parent_object()
if not check_user_access(request.user, view.parent_model, 'read', if not check_user_access(request.user, view.parent_model, 'read',
parent_obj): parent_obj):
return False return False
@@ -44,12 +44,12 @@ class ModelAccessPermission(permissions.BasePermission):
def check_post_permissions(self, request, view, obj=None): def check_post_permissions(self, request, view, obj=None):
if hasattr(view, 'parent_model'): if hasattr(view, 'parent_model'):
parent_obj = get_object_or_400(view.parent_model, pk=view.kwargs['pk']) parent_obj = view.get_parent_object()
if not check_user_access(request.user, view.parent_model, 'read', if not check_user_access(request.user, view.parent_model, 'read',
parent_obj): parent_obj):
return False return False
if hasattr(view, 'parent_key'): if hasattr(view, 'parent_key'):
if not check_user_access(request.user, view.model, 'add', {view.parent_key: parent_obj.pk}): if not check_user_access(request.user, view.model, 'add', {view.parent_key: parent_obj}):
return False return False
return True return True
elif getattr(view, 'is_job_start', False): elif getattr(view, 'is_job_start', False):

View File

@@ -48,7 +48,8 @@ class BrowsableAPIRenderer(renderers.BrowsableAPIRenderer):
obj = getattr(view, 'object', None) obj = getattr(view, 'object', None)
if obj is None and hasattr(view, 'get_object') and hasattr(view, 'retrieve'): if obj is None and hasattr(view, 'get_object') and hasattr(view, 'retrieve'):
try: try:
obj = view.get_object() view.object = view.get_object()
obj = view.object
except Exception: except Exception:
obj = None obj = None
with override_method(view, request, method) as request: with override_method(view, request, method) as request:

View File

@@ -17,7 +17,12 @@ from django.utils.translation import ugettext_lazy as _
from rest_framework.exceptions import ParseError, PermissionDenied, ValidationError from rest_framework.exceptions import ParseError, PermissionDenied, ValidationError
# AWX # AWX
from awx.main.utils import * # noqa from awx.main.utils import (
get_object_or_400,
get_pk_from_dict,
to_python_boolean,
get_licenser,
)
from awx.main.models import * # noqa from awx.main.models import * # noqa
from awx.main.models.unified_jobs import ACTIVE_STATES from awx.main.models.unified_jobs import ACTIVE_STATES
from awx.main.models.mixins import ResourceMixin from awx.main.models.mixins import ResourceMixin
@@ -36,6 +41,36 @@ access_registry = {
} }
def get_object_from_data(field, Model, data, obj=None):
"""
Utility method to obtain related object in data according to fallbacks:
- if data contains key with pointer to Django object, return that
- if contains integer, get object from database
- if this does not work, raise exception
"""
try:
raw_value = data[field]
except KeyError:
# Calling method needs to deal with non-existence of key
raise ParseError(_("Required related field %s for permission check." % field))
if isinstance(raw_value, Model):
return raw_value
elif raw_value is None:
return None
else:
try:
new_pk = int(raw_value)
# Avoid database query by comparing pk to model for similarity
if obj and new_pk == getattr(obj, '%s_id' % field, None):
return getattr(obj, field)
else:
# Get the new resource from the database
return get_object_or_400(Model, pk=new_pk)
except (TypeError, ValueError):
raise ParseError(_("Bad data found in related field %s." % field))
class StateConflict(ValidationError): class StateConflict(ValidationError):
status_code = 409 status_code = 409
@@ -205,24 +240,8 @@ class BaseAccess(object):
# Use reference object's related fields, if given # Use reference object's related fields, if given
new = getattr(data['reference_obj'], field) new = getattr(data['reference_obj'], field)
elif data and field in data: elif data and field in data:
# Obtain the resource specified in `data` new = get_object_from_data(field, Model, data, obj=obj)
raw_value = data[field] else:
if isinstance(raw_value, Model):
new = raw_value
elif raw_value is None:
new = None
else:
try:
new_pk = int(raw_value)
# Avoid database query by comparing pk to model for similarity
if obj and new_pk == getattr(obj, '%s_id' % field, None):
changed = False
else:
# Get the new resource from the database
new = get_object_or_400(Model, pk=new_pk)
except (TypeError, ValueError):
raise ParseError(_("Bad data found in related field %s." % field))
elif data is None or field not in data:
changed = False changed = False
# Obtain existing related resource # Obtain existing related resource
@@ -940,17 +959,14 @@ class CredentialAccess(BaseAccess):
def can_add(self, data): def can_add(self, data):
if not data: # So the browseable API will work if not data: # So the browseable API will work
return True return True
user_pk = get_pk_from_dict(data, 'user') if data and data.get('user', None):
if user_pk: user_obj = get_object_from_data('user', User, data)
user_obj = get_object_or_400(User, pk=user_pk)
return check_user_access(self.user, User, 'change', user_obj, None) return check_user_access(self.user, User, 'change', user_obj, None)
team_pk = get_pk_from_dict(data, 'team') if data and data.get('team', None):
if team_pk: team_obj = get_object_from_data('team', Team, data)
team_obj = get_object_or_400(Team, pk=team_pk)
return check_user_access(self.user, Team, 'change', team_obj, None) return check_user_access(self.user, Team, 'change', team_obj, None)
organization_pk = get_pk_from_dict(data, 'organization') if data and data.get('organization', None):
if organization_pk: organization_obj = get_object_from_data('organization', Organization, data)
organization_obj = get_object_or_400(Organization, pk=organization_pk)
return check_user_access(self.user, Organization, 'change', organization_obj, None) return check_user_access(self.user, Organization, 'change', organization_obj, None)
return False return False
@@ -1173,9 +1189,8 @@ class JobTemplateAccess(BaseAccess):
if reference_obj: if reference_obj:
return getattr(reference_obj, field, None) return getattr(reference_obj, field, None)
else: else:
pk = get_pk_from_dict(data, field) if data and data.get(field, None):
if pk: return get_object_from_data(field, Class, data)
return get_object_or_400(Class, pk=pk)
else: else:
return None return None
@@ -1261,23 +1276,6 @@ class JobTemplateAccess(BaseAccess):
return False return False
return True return True
def can_update_sensitive_fields(self, obj, data):
project_id = data.get('project', obj.project.id if obj.project else None)
inventory_id = data.get('inventory', obj.inventory.id if obj.inventory else None)
credential_id = data.get('credential', obj.credential.id if obj.credential else None)
vault_credential_id = data.get('credential', obj.vault_credential.id if obj.vault_credential else None)
if project_id and self.user not in Project.objects.get(pk=project_id).use_role:
return False
if inventory_id and self.user not in Inventory.objects.get(pk=inventory_id).use_role:
return False
if credential_id and self.user not in Credential.objects.get(pk=credential_id).use_role:
return False
if vault_credential_id and self.user not in Credential.objects.get(pk=vault_credential_id).use_role:
return False
return True
def can_delete(self, obj): def can_delete(self, obj):
is_delete_allowed = self.user.is_superuser or self.user in obj.admin_role is_delete_allowed = self.user.is_superuser or self.user in obj.admin_role
if not is_delete_allowed: if not is_delete_allowed:
@@ -1387,9 +1385,8 @@ class JobAccess(BaseAccess):
add_data = dict(data.items()) add_data = dict(data.items())
# If a job template is provided, the user should have read access to it. # If a job template is provided, the user should have read access to it.
job_template_pk = get_pk_from_dict(data, 'job_template') if data and data.get('job_template', None):
if job_template_pk: job_template = get_object_from_data('job_template', JobTemplate, data)
job_template = get_object_or_400(JobTemplate, pk=job_template_pk)
add_data.setdefault('inventory', job_template.inventory.pk) add_data.setdefault('inventory', job_template.inventory.pk)
add_data.setdefault('project', job_template.project.pk) add_data.setdefault('project', job_template.project.pk)
add_data.setdefault('job_type', job_template.job_type) add_data.setdefault('job_type', job_template.job_type)

View File

@@ -242,28 +242,28 @@ class TestResourceAccessList:
), method='GET') ), method='GET')
def mock_view(self): def mock_view(self, parent=None):
view = ResourceAccessList() view = ResourceAccessList()
view.parent_model = Organization view.parent_model = Organization
view.kwargs = {'pk': 4} view.kwargs = {'pk': 4}
if parent:
view.get_parent_object = lambda: parent
return view return view
def test_parent_access_check_failed(self, mocker, mock_organization): def test_parent_access_check_failed(self, mocker, mock_organization):
with mocker.patch('awx.api.permissions.get_object_or_400', return_value=mock_organization): mock_access = mocker.MagicMock(__name__='for logger', return_value=False)
mock_access = mocker.MagicMock(__name__='for logger', return_value=False) with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access):
with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access): with pytest.raises(PermissionDenied):
with pytest.raises(PermissionDenied): self.mock_view(parent=mock_organization).check_permissions(self.mock_request())
self.mock_view().check_permissions(self.mock_request()) mock_access.assert_called_once_with(mock_organization)
mock_access.assert_called_once_with(mock_organization)
def test_parent_access_check_worked(self, mocker, mock_organization): def test_parent_access_check_worked(self, mocker, mock_organization):
with mocker.patch('awx.api.permissions.get_object_or_400', return_value=mock_organization): mock_access = mocker.MagicMock(__name__='for logger', return_value=True)
mock_access = mocker.MagicMock(__name__='for logger', return_value=True) with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access):
with mocker.patch('awx.main.access.BaseAccess.can_read', mock_access): self.mock_view(parent=mock_organization).check_permissions(self.mock_request())
self.mock_view().check_permissions(self.mock_request()) mock_access.assert_called_once_with(mock_organization)
mock_access.assert_called_once_with(mock_organization)
def test_related_search_reverse_FK_field(): def test_related_search_reverse_FK_field():

View File

@@ -724,7 +724,10 @@ def get_pk_from_dict(_dict, key):
Helper for obtaining a pk from user data dict or None if not present. Helper for obtaining a pk from user data dict or None if not present.
''' '''
try: try:
return int(_dict[key]) val = _dict[key]
if isinstance(val, object) and hasattr(val, 'id'):
return val.id # return id if given model object
return int(val)
except (TypeError, KeyError, ValueError): except (TypeError, KeyError, ValueError):
return None return None