mirror of
https://github.com/ansible/awx.git
synced 2026-02-28 00:08:44 -03:30
2-level memoize
* Allows for invalidating an entire function from the memoizer
This commit is contained in:
@@ -2,7 +2,6 @@
|
|||||||
# All Rights Reserved.
|
# All Rights Reserved.
|
||||||
|
|
||||||
# Django
|
# Django
|
||||||
from django.core.cache import cache
|
|
||||||
from django.core.signals import setting_changed
|
from django.core.signals import setting_changed
|
||||||
from django.dispatch import receiver
|
from django.dispatch import receiver
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
@@ -12,7 +11,7 @@ from rest_framework.exceptions import APIException
|
|||||||
|
|
||||||
# Tower
|
# Tower
|
||||||
from awx.main.utils.common import get_licenser
|
from awx.main.utils.common import get_licenser
|
||||||
from awx.main.utils import memoize
|
from awx.main.utils import memoize, memoize_delete
|
||||||
|
|
||||||
__all__ = ['LicenseForbids', 'get_license', 'get_licensed_features',
|
__all__ = ['LicenseForbids', 'get_license', 'get_licensed_features',
|
||||||
'feature_enabled', 'feature_exists']
|
'feature_enabled', 'feature_exists']
|
||||||
@@ -23,7 +22,6 @@ class LicenseForbids(APIException):
|
|||||||
default_detail = _('Your Tower license does not allow that.')
|
default_detail = _('Your Tower license does not allow that.')
|
||||||
|
|
||||||
|
|
||||||
@memoize(cache_key='_validated_license_data')
|
|
||||||
def _get_validated_license_data():
|
def _get_validated_license_data():
|
||||||
return get_licenser().validate()
|
return get_licenser().validate()
|
||||||
|
|
||||||
@@ -32,7 +30,7 @@ def _get_validated_license_data():
|
|||||||
def _on_setting_changed(sender, **kwargs):
|
def _on_setting_changed(sender, **kwargs):
|
||||||
# Clear cached result above when license changes.
|
# Clear cached result above when license changes.
|
||||||
if kwargs.get('setting', None) == 'LICENSE':
|
if kwargs.get('setting', None) == 'LICENSE':
|
||||||
cache.delete('_validated_license_data')
|
memoize_delete('feature_enabled')
|
||||||
|
|
||||||
|
|
||||||
def get_license(show_key=False):
|
def get_license(show_key=False):
|
||||||
@@ -52,6 +50,7 @@ def get_licensed_features():
|
|||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@memoize(track_function=True)
|
||||||
def feature_enabled(name):
|
def feature_enabled(name):
|
||||||
"""Return True if the requested feature is enabled, False otherwise."""
|
"""Return True if the requested feature is enabled, False otherwise."""
|
||||||
validated_license_data = _get_validated_license_data()
|
validated_license_data = _get_validated_license_data()
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import os
|
|||||||
import pytest
|
import pytest
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from django.core.cache import cache
|
||||||
|
|
||||||
from awx.main.utils import common
|
from awx.main.utils import common
|
||||||
|
|
||||||
from awx.main.models import (
|
from awx.main.models import (
|
||||||
@@ -18,6 +20,14 @@ from awx.main.models import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
'''
|
||||||
|
Clear cache (local memory) for each test to prevent using cached settings.
|
||||||
|
'''
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('input_, output', [
|
@pytest.mark.parametrize('input_, output', [
|
||||||
({"foo": "bar"}, {"foo": "bar"}),
|
({"foo": "bar"}, {"foo": "bar"}),
|
||||||
('{"foo": "bar"}', {"foo": "bar"}),
|
('{"foo": "bar"}', {"foo": "bar"}),
|
||||||
@@ -49,3 +59,59 @@ def test_set_environ():
|
|||||||
])
|
])
|
||||||
def test_get_type_for_model(model, name):
|
def test_get_type_for_model(model, name):
|
||||||
assert common.get_type_for_model(model) == name
|
assert common.get_type_for_model(model) == name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def memoized_function(mocker):
|
||||||
|
@common.memoize(track_function=True)
|
||||||
|
def myfunction(key, value):
|
||||||
|
if key not in myfunction.calls:
|
||||||
|
myfunction.calls[key] = 0
|
||||||
|
|
||||||
|
myfunction.calls[key] += 1
|
||||||
|
|
||||||
|
if myfunction.calls[key] == 1:
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
return '%s called %s times' % (value, myfunction.calls[key])
|
||||||
|
myfunction.calls = dict()
|
||||||
|
return myfunction
|
||||||
|
|
||||||
|
|
||||||
|
def test_memoize_track_function(memoized_function):
|
||||||
|
assert memoized_function('scott', 'scotterson') == 'scotterson'
|
||||||
|
assert cache.get('myfunction') == {u'scott-scotterson': 'scotterson'}
|
||||||
|
assert memoized_function('scott', 'scotterson') == 'scotterson'
|
||||||
|
|
||||||
|
assert memoized_function.calls['scott'] == 1
|
||||||
|
|
||||||
|
assert memoized_function('john', 'smith') == 'smith'
|
||||||
|
assert cache.get('myfunction') == {u'scott-scotterson': 'scotterson', u'john-smith': 'smith'}
|
||||||
|
assert memoized_function('john', 'smith') == 'smith'
|
||||||
|
|
||||||
|
assert memoized_function.calls['john'] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_memoize_delete(memoized_function):
|
||||||
|
assert memoized_function('john', 'smith') == 'smith'
|
||||||
|
assert memoized_function('john', 'smith') == 'smith'
|
||||||
|
assert memoized_function.calls['john'] == 1
|
||||||
|
|
||||||
|
assert cache.get('myfunction') == {u'john-smith': 'smith'}
|
||||||
|
|
||||||
|
common.memoize_delete('myfunction')
|
||||||
|
|
||||||
|
assert cache.get('myfunction') is None
|
||||||
|
|
||||||
|
assert memoized_function('john', 'smith') == 'smith called 2 times'
|
||||||
|
assert memoized_function.calls['john'] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_memoize_parameter_error():
|
||||||
|
@common.memoize(cache_key='foo', track_function=True)
|
||||||
|
def fn():
|
||||||
|
return
|
||||||
|
|
||||||
|
with pytest.raises(common.IllegalArgumentError):
|
||||||
|
fn()
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from django.apps import apps
|
|||||||
|
|
||||||
logger = logging.getLogger('awx.main.utils')
|
logger = logging.getLogger('awx.main.utils')
|
||||||
|
|
||||||
__all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore', 'memoize',
|
__all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore', 'memoize', 'memoize_delete',
|
||||||
'get_ansible_version', 'get_ssh_version', 'get_licenser', 'get_awx_version', 'update_scm_url',
|
'get_ansible_version', 'get_ssh_version', 'get_licenser', 'get_awx_version', 'update_scm_url',
|
||||||
'get_type_for_model', 'get_model_for_type', 'copy_model_by_class',
|
'get_type_for_model', 'get_model_for_type', 'copy_model_by_class',
|
||||||
'copy_m2m_relationships' ,'cache_list_capabilities', 'to_python_boolean',
|
'copy_m2m_relationships' ,'cache_list_capabilities', 'to_python_boolean',
|
||||||
@@ -44,7 +44,7 @@ __all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore',
|
|||||||
'callback_filter_out_ansible_extra_vars', 'get_search_fields', 'get_system_task_capacity',
|
'callback_filter_out_ansible_extra_vars', 'get_search_fields', 'get_system_task_capacity',
|
||||||
'wrap_args_with_proot', 'build_proot_temp_dir', 'check_proot_installed', 'model_to_dict',
|
'wrap_args_with_proot', 'build_proot_temp_dir', 'check_proot_installed', 'model_to_dict',
|
||||||
'model_instance_diff', 'timestamp_apiformat', 'parse_yaml_or_json', 'RequireDebugTrueOrTest',
|
'model_instance_diff', 'timestamp_apiformat', 'parse_yaml_or_json', 'RequireDebugTrueOrTest',
|
||||||
'has_model_field_prefetched', 'set_environ']
|
'has_model_field_prefetched', 'set_environ', 'IllegalArgumentError',]
|
||||||
|
|
||||||
|
|
||||||
def get_object_or_400(klass, *args, **kwargs):
|
def get_object_or_400(klass, *args, **kwargs):
|
||||||
@@ -107,22 +107,48 @@ class RequireDebugTrueOrTest(logging.Filter):
|
|||||||
return settings.DEBUG or 'test' in sys.argv
|
return settings.DEBUG or 'test' in sys.argv
|
||||||
|
|
||||||
|
|
||||||
def memoize(ttl=60, cache_key=None):
|
class IllegalArgumentError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def memoize(ttl=60, cache_key=None, track_function=False):
|
||||||
'''
|
'''
|
||||||
Decorator to wrap a function and cache its result.
|
Decorator to wrap a function and cache its result.
|
||||||
'''
|
'''
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
|
|
||||||
|
|
||||||
def _memoizer(f, *args, **kwargs):
|
def _memoizer(f, *args, **kwargs):
|
||||||
key = cache_key or slugify('%s %r %r' % (f.__name__, args, kwargs))
|
if cache_key and track_function:
|
||||||
value = cache.get(key)
|
raise IllegalArgumentError("Can not specify cache_key when track_function is True")
|
||||||
if value is None:
|
|
||||||
value = f(*args, **kwargs)
|
if track_function:
|
||||||
cache.set(key, value, ttl)
|
cache_dict_key = slugify('%r %r' % (args, kwargs))
|
||||||
|
key = slugify("%s" % f.__name__)
|
||||||
|
cache_dict = cache.get(key) or dict()
|
||||||
|
if cache_dict_key not in cache_dict:
|
||||||
|
value = f(*args, **kwargs)
|
||||||
|
cache_dict[cache_dict_key] = value
|
||||||
|
cache.set(key, cache_dict, ttl)
|
||||||
|
else:
|
||||||
|
value = cache_dict[cache_dict_key]
|
||||||
|
else:
|
||||||
|
key = cache_key or slugify('%s %r %r' % (f.__name__, args, kwargs))
|
||||||
|
value = cache.get(key)
|
||||||
|
if value is None:
|
||||||
|
value = f(*args, **kwargs)
|
||||||
|
cache.set(key, value, ttl)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
return decorator(_memoizer)
|
return decorator(_memoizer)
|
||||||
|
|
||||||
|
|
||||||
|
def memoize_delete(function_name):
|
||||||
|
from django.core.cache import cache
|
||||||
|
|
||||||
|
return cache.delete(function_name)
|
||||||
|
|
||||||
|
|
||||||
@memoize()
|
@memoize()
|
||||||
def get_ansible_version():
|
def get_ansible_version():
|
||||||
'''
|
'''
|
||||||
|
|||||||
Reference in New Issue
Block a user