diff --git a/awx/main/models/jobs.py b/awx/main/models/jobs.py index 5a83491a97..b89697b69c 100644 --- a/awx/main/models/jobs.py +++ b/awx/main/models/jobs.py @@ -12,6 +12,8 @@ from urlparse import urljoin # Django from django.conf import settings from django.db import models +#from django.core.cache import cache +import memcache from django.db.models import Q, Count from django.utils.dateparse import parse_datetime from django.utils.encoding import force_text @@ -38,6 +40,8 @@ from awx.main.fields import JSONField from awx.main.consumers import emit_channel_notification +TIMEOUT = 60 + logger = logging.getLogger('awx.main.models.jobs') analytics_logger = logging.getLogger('awx.analytics.job_events') @@ -703,6 +707,74 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin): self.project_update.cancel(job_explanation=job_explanation) return res + @property + def store_facts_enabled(self): + if not self.job_template or self.job_template is False: + return False + return True + + @property + def memcached_fact_key(self): + return '{}'.format(self.inventory.id) + + def memcached_fact_host_key(self, host_name): + return '{}-{}'.format(self.inventory.id, host_name) + + def memcached_fact_modified_key(self, host_name): + return '{}-{}-modified'.format(self.inventory.id, host_name) + + def _get_inventory_hosts(self, only=['name', 'ansible_facts', 'modified',]): + return self.inventory.hosts.only(*only) + + def _get_memcache_connection(self): + return memcache.Client([settings.CACHES['default']['LOCATION']], debug=0) + + def start_job_fact_cache(self): + if not self.inventory: + return + + cache = self._get_memcache_connection() + + host_names = [] + + for host in self._get_inventory_hosts(): + host_key = self.memcached_fact_host_key(host.name) + modified_key = self.memcached_fact_modified_key(host.name) + # Only add host/facts if host doesn't already exist in the cache + if cache.get(modified_key) is None: + cache.set(host_key, host.ansible_facts) + cache.set(modified_key, False) + + host_names.append(host.name) + + cache.set(self.memcached_fact_key, host_names) + + def finish_job_fact_cache(self): + if not self.inventory: + # TODO: Uh oh, we need to clean up the cache + return + + cache = self._get_memcache_connection() + + hosts = self._get_inventory_hosts() + for host in hosts: + host_key = self.memcached_fact_host_key(host.name) + modified_key = self.memcached_fact_modified_key(host.name) + + modified = cache.get(modified_key) + if modified is None: + continue + + # Save facts that have changed + if modified: + ansible_facts = cache.get(host_key) + if ansible_facts is None: + cache.delete(host_key) + # TODO: Log cache inconsistency + continue + host.ansible_facts = ansible_facts + host.save() + class JobHostSummary(CreatedModifiedModel): ''' @@ -1357,3 +1429,4 @@ class SystemJob(UnifiedJob, SystemJobOptions, JobNotificationMixin): def get_notification_friendly_name(self): return "System Job" + diff --git a/awx/main/tasks.py b/awx/main/tasks.py index 19bd26bb1f..665389db77 100644 --- a/awx/main/tasks.py +++ b/awx/main/tasks.py @@ -877,6 +877,9 @@ class RunJob(BaseTask): # callbacks to work. env['JOB_ID'] = str(job.pk) env['INVENTORY_ID'] = str(job.inventory.pk) + if job.store_facts_enabled: + env['MEMCACHED_PREPEND_KEY'] = job.memcached_fact_key + env['MEMCACHED_LOCATION'] = settings.CACHES['default']['LOCATION'] if job.project: env['PROJECT_REVISION'] = job.project.scm_revision env['ANSIBLE_RETRY_FILES_ENABLED'] = "False" @@ -1140,8 +1143,14 @@ class RunJob(BaseTask): ('project_update', local_project_sync.name, local_project_sync.id))) raise + if job.store_facts_enabled: + job.start_job_fact_cache() + + def final_run_hook(self, job, status, **kwargs): super(RunJob, self).final_run_hook(job, status, **kwargs) + if job.store_facts_enabled: + job.finish_job_fact_cache() try: inventory = job.inventory except Inventory.DoesNotExist: diff --git a/awx/main/tests/unit/models/test_jobs.py b/awx/main/tests/unit/models/test_jobs.py new file mode 100644 index 0000000000..8c75541bdf --- /dev/null +++ b/awx/main/tests/unit/models/test_jobs.py @@ -0,0 +1,115 @@ +import pytest + +from awx.main.models import ( + Job, + Inventory, + Host, +) + + +class CacheMock(object): + def __init__(self): + self.d = dict() + + def get(self, key): + if key not in self.d: + return None + return self.d[key] + + def set(self, key, val): + self.d[key] = val + + def delete(self, key): + del self.d[key] + + +@pytest.fixture +def hosts(): + return [ + Host(name='host1', ansible_facts={"a": 1, "b": 2}), + Host(name='host2', ansible_facts={"a": 1, "b": 2}), + Host(name='host3', ansible_facts={"a": 1, "b": 2}), + ] + + +@pytest.fixture +def hosts2(): + return [ + Host(name='host2', ansible_facts="foobar"), + ] + + +@pytest.fixture +def inventory(): + return Inventory(id=5) + + +@pytest.fixture +def mock_cache(mocker): + cache = CacheMock() + mocker.patch.object(cache, 'set', wraps=cache.set) + mocker.patch.object(cache, 'get', wraps=cache.get) + mocker.patch.object(cache, 'delete', wraps=cache.delete) + return cache + + +@pytest.fixture +def job(mocker, hosts, inventory, mock_cache): + j = Job(inventory=inventory, id=2) + j._get_inventory_hosts = mocker.Mock(return_value=hosts) + j._get_memcache_connection = mocker.Mock(return_value=mock_cache) + return j + + +@pytest.fixture +def job2(mocker, hosts2, inventory, mock_cache): + j = Job(inventory=inventory, id=3) + j._get_inventory_hosts = mocker.Mock(return_value=hosts2) + j._get_memcache_connection = mocker.Mock(return_value=mock_cache) + return j + + +def test_start_job_fact_cache(hosts, job, inventory, mocker): + + job.start_job_fact_cache() + + job._get_memcache_connection().set.assert_any_call('{}'.format(5), [h.name for h in hosts]) + for host in hosts: + job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, host.name), host.ansible_facts) + job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, host.name), False) + + +def test_start_job_fact_cache_existing_host(hosts, hosts2, job, job2, inventory, mocker): + + job.start_job_fact_cache() + + for host in hosts: + job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, host.name), host.ansible_facts) + job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, host.name), False) + + job._get_memcache_connection().set.reset_mock() + + job2.start_job_fact_cache() + + # Ensure hosts2 ansible_facts didn't overwrite hosts ansible_facts + ansible_facts_cached = job._get_memcache_connection().get('{}-{}'.format(5, hosts2[0].name)) + assert ansible_facts_cached == hosts[1].ansible_facts + + +def test_finish_job_fact_cache(job, hosts, inventory, mocker): + + job.start_job_fact_cache() + + host = hosts[1] + host_key = job.memcached_fact_host_key(host.name) + modified_key = job.memcached_fact_modified_key(host.name) + host.save = mocker.Mock() + + job._get_memcache_connection().set(host_key, 'blah') + job._get_memcache_connection().set(modified_key, True) + + job.finish_job_fact_cache() + + assert host.ansible_facts == 'blah' + host.save.assert_called_once_with() + diff --git a/awx/plugins/fact_caching/tower.py b/awx/plugins/fact_caching/tower.py index 427cce8501..353c49010a 100755 --- a/awx/plugins/fact_caching/tower.py +++ b/awx/plugins/fact_caching/tower.py @@ -30,100 +30,67 @@ # POSSIBILITY OF SUCH DAMAGE. import os -import time +import memcache +import json try: from ansible.cache.base import BaseCacheModule except: from ansible.plugins.cache.base import BaseCacheModule -from kombu import Connection, Exchange, Producer - class CacheModule(BaseCacheModule): def __init__(self, *args, **kwargs): # Basic in-memory caching for typical runs - self._cache = {} - self._all_keys = {} + self.mc = memcache.Client([os.environ['MEMCACHED_LOCATION']], debug=0) + self.inventory_id = os.environ['INVENTORY_ID'] - self.date_key = time.time() - self.callback_connection = os.environ['CALLBACK_CONNECTION'] - self.callback_queue = os.environ['FACT_QUEUE'] - self.connection = Connection(self.callback_connection) - self.exchange = Exchange(self.callback_queue, type='direct') - self.producer = Producer(self.connection) + @property + def host_names_key(self): + return '{}'.format(self.inventory_id) - def filter_ansible_facts(self, facts): - return dict((k, facts[k]) for k in facts.keys() if k.startswith('ansible_')) + def translate_host_key(self, host_name): + return '{}-{}'.format(self.inventory_id, host_name) - def identify_new_module(self, key, value): - # Return the first key found that doesn't exist in the - # previous set of facts - if key in self._all_keys: - for k in value.iterkeys(): - if k not in self._all_keys[key] and not k.startswith('ansible_'): - return k - # First time we have seen facts from this host - # it's either ansible facts or a module facts (including module_setup) - elif len(value) == 1: - return value.iterkeys().next() - return None + def translate_modified_key(self, host_name): + return '{}-{}-modified'.format(self.inventory_id, host_name) def get(self, key): - return self._cache.get(key) + host_key = self.translate_host_key(key) + value_json = self.mc.get(host_key) + if not value_json: + raise KeyError + return json.loads(value_json) - ''' - get() returns a reference to the fact object (usually a dict). The object is modified directly, - then set is called. Effectively, pre-determining the set logic. - - The below logic creates a backup of the cache each set. The values are now preserved across set() calls. - - For a given key. The previous value is looked at for new keys that aren't of the form 'ansible_'. - If found, send the value of the found key. - If not found, send all the key value pairs of the form 'ansible_' (we presume set() is called because - of an ansible fact module invocation) - - More simply stated... - In value, if a new key is found at the top most dict then consider this a module request and only - emit the facts for the found top-level key. - - If a new key is not found, assume set() was called as a result of ansible facts scan. Thus, emit - all facts of the form 'ansible_'. - ''' def set(self, key, value): - module = self.identify_new_module(key, value) - # Assume ansible fact triggered the set if no new module found - facts = self.filter_ansible_facts(value) if not module else dict({ module : value[module]}) - self._cache[key] = value - self._all_keys[key] = value.keys() - packet = { - 'host': key, - 'inventory_id': os.environ['INVENTORY_ID'], - 'job_id': os.getenv('JOB_ID', ''), - 'facts': facts, - 'date_key': self.date_key, - } + host_key = self.translate_host_key(key) + modified_key = self.translate_modified_key(key) - # Emit fact data to tower for processing - self.producer.publish(packet, - serializer='json', - compression='bzip2', - exchange=self.exchange, - declare=[self.exchange], - routing_key=self.callback_queue) + self.mc.set(host_key, json.dumps(value)) + self.mc.set(modified_key, True) def keys(self): - return self._cache.keys() + return self.mc.get(self.host_names_key) def contains(self, key): - return key in self._cache + val = self.mc.get(key) + if val is None: + return False + return True def delete(self, key): - del self._cache[key] + self.mc.delete(self.translate_host_key(key)) + self.mc.delete(self.translate_modified_key(key)) def flush(self): - self._cache = {} + for k in self.mc.get(self.host_names_key): + self.mc.delete(self.translate_host_key(k)) + self.mc.delete(self.translate_modified_key(k)) def copy(self): - return self._cache.copy() + ret = dict() + for k in self.mc.get(self.host_names_key): + ret[k] = self.mc.get(self.translate_host_key(k)) + return ret +