From 443bdc1234682dd0004bae372078512fcf37cce9 Mon Sep 17 00:00:00 2001 From: Amol Gautam Date: Mon, 31 Jan 2022 06:04:23 -0500 Subject: [PATCH] Decoupled callback functions from BaseTask Class --- Removed all callback functions from 'jobs.py' and put them in a new file '/awx/main/tasks/callback.py' --- Modified Unit tests unit moved --- Moved 'update_model' from jobs.py to /awx/main/utils/update_model.py --- awx/main/tasks/callback.py | 257 +++++++++++++++++++++++++ awx/main/tasks/jobs.py | 305 +++++------------------------- awx/main/tasks/receptor.py | 8 +- awx/main/tests/unit/test_tasks.py | 173 +++++++++-------- awx/main/utils/update_model.py | 40 ++++ 5 files changed, 433 insertions(+), 350 deletions(-) create mode 100644 awx/main/tasks/callback.py create mode 100644 awx/main/utils/update_model.py diff --git a/awx/main/tasks/callback.py b/awx/main/tasks/callback.py new file mode 100644 index 0000000000..ccd9c39815 --- /dev/null +++ b/awx/main/tasks/callback.py @@ -0,0 +1,257 @@ +import json +import time +import logging +from collections import deque +import os +import stat + +# Django +from django.utils.timezone import now +from django.conf import settings +from django_guid.middleware import GuidMiddleware + +# AWX +from awx.main.redact import UriCleaner +from awx.main.constants import MINIMAL_EVENTS +from awx.main.utils.update_model import update_model +from awx.main.queue import CallbackQueueDispatcher + +logger = logging.getLogger('awx.main.tasks.callback') + + +class RunnerCallback: + event_data_key = 'job_id' + + def __init__(self, model=None): + self.parent_workflow_job_id = None + self.host_map = {} + self.guid = GuidMiddleware.get_guid() + self.job_created = None + self.recent_event_timings = deque(maxlen=settings.MAX_WEBSOCKET_EVENT_RATE) + self.dispatcher = CallbackQueueDispatcher() + self.safe_env = {} + self.event_ct = 0 + self.model = model + + def update_model(self, pk, _attempt=0, **updates): + return update_model(self.model, pk, _attempt=0, **updates) + + def event_handler(self, event_data): + # + # ⚠️ D-D-D-DANGER ZONE ⚠️ + # This method is called once for *every event* emitted by Ansible + # Runner as a playbook runs. That means that changes to the code in + # this method are _very_ likely to introduce performance regressions. + # + # Even if this function is made on average .05s slower, it can have + # devastating performance implications for playbooks that emit + # tens or hundreds of thousands of events. + # + # Proceed with caution! + # + """ + Ansible runner puts a parent_uuid on each event, no matter what the type. + AWX only saves the parent_uuid if the event is for a Job. + """ + # cache end_line locally for RunInventoryUpdate tasks + # which generate job events from two 'streams': + # ansible-inventory and the awx.main.commands.inventory_import + # logger + + if event_data.get(self.event_data_key, None): + if self.event_data_key != 'job_id': + event_data.pop('parent_uuid', None) + if self.parent_workflow_job_id: + event_data['workflow_job_id'] = self.parent_workflow_job_id + event_data['job_created'] = self.job_created + if self.host_map: + host = event_data.get('event_data', {}).get('host', '').strip() + if host: + event_data['host_name'] = host + if host in self.host_map: + event_data['host_id'] = self.host_map[host] + else: + event_data['host_name'] = '' + event_data['host_id'] = '' + if event_data.get('event') == 'playbook_on_stats': + event_data['host_map'] = self.host_map + + if isinstance(self, RunnerCallbackForProjectUpdate): + # need a better way to have this check. + # it's common for Ansible's SCM modules to print + # error messages on failure that contain the plaintext + # basic auth credentials (username + password) + # it's also common for the nested event data itself (['res']['...']) + # to contain unredacted text on failure + # this is a _little_ expensive to filter + # with regex, but project updates don't have many events, + # so it *should* have a negligible performance impact + task = event_data.get('event_data', {}).get('task_action') + try: + if task in ('git', 'svn'): + event_data_json = json.dumps(event_data) + event_data_json = UriCleaner.remove_sensitive(event_data_json) + event_data = json.loads(event_data_json) + except json.JSONDecodeError: + pass + + if 'event_data' in event_data: + event_data['event_data']['guid'] = self.guid + + # To prevent overwhelming the broadcast queue, skip some websocket messages + if self.recent_event_timings: + cpu_time = time.time() + first_window_time = self.recent_event_timings[0] + last_window_time = self.recent_event_timings[-1] + + if event_data.get('event') in MINIMAL_EVENTS: + should_emit = True # always send some types like playbook_on_stats + elif event_data.get('stdout') == '' and event_data['start_line'] == event_data['end_line']: + should_emit = False # exclude events with no output + else: + should_emit = any( + [ + # if 30the most recent websocket message was sent over 1 second ago + cpu_time - first_window_time > 1.0, + # if the very last websocket message came in over 1/30 seconds ago + self.recent_event_timings.maxlen * (cpu_time - last_window_time) > 1.0, + # if the queue is not yet full + len(self.recent_event_timings) != self.recent_event_timings.maxlen, + ] + ) + + if should_emit: + self.recent_event_timings.append(cpu_time) + else: + event_data.setdefault('event_data', {}) + event_data['skip_websocket_message'] = True + + elif self.recent_event_timings.maxlen: + self.recent_event_timings.append(time.time()) + + event_data.setdefault(self.event_data_key, self.instance.id) + self.dispatcher.dispatch(event_data) + self.event_ct += 1 + + ''' + Handle artifacts + ''' + if event_data.get('event_data', {}).get('artifact_data', {}): + self.instance.artifacts = event_data['event_data']['artifact_data'] + self.instance.save(update_fields=['artifacts']) + + return False + + def cancel_callback(self): + """ + Ansible runner callback to tell the job when/if it is canceled + """ + unified_job_id = self.instance.pk + self.instance.refresh_from_db() + if not self.instance: + logger.error('unified job {} was deleted while running, canceling'.format(unified_job_id)) + return True + if self.instance.cancel_flag or self.instance.status == 'canceled': + cancel_wait = (now() - self.instance.modified).seconds if self.instance.modified else 0 + if cancel_wait > 5: + logger.warn('Request to cancel {} took {} seconds to complete.'.format(self.instance.log_format, cancel_wait)) + return True + return False + + def finished_callback(self, runner_obj): + """ + Ansible runner callback triggered on finished run + """ + event_data = { + 'event': 'EOF', + 'final_counter': self.event_ct, + 'guid': self.guid, + } + event_data.setdefault(self.event_data_key, self.instance.id) + self.dispatcher.dispatch(event_data) + + def status_handler(self, status_data, runner_config): + """ + Ansible runner callback triggered on status transition + """ + if status_data['status'] == 'starting': + job_env = dict(runner_config.env) + ''' + Take the safe environment variables and overwrite + ''' + for k, v in self.safe_env.items(): + if k in job_env: + job_env[k] = v + from awx.main.signals import disable_activity_stream # Circular import + + with disable_activity_stream(): + self.instance = self.update_model(self.instance.pk, job_args=json.dumps(runner_config.command), job_cwd=runner_config.cwd, job_env=job_env) + elif status_data['status'] == 'failed': + # For encrypted ssh_key_data, ansible-runner worker will open and write the + # ssh_key_data to a named pipe. Then, once the podman container starts, ssh-agent will + # read from this named pipe so that the key can be used in ansible-playbook. + # Once the podman container exits, the named pipe is deleted. + # However, if the podman container fails to start in the first place, e.g. the image + # name is incorrect, then this pipe is not cleaned up. Eventually ansible-runner + # processor will attempt to write artifacts to the private data dir via unstream_dir, requiring + # that it open this named pipe. This leads to a hang. Thus, before any artifacts + # are written by the processor, it's important to remove this ssh_key_data pipe. + private_data_dir = self.instance.job_env.get('AWX_PRIVATE_DATA_DIR', None) + if private_data_dir: + key_data_file = os.path.join(private_data_dir, 'artifacts', str(self.instance.id), 'ssh_key_data') + if os.path.exists(key_data_file) and stat.S_ISFIFO(os.stat(key_data_file).st_mode): + os.remove(key_data_file) + elif status_data['status'] == 'error': + result_traceback = status_data.get('result_traceback', None) + if result_traceback: + from awx.main.signals import disable_activity_stream # Circular import + + with disable_activity_stream(): + self.instance = self.update_model(self.instance.pk, result_traceback=result_traceback) + + +class RunnerCallbackForProjectUpdate(RunnerCallback): + + event_data_key = 'project_update_id' + + def __init__(self, *args, **kwargs): + super(RunnerCallbackForProjectUpdate, self).__init__(*args, **kwargs) + self.playbook_new_revision = None + self.host_map = {} + + def event_handler(self, event_data): + super_return_value = super(RunnerCallbackForProjectUpdate, self).event_handler(event_data) + returned_data = event_data.get('event_data', {}) + if returned_data.get('task_action', '') == 'set_fact': + returned_facts = returned_data.get('res', {}).get('ansible_facts', {}) + if 'scm_version' in returned_facts: + self.playbook_new_revision = returned_facts['scm_version'] + return super_return_value + + +class RunnerCallbackForInventoryUpdate(RunnerCallback): + + event_data_key = 'inventory_update_id' + + def __init__(self, *args, **kwargs): + super(RunnerCallbackForInventoryUpdate, self).__init__(*args, **kwargs) + self.end_line = 0 + + def event_handler(self, event_data): + self.end_line = event_data['end_line'] + + return super(RunnerCallbackForInventoryUpdate, self).event_handler(event_data) + + +class RunnerCallbackForAdHocCommand(RunnerCallback): + + event_data_key = 'ad_hoc_command_id' + + def __init__(self, *args, **kwargs): + super(RunnerCallbackForAdHocCommand, self).__init__(*args, **kwargs) + self.host_map = {} + + +class RunnerCallbackForSystemJob(RunnerCallback): + + event_data_key = 'system_job_id' diff --git a/awx/main/tasks/jobs.py b/awx/main/tasks/jobs.py index caf1d4e558..f31eb7084f 100644 --- a/awx/main/tasks/jobs.py +++ b/awx/main/tasks/jobs.py @@ -1,5 +1,5 @@ # Python -from collections import deque, OrderedDict +from collections import OrderedDict from distutils.dir_util import copy_tree import errno import functools @@ -19,10 +19,8 @@ from uuid import uuid4 # Django -from django_guid.middleware import GuidMiddleware from django.conf import settings -from django.db import transaction, DatabaseError -from django.utils.timezone import now +from django.db import transaction # Runner @@ -40,11 +38,9 @@ from awx.main.dispatch import get_local_queuename from awx.main.constants import ( PRIVILEGE_ESCALATION_METHODS, STANDARD_INVENTORY_UPDATE_ENV, - MINIMAL_EVENTS, JOB_FOLDER_PREFIX, MAX_ISOLATED_PATH_COLON_DELIMITER, ) -from awx.main.redact import UriCleaner from awx.main.models import ( Instance, Inventory, @@ -61,7 +57,13 @@ from awx.main.models import ( SystemJobEvent, build_safe_env, ) -from awx.main.queue import CallbackQueueDispatcher +from awx.main.tasks.callback import ( + RunnerCallback, + RunnerCallbackForAdHocCommand, + RunnerCallbackForInventoryUpdate, + RunnerCallbackForProjectUpdate, + RunnerCallbackForSystemJob, +) from awx.main.tasks.receptor import AWXReceptorJob from awx.main.exceptions import AwxTaskError, PostRunError, ReceptorNodeNotFound from awx.main.utils.ansible import read_ansible_config @@ -76,6 +78,7 @@ from awx.main.utils.common import ( from awx.conf.license import get_license from awx.main.utils.handlers import SpecialInventoryHandler from awx.main.tasks.system import handle_success_and_failure_notifications, update_smart_memberships_for_inventory, update_inventory_computed_fields +from awx.main.utils.update_model import update_model from rest_framework.exceptions import PermissionDenied from django.utils.translation import ugettext_lazy as _ @@ -105,46 +108,14 @@ class BaseTask(object): model = None event_model = None abstract = True + callback_class = RunnerCallback def __init__(self): self.cleanup_paths = [] - self.parent_workflow_job_id = None - self.host_map = {} - self.guid = GuidMiddleware.get_guid() - self.job_created = None - self.recent_event_timings = deque(maxlen=settings.MAX_WEBSOCKET_EVENT_RATE) + self.runner_callback = self.callback_class(model=self.model) def update_model(self, pk, _attempt=0, **updates): - """Reload the model instance from the database and update the - given fields. - """ - try: - with transaction.atomic(): - # Retrieve the model instance. - instance = self.model.objects.get(pk=pk) - - # Update the appropriate fields and save the model - # instance, then return the new instance. - if updates: - update_fields = ['modified'] - for field, value in updates.items(): - setattr(instance, field, value) - update_fields.append(field) - if field == 'status': - update_fields.append('failed') - instance.save(update_fields=update_fields) - return instance - except DatabaseError as e: - # Log out the error to the debug logger. - logger.debug('Database error updating %s, retrying in 5 ' 'seconds (retry #%d): %s', self.model._meta.object_name, _attempt + 1, e) - - # Attempt to retry the update, assuming we haven't already - # tried too many times. - if _attempt < 5: - time.sleep(5) - return self.update_model(pk, _attempt=_attempt + 1, **updates) - else: - logger.error('Failed to update %s after %d retries.', self.model._meta.object_name, _attempt) + return update_model(self.model, pk, _attempt=0, **updates) def get_path_to(self, *args): """ @@ -350,7 +321,7 @@ class BaseTask(object): script_data = instance.inventory.get_script_data(**script_params) # maintain a list of host_name --> host_id # so we can associate emitted events to Host objects - self.host_map = {hostname: hv.pop('remote_tower_id', '') for hostname, hv in script_data.get('_meta', {}).get('hostvars', {}).items()} + self.runner_callback.host_map = {hostname: hv.pop('remote_tower_id', '') for hostname, hv in script_data.get('_meta', {}).get('hostvars', {}).items()} json_data = json.dumps(script_data) path = os.path.join(private_data_dir, 'inventory') fn = os.path.join(path, 'hosts') @@ -444,181 +415,6 @@ class BaseTask(object): instance.ansible_version = ansible_version_info instance.save(update_fields=['ansible_version']) - def event_handler(self, event_data): - # - # ⚠️ D-D-D-DANGER ZONE ⚠️ - # This method is called once for *every event* emitted by Ansible - # Runner as a playbook runs. That means that changes to the code in - # this method are _very_ likely to introduce performance regressions. - # - # Even if this function is made on average .05s slower, it can have - # devastating performance implications for playbooks that emit - # tens or hundreds of thousands of events. - # - # Proceed with caution! - # - """ - Ansible runner puts a parent_uuid on each event, no matter what the type. - AWX only saves the parent_uuid if the event is for a Job. - """ - # cache end_line locally for RunInventoryUpdate tasks - # which generate job events from two 'streams': - # ansible-inventory and the awx.main.commands.inventory_import - # logger - if isinstance(self, RunInventoryUpdate): - self.end_line = event_data['end_line'] - - if event_data.get(self.event_data_key, None): - if self.event_data_key != 'job_id': - event_data.pop('parent_uuid', None) - if self.parent_workflow_job_id: - event_data['workflow_job_id'] = self.parent_workflow_job_id - event_data['job_created'] = self.job_created - if self.host_map: - host = event_data.get('event_data', {}).get('host', '').strip() - if host: - event_data['host_name'] = host - if host in self.host_map: - event_data['host_id'] = self.host_map[host] - else: - event_data['host_name'] = '' - event_data['host_id'] = '' - if event_data.get('event') == 'playbook_on_stats': - event_data['host_map'] = self.host_map - - if isinstance(self, RunProjectUpdate): - # it's common for Ansible's SCM modules to print - # error messages on failure that contain the plaintext - # basic auth credentials (username + password) - # it's also common for the nested event data itself (['res']['...']) - # to contain unredacted text on failure - # this is a _little_ expensive to filter - # with regex, but project updates don't have many events, - # so it *should* have a negligible performance impact - task = event_data.get('event_data', {}).get('task_action') - try: - if task in ('git', 'svn'): - event_data_json = json.dumps(event_data) - event_data_json = UriCleaner.remove_sensitive(event_data_json) - event_data = json.loads(event_data_json) - except json.JSONDecodeError: - pass - - if 'event_data' in event_data: - event_data['event_data']['guid'] = self.guid - - # To prevent overwhelming the broadcast queue, skip some websocket messages - if self.recent_event_timings: - cpu_time = time.time() - first_window_time = self.recent_event_timings[0] - last_window_time = self.recent_event_timings[-1] - - if event_data.get('event') in MINIMAL_EVENTS: - should_emit = True # always send some types like playbook_on_stats - elif event_data.get('stdout') == '' and event_data['start_line'] == event_data['end_line']: - should_emit = False # exclude events with no output - else: - should_emit = any( - [ - # if 30the most recent websocket message was sent over 1 second ago - cpu_time - first_window_time > 1.0, - # if the very last websocket message came in over 1/30 seconds ago - self.recent_event_timings.maxlen * (cpu_time - last_window_time) > 1.0, - # if the queue is not yet full - len(self.recent_event_timings) != self.recent_event_timings.maxlen, - ] - ) - - if should_emit: - self.recent_event_timings.append(cpu_time) - else: - event_data.setdefault('event_data', {}) - event_data['skip_websocket_message'] = True - - elif self.recent_event_timings.maxlen: - self.recent_event_timings.append(time.time()) - - event_data.setdefault(self.event_data_key, self.instance.id) - self.dispatcher.dispatch(event_data) - self.event_ct += 1 - - ''' - Handle artifacts - ''' - if event_data.get('event_data', {}).get('artifact_data', {}): - self.instance.artifacts = event_data['event_data']['artifact_data'] - self.instance.save(update_fields=['artifacts']) - - return False - - def cancel_callback(self): - """ - Ansible runner callback to tell the job when/if it is canceled - """ - unified_job_id = self.instance.pk - self.instance = self.update_model(unified_job_id) - if not self.instance: - logger.error('unified job {} was deleted while running, canceling'.format(unified_job_id)) - return True - if self.instance.cancel_flag or self.instance.status == 'canceled': - cancel_wait = (now() - self.instance.modified).seconds if self.instance.modified else 0 - if cancel_wait > 5: - logger.warn('Request to cancel {} took {} seconds to complete.'.format(self.instance.log_format, cancel_wait)) - return True - return False - - def finished_callback(self, runner_obj): - """ - Ansible runner callback triggered on finished run - """ - event_data = { - 'event': 'EOF', - 'final_counter': self.event_ct, - 'guid': self.guid, - } - event_data.setdefault(self.event_data_key, self.instance.id) - self.dispatcher.dispatch(event_data) - - def status_handler(self, status_data, runner_config): - """ - Ansible runner callback triggered on status transition - """ - if status_data['status'] == 'starting': - job_env = dict(runner_config.env) - ''' - Take the safe environment variables and overwrite - ''' - for k, v in self.safe_env.items(): - if k in job_env: - job_env[k] = v - from awx.main.signals import disable_activity_stream # Circular import - - with disable_activity_stream(): - self.instance = self.update_model(self.instance.pk, job_args=json.dumps(runner_config.command), job_cwd=runner_config.cwd, job_env=job_env) - elif status_data['status'] == 'failed': - # For encrypted ssh_key_data, ansible-runner worker will open and write the - # ssh_key_data to a named pipe. Then, once the podman container starts, ssh-agent will - # read from this named pipe so that the key can be used in ansible-playbook. - # Once the podman container exits, the named pipe is deleted. - # However, if the podman container fails to start in the first place, e.g. the image - # name is incorrect, then this pipe is not cleaned up. Eventually ansible-runner - # processor will attempt to write artifacts to the private data dir via unstream_dir, requiring - # that it open this named pipe. This leads to a hang. Thus, before any artifacts - # are written by the processor, it's important to remove this ssh_key_data pipe. - private_data_dir = self.instance.job_env.get('AWX_PRIVATE_DATA_DIR', None) - if private_data_dir: - key_data_file = os.path.join(private_data_dir, 'artifacts', str(self.instance.id), 'ssh_key_data') - if os.path.exists(key_data_file) and stat.S_ISFIFO(os.stat(key_data_file).st_mode): - os.remove(key_data_file) - - elif status_data['status'] == 'error': - result_traceback = status_data.get('result_traceback', None) - if result_traceback: - from awx.main.signals import disable_activity_stream # Circular import - - with disable_activity_stream(): - self.instance = self.update_model(self.instance.pk, result_traceback=result_traceback) - @with_path_cleanup def run(self, pk, **kwargs): """ @@ -638,22 +434,15 @@ class BaseTask(object): status, rc = 'error', None extra_update_fields = {} fact_modification_times = {} - self.event_ct = 0 + self.runner_callback.event_ct = 0 ''' Needs to be an object property because status_handler uses it in a callback context ''' - self.safe_env = {} + self.safe_cred_env = {} private_data_dir = None - # store a reference to the parent workflow job (if any) so we can include - # it in event data JSON - if self.instance.spawned_by_workflow: - self.parent_workflow_job_id = self.instance.get_workflow_job().id - - self.job_created = str(self.instance.created) - try: self.instance.send_notification_templates("running") private_data_dir = self.build_private_data_dir(self.instance) @@ -685,7 +474,16 @@ class BaseTask(object): self.build_extra_vars_file(self.instance, private_data_dir) args = self.build_args(self.instance, private_data_dir, passwords) env = self.build_env(self.instance, private_data_dir, private_data_files=private_data_files) - self.safe_env = build_safe_env(env) + self.runner_callback.safe_env = build_safe_env(env) + + self.runner_callback.instance = self.instance + + # store a reference to the parent workflow job (if any) so we can include + # it in event data JSON + if self.instance.spawned_by_workflow: + self.runner_callback.parent_workflow_job_id = self.instance.get_workflow_job().id + + self.runner_callback.job_created = str(self.instance.created) credentials = self.build_credentials_list(self.instance) @@ -693,7 +491,7 @@ class BaseTask(object): if credential: credential.credential_type.inject_credential(credential, env, self.safe_cred_env, args, private_data_dir) - self.safe_env.update(self.safe_cred_env) + self.runner_callback.safe_env.update(self.safe_cred_env) self.write_args_file(private_data_dir, args) @@ -739,16 +537,14 @@ class BaseTask(object): if not params[v]: del params[v] - self.dispatcher = CallbackQueueDispatcher() - self.instance.log_lifecycle("running_playbook") if isinstance(self.instance, SystemJob): res = ansible_runner.interface.run( project_dir=settings.BASE_DIR, - event_handler=self.event_handler, - finished_callback=self.finished_callback, - status_handler=self.status_handler, - cancel_callback=self.cancel_callback, + event_handler=self.runner_callback.event_handler, + finished_callback=self.runner_callback.finished_callback, + status_handler=self.runner_callback.status_handler, + cancel_callback=self.runner_callback.cancel_callback, **params, ) else: @@ -779,7 +575,7 @@ class BaseTask(object): extra_update_fields['result_traceback'] = traceback.format_exc() logger.exception('%s Exception occurred while running task', self.instance.log_format) finally: - logger.debug('%s finished running, producing %s events.', self.instance.log_format, self.event_ct) + logger.debug('%s finished running, producing %s events.', self.instance.log_format, self.runner_callback.event_ct) try: self.post_run_hook(self.instance, status) @@ -793,7 +589,7 @@ class BaseTask(object): logger.exception('{} Post run hook errored.'.format(self.instance.log_format)) self.instance = self.update_model(pk) - self.instance = self.update_model(pk, status=status, emitted_events=self.event_ct, **extra_update_fields) + self.instance = self.update_model(pk, status=status, emitted_events=self.runner_callback.event_ct, **extra_update_fields) try: self.final_run_hook(self.instance, status, private_data_dir, fact_modification_times) @@ -816,7 +612,6 @@ class RunJob(BaseTask): model = Job event_model = JobEvent - event_data_key = 'job_id' def build_private_data(self, job, private_data_dir): """ @@ -1215,22 +1010,13 @@ class RunProjectUpdate(BaseTask): model = ProjectUpdate event_model = ProjectUpdateEvent - event_data_key = 'project_update_id' + callback_class = RunnerCallbackForProjectUpdate def __init__(self, *args, job_private_data_dir=None, **kwargs): super(RunProjectUpdate, self).__init__(*args, **kwargs) - self.playbook_new_revision = None self.original_branch = None self.job_private_data_dir = job_private_data_dir - def event_handler(self, event_data): - super(RunProjectUpdate, self).event_handler(event_data) - returned_data = event_data.get('event_data', {}) - if returned_data.get('task_action', '') == 'set_fact': - returned_facts = returned_data.get('res', {}).get('ansible_facts', {}) - if 'scm_version' in returned_facts: - self.playbook_new_revision = returned_facts['scm_version'] - def build_private_data(self, project_update, private_data_dir): """ Return SSH private key data needed for this project update. @@ -1631,8 +1417,8 @@ class RunProjectUpdate(BaseTask): super(RunProjectUpdate, self).post_run_hook(instance, status) # To avoid hangs, very important to release lock even if errors happen here try: - if self.playbook_new_revision: - instance.scm_revision = self.playbook_new_revision + if self.runner_callback.playbook_new_revision: + instance.scm_revision = self.runner_callback.playbook_new_revision instance.save(update_fields=['scm_revision']) # Roles and collection folders copy to durable cache @@ -1672,8 +1458,8 @@ class RunProjectUpdate(BaseTask): 'failed', 'canceled', ): - if self.playbook_new_revision: - p.scm_revision = self.playbook_new_revision + if self.runner_callback.playbook_new_revision: + p.scm_revision = self.runner_callback.playbook_new_revision else: if status == 'successful': logger.error("{} Could not find scm revision in check".format(instance.log_format)) @@ -1709,7 +1495,7 @@ class RunInventoryUpdate(BaseTask): model = InventoryUpdate event_model = InventoryUpdateEvent - event_data_key = 'inventory_update_id' + callback_class = RunnerCallbackForInventoryUpdate def build_private_data(self, inventory_update, private_data_dir): """ @@ -1932,6 +1718,7 @@ class RunInventoryUpdate(BaseTask): if status != 'successful': return # nothing to save, step out of the way to allow error reporting + inventory_update.refresh_from_db() private_data_dir = inventory_update.job_env['AWX_PRIVATE_DATA_DIR'] expected_output = os.path.join(private_data_dir, 'artifacts', str(inventory_update.id), 'output.json') with open(expected_output) as f: @@ -1966,13 +1753,13 @@ class RunInventoryUpdate(BaseTask): options['verbosity'] = inventory_update.verbosity handler = SpecialInventoryHandler( - self.event_handler, - self.cancel_callback, + self.runner_callback.event_handler, + self.runner_callback.cancel_callback, verbosity=inventory_update.verbosity, job_timeout=self.get_instance_timeout(self.instance), start_time=inventory_update.started, - counter=self.event_ct, - initial_line=self.end_line, + counter=self.runner_callback.event_ct, + initial_line=self.runner_callback.end_line, ) inv_logger = logging.getLogger('awx.main.commands.inventory_import') formatter = inv_logger.handlers[0].formatter @@ -2006,7 +1793,7 @@ class RunAdHocCommand(BaseTask): model = AdHocCommand event_model = AdHocCommandEvent - event_data_key = 'ad_hoc_command_id' + callback_class = RunnerCallbackForAdHocCommand def build_private_data(self, ad_hoc_command, private_data_dir): """ @@ -2164,7 +1951,7 @@ class RunSystemJob(BaseTask): model = SystemJob event_model = SystemJobEvent - event_data_key = 'system_job_id' + callback_class = RunnerCallbackForSystemJob def build_execution_environment_params(self, system_job, private_data_dir): return {} diff --git a/awx/main/tasks/receptor.py b/awx/main/tasks/receptor.py index 7bc5c1110e..4cb0a543a2 100644 --- a/awx/main/tasks/receptor.py +++ b/awx/main/tasks/receptor.py @@ -411,9 +411,9 @@ class AWXReceptorJob: streamer='process', quiet=True, _input=resultfile, - event_handler=self.task.event_handler, - finished_callback=self.task.finished_callback, - status_handler=self.task.status_handler, + event_handler=self.task.runner_callback.event_handler, + finished_callback=self.task.runner_callback.finished_callback, + status_handler=self.task.runner_callback.status_handler, **self.runner_params, ) @@ -458,7 +458,7 @@ class AWXReceptorJob: if processor_future.done(): return processor_future.result() - if self.task.cancel_callback(): + if self.task.runner_callback.cancel_callback(): result = namedtuple('result', ['status', 'rc']) return result('canceled', 1) diff --git a/awx/main/tests/unit/test_tasks.py b/awx/main/tests/unit/test_tasks.py index 76ec2ec337..f2d617abb8 100644 --- a/awx/main/tests/unit/test_tasks.py +++ b/awx/main/tests/unit/test_tasks.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- - import configparser import json import os @@ -34,7 +33,7 @@ from awx.main.models import ( ) from awx.main.models.credential import HIDDEN_PASSWORD, ManagedCredentialType -from awx.main import tasks +from awx.main.tasks import jobs, system from awx.main.utils import encrypt_field, encrypt_value from awx.main.utils.safe_yaml import SafeLoader from awx.main.utils.execution_environments import CONTAINER_ROOT, to_host_path @@ -113,12 +112,12 @@ def adhoc_update_model_wrapper(adhoc_job): def test_send_notifications_not_list(): with pytest.raises(TypeError): - tasks.system.send_notifications(None) + system.send_notifications(None) def test_send_notifications_job_id(mocker): with mocker.patch('awx.main.models.UnifiedJob.objects.get'): - tasks.system.send_notifications([], job_id=1) + system.send_notifications([], job_id=1) assert UnifiedJob.objects.get.called assert UnifiedJob.objects.get.called_with(id=1) @@ -127,7 +126,7 @@ def test_work_success_callback_missing_job(): task_data = {'type': 'project_update', 'id': 9999} with mock.patch('django.db.models.query.QuerySet.get') as get_mock: get_mock.side_effect = ProjectUpdate.DoesNotExist() - assert tasks.system.handle_work_success(task_data) is None + assert system.handle_work_success(task_data) is None @mock.patch('awx.main.models.UnifiedJob.objects.get') @@ -138,7 +137,7 @@ def test_send_notifications_list(mock_notifications_filter, mock_job_get, mocker mock_notifications = [mocker.MagicMock(spec=Notification, subject="test", body={'hello': 'world'})] mock_notifications_filter.return_value = mock_notifications - tasks.system.send_notifications([1, 2], job_id=1) + system.send_notifications([1, 2], job_id=1) assert Notification.objects.filter.call_count == 1 assert mock_notifications[0].status == "successful" assert mock_notifications[0].save.called @@ -168,7 +167,7 @@ def test_safe_env_returns_new_copy(): @pytest.mark.parametrize("source,expected", [(None, True), (False, False), (True, True)]) def test_openstack_client_config_generation(mocker, source, expected, private_data_dir): - update = tasks.jobs.RunInventoryUpdate() + update = jobs.RunInventoryUpdate() credential_type = CredentialType.defaults['openstack']() inputs = { 'host': 'https://keystone.openstack.example.org', @@ -208,7 +207,7 @@ def test_openstack_client_config_generation(mocker, source, expected, private_da @pytest.mark.parametrize("source,expected", [(None, True), (False, False), (True, True)]) def test_openstack_client_config_generation_with_project_domain_name(mocker, source, expected, private_data_dir): - update = tasks.jobs.RunInventoryUpdate() + update = jobs.RunInventoryUpdate() credential_type = CredentialType.defaults['openstack']() inputs = { 'host': 'https://keystone.openstack.example.org', @@ -250,7 +249,7 @@ def test_openstack_client_config_generation_with_project_domain_name(mocker, sou @pytest.mark.parametrize("source,expected", [(None, True), (False, False), (True, True)]) def test_openstack_client_config_generation_with_region(mocker, source, expected, private_data_dir): - update = tasks.jobs.RunInventoryUpdate() + update = jobs.RunInventoryUpdate() credential_type = CredentialType.defaults['openstack']() inputs = { 'host': 'https://keystone.openstack.example.org', @@ -294,7 +293,7 @@ def test_openstack_client_config_generation_with_region(mocker, source, expected @pytest.mark.parametrize("source,expected", [(False, False), (True, True)]) def test_openstack_client_config_generation_with_private_source_vars(mocker, source, expected, private_data_dir): - update = tasks.jobs.RunInventoryUpdate() + update = jobs.RunInventoryUpdate() credential_type = CredentialType.defaults['openstack']() inputs = { 'host': 'https://keystone.openstack.example.org', @@ -357,7 +356,7 @@ class TestExtraVarSanitation(TestJobExecution): job.created_by = User(pk=123, username='angry-spud') job.inventory = Inventory(pk=123, name='example-inv') - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.build_extra_vars_file(job, private_data_dir) fd = open(os.path.join(private_data_dir, 'env', 'extravars')) @@ -393,7 +392,7 @@ class TestExtraVarSanitation(TestJobExecution): def test_launchtime_vars_unsafe(self, job, private_data_dir): job.extra_vars = json.dumps({'msg': self.UNSAFE}) - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.build_extra_vars_file(job, private_data_dir) @@ -404,7 +403,7 @@ class TestExtraVarSanitation(TestJobExecution): def test_nested_launchtime_vars_unsafe(self, job, private_data_dir): job.extra_vars = json.dumps({'msg': {'a': [self.UNSAFE]}}) - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.build_extra_vars_file(job, private_data_dir) @@ -415,7 +414,7 @@ class TestExtraVarSanitation(TestJobExecution): def test_allowed_jt_extra_vars(self, job, private_data_dir): job.job_template.extra_vars = job.extra_vars = json.dumps({'msg': self.UNSAFE}) - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.build_extra_vars_file(job, private_data_dir) @@ -427,7 +426,7 @@ class TestExtraVarSanitation(TestJobExecution): def test_nested_allowed_vars(self, job, private_data_dir): job.extra_vars = json.dumps({'msg': {'a': {'b': [self.UNSAFE]}}}) job.job_template.extra_vars = job.extra_vars - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.build_extra_vars_file(job, private_data_dir) @@ -441,7 +440,7 @@ class TestExtraVarSanitation(TestJobExecution): # `other_var=SENSITIVE` job.job_template.extra_vars = json.dumps({'msg': self.UNSAFE}) job.extra_vars = json.dumps({'msg': 'other-value', 'other_var': self.UNSAFE}) - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.build_extra_vars_file(job, private_data_dir) @@ -456,7 +455,7 @@ class TestExtraVarSanitation(TestJobExecution): def test_overwritten_jt_extra_vars(self, job, private_data_dir): job.job_template.extra_vars = json.dumps({'msg': 'SAFE'}) job.extra_vars = json.dumps({'msg': self.UNSAFE}) - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.build_extra_vars_file(job, private_data_dir) @@ -472,7 +471,7 @@ class TestGenericRun: job.websocket_emit_status = mock.Mock() job.execution_environment = execution_environment - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.instance = job task.update_model = mock.Mock(return_value=job) task.model.objects.get = mock.Mock(return_value=job) @@ -494,7 +493,7 @@ class TestGenericRun: job.send_notification_templates = mock.Mock() job.execution_environment = execution_environment - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.instance = job task.update_model = mock.Mock(wraps=update_model_wrapper) task.model.objects.get = mock.Mock(return_value=job) @@ -508,45 +507,45 @@ class TestGenericRun: assert c in task.update_model.call_args_list def test_event_count(self): - task = tasks.jobs.RunJob() - task.dispatcher = mock.MagicMock() - task.instance = Job() - task.event_ct = 0 + task = jobs.RunJob() + task.runner_callback.dispatcher = mock.MagicMock() + task.runner_callback.instance = Job() + task.runner_callback.event_ct = 0 event_data = {} - [task.event_handler(event_data) for i in range(20)] - assert 20 == task.event_ct + [task.runner_callback.event_handler(event_data) for i in range(20)] + assert 20 == task.runner_callback.event_ct def test_finished_callback_eof(self): - task = tasks.jobs.RunJob() - task.dispatcher = mock.MagicMock() - task.instance = Job(pk=1, id=1) - task.event_ct = 17 - task.finished_callback(None) - task.dispatcher.dispatch.assert_called_with({'event': 'EOF', 'final_counter': 17, 'job_id': 1, 'guid': None}) + task = jobs.RunJob() + task.runner_callback.dispatcher = mock.MagicMock() + task.runner_callback.instance = Job(pk=1, id=1) + task.runner_callback.event_ct = 17 + task.runner_callback.finished_callback(None) + task.runner_callback.dispatcher.dispatch.assert_called_with({'event': 'EOF', 'final_counter': 17, 'job_id': 1, 'guid': None}) def test_save_job_metadata(self, job, update_model_wrapper): class MockMe: pass - task = tasks.jobs.RunJob() - task.instance = job - task.safe_env = {'secret_key': 'redacted_value'} - task.update_model = mock.Mock(wraps=update_model_wrapper) + task = jobs.RunJob() + task.runner_callback.instance = job + task.runner_callback.safe_env = {'secret_key': 'redacted_value'} + task.runner_callback.update_model = mock.Mock(wraps=update_model_wrapper) runner_config = MockMe() runner_config.command = {'foo': 'bar'} runner_config.cwd = '/foobar' runner_config.env = {'switch': 'blade', 'foot': 'ball', 'secret_key': 'secret_value'} - task.status_handler({'status': 'starting'}, runner_config) + task.runner_callback.status_handler({'status': 'starting'}, runner_config) - task.update_model.assert_called_with( + task.runner_callback.update_model.assert_called_with( 1, job_args=json.dumps({'foo': 'bar'}), job_cwd='/foobar', job_env={'switch': 'blade', 'foot': 'ball', 'secret_key': 'redacted_value'} ) def test_created_by_extra_vars(self): job = Job(created_by=User(pk=123, username='angry-spud')) - task = tasks.jobs.RunJob() + task = jobs.RunJob() task._write_extra_vars_file = mock.Mock() task.build_extra_vars_file(job, None) @@ -563,7 +562,7 @@ class TestGenericRun: job.extra_vars = json.dumps({'super_secret': encrypt_value('CLASSIFIED', pk=None)}) job.survey_passwords = {'super_secret': '$encrypted$'} - task = tasks.jobs.RunJob() + task = jobs.RunJob() task._write_extra_vars_file = mock.Mock() task.build_extra_vars_file(job, None) @@ -576,7 +575,7 @@ class TestGenericRun: job = Job(project=Project(), inventory=Inventory()) job.execution_environment = execution_environment - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.instance = job task._write_extra_vars_file = mock.Mock() @@ -595,7 +594,7 @@ class TestAdhocRun(TestJobExecution): adhoc_job.websocket_emit_status = mock.Mock() adhoc_job.send_notification_templates = mock.Mock() - task = tasks.jobs.RunAdHocCommand() + task = jobs.RunAdHocCommand() task.update_model = mock.Mock(wraps=adhoc_update_model_wrapper) task.model.objects.get = mock.Mock(return_value=adhoc_job) task.build_inventory = mock.Mock() @@ -619,7 +618,7 @@ class TestAdhocRun(TestJobExecution): }) #adhoc_job.websocket_emit_status = mock.Mock() - task = tasks.jobs.RunAdHocCommand() + task = jobs.RunAdHocCommand() #task.update_model = mock.Mock(wraps=adhoc_update_model_wrapper) #task.build_inventory = mock.Mock(return_value='/tmp/something.inventory') task._write_extra_vars_file = mock.Mock() @@ -634,7 +633,7 @@ class TestAdhocRun(TestJobExecution): def test_created_by_extra_vars(self): adhoc_job = AdHocCommand(created_by=User(pk=123, username='angry-spud')) - task = tasks.jobs.RunAdHocCommand() + task = jobs.RunAdHocCommand() task._write_extra_vars_file = mock.Mock() task.build_extra_vars_file(adhoc_job, None) @@ -693,7 +692,7 @@ class TestJobCredentials(TestJobExecution): } def test_username_jinja_usage(self, job, private_data_dir): - task = tasks.jobs.RunJob() + task = jobs.RunJob() ssh = CredentialType.defaults['ssh']() credential = Credential(pk=1, credential_type=ssh, inputs={'username': '{{ ansible_ssh_pass }}'}) job.credentials.add(credential) @@ -704,7 +703,7 @@ class TestJobCredentials(TestJobExecution): @pytest.mark.parametrize("flag", ['become_username', 'become_method']) def test_become_jinja_usage(self, job, private_data_dir, flag): - task = tasks.jobs.RunJob() + task = jobs.RunJob() ssh = CredentialType.defaults['ssh']() credential = Credential(pk=1, credential_type=ssh, inputs={'username': 'joe', flag: '{{ ansible_ssh_pass }}'}) job.credentials.add(credential) @@ -715,7 +714,7 @@ class TestJobCredentials(TestJobExecution): assert 'Jinja variables are not allowed' in str(e.value) def test_ssh_passwords(self, job, private_data_dir, field, password_name, expected_flag): - task = tasks.jobs.RunJob() + task = jobs.RunJob() ssh = CredentialType.defaults['ssh']() credential = Credential(pk=1, credential_type=ssh, inputs={'username': 'bob', field: 'secret'}) credential.inputs[field] = encrypt_field(credential, field) @@ -732,7 +731,7 @@ class TestJobCredentials(TestJobExecution): assert expected_flag in ' '.join(args) def test_net_ssh_key_unlock(self, job): - task = tasks.jobs.RunJob() + task = jobs.RunJob() net = CredentialType.defaults['net']() credential = Credential(pk=1, credential_type=net, inputs={'ssh_key_unlock': 'secret'}) credential.inputs['ssh_key_unlock'] = encrypt_field(credential, 'ssh_key_unlock') @@ -745,7 +744,7 @@ class TestJobCredentials(TestJobExecution): assert 'secret' in expect_passwords.values() def test_net_first_ssh_key_unlock_wins(self, job): - task = tasks.jobs.RunJob() + task = jobs.RunJob() for i in range(3): net = CredentialType.defaults['net']() credential = Credential(pk=i, credential_type=net, inputs={'ssh_key_unlock': 'secret{}'.format(i)}) @@ -759,7 +758,7 @@ class TestJobCredentials(TestJobExecution): assert 'secret0' in expect_passwords.values() def test_prefer_ssh_over_net_ssh_key_unlock(self, job): - task = tasks.jobs.RunJob() + task = jobs.RunJob() net = CredentialType.defaults['net']() net_credential = Credential(pk=1, credential_type=net, inputs={'ssh_key_unlock': 'net_secret'}) net_credential.inputs['ssh_key_unlock'] = encrypt_field(net_credential, 'ssh_key_unlock') @@ -778,7 +777,7 @@ class TestJobCredentials(TestJobExecution): assert 'ssh_secret' in expect_passwords.values() def test_vault_password(self, private_data_dir, job): - task = tasks.jobs.RunJob() + task = jobs.RunJob() vault = CredentialType.defaults['vault']() credential = Credential(pk=1, credential_type=vault, inputs={'vault_password': 'vault-me'}) credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password') @@ -793,7 +792,7 @@ class TestJobCredentials(TestJobExecution): assert '--ask-vault-pass' in ' '.join(args) def test_vault_password_ask(self, private_data_dir, job): - task = tasks.jobs.RunJob() + task = jobs.RunJob() vault = CredentialType.defaults['vault']() credential = Credential(pk=1, credential_type=vault, inputs={'vault_password': 'ASK'}) credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password') @@ -808,7 +807,7 @@ class TestJobCredentials(TestJobExecution): assert '--ask-vault-pass' in ' '.join(args) def test_multi_vault_password(self, private_data_dir, job): - task = tasks.jobs.RunJob() + task = jobs.RunJob() vault = CredentialType.defaults['vault']() for i, label in enumerate(['dev', 'prod', 'dotted.name']): credential = Credential(pk=i, credential_type=vault, inputs={'vault_password': 'pass@{}'.format(label), 'vault_id': label}) @@ -831,7 +830,7 @@ class TestJobCredentials(TestJobExecution): assert '--vault-id dotted.name@prompt' in ' '.join(args) def test_multi_vault_id_conflict(self, job): - task = tasks.jobs.RunJob() + task = jobs.RunJob() vault = CredentialType.defaults['vault']() for i in range(2): credential = Credential(pk=i, credential_type=vault, inputs={'vault_password': 'some-pass', 'vault_id': 'conflict'}) @@ -844,7 +843,7 @@ class TestJobCredentials(TestJobExecution): assert 'multiple vault credentials were specified with --vault-id' in str(e.value) def test_multi_vault_password_ask(self, private_data_dir, job): - task = tasks.jobs.RunJob() + task = jobs.RunJob() vault = CredentialType.defaults['vault']() for i, label in enumerate(['dev', 'prod']): credential = Credential(pk=i, credential_type=vault, inputs={'vault_password': 'ASK', 'vault_id': label}) @@ -999,7 +998,7 @@ class TestJobCredentials(TestJobExecution): assert safe_env['VMWARE_PASSWORD'] == HIDDEN_PASSWORD def test_openstack_credentials(self, private_data_dir, job): - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.instance = job openstack = CredentialType.defaults['openstack']() credential = Credential( @@ -1067,7 +1066,7 @@ class TestJobCredentials(TestJobExecution): ], ) def test_net_credentials(self, authorize, expected_authorize, job, private_data_dir): - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.instance = job net = CredentialType.defaults['net']() inputs = {'username': 'bob', 'password': 'secret', 'ssh_key_data': self.EXAMPLE_PRIVATE_KEY, 'authorize_password': 'authorizeme'} @@ -1135,7 +1134,7 @@ class TestJobCredentials(TestJobExecution): assert env['TURBO_BUTTON'] == str(True) def test_custom_environment_injectors_with_reserved_env_var(self, private_data_dir, job): - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.instance = job some_cloud = CredentialType( kind='cloud', @@ -1171,7 +1170,7 @@ class TestJobCredentials(TestJobExecution): assert safe_env['MY_CLOUD_PRIVATE_VAR'] == HIDDEN_PASSWORD def test_custom_environment_injectors_with_extra_vars(self, private_data_dir, job): - task = tasks.jobs.RunJob() + task = jobs.RunJob() some_cloud = CredentialType( kind='cloud', name='SomeCloud', @@ -1190,7 +1189,7 @@ class TestJobCredentials(TestJobExecution): assert hasattr(extra_vars["api_token"], '__UNSAFE__') def test_custom_environment_injectors_with_boolean_extra_vars(self, job, private_data_dir): - task = tasks.jobs.RunJob() + task = jobs.RunJob() some_cloud = CredentialType( kind='cloud', name='SomeCloud', @@ -1209,7 +1208,7 @@ class TestJobCredentials(TestJobExecution): return ['successful', 0] def test_custom_environment_injectors_with_complicated_boolean_template(self, job, private_data_dir): - task = tasks.jobs.RunJob() + task = jobs.RunJob() some_cloud = CredentialType( kind='cloud', name='SomeCloud', @@ -1230,7 +1229,7 @@ class TestJobCredentials(TestJobExecution): """ extra_vars that contain secret field values should be censored in the DB """ - task = tasks.jobs.RunJob() + task = jobs.RunJob() some_cloud = CredentialType( kind='cloud', name='SomeCloud', @@ -1335,7 +1334,7 @@ class TestJobCredentials(TestJobExecution): def test_awx_task_env(self, settings, private_data_dir, job): settings.AWX_TASK_ENV = {'FOO': 'BAR'} - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.instance = job env = task.build_env(job, private_data_dir) @@ -1362,7 +1361,7 @@ class TestProjectUpdateGalaxyCredentials(TestJobExecution): def test_galaxy_credentials_ignore_certs(self, private_data_dir, project_update, ignore): settings.GALAXY_IGNORE_CERTS = ignore - task = tasks.jobs.RunProjectUpdate() + task = jobs.RunProjectUpdate() task.instance = project_update env = task.build_env(project_update, private_data_dir) if ignore: @@ -1371,7 +1370,7 @@ class TestProjectUpdateGalaxyCredentials(TestJobExecution): assert 'ANSIBLE_GALAXY_IGNORE' not in env def test_galaxy_credentials_empty(self, private_data_dir, project_update): - class RunProjectUpdate(tasks.jobs.RunProjectUpdate): + class RunProjectUpdate(jobs.RunProjectUpdate): __vars__ = {} def _write_extra_vars_file(self, private_data_dir, extra_vars, *kw): @@ -1390,7 +1389,7 @@ class TestProjectUpdateGalaxyCredentials(TestJobExecution): assert not k.startswith('ANSIBLE_GALAXY_SERVER') def test_single_public_galaxy(self, private_data_dir, project_update): - class RunProjectUpdate(tasks.jobs.RunProjectUpdate): + class RunProjectUpdate(jobs.RunProjectUpdate): __vars__ = {} def _write_extra_vars_file(self, private_data_dir, extra_vars, *kw): @@ -1439,7 +1438,7 @@ class TestProjectUpdateGalaxyCredentials(TestJobExecution): ) project_update.project.organization.galaxy_credentials.add(public_galaxy) project_update.project.organization.galaxy_credentials.add(rh) - task = tasks.jobs.RunProjectUpdate() + task = jobs.RunProjectUpdate() task.instance = project_update env = task.build_env(project_update, private_data_dir) assert sorted([(k, v) for k, v in env.items() if k.startswith('ANSIBLE_GALAXY')]) == [ @@ -1481,7 +1480,7 @@ class TestProjectUpdateCredentials(TestJobExecution): } def test_username_and_password_auth(self, project_update, scm_type): - task = tasks.jobs.RunProjectUpdate() + task = jobs.RunProjectUpdate() ssh = CredentialType.defaults['ssh']() project_update.scm_type = scm_type project_update.credential = Credential(pk=1, credential_type=ssh, inputs={'username': 'bob', 'password': 'secret'}) @@ -1495,7 +1494,7 @@ class TestProjectUpdateCredentials(TestJobExecution): assert 'secret' in expect_passwords.values() def test_ssh_key_auth(self, project_update, scm_type): - task = tasks.jobs.RunProjectUpdate() + task = jobs.RunProjectUpdate() ssh = CredentialType.defaults['ssh']() project_update.scm_type = scm_type project_update.credential = Credential(pk=1, credential_type=ssh, inputs={'username': 'bob', 'ssh_key_data': self.EXAMPLE_PRIVATE_KEY}) @@ -1509,7 +1508,7 @@ class TestProjectUpdateCredentials(TestJobExecution): def test_awx_task_env(self, project_update, settings, private_data_dir, scm_type, execution_environment): project_update.execution_environment = execution_environment settings.AWX_TASK_ENV = {'FOO': 'BAR'} - task = tasks.jobs.RunProjectUpdate() + task = jobs.RunProjectUpdate() task.instance = project_update project_update.scm_type = scm_type @@ -1524,7 +1523,7 @@ class TestInventoryUpdateCredentials(TestJobExecution): return InventoryUpdate(pk=1, execution_environment=execution_environment, inventory_source=InventorySource(pk=1, inventory=Inventory(pk=1))) def test_source_without_credential(self, mocker, inventory_update, private_data_dir): - task = tasks.jobs.RunInventoryUpdate() + task = jobs.RunInventoryUpdate() task.instance = inventory_update inventory_update.source = 'ec2' inventory_update.get_cloud_credential = mocker.Mock(return_value=None) @@ -1537,7 +1536,7 @@ class TestInventoryUpdateCredentials(TestJobExecution): assert 'AWS_SECRET_ACCESS_KEY' not in env def test_ec2_source(self, private_data_dir, inventory_update, mocker): - task = tasks.jobs.RunInventoryUpdate() + task = jobs.RunInventoryUpdate() task.instance = inventory_update aws = CredentialType.defaults['aws']() inventory_update.source = 'ec2' @@ -1561,7 +1560,7 @@ class TestInventoryUpdateCredentials(TestJobExecution): assert safe_env['AWS_SECRET_ACCESS_KEY'] == HIDDEN_PASSWORD def test_vmware_source(self, inventory_update, private_data_dir, mocker): - task = tasks.jobs.RunInventoryUpdate() + task = jobs.RunInventoryUpdate() task.instance = inventory_update vmware = CredentialType.defaults['vmware']() inventory_update.source = 'vmware' @@ -1589,7 +1588,7 @@ class TestInventoryUpdateCredentials(TestJobExecution): env["VMWARE_VALIDATE_CERTS"] == "False", def test_azure_rm_source_with_tenant(self, private_data_dir, inventory_update, mocker): - task = tasks.jobs.RunInventoryUpdate() + task = jobs.RunInventoryUpdate() task.instance = inventory_update azure_rm = CredentialType.defaults['azure_rm']() inventory_update.source = 'azure_rm' @@ -1625,7 +1624,7 @@ class TestInventoryUpdateCredentials(TestJobExecution): assert safe_env['AZURE_SECRET'] == HIDDEN_PASSWORD def test_azure_rm_source_with_password(self, private_data_dir, inventory_update, mocker): - task = tasks.jobs.RunInventoryUpdate() + task = jobs.RunInventoryUpdate() task.instance = inventory_update azure_rm = CredentialType.defaults['azure_rm']() inventory_update.source = 'azure_rm' @@ -1654,7 +1653,7 @@ class TestInventoryUpdateCredentials(TestJobExecution): assert safe_env['AZURE_PASSWORD'] == HIDDEN_PASSWORD def test_gce_source(self, inventory_update, private_data_dir, mocker): - task = tasks.jobs.RunInventoryUpdate() + task = jobs.RunInventoryUpdate() task.instance = inventory_update gce = CredentialType.defaults['gce']() inventory_update.source = 'gce' @@ -1684,7 +1683,7 @@ class TestInventoryUpdateCredentials(TestJobExecution): assert json_data['project_id'] == 'some-project' def test_openstack_source(self, inventory_update, private_data_dir, mocker): - task = tasks.jobs.RunInventoryUpdate() + task = jobs.RunInventoryUpdate() task.instance = inventory_update openstack = CredentialType.defaults['openstack']() inventory_update.source = 'openstack' @@ -1724,7 +1723,7 @@ class TestInventoryUpdateCredentials(TestJobExecution): ) def test_satellite6_source(self, inventory_update, private_data_dir, mocker): - task = tasks.jobs.RunInventoryUpdate() + task = jobs.RunInventoryUpdate() task.instance = inventory_update satellite6 = CredentialType.defaults['satellite6']() inventory_update.source = 'satellite6' @@ -1747,7 +1746,7 @@ class TestInventoryUpdateCredentials(TestJobExecution): assert safe_env["FOREMAN_PASSWORD"] == HIDDEN_PASSWORD def test_insights_source(self, inventory_update, private_data_dir, mocker): - task = tasks.jobs.RunInventoryUpdate() + task = jobs.RunInventoryUpdate() task.instance = inventory_update insights = CredentialType.defaults['insights']() inventory_update.source = 'insights' @@ -1776,7 +1775,7 @@ class TestInventoryUpdateCredentials(TestJobExecution): @pytest.mark.parametrize('verify', [True, False]) def test_tower_source(self, verify, inventory_update, private_data_dir, mocker): - task = tasks.jobs.RunInventoryUpdate() + task = jobs.RunInventoryUpdate() task.instance = inventory_update tower = CredentialType.defaults['controller']() inventory_update.source = 'controller' @@ -1804,7 +1803,7 @@ class TestInventoryUpdateCredentials(TestJobExecution): assert safe_env['CONTROLLER_PASSWORD'] == HIDDEN_PASSWORD def test_tower_source_ssl_verify_empty(self, inventory_update, private_data_dir, mocker): - task = tasks.jobs.RunInventoryUpdate() + task = jobs.RunInventoryUpdate() task.instance = inventory_update tower = CredentialType.defaults['controller']() inventory_update.source = 'controller' @@ -1832,7 +1831,7 @@ class TestInventoryUpdateCredentials(TestJobExecution): assert env['TOWER_VERIFY_SSL'] == 'False' def test_awx_task_env(self, inventory_update, private_data_dir, settings, mocker): - task = tasks.jobs.RunInventoryUpdate() + task = jobs.RunInventoryUpdate() task.instance = inventory_update gce = CredentialType.defaults['gce']() inventory_update.source = 'gce' @@ -1883,7 +1882,7 @@ def test_aquire_lock_open_fail_logged(logging_getLogger, os_open): logger = mock.Mock() logging_getLogger.return_value = logger - ProjectUpdate = tasks.jobs.RunProjectUpdate() + ProjectUpdate = jobs.RunProjectUpdate() with pytest.raises(OSError): ProjectUpdate.acquire_lock(instance) @@ -1910,7 +1909,7 @@ def test_aquire_lock_acquisition_fail_logged(fcntl_lockf, logging_getLogger, os_ fcntl_lockf.side_effect = err - ProjectUpdate = tasks.jobs.RunProjectUpdate() + ProjectUpdate = jobs.RunProjectUpdate() with pytest.raises(IOError): ProjectUpdate.acquire_lock(instance) os_close.assert_called_with(3) @@ -1947,7 +1946,7 @@ def test_notification_job_not_finished(logging_getLogger, mocker): logging_getLogger.return_value = logger with mocker.patch('awx.main.models.UnifiedJob.objects.get', uj): - tasks.system.handle_success_and_failure_notifications(1) + system.handle_success_and_failure_notifications(1) assert logger.warn.called_with(f"Failed to even try to send notifications for job '{uj}' due to job not being in finished state.") @@ -1955,7 +1954,7 @@ def test_notification_job_finished(mocker): uj = mocker.MagicMock(send_notification_templates=mocker.MagicMock(), finished=True) with mocker.patch('awx.main.models.UnifiedJob.objects.get', mocker.MagicMock(return_value=uj)): - tasks.system.handle_success_and_failure_notifications(1) + system.handle_success_and_failure_notifications(1) uj.send_notification_templates.assert_called() @@ -1964,7 +1963,7 @@ def test_job_run_no_ee(): proj = Project(pk=1, organization=org) job = Job(project=proj, organization=org, inventory=Inventory(pk=1)) job.execution_environment = None - task = tasks.jobs.RunJob() + task = jobs.RunJob() task.instance = job task.update_model = mock.Mock(return_value=job) task.model.objects.get = mock.Mock(return_value=job) @@ -1983,7 +1982,7 @@ def test_project_update_no_ee(): proj = Project(pk=1, organization=org) project_update = ProjectUpdate(pk=1, project=proj, scm_type='git') project_update.execution_environment = None - task = tasks.jobs.RunProjectUpdate() + task = jobs.RunProjectUpdate() task.instance = project_update with pytest.raises(RuntimeError) as e: diff --git a/awx/main/utils/update_model.py b/awx/main/utils/update_model.py new file mode 100644 index 0000000000..95c261cd6f --- /dev/null +++ b/awx/main/utils/update_model.py @@ -0,0 +1,40 @@ +from django.db import transaction, DatabaseError + +import logging +import time + + +logger = logging.getLogger('awx.main.tasks.utils') + + +def update_model(model, pk, _attempt=0, **updates): + """Reload the model instance from the database and update the + given fields. + """ + try: + with transaction.atomic(): + # Retrieve the model instance. + instance = model.objects.get(pk=pk) + + # Update the appropriate fields and save the model + # instance, then return the new instance. + if updates: + update_fields = ['modified'] + for field, value in updates.items(): + setattr(instance, field, value) + update_fields.append(field) + if field == 'status': + update_fields.append('failed') + instance.save(update_fields=update_fields) + return instance + except DatabaseError as e: + # Log out the error to the debug logger. + logger.debug('Database error updating %s, retrying in 5 seconds (retry #%d): %s', model._meta.object_name, _attempt + 1, e) + + # Attempt to retry the update, assuming we haven't already + # tried too many times. + if _attempt < 5: + time.sleep(5) + return update_model(model, pk, _attempt=_attempt + 1, **updates) + else: + logger.error('Failed to update %s after %d retries.', model._meta.object_name, _attempt)