diff --git a/awx/api/serializers.py b/awx/api/serializers.py index fcdde9e512..97c369c051 100644 --- a/awx/api/serializers.py +++ b/awx/api/serializers.py @@ -2196,6 +2196,7 @@ class CredentialSerializer(BaseSerializer): _('You cannot change the credential type of the credential, as it may break the functionality' ' of the resources using it.'), ) + return credential_type diff --git a/awx/conf/views.py b/awx/conf/views.py index 189dd387dc..60ea39d911 100644 --- a/awx/conf/views.py +++ b/awx/conf/views.py @@ -21,7 +21,7 @@ from awx.api.generics import * # noqa from awx.api.permissions import IsSuperUser from awx.api.versioning import reverse, get_request_version from awx.main.utils import * # noqa -from awx.main.utils.handlers import BaseHTTPSHandler, LoggingConnectivityException +from awx.main.utils.handlers import BaseHTTPSHandler, UDPHandler, LoggingConnectivityException from awx.main.tasks import handle_setting_changes from awx.conf.license import get_licensed_features from awx.conf.models import Setting @@ -199,7 +199,11 @@ class SettingLoggingTest(GenericAPIView): for k, v in serializer.validated_data.items(): setattr(mock_settings, k, v) mock_settings.LOG_AGGREGATOR_LEVEL = 'DEBUG' - BaseHTTPSHandler.perform_test(mock_settings) + if mock_settings.LOG_AGGREGATOR_PROTOCOL.upper() == 'UDP': + UDPHandler.perform_test(mock_settings) + return Response(status=status.HTTP_201_CREATED) + else: + BaseHTTPSHandler.perform_test(mock_settings) except LoggingConnectivityException as e: return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return Response(status=status.HTTP_200_OK) diff --git a/awx/lib/awx_display_callback/module.py b/awx/lib/awx_display_callback/module.py index 6800560cfc..368063d0d1 100644 --- a/awx/lib/awx_display_callback/module.py +++ b/awx/lib/awx_display_callback/module.py @@ -18,7 +18,11 @@ from __future__ import (absolute_import, division, print_function) # Python +import codecs import contextlib +import json +import os +import stat import sys import uuid from copy import copy @@ -292,10 +296,22 @@ class BaseCallbackModule(CallbackBase): failures=stats.failures, ok=stats.ok, processed=stats.processed, - skipped=stats.skipped, - artifact_data=stats.custom.get('_run', {}) if hasattr(stats, 'custom') else {} + skipped=stats.skipped ) + # write custom set_stat artifact data to the local disk so that it can + # be persisted by awx after the process exits + custom_artifact_data = stats.custom.get('_run', {}) if hasattr(stats, 'custom') else {} + if custom_artifact_data: + # create the directory for custom stats artifacts to live in (if it doesn't exist) + custom_artifacts_dir = os.path.join(os.getenv('AWX_PRIVATE_DATA_DIR'), 'artifacts') + os.makedirs(custom_artifacts_dir, mode=stat.S_IXUSR + stat.S_IWUSR + stat.S_IRUSR) + + custom_artifacts_path = os.path.join(custom_artifacts_dir, 'custom') + with codecs.open(custom_artifacts_path, 'w', encoding='utf-8') as f: + os.chmod(custom_artifacts_path, stat.S_IRUSR | stat.S_IWUSR) + json.dump(custom_artifact_data, f) + with self.capture_event_data('playbook_on_stats', **event_data): super(BaseCallbackModule, self).v2_playbook_on_stats(stats) diff --git a/awx/lib/tests/test_display_callback.py b/awx/lib/tests/test_display_callback.py index 34873dd6cc..d8c7923108 100644 --- a/awx/lib/tests/test_display_callback.py +++ b/awx/lib/tests/test_display_callback.py @@ -7,7 +7,9 @@ from collections import OrderedDict import json import mock import os +import shutil import sys +import tempfile import pytest @@ -259,3 +261,26 @@ def test_callback_plugin_strips_task_environ_variables(executor, cache, playbook assert len(cache) for event in cache.values(): assert os.environ['PATH'] not in json.dumps(event) + + +@pytest.mark.parametrize('playbook', [ +{'custom_set_stat.yml': ''' +- name: custom set_stat calls should persist to the local disk so awx can save them + connection: local + hosts: all + tasks: + - set_stats: + data: + foo: "bar" +'''}, # noqa +]) +def test_callback_plugin_saves_custom_stats(executor, cache, playbook): + try: + private_data_dir = tempfile.mkdtemp() + with mock.patch.dict(os.environ, {'AWX_PRIVATE_DATA_DIR': private_data_dir}): + executor.run() + artifacts_path = os.path.join(private_data_dir, 'artifacts', 'custom') + with open(artifacts_path, 'r') as f: + assert json.load(f) == {'foo': 'bar'} + finally: + shutil.rmtree(os.path.join(private_data_dir)) diff --git a/awx/main/fields.py b/awx/main/fields.py index 5bec40d1de..d0f6081693 100644 --- a/awx/main/fields.py +++ b/awx/main/fields.py @@ -506,6 +506,12 @@ class CredentialInputField(JSONSchemaField): v != '$encrypted$', model_instance.pk ]): + if not isinstance(getattr(model_instance, k), six.string_types): + raise django_exceptions.ValidationError( + _('secret values must be of type string, not {}').format(type(v).__name__), + code='invalid', + params={'value': v}, + ) decrypted_values[k] = utils.decrypt_field(model_instance, k) else: decrypted_values[k] = v diff --git a/awx/main/models/jobs.py b/awx/main/models/jobs.py index 5efd502a7d..83eb228216 100644 --- a/awx/main/models/jobs.py +++ b/awx/main/models/jobs.py @@ -2,21 +2,22 @@ # All Rights Reserved. # Python +import codecs import datetime import logging +import os import time import json -import base64 from urlparse import urljoin +import six + # Django from django.conf import settings from django.db import models #from django.core.cache import cache -import memcache -from dateutil import parser -from dateutil.tz import tzutc from django.utils.encoding import smart_str +from django.utils.timezone import now from django.utils.translation import ugettext_lazy as _ from django.core.exceptions import ValidationError, FieldDoesNotExist @@ -738,86 +739,68 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana def get_notification_friendly_name(self): return "Job" - @property - def memcached_fact_key(self): - return '{}'.format(self.inventory.id) - - def memcached_fact_host_key(self, host_name): - return '{}-{}'.format(self.inventory.id, base64.b64encode(host_name.encode('utf-8'))) - - def memcached_fact_modified_key(self, host_name): - return '{}-{}-modified'.format(self.inventory.id, base64.b64encode(host_name.encode('utf-8'))) - - def _get_inventory_hosts(self, only=['name', 'ansible_facts', 'modified',]): + def _get_inventory_hosts(self, only=['name', 'ansible_facts', 'ansible_facts_modified', 'modified',]): + if not self.inventory: + return [] return self.inventory.hosts.only(*only) - def _get_memcache_connection(self): - return memcache.Client([settings.CACHES['default']['LOCATION']], debug=0) - - def start_job_fact_cache(self): - if not self.inventory: - return - - cache = self._get_memcache_connection() - - host_names = [] - - for host in self._get_inventory_hosts(): - host_key = self.memcached_fact_host_key(host.name) - modified_key = self.memcached_fact_modified_key(host.name) - - if cache.get(modified_key) is None: - if host.ansible_facts_modified: - host_modified = host.ansible_facts_modified.replace(tzinfo=tzutc()).isoformat() - else: - host_modified = datetime.datetime.now(tzutc()).isoformat() - cache.set(host_key, json.dumps(host.ansible_facts)) - cache.set(modified_key, host_modified) - - host_names.append(host.name) - - cache.set(self.memcached_fact_key, host_names) - - def finish_job_fact_cache(self): - if not self.inventory: - return - - cache = self._get_memcache_connection() - + def start_job_fact_cache(self, destination, modification_times, timeout=None): + destination = os.path.join(destination, 'facts') + os.makedirs(destination, mode=0700) hosts = self._get_inventory_hosts() + if timeout is None: + timeout = settings.ANSIBLE_FACT_CACHE_TIMEOUT + if timeout > 0: + # exclude hosts with fact data older than `settings.ANSIBLE_FACT_CACHE_TIMEOUT seconds` + timeout = now() - datetime.timedelta(seconds=timeout) + hosts = hosts.filter(ansible_facts_modified__gte=timeout) for host in hosts: - host_key = self.memcached_fact_host_key(host.name) - modified_key = self.memcached_fact_modified_key(host.name) - - modified = cache.get(modified_key) - if modified is None: - cache.delete(host_key) + filepath = os.sep.join(map(six.text_type, [destination, host.name])) + if not os.path.realpath(filepath).startswith(destination): + system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name))) continue + with codecs.open(filepath, 'w', encoding='utf-8') as f: + os.chmod(f.name, 0600) + json.dump(host.ansible_facts, f) + # make note of the time we wrote the file so we can check if it changed later + modification_times[filepath] = os.path.getmtime(filepath) - # Save facts if cache is newer than DB - modified = parser.parse(modified, tzinfos=[tzutc()]) - if not host.ansible_facts_modified or modified > host.ansible_facts_modified: - ansible_facts = cache.get(host_key) - try: - ansible_facts = json.loads(ansible_facts) - except Exception: - ansible_facts = None - - if ansible_facts is None: - cache.delete(host_key) - continue - host.ansible_facts = ansible_facts - host.ansible_facts_modified = modified - if 'insights' in ansible_facts and 'system_id' in ansible_facts['insights']: - host.insights_system_id = ansible_facts['insights']['system_id'] - host.save() + def finish_job_fact_cache(self, destination, modification_times): + destination = os.path.join(destination, 'facts') + for host in self._get_inventory_hosts(): + filepath = os.sep.join(map(six.text_type, [destination, host.name])) + if not os.path.realpath(filepath).startswith(destination): + system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name))) + continue + if os.path.exists(filepath): + # If the file changed since we wrote it pre-playbook run... + modified = os.path.getmtime(filepath) + if modified > modification_times.get(filepath, 0): + with codecs.open(filepath, 'r', encoding='utf-8') as f: + try: + ansible_facts = json.load(f) + except ValueError: + continue + host.ansible_facts = ansible_facts + host.ansible_facts_modified = now() + if 'insights' in ansible_facts and 'system_id' in ansible_facts['insights']: + host.insights_system_id = ansible_facts['insights']['system_id'] + host.save() + system_tracking_logger.info( + 'New fact for inventory {} host {}'.format( + smart_str(host.inventory.name), smart_str(host.name)), + extra=dict(inventory_id=host.inventory.id, host_name=host.name, + ansible_facts=host.ansible_facts, + ansible_facts_modified=host.ansible_facts_modified.isoformat(), + job_id=self.id)) + else: + # if the file goes missing, ansible removed it (likely via clear_facts) + host.ansible_facts = {} + host.ansible_facts_modified = now() system_tracking_logger.info( - 'New fact for inventory {} host {}'.format( - smart_str(host.inventory.name), smart_str(host.name)), - extra=dict(inventory_id=host.inventory.id, host_name=host.name, - ansible_facts=host.ansible_facts, - ansible_facts_modified=host.ansible_facts_modified.isoformat(), - job_id=self.id)) + 'Facts cleared for inventory {} host {}'.format( + smart_str(host.inventory.name), smart_str(host.name))) + host.save() # Add on aliases for the non-related-model fields diff --git a/awx/main/notifications/slack_backend.py b/awx/main/notifications/slack_backend.py index 3cea4bd44e..6e966e882b 100644 --- a/awx/main/notifications/slack_backend.py +++ b/awx/main/notifications/slack_backend.py @@ -1,6 +1,7 @@ # Copyright (c) 2016 Ansible, Inc. # All Rights Reserved. +import time import logging from slackclient import SlackClient @@ -9,6 +10,7 @@ from django.utils.translation import ugettext_lazy as _ from awx.main.notifications.base import AWXBaseEmailBackend logger = logging.getLogger('awx.main.notifications.slack_backend') +WEBSOCKET_TIMEOUT = 30 class SlackBackend(AWXBaseEmailBackend): @@ -30,7 +32,18 @@ class SlackBackend(AWXBaseEmailBackend): if not self.connection.rtm_connect(): if not self.fail_silently: raise Exception("Slack Notification Token is invalid") - return True + + start = time.time() + time.clock() + elapsed = 0 + while elapsed < WEBSOCKET_TIMEOUT: + events = self.connection.rtm_read() + if any(event['type'] == 'hello' for event in events): + return True + elapsed = time.time() - start + time.sleep(0.5) + + raise RuntimeError("Slack Notification unable to establish websocket connection after {} seconds".format(WEBSOCKET_TIMEOUT)) def close(self): if self.connection is None: diff --git a/awx/main/tasks.py b/awx/main/tasks.py index b15d0b73bd..9aae7e6dbf 100644 --- a/awx/main/tasks.py +++ b/awx/main/tasks.py @@ -16,6 +16,7 @@ import stat import tempfile import time import traceback +import six import urlparse from distutils.version import LooseVersion as Version import yaml @@ -44,8 +45,6 @@ from django.core.exceptions import ObjectDoesNotExist # Django-CRUM from crum import impersonate -import six - # AWX from awx import __version__ as awx_application_version from awx import celery_app @@ -781,6 +780,7 @@ class BaseTask(LogErrorsTask): # Derived class should call add_ansible_venv() or add_awx_venv() if self.should_use_proot(instance, **kwargs): env['PROOT_TMP_DIR'] = settings.AWX_PROOT_BASE_PATH + env['AWX_PRIVATE_DATA_DIR'] = kwargs['private_data_dir'] return env def should_use_proot(self, instance, **kwargs): @@ -898,6 +898,15 @@ class BaseTask(LogErrorsTask): # Fetch ansible version once here to support version-dependent features. kwargs['ansible_version'] = get_ansible_version() kwargs['private_data_dir'] = self.build_private_data_dir(instance, **kwargs) + + # Fetch "cached" fact data from prior runs and put on the disk + # where ansible expects to find it + if getattr(instance, 'use_fact_cache', False) and not kwargs.get('isolated'): + instance.start_job_fact_cache( + os.path.join(kwargs['private_data_dir']), + kwargs.setdefault('fact_modification_times', {}) + ) + # May have to serialize the value kwargs['private_data_files'] = self.build_private_data_files(instance, **kwargs) kwargs['passwords'] = self.build_passwords(instance, **kwargs) @@ -1129,11 +1138,15 @@ class RunJob(BaseTask): env['JOB_ID'] = str(job.pk) env['INVENTORY_ID'] = str(job.inventory.pk) if job.use_fact_cache and not kwargs.get('isolated'): - env['ANSIBLE_LIBRARY'] = self.get_path_to('..', 'plugins', 'library') - env['ANSIBLE_CACHE_PLUGINS'] = self.get_path_to('..', 'plugins', 'fact_caching') - env['ANSIBLE_CACHE_PLUGIN'] = "awx" - env['ANSIBLE_CACHE_PLUGIN_TIMEOUT'] = str(settings.ANSIBLE_FACT_CACHE_TIMEOUT) - env['ANSIBLE_CACHE_PLUGIN_CONNECTION'] = settings.CACHES['default']['LOCATION'] if 'LOCATION' in settings.CACHES['default'] else '' + library_path = env.get('ANSIBLE_LIBRARY') + env['ANSIBLE_LIBRARY'] = ':'.join( + filter(None, [ + library_path, + self.get_path_to('..', 'plugins', 'library') + ]) + ) + env['ANSIBLE_CACHE_PLUGIN'] = "jsonfile" + env['ANSIBLE_CACHE_PLUGIN_CONNECTION'] = os.path.join(kwargs['private_data_dir'], 'facts') if job.project: env['PROJECT_REVISION'] = job.project.scm_revision env['ANSIBLE_RETRY_FILES_ENABLED'] = "False" @@ -1276,6 +1289,7 @@ class RunJob(BaseTask): for method in PRIVILEGE_ESCALATION_METHODS: d[re.compile(r'%s password.*:\s*?$' % (method[0]), re.M)] = 'become_password' d[re.compile(r'%s password.*:\s*?$' % (method[0].upper()), re.M)] = 'become_password' + d[re.compile(r'BECOME password.*:\s*?$', re.M)] = 'become_password' d[re.compile(r'SSH password:\s*?$', re.M)] = 'ssh_password' d[re.compile(r'Password:\s*?$', re.M)] = 'ssh_password' d[re.compile(r'Vault password:\s*?$', re.M)] = 'vault_password' @@ -1329,14 +1343,29 @@ class RunJob(BaseTask): ('project_update', local_project_sync.name, local_project_sync.id))) raise - if job.use_fact_cache and not kwargs.get('isolated'): - job.start_job_fact_cache() - def final_run_hook(self, job, status, **kwargs): super(RunJob, self).final_run_hook(job, status, **kwargs) if job.use_fact_cache and not kwargs.get('isolated'): - job.finish_job_fact_cache() + job.finish_job_fact_cache( + kwargs['private_data_dir'], + kwargs['fact_modification_times'] + ) + + # persist artifacts set via `set_stat` (if any) + custom_stats_path = os.path.join(kwargs['private_data_dir'], 'artifacts', 'custom') + if os.path.exists(custom_stats_path): + with open(custom_stats_path, 'r') as f: + custom_stat_data = None + try: + custom_stat_data = json.load(f) + except ValueError: + logger.warning('Could not parse custom `set_fact` data for job {}'.format(job.id)) + + if custom_stat_data: + job.artifacts = custom_stat_data + job.save(update_fields=['artifacts']) + try: inventory = job.inventory except Inventory.DoesNotExist: @@ -1554,15 +1583,15 @@ class RunProjectUpdate(BaseTask): if not inv_src.update_on_project_update: continue if inv_src.scm_last_revision == scm_revision: - logger.debug('Skipping SCM inventory update for `{}` because ' - 'project has not changed.'.format(inv_src.name)) + logger.debug(six.text_type('Skipping SCM inventory update for `{}` because ' + 'project has not changed.').format(inv_src.name)) continue - logger.debug('Local dependent inventory update for `{}`.'.format(inv_src.name)) + logger.debug(six.text_type('Local dependent inventory update for `{}`.').format(inv_src.name)) with transaction.atomic(): if InventoryUpdate.objects.filter(inventory_source=inv_src, status__in=ACTIVE_STATES).exists(): - logger.info('Skipping SCM inventory update for `{}` because ' - 'another update is already active.'.format(inv_src.name)) + logger.info(six.text_type('Skipping SCM inventory update for `{}` because ' + 'another update is already active.').format(inv_src.name)) continue local_inv_update = inv_src.create_inventory_update( _eager_fields=dict( @@ -2225,6 +2254,7 @@ class RunAdHocCommand(BaseTask): for method in PRIVILEGE_ESCALATION_METHODS: d[re.compile(r'%s password.*:\s*?$' % (method[0]), re.M)] = 'become_password' d[re.compile(r'%s password.*:\s*?$' % (method[0].upper()), re.M)] = 'become_password' + d[re.compile(r'BECOME password.*:\s*?$', re.M)] = 'become_password' d[re.compile(r'SSH password:\s*?$', re.M)] = 'ssh_password' d[re.compile(r'Password:\s*?$', re.M)] = 'ssh_password' return d diff --git a/awx/main/tests/functional/api/test_credential.py b/awx/main/tests/functional/api/test_credential.py index 4834e96cbd..486ba7cacc 100644 --- a/awx/main/tests/functional/api/test_credential.py +++ b/awx/main/tests/functional/api/test_credential.py @@ -1480,6 +1480,105 @@ def test_credential_type_mutability(patch, organization, admin, credentialtype_s assert response.status_code == 200 +@pytest.mark.django_db +def test_vault_credential_type_mutability(patch, organization, admin, credentialtype_ssh, + credentialtype_vault): + cred = Credential( + credential_type=credentialtype_vault, + name='Best credential ever', + organization=organization, + inputs={ + 'vault_password': u'some-vault', + } + ) + cred.save() + + jt = JobTemplate() + jt.save() + jt.credentials.add(cred) + + def _change_credential_type(): + return patch( + reverse('api:credential_detail', kwargs={'version': 'v2', 'pk': cred.pk}), + { + 'credential_type': credentialtype_ssh.pk, + 'inputs': { + 'username': u'jim', + 'password': u'pass' + } + }, + admin + ) + + response = _change_credential_type() + assert response.status_code == 400 + expected = ['You cannot change the credential type of the credential, ' + 'as it may break the functionality of the resources using it.'] + assert response.data['credential_type'] == expected + + response = patch( + reverse('api:credential_detail', kwargs={'version': 'v2', 'pk': cred.pk}), + {'name': 'Worst credential ever'}, + admin + ) + assert response.status_code == 200 + assert Credential.objects.get(pk=cred.pk).name == 'Worst credential ever' + + jt.delete() + response = _change_credential_type() + assert response.status_code == 200 + + +@pytest.mark.django_db +def test_cloud_credential_type_mutability(patch, organization, admin, credentialtype_ssh, + credentialtype_aws): + cred = Credential( + credential_type=credentialtype_aws, + name='Best credential ever', + organization=organization, + inputs={ + 'username': u'jim', + 'password': u'pass' + } + ) + cred.save() + + jt = JobTemplate() + jt.save() + jt.credentials.add(cred) + + def _change_credential_type(): + return patch( + reverse('api:credential_detail', kwargs={'version': 'v2', 'pk': cred.pk}), + { + 'credential_type': credentialtype_ssh.pk, + 'inputs': { + 'username': u'jim', + 'password': u'pass' + } + }, + admin + ) + + response = _change_credential_type() + assert response.status_code == 400 + expected = ['You cannot change the credential type of the credential, ' + 'as it may break the functionality of the resources using it.'] + assert response.data['credential_type'] == expected + + response = patch( + reverse('api:credential_detail', kwargs={'version': 'v2', 'pk': cred.pk}), + {'name': 'Worst credential ever'}, + admin + ) + assert response.status_code == 200 + assert Credential.objects.get(pk=cred.pk).name == 'Worst credential ever' + + jt.delete() + response = _change_credential_type() + assert response.status_code == 200 + + @pytest.mark.django_db @pytest.mark.parametrize('version, params', [ ['v1', { diff --git a/awx/main/tests/unit/models/test_jobs.py b/awx/main/tests/unit/models/test_jobs.py index e60b775066..516a6f076f 100644 --- a/awx/main/tests/unit/models/test_jobs.py +++ b/awx/main/tests/unit/models/test_jobs.py @@ -1,4 +1,7 @@ # -*- coding: utf-8 -*- +import json +import os +import time import pytest @@ -8,51 +11,14 @@ from awx.main.models import ( Host, ) -import datetime -import json -import base64 -from dateutil.tz import tzutc - - -class CacheMock(object): - def __init__(self): - self.d = dict() - - def get(self, key): - if key not in self.d: - return None - return self.d[key] - - def set(self, key, val): - self.d[key] = val - - def delete(self, key): - del self.d[key] - @pytest.fixture -def old_time(): - return (datetime.datetime.now(tzutc()) - datetime.timedelta(minutes=60)) - - -@pytest.fixture() -def new_time(): - return (datetime.datetime.now(tzutc())) - - -@pytest.fixture -def hosts(old_time, inventory): +def hosts(inventory): return [ - Host(name='host1', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=old_time, inventory=inventory), - Host(name='host2', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=old_time, inventory=inventory), - Host(name='host3', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=old_time, inventory=inventory), - ] - - -@pytest.fixture -def hosts2(inventory): - return [ - Host(name='host2', ansible_facts="foobar", ansible_facts_modified=old_time, inventory=inventory), + Host(name='host1', ansible_facts={"a": 1, "b": 2}, inventory=inventory), + Host(name='host2', ansible_facts={"a": 1, "b": 2}, inventory=inventory), + Host(name='host3', ansible_facts={"a": 1, "b": 2}, inventory=inventory), + Host(name=u'Iñtërnâtiônàlizætiøn', ansible_facts={"a": 1, "b": 2}, inventory=inventory), ] @@ -62,87 +28,103 @@ def inventory(): @pytest.fixture -def mock_cache(mocker): - cache = CacheMock() - mocker.patch.object(cache, 'set', wraps=cache.set) - mocker.patch.object(cache, 'get', wraps=cache.get) - mocker.patch.object(cache, 'delete', wraps=cache.delete) - return cache - - -@pytest.fixture -def job(mocker, hosts, inventory, mock_cache): +def job(mocker, hosts, inventory): j = Job(inventory=inventory, id=2) j._get_inventory_hosts = mocker.Mock(return_value=hosts) - j._get_memcache_connection = mocker.Mock(return_value=mock_cache) return j -@pytest.fixture -def job2(mocker, hosts2, inventory, mock_cache): - j = Job(inventory=inventory, id=3) - j._get_inventory_hosts = mocker.Mock(return_value=hosts2) - j._get_memcache_connection = mocker.Mock(return_value=mock_cache) - return j +def test_start_job_fact_cache(hosts, job, inventory, tmpdir): + fact_cache = str(tmpdir) + modified_times = {} + job.start_job_fact_cache(fact_cache, modified_times, 0) + + for host in hosts: + filepath = os.path.join(fact_cache, 'facts', host.name) + assert os.path.exists(filepath) + with open(filepath, 'r') as f: + assert f.read() == json.dumps(host.ansible_facts) + assert filepath in modified_times -def test_start_job_fact_cache(hosts, job, inventory, mocker): +def test_fact_cache_with_invalid_path_traversal(job, inventory, tmpdir, mocker): + job._get_inventory_hosts = mocker.Mock(return_value=[ + Host(name='../foo', ansible_facts={"a": 1, "b": 2},), + ]) - job.start_job_fact_cache() - - job._get_memcache_connection().set.assert_any_call('5', [h.name for h in hosts]) - for host in hosts: - job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, base64.b64encode(host.name)), json.dumps(host.ansible_facts)) - job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, base64.b64encode(host.name)), host.ansible_facts_modified.isoformat()) + fact_cache = str(tmpdir) + job.start_job_fact_cache(fact_cache, {}, 0) + # a file called "foo" should _not_ be written outside the facts dir + assert os.listdir(os.path.join(fact_cache, 'facts', '..')) == ['facts'] -def test_start_job_fact_cache_existing_host(hosts, hosts2, job, job2, inventory, mocker): +def test_finish_job_fact_cache_with_existing_data(job, hosts, inventory, mocker, tmpdir): + fact_cache = str(tmpdir) + modified_times = {} + job.start_job_fact_cache(fact_cache, modified_times, 0) - job.start_job_fact_cache() - - for host in hosts: - job._get_memcache_connection().set.assert_any_call('{}-{}'.format(5, base64.b64encode(host.name)), json.dumps(host.ansible_facts)) - job._get_memcache_connection().set.assert_any_call('{}-{}-modified'.format(5, base64.b64encode(host.name)), host.ansible_facts_modified.isoformat()) - - job._get_memcache_connection().set.reset_mock() - - job2.start_job_fact_cache() - - # Ensure hosts2 ansible_facts didn't overwrite hosts ansible_facts - ansible_facts_cached = job._get_memcache_connection().get('{}-{}'.format(5, base64.b64encode(hosts2[0].name))) - assert ansible_facts_cached == json.dumps(hosts[1].ansible_facts) - - -def test_memcached_fact_host_key_unicode(job): - host_name = u'Iñtërnâtiônàlizætiøn' - host_key = job.memcached_fact_host_key(host_name) - assert host_key == '5-ScOxdMOrcm7DonRpw7Ruw6BsaXrDpnRpw7hu' - - -def test_memcached_fact_modified_key_unicode(job): - host_name = u'Iñtërnâtiônàlizætiøn' - host_key = job.memcached_fact_modified_key(host_name) - assert host_key == '5-ScOxdMOrcm7DonRpw7Ruw6BsaXrDpnRpw7hu-modified' - - -def test_finish_job_fact_cache(job, hosts, inventory, mocker, new_time): - - job.start_job_fact_cache() for h in hosts: h.save = mocker.Mock() - host_key = job.memcached_fact_host_key(hosts[1].name) - modified_key = job.memcached_fact_modified_key(hosts[1].name) - ansible_facts_new = {"foo": "bar", "insights": {"system_id": "updated_by_scan"}} - job._get_memcache_connection().set(host_key, json.dumps(ansible_facts_new)) - job._get_memcache_connection().set(modified_key, new_time.isoformat()) - - job.finish_job_fact_cache() + filepath = os.path.join(fact_cache, 'facts', hosts[1].name) + with open(filepath, 'w') as f: + f.write(json.dumps(ansible_facts_new)) + f.flush() + # I feel kind of gross about calling `os.utime` by hand, but I noticed + # that in our container-based dev environment, the resolution for + # `os.stat()` after a file write was over a second, and I don't want to put + # a sleep() in this test + new_modification_time = time.time() + 3600 + os.utime(filepath, (new_modification_time, new_modification_time)) - hosts[0].save.assert_not_called() - hosts[2].save.assert_not_called() + job.finish_job_fact_cache(fact_cache, modified_times) + + for host in (hosts[0], hosts[2], hosts[3]): + host.save.assert_not_called() + assert host.ansible_facts == {"a": 1, "b": 2} + assert host.ansible_facts_modified is None assert hosts[1].ansible_facts == ansible_facts_new assert hosts[1].insights_system_id == "updated_by_scan" hosts[1].save.assert_called_once_with() + +def test_finish_job_fact_cache_with_bad_data(job, hosts, inventory, mocker, tmpdir): + fact_cache = str(tmpdir) + modified_times = {} + job.start_job_fact_cache(fact_cache, modified_times, 0) + + for h in hosts: + h.save = mocker.Mock() + + for h in hosts: + filepath = os.path.join(fact_cache, 'facts', h.name) + with open(filepath, 'w') as f: + f.write('not valid json!') + f.flush() + new_modification_time = time.time() + 3600 + os.utime(filepath, (new_modification_time, new_modification_time)) + + job.finish_job_fact_cache(fact_cache, modified_times) + + for h in hosts: + h.save.assert_not_called() + + +def test_finish_job_fact_cache_clear(job, hosts, inventory, mocker, tmpdir): + fact_cache = str(tmpdir) + modified_times = {} + job.start_job_fact_cache(fact_cache, modified_times, 0) + + for h in hosts: + h.save = mocker.Mock() + + os.remove(os.path.join(fact_cache, 'facts', hosts[1].name)) + job.finish_job_fact_cache(fact_cache, modified_times) + + for host in (hosts[0], hosts[2], hosts[3]): + host.save.assert_not_called() + assert host.ansible_facts == {"a": 1, "b": 2} + assert host.ansible_facts_modified is None + assert hosts[1].ansible_facts == {} + hosts[1].save.assert_called_once_with() diff --git a/awx/main/tests/unit/test_tasks.py b/awx/main/tests/unit/test_tasks.py index f2a1617b8f..1d00fdbb94 100644 --- a/awx/main/tests/unit/test_tasks.py +++ b/awx/main/tests/unit/test_tasks.py @@ -392,6 +392,42 @@ class TestGenericRun(TestJobExecution): tb = self.task.update_model.call_args[-1]['result_traceback'] assert 'a valid Python virtualenv does not exist at /venv/missing' in tb + def test_fact_cache_usage(self): + self.instance.use_fact_cache = True + + start_mock = mock.Mock() + patch = mock.patch.object(Job, 'start_job_fact_cache', start_mock) + self.patches.append(patch) + patch.start() + + self.task.run(self.pk) + call_args, _ = self.run_pexpect.call_args_list[0] + args, cwd, env, stdout = call_args + start_mock.assert_called_once() + tmpdir, _ = start_mock.call_args[0] + + assert env['ANSIBLE_CACHE_PLUGIN'] == 'jsonfile' + assert env['ANSIBLE_CACHE_PLUGIN_CONNECTION'] == os.path.join(tmpdir, 'facts') + + @pytest.mark.parametrize('task_env, ansible_library_env', [ + [{}, '/awx_devel/awx/plugins/library'], + [{'ANSIBLE_LIBRARY': '/foo/bar'}, '/foo/bar:/awx_devel/awx/plugins/library'], + ]) + def test_fact_cache_usage_with_ansible_library(self, task_env, ansible_library_env): + patch = mock.patch('awx.main.tasks.settings.AWX_TASK_ENV', task_env) + patch.start() + + self.instance.use_fact_cache = True + start_mock = mock.Mock() + patch = mock.patch.object(Job, 'start_job_fact_cache', start_mock) + self.patches.append(patch) + patch.start() + + self.task.run(self.pk) + call_args, _ = self.run_pexpect.call_args_list[0] + args, cwd, env, stdout = call_args + assert env['ANSIBLE_LIBRARY'] == ansible_library_env + class TestAdhocRun(TestJobExecution): diff --git a/awx/main/utils/handlers.py b/awx/main/utils/handlers.py index 20a30b499e..8ed1127292 100644 --- a/awx/main/utils/handlers.py +++ b/awx/main/utils/handlers.py @@ -292,6 +292,21 @@ class UDPHandler(BaseHandler): payload = _encode_payload_for_socket(payload) return self.socket.sendto(payload, (self._get_host(hostname_only=True), self.port or 0)) + @classmethod + def perform_test(cls, settings): + """ + Tests logging connectivity for the current logging settings. + """ + handler = cls.from_django_settings(settings) + handler.enabled_flag = True + handler.setFormatter(LogstashFormatter(settings_module=settings)) + logger = logging.getLogger(__file__) + fn, lno, func = logger.findCaller() + record = logger.makeRecord('awx', 10, fn, lno, + 'AWX Connection Test', tuple(), + None, func) + handler.emit(_encode_payload_for_socket(record)) + HANDLER_MAPPING = { 'https': BaseHTTPSHandler, diff --git a/awx/ui/client/src/inventories-hosts/shared/factories/set-status.factory.js b/awx/ui/client/src/inventories-hosts/shared/factories/set-status.factory.js index 3e7b76fb4a..aea5b5671f 100644 --- a/awx/ui/client/src/inventories-hosts/shared/factories/set-status.factory.js +++ b/awx/ui/client/src/inventories-hosts/shared/factories/set-status.factory.js @@ -60,7 +60,7 @@ export default html += "