Extend SmartFilter to expand search and related search fields

This commit is contained in:
Wayne Witzel III 2017-05-19 16:01:11 -04:00
parent 3cf698d4ae
commit bac1e8b4fe
4 changed files with 67 additions and 15 deletions

View File

@ -293,12 +293,7 @@ class ListAPIView(generics.ListAPIView, GenericAPIView):
@property
def search_fields(self):
fields = []
for field in self.model._meta.fields:
if field.name in ('username', 'first_name', 'last_name', 'email',
'name', 'description'):
fields.append(field.name)
return fields
return get_search_fields(self.model)
@property
def related_search_fields(self):

View File

@ -2,6 +2,7 @@
# Python
import pytest
import mock
from collections import namedtuple
# AWX
from awx.main.utils.filters import SmartFilter
@ -10,6 +11,10 @@ from awx.main.utils.filters import SmartFilter
from django.db.models import Q
Field = namedtuple('Field', 'name')
Meta = namedtuple('Meta', 'fields')
class mockObjects:
def filter(self, *args, **kwargs):
return Q(*args, **kwargs)
@ -19,9 +24,10 @@ class mockHost:
def __init__(self):
print("Host mock created")
self.objects = mockObjects()
self._meta = Meta(fields=(Field(name='name'), Field(name='description')))
@mock.patch('awx.main.utils.filters.get_host_model', return_value=mockHost())
@mock.patch('awx.main.utils.filters.get_model', return_value=mockHost())
class TestSmartFilterQueryFromString():
@pytest.mark.parametrize("filter_string,q_expected", [
('facts__facts__blank=""', Q(**{u"facts__facts__blank": u""})),
@ -109,6 +115,20 @@ class TestSmartFilterQueryFromString():
assert unicode(q) == unicode(q_expected)
@pytest.mark.parametrize("filter_string,q_expected", [
('search=foo', Q(**{u"name": u"foo"}) | Q(**{ u"description": u"foo"})),
('group__search=foo', Q(**{u"group__name": u"foo"}) | Q(**{u"group__description": u"foo"})),
('search=foo and group__search=foo', Q(
Q(**{u"name": u"foo"}) | Q(**{ u"description": u"foo"}),
Q(**{u"group__name": u"foo"}) | Q(**{u"group__description": u"foo"}))),
('search=foo or ansible_facts__a=null',
(Q(**{u"name": u"foo"}) | Q(**{u"description": u"foo"})) |
Q(**{u"ansible_facts__contains": {u"a": u"null"}})),
])
def test_search_related_fields(self, mock_get_host_model, filter_string, q_expected):
q = SmartFilter.query_from_string(filter_string)
assert unicode(q) == unicode(q_expected)
'''
#('"facts__quoted_val"="f\"oo"', 1),
#('facts__facts__arr[]="foo"', 1),

View File

@ -45,7 +45,7 @@ __all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore',
'ignore_inventory_computed_fields', 'ignore_inventory_group_removal',
'_inventory_updates', 'get_pk_from_dict', 'getattrd', 'NoDefaultProvided',
'get_current_apps', 'set_current_apps', 'OutputEventFilter',
'callback_filter_out_ansible_extra_vars',]
'callback_filter_out_ansible_extra_vars', 'get_search_fields',]
def get_object_or_400(klass, *args, **kwargs):
@ -862,3 +862,12 @@ def callback_filter_out_ansible_extra_vars(extra_vars):
if not key.startswith('ansible_'):
extra_vars_redacted[key] = value
return extra_vars_redacted
def get_search_fields(model):
fields = []
for field in model._meta.fields:
if field.name in ('username', 'first_name', 'last_name', 'email',
'name', 'description'):
fields.append(field.name)
return fields

View File

@ -10,6 +10,8 @@ from pyparsing import (
import django
from awx.main.utils.common import get_search_fields
__all__ = ['SmartFilter']
unicode_spaces = [unichr(c) for c in xrange(sys.maxunicode) if unichr(c).isspace()]
@ -31,8 +33,8 @@ def string_to_type(t):
return t
def get_host_model():
return django.apps.apps.get_model('main', 'host')
def get_model(name):
return django.apps.apps.get_model('main', name)
class SmartFilter(object):
@ -43,11 +45,16 @@ class SmartFilter(object):
kwargs = dict()
k, v = self._extract_key_value(t)
k, v = self._json_path_to_contains(k, v)
kwargs[k] = v
# Avoid import circular dependency
Host = get_host_model()
self.result = Host.objects.filter(**kwargs)
Host = get_model('host')
search_kwargs = self._expand_search(k, v)
if search_kwargs:
kwargs.update(search_kwargs)
q = reduce(lambda x, y: x | y, [django.db.models.Q(**{u'%s' % _k:_v}) for _k, _v in kwargs.items()])
self.result = q
else:
kwargs[k] = v
self.result = Host.objects.filter(**kwargs)
def strip_quotes_traditional_logic(self, v):
if type(v) is unicode and v.startswith('"') and v.endswith('"'):
@ -145,6 +152,28 @@ class SmartFilter(object):
return (k, v)
def _expand_search(self, k, v):
if 'search' not in k:
return None
model, relation = None, None
if k == 'search':
model = get_model('host')
elif k.endswith('__search'):
relation = k.split('__')[0]
model = get_model(relation)
search_kwargs = {}
if model is not None:
search_fields = get_search_fields(model)
for field in search_fields:
if relation is not None:
k = '{0}__{1}'.format(relation, field)
else:
k = field
search_kwargs[k] = v
return search_kwargs
class BoolBinOp(object):
def __init__(self, t):
@ -206,7 +235,6 @@ class SmartFilter(object):
try:
res = boolExpr.parseString('(' + filter_string + ')')
#except ParseException as e:
except Exception:
raise RuntimeError(u"Invalid query %s" % filter_string_raw)