2-level memoize

* Allows for invalidating an entire function from the memoizer
This commit is contained in:
Chris Meyers 2017-09-21 15:34:51 -04:00
parent 44af8ac629
commit 26d393e5c2
3 changed files with 103 additions and 12 deletions

View File

@ -2,7 +2,6 @@
# All Rights Reserved.
# Django
from django.core.cache import cache
from django.core.signals import setting_changed
from django.dispatch import receiver
from django.utils.translation import ugettext_lazy as _
@ -12,7 +11,7 @@ from rest_framework.exceptions import APIException
# Tower
from awx.main.utils.common import get_licenser
from awx.main.utils import memoize
from awx.main.utils import memoize, memoize_delete
__all__ = ['LicenseForbids', 'get_license', 'get_licensed_features',
'feature_enabled', 'feature_exists']
@ -23,7 +22,6 @@ class LicenseForbids(APIException):
default_detail = _('Your Tower license does not allow that.')
@memoize(cache_key='_validated_license_data')
def _get_validated_license_data():
return get_licenser().validate()
@ -32,7 +30,7 @@ def _get_validated_license_data():
def _on_setting_changed(sender, **kwargs):
# Clear cached result above when license changes.
if kwargs.get('setting', None) == 'LICENSE':
cache.delete('_validated_license_data')
memoize_delete('feature_enabled')
def get_license(show_key=False):
@ -52,6 +50,7 @@ def get_licensed_features():
return features
@memoize(track_function=True)
def feature_enabled(name):
"""Return True if the requested feature is enabled, False otherwise."""
validated_license_data = _get_validated_license_data()

View File

@ -6,6 +6,8 @@ import os
import pytest
from uuid import uuid4
from django.core.cache import cache
from awx.main.utils import common
from awx.main.models import (
@ -18,6 +20,14 @@ from awx.main.models import (
)
@pytest.fixture(autouse=True)
def clear_cache():
'''
Clear cache (local memory) for each test to prevent using cached settings.
'''
cache.clear()
@pytest.mark.parametrize('input_, output', [
({"foo": "bar"}, {"foo": "bar"}),
('{"foo": "bar"}', {"foo": "bar"}),
@ -49,3 +59,59 @@ def test_set_environ():
])
def test_get_type_for_model(model, name):
assert common.get_type_for_model(model) == name
@pytest.fixture
def memoized_function(mocker):
@common.memoize(track_function=True)
def myfunction(key, value):
if key not in myfunction.calls:
myfunction.calls[key] = 0
myfunction.calls[key] += 1
if myfunction.calls[key] == 1:
return value
else:
return '%s called %s times' % (value, myfunction.calls[key])
myfunction.calls = dict()
return myfunction
def test_memoize_track_function(memoized_function):
assert memoized_function('scott', 'scotterson') == 'scotterson'
assert cache.get('myfunction') == {u'scott-scotterson': 'scotterson'}
assert memoized_function('scott', 'scotterson') == 'scotterson'
assert memoized_function.calls['scott'] == 1
assert memoized_function('john', 'smith') == 'smith'
assert cache.get('myfunction') == {u'scott-scotterson': 'scotterson', u'john-smith': 'smith'}
assert memoized_function('john', 'smith') == 'smith'
assert memoized_function.calls['john'] == 1
def test_memoize_delete(memoized_function):
assert memoized_function('john', 'smith') == 'smith'
assert memoized_function('john', 'smith') == 'smith'
assert memoized_function.calls['john'] == 1
assert cache.get('myfunction') == {u'john-smith': 'smith'}
common.memoize_delete('myfunction')
assert cache.get('myfunction') is None
assert memoized_function('john', 'smith') == 'smith called 2 times'
assert memoized_function.calls['john'] == 2
def test_memoize_parameter_error():
@common.memoize(cache_key='foo', track_function=True)
def fn():
return
with pytest.raises(common.IllegalArgumentError):
fn()

View File

@ -34,7 +34,7 @@ from django.apps import apps
logger = logging.getLogger('awx.main.utils')
__all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore', 'memoize',
__all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore', 'memoize', 'memoize_delete',
'get_ansible_version', 'get_ssh_version', 'get_licenser', 'get_awx_version', 'update_scm_url',
'get_type_for_model', 'get_model_for_type', 'copy_model_by_class',
'copy_m2m_relationships' ,'cache_list_capabilities', 'to_python_boolean',
@ -44,7 +44,7 @@ __all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore',
'callback_filter_out_ansible_extra_vars', 'get_search_fields', 'get_system_task_capacity',
'wrap_args_with_proot', 'build_proot_temp_dir', 'check_proot_installed', 'model_to_dict',
'model_instance_diff', 'timestamp_apiformat', 'parse_yaml_or_json', 'RequireDebugTrueOrTest',
'has_model_field_prefetched', 'set_environ']
'has_model_field_prefetched', 'set_environ', 'IllegalArgumentError',]
def get_object_or_400(klass, *args, **kwargs):
@ -107,22 +107,48 @@ class RequireDebugTrueOrTest(logging.Filter):
return settings.DEBUG or 'test' in sys.argv
def memoize(ttl=60, cache_key=None):
class IllegalArgumentError(ValueError):
pass
def memoize(ttl=60, cache_key=None, track_function=False):
'''
Decorator to wrap a function and cache its result.
'''
from django.core.cache import cache
def _memoizer(f, *args, **kwargs):
key = cache_key or slugify('%s %r %r' % (f.__name__, args, kwargs))
value = cache.get(key)
if value is None:
value = f(*args, **kwargs)
cache.set(key, value, ttl)
if cache_key and track_function:
raise IllegalArgumentError("Can not specify cache_key when track_function is True")
if track_function:
cache_dict_key = slugify('%r %r' % (args, kwargs))
key = slugify("%s" % f.__name__)
cache_dict = cache.get(key) or dict()
if cache_dict_key not in cache_dict:
value = f(*args, **kwargs)
cache_dict[cache_dict_key] = value
cache.set(key, cache_dict, ttl)
else:
value = cache_dict[cache_dict_key]
else:
key = cache_key or slugify('%s %r %r' % (f.__name__, args, kwargs))
value = cache.get(key)
if value is None:
value = f(*args, **kwargs)
cache.set(key, value, ttl)
return value
return decorator(_memoizer)
def memoize_delete(function_name):
from django.core.cache import cache
return cache.delete(function_name)
@memoize()
def get_ansible_version():
'''