Merge pull request #5191 from AlanCoding/UJT_capability_prefetch

Unified Job Template user_capability prefetch + listview optimizations
This commit is contained in:
Alan Rominger
2017-02-06 10:18:26 -05:00
committed by GitHub
9 changed files with 92 additions and 34 deletions

View File

@@ -331,13 +331,7 @@ class BaseSerializer(serializers.ModelSerializer):
roles = {} roles = {}
for field in obj._meta.get_fields(): for field in obj._meta.get_fields():
if type(field) is ImplicitRoleField: if type(field) is ImplicitRoleField:
role = getattr(obj, field.name) roles[field.name] = role_summary_fields_generator(obj, field.name)
#roles[field.name] = RoleSerializer(data=role).to_representation(role)
roles[field.name] = {
'id': role.id,
'name': role.name,
'description': role.get_description(reference_content_object=obj),
}
if len(roles) > 0: if len(roles) > 0:
summary_fields['object_roles'] = roles summary_fields['object_roles'] = roles
@@ -1839,11 +1833,15 @@ class OrganizationCredentialSerializerCreate(CredentialSerializerCreate):
class LabelsListMixin(object): class LabelsListMixin(object):
def _summary_field_labels(self, obj): def _summary_field_labels(self, obj):
label_list = [{'id': x.id, 'name': x.name} for x in obj.labels.all().order_by('name')[:10]] if hasattr(obj, '_prefetched_objects_cache') and obj.labels.prefetch_cache_name in obj._prefetched_objects_cache:
if len(label_list) < 10: label_list = [{'id': x.id, 'name': x.name} for x in obj.labels.all()[:10]]
label_ct = len(label_list) label_ct = len(obj.labels.all())
else: else:
label_ct = obj.labels.count() label_list = [{'id': x.id, 'name': x.name} for x in obj.labels.all().order_by('name')[:10]]
if len(label_list) < 10:
label_ct = len(label_list)
else:
label_ct = obj.labels.count()
return {'count': label_ct, 'results': label_list} return {'count': label_ct, 'results': label_list}
def get_summary_fields(self, obj): def get_summary_fields(self, obj):

View File

@@ -2568,6 +2568,9 @@ class JobTemplateLabelList(DeleteLastUnattachLabelMixin, SubListCreateAttachDeta
request.data['id'] = existing.id request.data['id'] = existing.id
del request.data['name'] del request.data['name']
del request.data['organization'] del request.data['organization']
if Label.objects.filter(unifiedjobtemplate_labels=self.kwargs['pk']).count() > 100:
return Response(dict(msg=_('Maximum number of labels for {} reached.'.format(
self.parent_model._meta.verbose_name_raw))), status=status.HTTP_400_BAD_REQUEST)
return super(JobTemplateLabelList, self).post(request, *args, **kwargs) return super(JobTemplateLabelList, self).post(request, *args, **kwargs)
@@ -3783,6 +3786,12 @@ class UnifiedJobTemplateList(ListAPIView):
model = UnifiedJobTemplate model = UnifiedJobTemplate
serializer_class = UnifiedJobTemplateSerializer serializer_class = UnifiedJobTemplateSerializer
new_in_148 = True new_in_148 = True
capabilities_prefetch = [
'admin', 'execute',
{'copy': ['jobtemplate.project.use', 'jobtemplate.inventory.use', 'jobtemplate.credential.use',
'jobtemplate.cloud_credential.use', 'jobtemplate.network_credential.use',
'workflowjobtemplate.organization.admin']}
]
class UnifiedJobList(ListAPIView): class UnifiedJobList(ListAPIView):

View File

@@ -8,7 +8,7 @@ import logging
# Django # Django
from django.conf import settings from django.conf import settings
from django.db.models import Q from django.db.models import Q, Prefetch
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@@ -1868,6 +1868,7 @@ class UnifiedJobTemplateAccess(BaseAccess):
qs = qs.prefetch_related( qs = qs.prefetch_related(
'last_job', 'last_job',
'current_job', 'current_job',
Prefetch('labels', queryset=Label.objects.all().order_by('name'))
) )
# WISH - sure would be nice if the following worked, but it does not. # WISH - sure would be nice if the following worked, but it does not.
@@ -1915,6 +1916,7 @@ class UnifiedJobAccess(BaseAccess):
'modified_by', 'modified_by',
'unified_job_node__workflow_job', 'unified_job_node__workflow_job',
'unified_job_template', 'unified_job_template',
Prefetch('labels', queryset=Label.objects.all().order_by('name'))
) )
# WISH - sure would be nice if the following worked, but it does not. # WISH - sure would be nice if the following worked, but it does not.

