adds per-host timeout

This commit is contained in:
Chris Meyers
2017-06-19 10:59:37 -04:00
parent 817dbe8d33
commit 12cdbcf8b5
3 changed files with 76 additions and 37 deletions

View File

@@ -17,6 +17,8 @@ from django.db import models
import memcache import memcache
from django.db.models import Q, Count from django.db.models import Q, Count
from django.utils.dateparse import parse_datetime from django.utils.dateparse import parse_datetime
from dateutil import parser
from dateutil.tz import tzutc
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.timezone import utc from django.utils.timezone import utc
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@@ -41,8 +43,6 @@ from awx.main.fields import JSONField
from awx.main.consumers import emit_channel_notification from awx.main.consumers import emit_channel_notification
TIMEOUT = 60
logger = logging.getLogger('awx.main.models.jobs') logger = logging.getLogger('awx.main.models.jobs')
analytics_logger = logging.getLogger('awx.analytics.job_events') analytics_logger = logging.getLogger('awx.analytics.job_events')
@@ -735,10 +735,14 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin):
for host in self._get_inventory_hosts(): for host in self._get_inventory_hosts():
host_key = self.memcached_fact_host_key(host.name) host_key = self.memcached_fact_host_key(host.name)
modified_key = self.memcached_fact_modified_key(host.name) modified_key = self.memcached_fact_modified_key(host.name)
# Only add host/facts if host doesn't already exist in the cache
if cache.get(modified_key) is None: if cache.get(modified_key) is None:
if host.ansible_facts_modified:
host_modified = host.ansible_facts_modified.replace(tzinfo=tzutc()).isoformat()
else:
host_modified = datetime.datetime.now(tzutc()).isoformat()
cache.set(host_key, json.dumps(host.ansible_facts)) cache.set(host_key, json.dumps(host.ansible_facts))
cache.set(modified_key, False) cache.set(modified_key, host_modified)
host_names.append(host.name) host_names.append(host.name)
@@ -746,7 +750,6 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin):
def finish_job_fact_cache(self): def finish_job_fact_cache(self):
if not self.inventory: if not self.inventory:
# TODO: Uh oh, we need to clean up the cache
return return
cache = self._get_memcache_connection() cache = self._get_memcache_connection()
@@ -758,16 +761,18 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin):
modified = cache.get(modified_key) modified = cache.get(modified_key)
if modified is None: if modified is None:
cache.delete(host_key)
continue continue
# Save facts that have changed # Save facts if cache is newer than DB
if modified: modified = parser.parse(modified, tzinfos=[tzutc()])
if not host.ansible_facts_modified or modified > host.ansible_facts_modified:
ansible_facts = cache.get(host_key) ansible_facts = cache.get(host_key)
if ansible_facts is None: if ansible_facts is None:
cache.delete(host_key) cache.delete(host_key)
# TODO: Log cache inconsistency
continue continue
host.ansible_facts = ansible_facts host.ansible_facts = ansible_facts
host.ansible_facts_modified = modified
host.save() host.save()

View File

