take into account memcached key restrictions

* Keys can't contain spaces or control characters
This commit is contained in:
Chris Meyers
2017-07-05 10:53:09 -04:00
parent 3d4f8b0074
commit 1331865749
3 changed files with 12 additions and 9 deletions

View File

@@ -8,6 +8,7 @@ import hmac
import logging import logging
import time import time
import json import json
import base64
from urlparse import urljoin from urlparse import urljoin
# Django # Django
@@ -705,10 +706,10 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin):
return '{}'.format(self.inventory.id) return '{}'.format(self.inventory.id)
def memcached_fact_host_key(self, host_name): def memcached_fact_host_key(self, host_name):
return '{}-{}'.format(self.inventory.id, host_name) return '{}-{}'.format(self.inventory.id, base64.b64encode(host_name))
def memcached_fact_modified_key(self, host_name): def memcached_fact_modified_key(self, host_name):
return '{}-{}-modified'.format(self.inventory.id, host_name) return '{}-{}-modified'.format(self.inventory.id, base64.b64encode(host_name))
def _get_inventory_hosts(self, only=['name', 'ansible_facts', 'modified',]): def _get_inventory_hosts(self, only=['name', 'ansible_facts', 'modified',]):
return self.inventory.hosts.only(*only) return self.inventory.hosts.only(*only)

View File

@@ -8,6 +8,7 @@ from awx.main.models import (
import datetime import datetime
import json import json
import base64
from dateutil.tz import tzutc from dateutil.tz import tzutc
@@ -89,8 +90,8 @@ def test_start_job_fact_cache(hosts, job, inventory, mocker):
job._get_memcache_connection().set.assert_any_call('5', [h.name for h in hosts]) job._get_memcache_connection().set.assert_any_call('5', [h.name for h in hosts])
for host in hosts: for host in hosts:
job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, host.name), json.dumps(host.ansible_facts)) job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, base64.b64encode(host.name)), json.dumps(host.ansible_facts))
job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, host.name), host.ansible_facts_modified.isoformat()) job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, base64.b64encode(host.name)), host.ansible_facts_modified.isoformat())
def test_start_job_fact_cache_existing_host(hosts, hosts2, job, job2, inventory, mocker): def test_start_job_fact_cache_existing_host(hosts, hosts2, job, job2, inventory, mocker):
@@ -98,15 +99,15 @@ def test_start_job_fact_cache_existing_host(hosts, hosts2, job, job2, inventory,
job.start_job_fact_cache() job.start_job_fact_cache()
for host in hosts: for host in hosts:
job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, host.name), json.dumps(host.ansible_facts)) job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, base64.b64encode(host.name)), json.dumps(host.ansible_facts))
job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, host.name), host.ansible_facts_modified.isoformat()) job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, base64.b64encode(host.name)), host.ansible_facts_modified.isoformat())
job._get_memcache_connection().set.reset_mock() job._get_memcache_connection().set.reset_mock()
job2.start_job_fact_cache() job2.start_job_fact_cache()
# Ensure hosts2 ansible_facts didn't overwrite hosts ansible_facts # Ensure hosts2 ansible_facts didn't overwrite hosts ansible_facts
ansible_facts_cached = job._get_memcache_connection().get('{}-{}'.format(5, hosts2[0].name)) ansible_facts_cached = job._get_memcache_connection().get('{}-{}'.format(5, base64.b64encode(hosts2[0].name)))
assert ansible_facts_cached == json.dumps(hosts[1].ansible_facts) assert ansible_facts_cached == json.dumps(hosts[1].ansible_facts)

View File

@@ -33,6 +33,7 @@ import os
import memcache import memcache
import json import json
import datetime import datetime
import base64
from dateutil import parser from dateutil import parser
from dateutil.tz import tzutc from dateutil.tz import tzutc
@@ -56,10 +57,10 @@ class CacheModule(BaseCacheModule):
return '{}'.format(self._inventory_id) return '{}'.format(self._inventory_id)
def translate_host_key(self, host_name): def translate_host_key(self, host_name):
return '{}-{}'.format(self._inventory_id, host_name) return '{}-{}'.format(self._inventory_id, base64.b64encode(host_name))
def translate_modified_key(self, host_name): def translate_modified_key(self, host_name):
return '{}-{}-modified'.format(self._inventory_id, host_name) return '{}-{}-modified'.format(self._inventory_id, base64.b64encode(host_name))
def get(self, key): def get(self, key):
host_key = self.translate_host_key(key) host_key = self.translate_host_key(key)