View File

@@ -25,6 +25,7 @@ __all__ = [
'get_roles_on_resource', 'get_roles_on_resource',
'ROLE_SINGLETON_SYSTEM_ADMINISTRATOR', 'ROLE_SINGLETON_SYSTEM_ADMINISTRATOR',
'ROLE_SINGLETON_SYSTEM_AUDITOR', 'ROLE_SINGLETON_SYSTEM_AUDITOR',
'role_summary_fields_generator'
] ]
logger = logging.getLogger('awx.main.models.rbac') logger = logging.getLogger('awx.main.models.rbac')
@@ -165,13 +166,11 @@ class Role(models.Model):
global role_names global role_names
return role_names[self.role_field] return role_names[self.role_field]
def get_description(self, reference_content_object=None): @property
def description(self):
global role_descriptions global role_descriptions
description = role_descriptions[self.role_field] description = role_descriptions[self.role_field]
if reference_content_object: content_type = self.content_type
content_type = ContentType.objects.get_for_model(reference_content_object)
else:
content_type = self.content_type
if '%s' in description and content_type: if '%s' in description and content_type:
model = content_type.model_class() model = content_type.model_class()
model_name = re.sub(r'([a-z])([A-Z])', r'\1 \2', model.__name__).lower() model_name = re.sub(r'([a-z])([A-Z])', r'\1 \2', model.__name__).lower()
@@ -179,8 +178,6 @@ class Role(models.Model):
return description return description
description = property(get_description)
@staticmethod @staticmethod
def rebuild_role_ancestor_list(additions, removals): def rebuild_role_ancestor_list(additions, removals):
''' '''
@@ -474,3 +471,20 @@ def get_roles_on_resource(resource, accessor):
object_id=resource.id object_id=resource.id
).values_list('role_field', flat=True).distinct() ).values_list('role_field', flat=True).distinct()
] ]
def role_summary_fields_generator(content_object, role_field):
global role_descriptions
global role_names
summary = {}
description = role_descriptions[role_field]
content_type = ContentType.objects.get_for_model(content_object)
if '%s' in description and content_type:
model = content_object.__class__
model_name = re.sub(r'([a-z])([A-Z])', r'\1 \2', model.__name__).lower()
description = description % model_name
summary['description'] = description
summary['name'] = role_names[role_field]
summary['id'] = getattr(content_object, '{}_id'.format(role_field))
return summary

View File