@@ -6,6 +6,10 @@ from awx.main.models import (
Host, Host,
) )
import datetime
import json
from dateutil.tz import tzutc
class CacheMock(object): class CacheMock(object):
def __init__(self): def __init__(self):
@@ -24,18 +28,28 @@ class CacheMock(object):
@pytest.fixture @pytest.fixture
def hosts(): def old_time():
return (datetime.datetime.now(tzutc()) - datetime.timedelta(minutes=60))
@pytest.fixture()
def new_time():
return (datetime.datetime.now(tzutc()))
@pytest.fixture
def hosts(old_time):
return [ return [
Host(name='host1', ansible_facts={"a": 1, "b": 2}), Host(name='host1', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=old_time),
Host(name='host2', ansible_facts={"a": 1, "b": 2}), Host(name='host2', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=old_time),
Host(name='host3', ansible_facts={"a": 1, "b": 2}), Host(name='host3', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=old_time),
] ]
@pytest.fixture @pytest.fixture
def hosts2(): def hosts2():
return [ return [
Host(name='host2', ansible_facts="foobar"), Host(name='host2', ansible_facts="foobar", ansible_facts_modified=old_time),
] ]
@@ -75,8 +89,8 @@ def test_start_job_fact_cache(hosts, job, inventory, mocker):
job._get_memcache_connection().set.assert_any_call('{}'.format(5), [h.name for h in hosts]) job._get_memcache_connection().set.assert_any_call('{}'.format(5), [h.name for h in hosts])
for host in hosts: for host in hosts:
job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, host.name), host.ansible_facts) job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, host.name), json.dumps(host.ansible_facts))
job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, host.name), False) job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, host.name), host.ansible_facts_modified.isoformat())
def test_start_job_fact_cache_existing_host(hosts, hosts2, job, job2, inventory, mocker): def test_start_job_fact_cache_existing_host(hosts, hosts2, job, job2, inventory, mocker):
@@ -84,8 +98,8 @@ def test_start_job_fact_cache_existing_host(hosts, hosts2, job, job2, inventory,
job.start_job_fact_cache() job.start_job_fact_cache()
for host in hosts: for host in hosts:
job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, host.name), host.ansible_facts) job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, host.name), json.dumps(host.ansible_facts))
job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, host.name), False) job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, host.name), host.ansible_facts_modified.isoformat())
job._get_memcache_connection().set.reset_mock() job._get_memcache_connection().set.reset_mock()
@@ -93,23 +107,25 @@ def test_start_job_fact_cache_existing_host(hosts, hosts2, job, job2, inventory,
# Ensure hosts2 ansible_facts didn't overwrite hosts ansible_facts # Ensure hosts2 ansible_facts didn't overwrite hosts ansible_facts
ansible_facts_cached = job._get_memcache_connection().get('{}-{}'.format(5, hosts2[0].name)) ansible_facts_cached = job._get_memcache_connection().get('{}-{}'.format(5, hosts2[0].name))
assert ansible_facts_cached == hosts[1].ansible_facts assert ansible_facts_cached == json.dumps(hosts[1].ansible_facts)
def test_finish_job_fact_cache(job, hosts, inventory, mocker): def test_finish_job_fact_cache(job, hosts, inventory, mocker, new_time):
job.start_job_fact_cache() job.start_job_fact_cache()
for h in hosts:
h.save = mocker.Mock()
host = hosts[1] host_key = job.memcached_fact_host_key(hosts[1].name)
host_key = job.memcached_fact_host_key(host.name) modified_key = job.memcached_fact_modified_key(hosts[1].name)
modified_key = job.memcached_fact_modified_key(host.name)
host.save = mocker.Mock()
job._get_memcache_connection().set(host_key, 'blah') job._get_memcache_connection().set(host_key, 'blah')
job._get_memcache_connection().set(modified_key, True) job._get_memcache_connection().set(modified_key, new_time.isoformat())
job.finish_job_fact_cache() job.finish_job_fact_cache()
assert host.ansible_facts == 'blah' hosts[0].save.assert_not_called()
host.save.assert_called_once_with() hosts[2].save.assert_not_called()
assert hosts[1].ansible_facts == 'blah'
hosts[1].save.assert_called_once_with()

View File

@@ -32,6 +32,9 @@
import os import os
import memcache import memcache
import json import json
import datetime
from dateutil import parser
from dateutil.tz import tzutc
from ansible import constants as C from ansible import constants as C
@@ -45,21 +48,34 @@ class CacheModule(BaseCacheModule):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.mc = memcache.Client([C.CACHE_PLUGIN_CONNECTION], debug=0) self.mc = memcache.Client([C.CACHE_PLUGIN_CONNECTION], debug=0)
self.timeout = int(C.CACHE_PLUGIN_TIMEOUT) self._timeout = int(C.CACHE_PLUGIN_TIMEOUT)
self.inventory_id = os.environ['INVENTORY_ID'] self._inventory_id = os.environ['INVENTORY_ID']
@property @property
def host_names_key(self): def host_names_key(self):
return '{}'.format(self.inventory_id) return '{}'.format(self._inventory_id)
def translate_host_key(self, host_name): def translate_host_key(self, host_name):
return '{}-{}'.format(self.inventory_id, host_name) return '{}-{}'.format(self._inventory_id, host_name)
def translate_modified_key(self, host_name): def translate_modified_key(self, host_name):
return '{}-{}-modified'.format(self.inventory_id, host_name) return '{}-{}-modified'.format(self._inventory_id, host_name)
def get(self, key): def get(self, key):
host_key = self.translate_host_key(key) host_key = self.translate_host_key(key)
modified_key = self.translate_modified_key(key)
'''
Cache entry expired
'''
modified = self.mc.get(modified_key)
if modified is None:
raise KeyError
modified = parser.parse(modified).replace(tzinfo=tzutc())
now_utc = datetime.datetime.now(tzutc())
if self._timeout != 0 and (modified + datetime.timedelta(seconds=self._timeout)) < now_utc:
raise KeyError
value_json = self.mc.get(host_key) value_json = self.mc.get(host_key)
if value_json is None: if value_json is None:
raise KeyError raise KeyError
@@ -75,17 +91,17 @@ class CacheModule(BaseCacheModule):
modified_key = self.translate_modified_key(key) modified_key = self.translate_modified_key(key)
self.mc.set(host_key, json.dumps(value)) self.mc.set(host_key, json.dumps(value))
self.mc.set(modified_key, True) self.mc.set(modified_key, datetime.datetime.now(tzutc()).isoformat())
def keys(self): def keys(self):
return self.mc.get(self.host_names_key) return self.mc.get(self.host_names_key)
def contains(self, key): def contains(self, key):
host_key = self.translate_host_key(key) try:
val = self.mc.get(host_key) self.get(key)
if val is None: return True
except KeyError:
return False return False
return True
def delete(self, key): def delete(self, key):
self.mc.delete(self.translate_host_key(key)) self.mc.delete(self.translate_host_key(key))
@@ -106,5 +122,7 @@ class CacheModule(BaseCacheModule):
if not host_names: if not host_names:
return return
return [self.mc.get(self.translate_host_key(k)) for k in host_names] for k in host_names:
ret[k] = self.mc.get(self.translate_host_key(k))
return ret