From 26d393e5c2f33d3baa4f8f7b2e24bb8b1fee8dfb Mon Sep 17 00:00:00 2001 From: Chris Meyers Date: Thu, 21 Sep 2017 15:34:51 -0400 Subject: [PATCH] 2-level memoize * Allows for invalidating an entire function from the memoizer --- awx/conf/license.py | 7 ++- awx/main/tests/unit/utils/test_common.py | 66 ++++++++++++++++++++++++ awx/main/utils/common.py | 42 ++++++++++++--- 3 files changed, 103 insertions(+), 12 deletions(-) diff --git a/awx/conf/license.py b/awx/conf/license.py index 820f4f2d58..a2e3588470 100644 --- a/awx/conf/license.py +++ b/awx/conf/license.py @@ -2,7 +2,6 @@ # All Rights Reserved. # Django -from django.core.cache import cache from django.core.signals import setting_changed from django.dispatch import receiver from django.utils.translation import ugettext_lazy as _ @@ -12,7 +11,7 @@ from rest_framework.exceptions import APIException # Tower from awx.main.utils.common import get_licenser -from awx.main.utils import memoize +from awx.main.utils import memoize, memoize_delete __all__ = ['LicenseForbids', 'get_license', 'get_licensed_features', 'feature_enabled', 'feature_exists'] @@ -23,7 +22,6 @@ class LicenseForbids(APIException): default_detail = _('Your Tower license does not allow that.') -@memoize(cache_key='_validated_license_data') def _get_validated_license_data(): return get_licenser().validate() @@ -32,7 +30,7 @@ def _get_validated_license_data(): def _on_setting_changed(sender, **kwargs): # Clear cached result above when license changes. if kwargs.get('setting', None) == 'LICENSE': - cache.delete('_validated_license_data') + memoize_delete('feature_enabled') def get_license(show_key=False): @@ -52,6 +50,7 @@ def get_licensed_features(): return features +@memoize(track_function=True) def feature_enabled(name): """Return True if the requested feature is enabled, False otherwise.""" validated_license_data = _get_validated_license_data() diff --git a/awx/main/tests/unit/utils/test_common.py b/awx/main/tests/unit/utils/test_common.py index 44e0461a9a..41f9012040 100644 --- a/awx/main/tests/unit/utils/test_common.py +++ b/awx/main/tests/unit/utils/test_common.py @@ -6,6 +6,8 @@ import os import pytest from uuid import uuid4 +from django.core.cache import cache + from awx.main.utils import common from awx.main.models import ( @@ -18,6 +20,14 @@ from awx.main.models import ( ) +@pytest.fixture(autouse=True) +def clear_cache(): + ''' + Clear cache (local memory) for each test to prevent using cached settings. + ''' + cache.clear() + + @pytest.mark.parametrize('input_, output', [ ({"foo": "bar"}, {"foo": "bar"}), ('{"foo": "bar"}', {"foo": "bar"}), @@ -49,3 +59,59 @@ def test_set_environ(): ]) def test_get_type_for_model(model, name): assert common.get_type_for_model(model) == name + + +@pytest.fixture +def memoized_function(mocker): + @common.memoize(track_function=True) + def myfunction(key, value): + if key not in myfunction.calls: + myfunction.calls[key] = 0 + + myfunction.calls[key] += 1 + + if myfunction.calls[key] == 1: + return value + else: + return '%s called %s times' % (value, myfunction.calls[key]) + myfunction.calls = dict() + return myfunction + + +def test_memoize_track_function(memoized_function): + assert memoized_function('scott', 'scotterson') == 'scotterson' + assert cache.get('myfunction') == {u'scott-scotterson': 'scotterson'} + assert memoized_function('scott', 'scotterson') == 'scotterson' + + assert memoized_function.calls['scott'] == 1 + + assert memoized_function('john', 'smith') == 'smith' + assert cache.get('myfunction') == {u'scott-scotterson': 'scotterson', u'john-smith': 'smith'} + assert memoized_function('john', 'smith') == 'smith' + + assert memoized_function.calls['john'] == 1 + + +def test_memoize_delete(memoized_function): + assert memoized_function('john', 'smith') == 'smith' + assert memoized_function('john', 'smith') == 'smith' + assert memoized_function.calls['john'] == 1 + + assert cache.get('myfunction') == {u'john-smith': 'smith'} + + common.memoize_delete('myfunction') + + assert cache.get('myfunction') is None + + assert memoized_function('john', 'smith') == 'smith called 2 times' + assert memoized_function.calls['john'] == 2 + + +def test_memoize_parameter_error(): + @common.memoize(cache_key='foo', track_function=True) + def fn(): + return + + with pytest.raises(common.IllegalArgumentError): + fn() + diff --git a/awx/main/utils/common.py b/awx/main/utils/common.py index c8c382472a..dbbc589392 100644 --- a/awx/main/utils/common.py +++ b/awx/main/utils/common.py @@ -34,7 +34,7 @@ from django.apps import apps logger = logging.getLogger('awx.main.utils') -__all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore', 'memoize', +__all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore', 'memoize', 'memoize_delete', 'get_ansible_version', 'get_ssh_version', 'get_licenser', 'get_awx_version', 'update_scm_url', 'get_type_for_model', 'get_model_for_type', 'copy_model_by_class', 'copy_m2m_relationships' ,'cache_list_capabilities', 'to_python_boolean', @@ -44,7 +44,7 @@ __all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore', 'callback_filter_out_ansible_extra_vars', 'get_search_fields', 'get_system_task_capacity', 'wrap_args_with_proot', 'build_proot_temp_dir', 'check_proot_installed', 'model_to_dict', 'model_instance_diff', 'timestamp_apiformat', 'parse_yaml_or_json', 'RequireDebugTrueOrTest', - 'has_model_field_prefetched', 'set_environ'] + 'has_model_field_prefetched', 'set_environ', 'IllegalArgumentError',] def get_object_or_400(klass, *args, **kwargs): @@ -107,22 +107,48 @@ class RequireDebugTrueOrTest(logging.Filter): return settings.DEBUG or 'test' in sys.argv -def memoize(ttl=60, cache_key=None): +class IllegalArgumentError(ValueError): + pass + + +def memoize(ttl=60, cache_key=None, track_function=False): ''' Decorator to wrap a function and cache its result. ''' from django.core.cache import cache + def _memoizer(f, *args, **kwargs): - key = cache_key or slugify('%s %r %r' % (f.__name__, args, kwargs)) - value = cache.get(key) - if value is None: - value = f(*args, **kwargs) - cache.set(key, value, ttl) + if cache_key and track_function: + raise IllegalArgumentError("Can not specify cache_key when track_function is True") + + if track_function: + cache_dict_key = slugify('%r %r' % (args, kwargs)) + key = slugify("%s" % f.__name__) + cache_dict = cache.get(key) or dict() + if cache_dict_key not in cache_dict: + value = f(*args, **kwargs) + cache_dict[cache_dict_key] = value + cache.set(key, cache_dict, ttl) + else: + value = cache_dict[cache_dict_key] + else: + key = cache_key or slugify('%s %r %r' % (f.__name__, args, kwargs)) + value = cache.get(key) + if value is None: + value = f(*args, **kwargs) + cache.set(key, value, ttl) + return value return decorator(_memoizer) +def memoize_delete(function_name): + from django.core.cache import cache + + return cache.delete(function_name) + + @memoize() def get_ansible_version(): '''