@@ -168,6 +168,12 @@ class UnifiedJobTemplate(PolymorphicModel, CommonModelNameNotUnique, Notificatio
else: else:
return super(UnifiedJobTemplate, self).unique_error_message(model_class, unique_check) return super(UnifiedJobTemplate, self).unique_error_message(model_class, unique_check)
@classmethod
def invalid_user_capabilities_prefetch_models(cls):
if cls != UnifiedJobTemplate:
return []
return ['project', 'inventorysource', 'systemjobtemplate']
@classmethod @classmethod
def accessible_pk_qs(cls, accessor, role_field): def accessible_pk_qs(cls, accessor, role_field):
''' '''
@@ -175,6 +181,9 @@ class UnifiedJobTemplate(PolymorphicModel, CommonModelNameNotUnique, Notificatio
Does not return inventory sources or system JTs, these should Does not return inventory sources or system JTs, these should
be handled inside of get_queryset where it is utilized. be handled inside of get_queryset where it is utilized.
''' '''
# do not use this if in a subclass
if cls != UnifiedJobTemplate:
return super(UnifiedJobTemplate, cls).accessible_pk_qs(accessor, role_field)
ujt_names = [c.__name__.lower() for c in cls.__subclasses__() ujt_names = [c.__name__.lower() for c in cls.__subclasses__()
if c.__name__.lower() not in ['inventorysource', 'systemjobtemplate']] if c.__name__.lower() not in ['inventorysource', 'systemjobtemplate']]
subclass_content_types = list(ContentType.objects.filter( subclass_content_types = list(ContentType.objects.filter(

View File

@@ -3,8 +3,7 @@ import pytest
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.test.client import RequestFactory from django.test.client import RequestFactory
from awx.main.models.jobs import JobTemplate from awx.main.models import Role, Group, UnifiedJobTemplate, JobTemplate
from awx.main.models import Role, Group
from awx.main.access import ( from awx.main.access import (
access_registry, access_registry,
get_user_capabilities get_user_capabilities
@@ -283,6 +282,25 @@ def test_prefetch_jt_capabilities(job_template, rando):
assert qs[0].capabilities_cache == {'edit': False, 'start': True} assert qs[0].capabilities_cache == {'edit': False, 'start': True}
@pytest.mark.django_db
def test_prefetch_ujt_job_template_capabilities(alice, bob, job_template):
job_template.execute_role.members.add(alice)
qs = UnifiedJobTemplate.objects.all()
cache_list_capabilities(qs, ['admin', 'execute'], UnifiedJobTemplate, alice)
assert qs[0].capabilities_cache == {'edit': False, 'start': True}
qs = UnifiedJobTemplate.objects.all()
cache_list_capabilities(qs, ['admin', 'execute'], UnifiedJobTemplate, bob)
assert qs[0].capabilities_cache == {'edit': False, 'start': False}
@pytest.mark.django_db
def test_prefetch_ujt_project_capabilities(alice, project):
project.update_role.members.add(alice)
qs = UnifiedJobTemplate.objects.all()
cache_list_capabilities(qs, ['admin', 'execute'], UnifiedJobTemplate, alice)
assert qs[0].capabilities_cache == {}
@pytest.mark.django_db @pytest.mark.django_db
def test_prefetch_group_capabilities(group, rando): def test_prefetch_group_capabilities(group, rando):
group.inventory.adhoc_role.members.add(rando) group.inventory.adhoc_role.members.add(rando)

View File

@@ -110,7 +110,7 @@ class TestJobTemplateSerializerGetSummaryFields():
view.request = request view.request = request
serializer.context['view'] = view serializer.context['view'] = view
with mocker.patch("awx.main.models.rbac.Role.get_description", return_value='Can eat pie'): with mocker.patch("awx.api.serializers.role_summary_fields_generator", return_value='Can eat pie'):
with mocker.patch("awx.main.access.JobTemplateAccess.can_change", return_value='foobar'): with mocker.patch("awx.main.access.JobTemplateAccess.can_change", return_value='foobar'):
with mocker.patch("awx.main.access.JobTemplateAccess.can_add", return_value='foo'): with mocker.patch("awx.main.access.JobTemplateAccess.can_add", return_value='foo'):
response = serializer.get_summary_fields(jt_obj) response = serializer.get_summary_fields(jt_obj)

View File

@@ -46,13 +46,13 @@ class TestLabelFilterMocked:
def test_is_candidate_for_detach(self, mocker, jt_count, j_count, expected): def test_is_candidate_for_detach(self, mocker, jt_count, j_count, expected):
mock_job_qs = mocker.MagicMock() mock_job_qs = mocker.MagicMock()
mock_job_qs.count = mocker.MagicMock(return_value=j_count) mock_job_qs.count = mocker.MagicMock(return_value=j_count)
UnifiedJob.objects = mocker.MagicMock() mocker.patch.object(UnifiedJob, 'objects', mocker.MagicMock(
UnifiedJob.objects.filter = mocker.MagicMock(return_value=mock_job_qs) filter=mocker.MagicMock(return_value=mock_job_qs)))
mock_jt_qs = mocker.MagicMock() mock_jt_qs = mocker.MagicMock()
mock_jt_qs.count = mocker.MagicMock(return_value=jt_count) mock_jt_qs.count = mocker.MagicMock(return_value=jt_count)
UnifiedJobTemplate.objects = mocker.MagicMock() mocker.patch.object(UnifiedJobTemplate, 'objects', mocker.MagicMock(
UnifiedJobTemplate.objects.filter = mocker.MagicMock(return_value=mock_jt_qs) filter=mocker.MagicMock(return_value=mock_jt_qs)))
label = Label(id=37) label = Label(id=37)
ret = label.is_candidate_for_detach() ret = label.is_candidate_for_detach()

View File

@@ -519,6 +519,10 @@ def cache_list_capabilities(page, prefetch_list, model, user):
for obj in page: for obj in page:
obj.capabilities_cache = {} obj.capabilities_cache = {}
skip_models = []
if hasattr(model, 'invalid_user_capabilities_prefetch_models'):
skip_models = model.invalid_user_capabilities_prefetch_models()
for prefetch_entry in prefetch_list: for prefetch_entry in prefetch_list:
display_method = None display_method = None
@@ -532,19 +536,20 @@ def cache_list_capabilities(page, prefetch_list, model, user):
paths = [paths] paths = [paths]
# Build the query for accessible_objects according the user & role(s) # Build the query for accessible_objects according the user & role(s)
qs_obj = None filter_args = []
for role_path in paths: for role_path in paths:
if '.' in role_path: if '.' in role_path:
res_path = '__'.join(role_path.split('.')[:-1]) res_path = '__'.join(role_path.split('.')[:-1])
role_type = role_path.split('.')[-1] role_type = role_path.split('.')[-1]
if qs_obj is None: parent_model = model
qs_obj = model.objects for subpath in role_path.split('.')[:-1]:
parent_model = model._meta.get_field(res_path).related_model parent_model = parent_model._meta.get_field(subpath).related_model
kwargs = {'%s__in' % res_path: parent_model.accessible_objects(user, '%s_role' % role_type)} filter_args.append(Q(
qs_obj = qs_obj.filter(Q(**kwargs) | Q(**{'%s__isnull' % res_path: True})) Q(**{'%s__pk__in' % res_path: parent_model.accessible_pk_qs(user, '%s_role' % role_type)}) |
Q(**{'%s__isnull' % res_path: True})))
else: else:
role_type = role_path role_type = role_path
qs_obj = model.accessible_objects(user, '%s_role' % role_type) filter_args.append(Q(**{'pk__in': model.accessible_pk_qs(user, '%s_role' % role_type)}))
if display_method is None: if display_method is None:
# Role name translation to UI names for methods # Role name translation to UI names for methods
@@ -555,10 +560,13 @@ def cache_list_capabilities(page, prefetch_list, model, user):
display_method = 'start' display_method = 'start'
# Union that query with the list of items on page # Union that query with the list of items on page
ids_with_role = set(qs_obj.filter(pk__in=page_ids).values_list('pk', flat=True)) filter_args.append(Q(pk__in=page_ids))
ids_with_role = set(model.objects.filter(*filter_args).values_list('pk', flat=True))
# Save data item-by-item # Save data item-by-item
for obj in page: for obj in page:
if skip_models and obj.__class__.__name__.lower() in skip_models:
continue
obj.capabilities_cache[display_method] = False obj.capabilities_cache[display_method] = False
if obj.pk in ids_with_role: if obj.pk in ids_with_role:
obj.capabilities_cache[display_method] = True obj.capabilities_cache[display_method] = True