tower fact cache implementation

* Tower now injects facts into jobs via memcached for use by Ansible
playbooks. On the Ansible side, this is accomplished by the existing
mechanism, an Ansible Fact Cache Plugin + memcached. On the Tower side,
memcached is leveraged heavily.
This commit is contained in:
Chris Meyers 2017-06-13 12:41:35 -04:00
parent 0121f5cde4
commit 626e2d1c9b
4 changed files with 232 additions and 68 deletions

View File

@ -12,6 +12,8 @@ from urlparse import urljoin
# Django
from django.conf import settings
from django.db import models
#from django.core.cache import cache
import memcache
from django.db.models import Q, Count
from django.utils.dateparse import parse_datetime
from django.utils.encoding import force_text
@ -38,6 +40,8 @@ from awx.main.fields import JSONField
from awx.main.consumers import emit_channel_notification
TIMEOUT = 60
logger = logging.getLogger('awx.main.models.jobs')
analytics_logger = logging.getLogger('awx.analytics.job_events')
@ -703,6 +707,74 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin):
self.project_update.cancel(job_explanation=job_explanation)
return res
@property
def store_facts_enabled(self):
if not self.job_template or self.job_template is False:
return False
return True
@property
def memcached_fact_key(self):
return '{}'.format(self.inventory.id)
def memcached_fact_host_key(self, host_name):
return '{}-{}'.format(self.inventory.id, host_name)
def memcached_fact_modified_key(self, host_name):
return '{}-{}-modified'.format(self.inventory.id, host_name)
def _get_inventory_hosts(self, only=['name', 'ansible_facts', 'modified',]):
return self.inventory.hosts.only(*only)
def _get_memcache_connection(self):
return memcache.Client([settings.CACHES['default']['LOCATION']], debug=0)
def start_job_fact_cache(self):
if not self.inventory:
return
cache = self._get_memcache_connection()
host_names = []
for host in self._get_inventory_hosts():
host_key = self.memcached_fact_host_key(host.name)
modified_key = self.memcached_fact_modified_key(host.name)
# Only add host/facts if host doesn't already exist in the cache
if cache.get(modified_key) is None:
cache.set(host_key, host.ansible_facts)
cache.set(modified_key, False)
host_names.append(host.name)
cache.set(self.memcached_fact_key, host_names)
def finish_job_fact_cache(self):
if not self.inventory:
# TODO: Uh oh, we need to clean up the cache
return
cache = self._get_memcache_connection()
hosts = self._get_inventory_hosts()
for host in hosts:
host_key = self.memcached_fact_host_key(host.name)
modified_key = self.memcached_fact_modified_key(host.name)
modified = cache.get(modified_key)
if modified is None:
continue
# Save facts that have changed
if modified:
ansible_facts = cache.get(host_key)
if ansible_facts is None:
cache.delete(host_key)
# TODO: Log cache inconsistency
continue
host.ansible_facts = ansible_facts
host.save()
class JobHostSummary(CreatedModifiedModel):
'''
@ -1357,3 +1429,4 @@ class SystemJob(UnifiedJob, SystemJobOptions, JobNotificationMixin):
def get_notification_friendly_name(self):
return "System Job"

View File

@ -877,6 +877,9 @@ class RunJob(BaseTask):
# callbacks to work.
env['JOB_ID'] = str(job.pk)
env['INVENTORY_ID'] = str(job.inventory.pk)
if job.store_facts_enabled:
env['MEMCACHED_PREPEND_KEY'] = job.memcached_fact_key
env['MEMCACHED_LOCATION'] = settings.CACHES['default']['LOCATION']
if job.project:
env['PROJECT_REVISION'] = job.project.scm_revision
env['ANSIBLE_RETRY_FILES_ENABLED'] = "False"
@ -1140,8 +1143,14 @@ class RunJob(BaseTask):
('project_update', local_project_sync.name, local_project_sync.id)))
raise
if job.store_facts_enabled:
job.start_job_fact_cache()
def final_run_hook(self, job, status, **kwargs):
super(RunJob, self).final_run_hook(job, status, **kwargs)
if job.store_facts_enabled:
job.finish_job_fact_cache()
try:
inventory = job.inventory
except Inventory.DoesNotExist:

View File

@ -0,0 +1,115 @@
import pytest
from awx.main.models import (
Job,
Inventory,
Host,
)
class CacheMock(object):
def __init__(self):
self.d = dict()
def get(self, key):
if key not in self.d:
return None
return self.d[key]
def set(self, key, val):
self.d[key] = val
def delete(self, key):
del self.d[key]
@pytest.fixture
def hosts():
return [
Host(name='host1', ansible_facts={"a": 1, "b": 2}),
Host(name='host2', ansible_facts={"a": 1, "b": 2}),
Host(name='host3', ansible_facts={"a": 1, "b": 2}),
]
@pytest.fixture
def hosts2():
return [
Host(name='host2', ansible_facts="foobar"),
]
@pytest.fixture
def inventory():
return Inventory(id=5)
@pytest.fixture
def mock_cache(mocker):
cache = CacheMock()
mocker.patch.object(cache, 'set', wraps=cache.set)
mocker.patch.object(cache, 'get', wraps=cache.get)
mocker.patch.object(cache, 'delete', wraps=cache.delete)
return cache
@pytest.fixture
def job(mocker, hosts, inventory, mock_cache):
j = Job(inventory=inventory, id=2)
j._get_inventory_hosts = mocker.Mock(return_value=hosts)
j._get_memcache_connection = mocker.Mock(return_value=mock_cache)
return j
@pytest.fixture
def job2(mocker, hosts2, inventory, mock_cache):
j = Job(inventory=inventory, id=3)
j._get_inventory_hosts = mocker.Mock(return_value=hosts2)
j._get_memcache_connection = mocker.Mock(return_value=mock_cache)
return j
def test_start_job_fact_cache(hosts, job, inventory, mocker):
job.start_job_fact_cache()
job._get_memcache_connection().set.assert_any_call('{}'.format(5), [h.name for h in hosts])
for host in hosts:
job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, host.name), host.ansible_facts)
job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, host.name), False)
def test_start_job_fact_cache_existing_host(hosts, hosts2, job, job2, inventory, mocker):
job.start_job_fact_cache()
for host in hosts:
job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, host.name), host.ansible_facts)
job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, host.name), False)
job._get_memcache_connection().set.reset_mock()
job2.start_job_fact_cache()
# Ensure hosts2 ansible_facts didn't overwrite hosts ansible_facts
ansible_facts_cached = job._get_memcache_connection().get('{}-{}'.format(5, hosts2[0].name))
assert ansible_facts_cached == hosts[1].ansible_facts
def test_finish_job_fact_cache(job, hosts, inventory, mocker):
job.start_job_fact_cache()
host = hosts[1]
host_key = job.memcached_fact_host_key(host.name)
modified_key = job.memcached_fact_modified_key(host.name)
host.save = mocker.Mock()
job._get_memcache_connection().set(host_key, 'blah')
job._get_memcache_connection().set(modified_key, True)
job.finish_job_fact_cache()
assert host.ansible_facts == 'blah'
host.save.assert_called_once_with()

View File

@ -30,100 +30,67 @@
# POSSIBILITY OF SUCH DAMAGE.
import os
import time
import memcache
import json
try:
from ansible.cache.base import BaseCacheModule
except:
from ansible.plugins.cache.base import BaseCacheModule
from kombu import Connection, Exchange, Producer
class CacheModule(BaseCacheModule):
def __init__(self, *args, **kwargs):
# Basic in-memory caching for typical runs
self._cache = {}
self._all_keys = {}
self.mc = memcache.Client([os.environ['MEMCACHED_LOCATION']], debug=0)
self.inventory_id = os.environ['INVENTORY_ID']
self.date_key = time.time()
self.callback_connection = os.environ['CALLBACK_CONNECTION']
self.callback_queue = os.environ['FACT_QUEUE']
self.connection = Connection(self.callback_connection)
self.exchange = Exchange(self.callback_queue, type='direct')
self.producer = Producer(self.connection)
@property
def host_names_key(self):
return '{}'.format(self.inventory_id)
def filter_ansible_facts(self, facts):
return dict((k, facts[k]) for k in facts.keys() if k.startswith('ansible_'))
def translate_host_key(self, host_name):
return '{}-{}'.format(self.inventory_id, host_name)
def identify_new_module(self, key, value):
# Return the first key found that doesn't exist in the
# previous set of facts
if key in self._all_keys:
for k in value.iterkeys():
if k not in self._all_keys[key] and not k.startswith('ansible_'):
return k
# First time we have seen facts from this host
# it's either ansible facts or a module facts (including module_setup)
elif len(value) == 1:
return value.iterkeys().next()
return None
def translate_modified_key(self, host_name):
return '{}-{}-modified'.format(self.inventory_id, host_name)
def get(self, key):
return self._cache.get(key)
host_key = self.translate_host_key(key)
value_json = self.mc.get(host_key)
if not value_json:
raise KeyError
return json.loads(value_json)
'''
get() returns a reference to the fact object (usually a dict). The object is modified directly,
then set is called. Effectively, pre-determining the set logic.
The below logic creates a backup of the cache each set. The values are now preserved across set() calls.
For a given key. The previous value is looked at for new keys that aren't of the form 'ansible_'.
If found, send the value of the found key.
If not found, send all the key value pairs of the form 'ansible_' (we presume set() is called because
of an ansible fact module invocation)
More simply stated...
In value, if a new key is found at the top most dict then consider this a module request and only
emit the facts for the found top-level key.
If a new key is not found, assume set() was called as a result of ansible facts scan. Thus, emit
all facts of the form 'ansible_'.
'''
def set(self, key, value):
module = self.identify_new_module(key, value)
# Assume ansible fact triggered the set if no new module found
facts = self.filter_ansible_facts(value) if not module else dict({ module : value[module]})
self._cache[key] = value
self._all_keys[key] = value.keys()
packet = {
'host': key,
'inventory_id': os.environ['INVENTORY_ID'],
'job_id': os.getenv('JOB_ID', ''),
'facts': facts,
'date_key': self.date_key,
}
host_key = self.translate_host_key(key)
modified_key = self.translate_modified_key(key)
# Emit fact data to tower for processing
self.producer.publish(packet,
serializer='json',
compression='bzip2',
exchange=self.exchange,
declare=[self.exchange],
routing_key=self.callback_queue)
self.mc.set(host_key, json.dumps(value))
self.mc.set(modified_key, True)
def keys(self):
return self._cache.keys()
return self.mc.get(self.host_names_key)
def contains(self, key):
return key in self._cache
val = self.mc.get(key)
if val is None:
return False
return True
def delete(self, key):
del self._cache[key]
self.mc.delete(self.translate_host_key(key))
self.mc.delete(self.translate_modified_key(key))
def flush(self):
self._cache = {}
for k in self.mc.get(self.host_names_key):
self.mc.delete(self.translate_host_key(k))
self.mc.delete(self.translate_modified_key(k))
def copy(self):
return self._cache.copy()
ret = dict()
for k in self.mc.get(self.host_names_key):
ret[k] = self.mc.get(self.translate_host_key(k))
return ret