diff --git a/awx/api/filters.py b/awx/api/filters.py index ea9d011562..213040ff73 100644 --- a/awx/api/filters.py +++ b/awx/api/filters.py @@ -9,7 +9,7 @@ from functools import reduce # Django from django.core.exceptions import FieldError, ValidationError from django.db import models -from django.db.models import Q +from django.db.models import Q, CharField, IntegerField, BooleanField from django.db.models.fields import FieldDoesNotExist from django.db.models.fields.related import ForeignObjectRel, ManyToManyField, ForeignKey from django.contrib.contenttypes.models import ContentType @@ -63,19 +63,19 @@ class TypeFilterBackend(BaseFilterBackend): raise ParseError(*e.args) -def get_field_from_path(model, path): +def get_fields_from_path(model, path): ''' Given a Django ORM lookup path (possibly over multiple models) - Returns the last field in the line, and also the revised lookup path + Returns the fields in the line, and also the revised lookup path ex., given model=Organization path='project__timeout' - returns tuple of field at the end of the line as well as a corrected - path, for special cases we do substitutions - (, 'project__timeout') + returns tuple of fields traversed as well and a corrected path, + for special cases we do substitutions + ([], 'project__timeout') ''' # Store of all the fields used to detect repeats - field_set = set([]) + field_list = [] new_parts = [] for name in path.split('__'): if model is None: @@ -111,13 +111,24 @@ def get_field_from_path(model, path): raise PermissionDenied(_('Filtering on %s is not allowed.' % name)) elif getattr(field, '__prevent_search__', False): raise PermissionDenied(_('Filtering on %s is not allowed.' % name)) - if field in field_set: + if field in field_list: # Field traversed twice, could create infinite JOINs, DoSing Tower raise ParseError(_('Loops not allowed in filters, detected on field {}.').format(field.name)) - field_set.add(field) + field_list.append(field) model = getattr(field, 'related_model', None) - return field, '__'.join(new_parts) + return field_list, '__'.join(new_parts) + + +def get_field_from_path(model, path): + ''' + Given a Django ORM lookup path (possibly over multiple models) + Returns the last field in the line, and the revised lookup path + ex. + (, 'project__timeout') + ''' + field_list, new_path = get_fields_from_path(model, path) + return (field_list[-1], new_path) class FieldLookupBackend(BaseFilterBackend): @@ -133,7 +144,11 @@ class FieldLookupBackend(BaseFilterBackend): 'regex', 'iregex', 'gt', 'gte', 'lt', 'lte', 'in', 'isnull', 'search') - def get_field_from_lookup(self, model, lookup): + # A list of fields that we know can be filtered on without the possiblity + # of introducing duplicates + NO_DUPLICATES_WHITELIST = (CharField, IntegerField, BooleanField) + + def get_fields_from_lookup(self, model, lookup): if '__' in lookup and lookup.rsplit('__', 1)[-1] in self.SUPPORTED_LOOKUPS: path, suffix = lookup.rsplit('__', 1) @@ -147,11 +162,16 @@ class FieldLookupBackend(BaseFilterBackend): # FIXME: Could build up a list of models used across relationships, use # those lookups combined with request.user.get_queryset(Model) to make # sure user cannot query using objects he could not view. - field, new_path = get_field_from_path(model, path) + field_list, new_path = get_fields_from_path(model, path) new_lookup = new_path new_lookup = '__'.join([new_path, suffix]) - return field, new_lookup + return field_list, new_lookup + + def get_field_from_lookup(self, model, lookup): + '''Method to match return type of single field, if needed.''' + field_list, new_lookup = self.get_fields_from_lookup(model, lookup) + return (field_list[-1], new_lookup) def to_python_related(self, value): value = force_text(value) @@ -182,7 +202,10 @@ class FieldLookupBackend(BaseFilterBackend): except UnicodeEncodeError: raise ValueError("%r is not an allowed field name. Must be ascii encodable." % lookup) - field, new_lookup = self.get_field_from_lookup(model, lookup) + field_list, new_lookup = self.get_fields_from_lookup(model, lookup) + field = field_list[-1] + + needs_distinct = (not all(isinstance(f, self.NO_DUPLICATES_WHITELIST) for f in field_list)) # Type names are stored without underscores internally, but are presented and # and serialized over the API containing underscores so we remove `_` @@ -211,10 +234,10 @@ class FieldLookupBackend(BaseFilterBackend): for rm_field in related_model._meta.fields: if rm_field.name in ('username', 'first_name', 'last_name', 'email', 'name', 'description', 'playbook'): new_lookups.append('{}__{}__icontains'.format(new_lookup[:-8], rm_field.name)) - return value, new_lookups + return value, new_lookups, needs_distinct else: value = self.value_to_python_for_field(field, value) - return value, new_lookup + return value, new_lookup, needs_distinct def filter_queryset(self, request, queryset, view): try: @@ -225,6 +248,7 @@ class FieldLookupBackend(BaseFilterBackend): chain_filters = [] role_filters = [] search_filters = {} + needs_distinct = False # Can only have two values: 'AND', 'OR' # If 'AND' is used, an iterm must satisfy all condition to show up in the results. # If 'OR' is used, an item just need to satisfy one condition to appear in results. @@ -256,7 +280,7 @@ class FieldLookupBackend(BaseFilterBackend): search_filter_relation = 'AND' values = reduce(lambda list1, list2: list1 + list2, [i.split(',') for i in values]) for value in values: - search_value, new_keys = self.value_to_python(queryset.model, key, force_text(value)) + search_value, new_keys, _ = self.value_to_python(queryset.model, key, force_text(value)) assert isinstance(new_keys, list) search_filters[search_value] = new_keys continue @@ -282,7 +306,9 @@ class FieldLookupBackend(BaseFilterBackend): for value in values: if q_int: value = int(value) - value, new_key = self.value_to_python(queryset.model, key, value) + value, new_key, distinct = self.value_to_python(queryset.model, key, value) + if distinct: + needs_distinct = True if q_chain: chain_filters.append((q_not, new_key, value)) elif q_or: @@ -332,7 +358,9 @@ class FieldLookupBackend(BaseFilterBackend): else: q = Q(**{k:v}) queryset = queryset.filter(q) - queryset = queryset.filter(*args).distinct() + queryset = queryset.filter(*args) + if needs_distinct: + queryset = queryset.distinct() return queryset except (FieldError, FieldDoesNotExist, ValueError, TypeError) as e: raise ParseError(e.args[0]) diff --git a/awx/main/tests/unit/api/test_filters.py b/awx/main/tests/unit/api/test_filters.py index 913413a35f..4a951890e7 100644 --- a/awx/main/tests/unit/api/test_filters.py +++ b/awx/main/tests/unit/api/test_filters.py @@ -57,7 +57,7 @@ def test_empty_in(empty_value): @pytest.mark.parametrize(u"valid_value", [u'foo', u'foo,']) def test_valid_in(valid_value): field_lookup = FieldLookupBackend() - value, new_lookup = field_lookup.value_to_python(JobTemplate, 'project__name__in', valid_value) + value, new_lookup, _ = field_lookup.value_to_python(JobTemplate, 'project__name__in', valid_value) assert 'foo' in value diff --git a/awx/main/tests/unit/utils/test_filters.py b/awx/main/tests/unit/utils/test_filters.py index 54ae9c9691..76effe8284 100644 --- a/awx/main/tests/unit/utils/test_filters.py +++ b/awx/main/tests/unit/utils/test_filters.py @@ -79,8 +79,8 @@ class mockHost: @mock.patch('awx.main.utils.filters.get_model', return_value=mockHost()) class TestSmartFilterQueryFromString(): @mock.patch( - 'awx.api.filters.get_field_from_path', - lambda model, path: (model, path) # disable field filtering, because a__b isn't a real Host field + 'awx.api.filters.get_fields_from_path', + lambda model, path: ([model], path) # disable field filtering, because a__b isn't a real Host field ) @pytest.mark.parametrize("filter_string,q_expected", [ ('facts__facts__blank=""', Q(**{u"facts__facts__blank": u""})),