Merge pull request #11571 from amolgautam25/tasks-refactor-2

Added new class for  Ansible Runner Callbacks
This commit is contained in:
Amol Gautam 2022-02-15 10:31:32 -05:00 committed by GitHub
commit 3f08e26881
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 433 additions and 350 deletions

257
awx/main/tasks/callback.py Normal file
View File

@ -0,0 +1,257 @@
import json
import time
import logging
from collections import deque
import os
import stat
# Django
from django.utils.timezone import now
from django.conf import settings
from django_guid.middleware import GuidMiddleware
# AWX
from awx.main.redact import UriCleaner
from awx.main.constants import MINIMAL_EVENTS
from awx.main.utils.update_model import update_model
from awx.main.queue import CallbackQueueDispatcher
logger = logging.getLogger('awx.main.tasks.callback')
class RunnerCallback:
event_data_key = 'job_id'
def __init__(self, model=None):
self.parent_workflow_job_id = None
self.host_map = {}
self.guid = GuidMiddleware.get_guid()
self.job_created = None
self.recent_event_timings = deque(maxlen=settings.MAX_WEBSOCKET_EVENT_RATE)
self.dispatcher = CallbackQueueDispatcher()
self.safe_env = {}
self.event_ct = 0
self.model = model
def update_model(self, pk, _attempt=0, **updates):
return update_model(self.model, pk, _attempt=0, **updates)
def event_handler(self, event_data):
#
# ⚠️ D-D-D-DANGER ZONE ⚠️
# This method is called once for *every event* emitted by Ansible
# Runner as a playbook runs. That means that changes to the code in
# this method are _very_ likely to introduce performance regressions.
#
# Even if this function is made on average .05s slower, it can have
# devastating performance implications for playbooks that emit
# tens or hundreds of thousands of events.
#
# Proceed with caution!
#
"""
Ansible runner puts a parent_uuid on each event, no matter what the type.
AWX only saves the parent_uuid if the event is for a Job.
"""
# cache end_line locally for RunInventoryUpdate tasks
# which generate job events from two 'streams':
# ansible-inventory and the awx.main.commands.inventory_import
# logger
if event_data.get(self.event_data_key, None):
if self.event_data_key != 'job_id':
event_data.pop('parent_uuid', None)
if self.parent_workflow_job_id:
event_data['workflow_job_id'] = self.parent_workflow_job_id
event_data['job_created'] = self.job_created
if self.host_map:
host = event_data.get('event_data', {}).get('host', '').strip()
if host:
event_data['host_name'] = host
if host in self.host_map:
event_data['host_id'] = self.host_map[host]
else:
event_data['host_name'] = ''
event_data['host_id'] = ''
if event_data.get('event') == 'playbook_on_stats':
event_data['host_map'] = self.host_map
if isinstance(self, RunnerCallbackForProjectUpdate):
# need a better way to have this check.
# it's common for Ansible's SCM modules to print
# error messages on failure that contain the plaintext
# basic auth credentials (username + password)
# it's also common for the nested event data itself (['res']['...'])
# to contain unredacted text on failure
# this is a _little_ expensive to filter
# with regex, but project updates don't have many events,
# so it *should* have a negligible performance impact
task = event_data.get('event_data', {}).get('task_action')
try:
if task in ('git', 'svn'):
event_data_json = json.dumps(event_data)
event_data_json = UriCleaner.remove_sensitive(event_data_json)
event_data = json.loads(event_data_json)
except json.JSONDecodeError:
pass
if 'event_data' in event_data:
event_data['event_data']['guid'] = self.guid
# To prevent overwhelming the broadcast queue, skip some websocket messages
if self.recent_event_timings:
cpu_time = time.time()
first_window_time = self.recent_event_timings[0]
last_window_time = self.recent_event_timings[-1]
if event_data.get('event') in MINIMAL_EVENTS:
should_emit = True # always send some types like playbook_on_stats
elif event_data.get('stdout') == '' and event_data['start_line'] == event_data['end_line']:
should_emit = False # exclude events with no output
else:
should_emit = any(
[
# if 30the most recent websocket message was sent over 1 second ago
cpu_time - first_window_time > 1.0,
# if the very last websocket message came in over 1/30 seconds ago
self.recent_event_timings.maxlen * (cpu_time - last_window_time) > 1.0,
# if the queue is not yet full
len(self.recent_event_timings) != self.recent_event_timings.maxlen,
]
)
if should_emit:
self.recent_event_timings.append(cpu_time)
else:
event_data.setdefault('event_data', {})
event_data['skip_websocket_message'] = True
elif self.recent_event_timings.maxlen:
self.recent_event_timings.append(time.time())
event_data.setdefault(self.event_data_key, self.instance.id)
self.dispatcher.dispatch(event_data)
self.event_ct += 1
'''
Handle artifacts
'''
if event_data.get('event_data', {}).get('artifact_data', {}):
self.instance.artifacts = event_data['event_data']['artifact_data']
self.instance.save(update_fields=['artifacts'])
return False
def cancel_callback(self):
"""
Ansible runner callback to tell the job when/if it is canceled
"""
unified_job_id = self.instance.pk
self.instance.refresh_from_db()
if not self.instance:
logger.error('unified job {} was deleted while running, canceling'.format(unified_job_id))
return True
if self.instance.cancel_flag or self.instance.status == 'canceled':
cancel_wait = (now() - self.instance.modified).seconds if self.instance.modified else 0
if cancel_wait > 5:
logger.warn('Request to cancel {} took {} seconds to complete.'.format(self.instance.log_format, cancel_wait))
return True
return False
def finished_callback(self, runner_obj):
"""
Ansible runner callback triggered on finished run
"""
event_data = {
'event': 'EOF',
'final_counter': self.event_ct,
'guid': self.guid,
}
event_data.setdefault(self.event_data_key, self.instance.id)
self.dispatcher.dispatch(event_data)
def status_handler(self, status_data, runner_config):
"""
Ansible runner callback triggered on status transition
"""
if status_data['status'] == 'starting':
job_env = dict(runner_config.env)
'''
Take the safe environment variables and overwrite
'''
for k, v in self.safe_env.items():
if k in job_env:
job_env[k] = v
from awx.main.signals import disable_activity_stream # Circular import
with disable_activity_stream():
self.instance = self.update_model(self.instance.pk, job_args=json.dumps(runner_config.command), job_cwd=runner_config.cwd, job_env=job_env)
elif status_data['status'] == 'failed':
# For encrypted ssh_key_data, ansible-runner worker will open and write the
# ssh_key_data to a named pipe. Then, once the podman container starts, ssh-agent will
# read from this named pipe so that the key can be used in ansible-playbook.
# Once the podman container exits, the named pipe is deleted.
# However, if the podman container fails to start in the first place, e.g. the image
# name is incorrect, then this pipe is not cleaned up. Eventually ansible-runner
# processor will attempt to write artifacts to the private data dir via unstream_dir, requiring
# that it open this named pipe. This leads to a hang. Thus, before any artifacts
# are written by the processor, it's important to remove this ssh_key_data pipe.
private_data_dir = self.instance.job_env.get('AWX_PRIVATE_DATA_DIR', None)
if private_data_dir:
key_data_file = os.path.join(private_data_dir, 'artifacts', str(self.instance.id), 'ssh_key_data')
if os.path.exists(key_data_file) and stat.S_ISFIFO(os.stat(key_data_file).st_mode):
os.remove(key_data_file)
elif status_data['status'] == 'error':
result_traceback = status_data.get('result_traceback', None)
if result_traceback:
from awx.main.signals import disable_activity_stream # Circular import
with disable_activity_stream():
self.instance = self.update_model(self.instance.pk, result_traceback=result_traceback)
class RunnerCallbackForProjectUpdate(RunnerCallback):
event_data_key = 'project_update_id'
def __init__(self, *args, **kwargs):
super(RunnerCallbackForProjectUpdate, self).__init__(*args, **kwargs)
self.playbook_new_revision = None
self.host_map = {}
def event_handler(self, event_data):
super_return_value = super(RunnerCallbackForProjectUpdate, self).event_handler(event_data)
returned_data = event_data.get('event_data', {})
if returned_data.get('task_action', '') == 'set_fact':
returned_facts = returned_data.get('res', {}).get('ansible_facts', {})
if 'scm_version' in returned_facts:
self.playbook_new_revision = returned_facts['scm_version']
return super_return_value
class RunnerCallbackForInventoryUpdate(RunnerCallback):
event_data_key = 'inventory_update_id'
def __init__(self, *args, **kwargs):
super(RunnerCallbackForInventoryUpdate, self).__init__(*args, **kwargs)
self.end_line = 0
def event_handler(self, event_data):
self.end_line = event_data['end_line']
return super(RunnerCallbackForInventoryUpdate, self).event_handler(event_data)
class RunnerCallbackForAdHocCommand(RunnerCallback):
event_data_key = 'ad_hoc_command_id'
def __init__(self, *args, **kwargs):
super(RunnerCallbackForAdHocCommand, self).__init__(*args, **kwargs)
self.host_map = {}
class RunnerCallbackForSystemJob(RunnerCallback):
event_data_key = 'system_job_id'

View File

@ -1,5 +1,5 @@
# Python
from collections import deque, OrderedDict
from collections import OrderedDict
from distutils.dir_util import copy_tree
import errno
import functools
@ -19,10 +19,8 @@ from uuid import uuid4
# Django
from django_guid.middleware import GuidMiddleware
from django.conf import settings
from django.db import transaction, DatabaseError
from django.utils.timezone import now
from django.db import transaction
# Runner
@ -40,11 +38,9 @@ from awx.main.dispatch import get_local_queuename
from awx.main.constants import (
PRIVILEGE_ESCALATION_METHODS,
STANDARD_INVENTORY_UPDATE_ENV,
MINIMAL_EVENTS,
JOB_FOLDER_PREFIX,
MAX_ISOLATED_PATH_COLON_DELIMITER,
)
from awx.main.redact import UriCleaner
from awx.main.models import (
Instance,
Inventory,
@ -61,7 +57,13 @@ from awx.main.models import (
SystemJobEvent,
build_safe_env,
)
from awx.main.queue import CallbackQueueDispatcher
from awx.main.tasks.callback import (
RunnerCallback,
RunnerCallbackForAdHocCommand,
RunnerCallbackForInventoryUpdate,
RunnerCallbackForProjectUpdate,
RunnerCallbackForSystemJob,
)
from awx.main.tasks.receptor import AWXReceptorJob
from awx.main.exceptions import AwxTaskError, PostRunError, ReceptorNodeNotFound
from awx.main.utils.ansible import read_ansible_config
@ -76,6 +78,7 @@ from awx.main.utils.common import (
from awx.conf.license import get_license
from awx.main.utils.handlers import SpecialInventoryHandler
from awx.main.tasks.system import handle_success_and_failure_notifications, update_smart_memberships_for_inventory, update_inventory_computed_fields
from awx.main.utils.update_model import update_model
from rest_framework.exceptions import PermissionDenied
from django.utils.translation import ugettext_lazy as _
@ -105,46 +108,14 @@ class BaseTask(object):
model = None
event_model = None
abstract = True
callback_class = RunnerCallback
def __init__(self):
self.cleanup_paths = []
self.parent_workflow_job_id = None
self.host_map = {}
self.guid = GuidMiddleware.get_guid()
self.job_created = None
self.recent_event_timings = deque(maxlen=settings.MAX_WEBSOCKET_EVENT_RATE)
self.runner_callback = self.callback_class(model=self.model)
def update_model(self, pk, _attempt=0, **updates):
"""Reload the model instance from the database and update the
given fields.
"""
try:
with transaction.atomic():
# Retrieve the model instance.
instance = self.model.objects.get(pk=pk)
# Update the appropriate fields and save the model
# instance, then return the new instance.
if updates:
update_fields = ['modified']
for field, value in updates.items():
setattr(instance, field, value)
update_fields.append(field)
if field == 'status':
update_fields.append('failed')
instance.save(update_fields=update_fields)
return instance
except DatabaseError as e:
# Log out the error to the debug logger.
logger.debug('Database error updating %s, retrying in 5 ' 'seconds (retry #%d): %s', self.model._meta.object_name, _attempt + 1, e)
# Attempt to retry the update, assuming we haven't already
# tried too many times.
if _attempt < 5:
time.sleep(5)
return self.update_model(pk, _attempt=_attempt + 1, **updates)
else:
logger.error('Failed to update %s after %d retries.', self.model._meta.object_name, _attempt)
return update_model(self.model, pk, _attempt=0, **updates)
def get_path_to(self, *args):
"""
@ -350,7 +321,7 @@ class BaseTask(object):
script_data = instance.inventory.get_script_data(**script_params)
# maintain a list of host_name --> host_id
# so we can associate emitted events to Host objects
self.host_map = {hostname: hv.pop('remote_tower_id', '') for hostname, hv in script_data.get('_meta', {}).get('hostvars', {}).items()}
self.runner_callback.host_map = {hostname: hv.pop('remote_tower_id', '') for hostname, hv in script_data.get('_meta', {}).get('hostvars', {}).items()}
json_data = json.dumps(script_data)
path = os.path.join(private_data_dir, 'inventory')
fn = os.path.join(path, 'hosts')
@ -444,181 +415,6 @@ class BaseTask(object):
instance.ansible_version = ansible_version_info
instance.save(update_fields=['ansible_version'])
def event_handler(self, event_data):
#
# ⚠️ D-D-D-DANGER ZONE ⚠️
# This method is called once for *every event* emitted by Ansible
# Runner as a playbook runs. That means that changes to the code in
# this method are _very_ likely to introduce performance regressions.
#
# Even if this function is made on average .05s slower, it can have
# devastating performance implications for playbooks that emit
# tens or hundreds of thousands of events.
#
# Proceed with caution!
#
"""
Ansible runner puts a parent_uuid on each event, no matter what the type.
AWX only saves the parent_uuid if the event is for a Job.
"""
# cache end_line locally for RunInventoryUpdate tasks
# which generate job events from two 'streams':
# ansible-inventory and the awx.main.commands.inventory_import
# logger
if isinstance(self, RunInventoryUpdate):
self.end_line = event_data['end_line']
if event_data.get(self.event_data_key, None):
if self.event_data_key != 'job_id':
event_data.pop('parent_uuid', None)
if self.parent_workflow_job_id:
event_data['workflow_job_id'] = self.parent_workflow_job_id
event_data['job_created'] = self.job_created
if self.host_map:
host = event_data.get('event_data', {}).get('host', '').strip()
if host:
event_data['host_name'] = host
if host in self.host_map:
event_data['host_id'] = self.host_map[host]
else:
event_data['host_name'] = ''
event_data['host_id'] = ''
if event_data.get('event') == 'playbook_on_stats':
event_data['host_map'] = self.host_map
if isinstance(self, RunProjectUpdate):
# it's common for Ansible's SCM modules to print
# error messages on failure that contain the plaintext
# basic auth credentials (username + password)
# it's also common for the nested event data itself (['res']['...'])
# to contain unredacted text on failure
# this is a _little_ expensive to filter
# with regex, but project updates don't have many events,
# so it *should* have a negligible performance impact
task = event_data.get('event_data', {}).get('task_action')
try:
if task in ('git', 'svn'):
event_data_json = json.dumps(event_data)
event_data_json = UriCleaner.remove_sensitive(event_data_json)
event_data = json.loads(event_data_json)
except json.JSONDecodeError:
pass
if 'event_data' in event_data:
event_data['event_data']['guid'] = self.guid
# To prevent overwhelming the broadcast queue, skip some websocket messages
if self.recent_event_timings:
cpu_time = time.time()
first_window_time = self.recent_event_timings[0]
last_window_time = self.recent_event_timings[-1]
if event_data.get('event') in MINIMAL_EVENTS:
should_emit = True # always send some types like playbook_on_stats
elif event_data.get('stdout') == '' and event_data['start_line'] == event_data['end_line']:
should_emit = False # exclude events with no output
else:
should_emit = any(
[
# if 30the most recent websocket message was sent over 1 second ago
cpu_time - first_window_time > 1.0,
# if the very last websocket message came in over 1/30 seconds ago
self.recent_event_timings.maxlen * (cpu_time - last_window_time) > 1.0,
# if the queue is not yet full
len(self.recent_event_timings) != self.recent_event_timings.maxlen,
]
)
if should_emit:
self.recent_event_timings.append(cpu_time)
else:
event_data.setdefault('event_data', {})
event_data['skip_websocket_message'] = True
elif self.recent_event_timings.maxlen:
self.recent_event_timings.append(time.time())
event_data.setdefault(self.event_data_key, self.instance.id)
self.dispatcher.dispatch(event_data)
self.event_ct += 1
'''
Handle artifacts
'''
if event_data.get('event_data', {}).get('artifact_data', {}):
self.instance.artifacts = event_data['event_data']['artifact_data']
self.instance.save(update_fields=['artifacts'])
return False
def cancel_callback(self):
"""
Ansible runner callback to tell the job when/if it is canceled
"""
unified_job_id = self.instance.pk
self.instance = self.update_model(unified_job_id)
if not self.instance:
logger.error('unified job {} was deleted while running, canceling'.format(unified_job_id))
return True
if self.instance.cancel_flag or self.instance.status == 'canceled':
cancel_wait = (now() - self.instance.modified).seconds if self.instance.modified else 0
if cancel_wait > 5:
logger.warn('Request to cancel {} took {} seconds to complete.'.format(self.instance.log_format, cancel_wait))
return True
return False
def finished_callback(self, runner_obj):
"""
Ansible runner callback triggered on finished run
"""
event_data = {
'event': 'EOF',
'final_counter': self.event_ct,
'guid': self.guid,
}
event_data.setdefault(self.event_data_key, self.instance.id)
self.dispatcher.dispatch(event_data)
def status_handler(self, status_data, runner_config):
"""
Ansible runner callback triggered on status transition
"""
if status_data['status'] == 'starting':
job_env = dict(runner_config.env)
'''
Take the safe environment variables and overwrite
'''
for k, v in self.safe_env.items():
if k in job_env:
job_env[k] = v
from awx.main.signals import disable_activity_stream # Circular import
with disable_activity_stream():
self.instance = self.update_model(self.instance.pk, job_args=json.dumps(runner_config.command), job_cwd=runner_config.cwd, job_env=job_env)
elif status_data['status'] == 'failed':
# For encrypted ssh_key_data, ansible-runner worker will open and write the
# ssh_key_data to a named pipe. Then, once the podman container starts, ssh-agent will
# read from this named pipe so that the key can be used in ansible-playbook.
# Once the podman container exits, the named pipe is deleted.
# However, if the podman container fails to start in the first place, e.g. the image
# name is incorrect, then this pipe is not cleaned up. Eventually ansible-runner
# processor will attempt to write artifacts to the private data dir via unstream_dir, requiring
# that it open this named pipe. This leads to a hang. Thus, before any artifacts
# are written by the processor, it's important to remove this ssh_key_data pipe.
private_data_dir = self.instance.job_env.get('AWX_PRIVATE_DATA_DIR', None)
if private_data_dir:
key_data_file = os.path.join(private_data_dir, 'artifacts', str(self.instance.id), 'ssh_key_data')
if os.path.exists(key_data_file) and stat.S_ISFIFO(os.stat(key_data_file).st_mode):
os.remove(key_data_file)
elif status_data['status'] == 'error':
result_traceback = status_data.get('result_traceback', None)
if result_traceback:
from awx.main.signals import disable_activity_stream # Circular import
with disable_activity_stream():
self.instance = self.update_model(self.instance.pk, result_traceback=result_traceback)
@with_path_cleanup
def run(self, pk, **kwargs):
"""
@ -638,22 +434,15 @@ class BaseTask(object):
status, rc = 'error', None
extra_update_fields = {}
fact_modification_times = {}
self.event_ct = 0
self.runner_callback.event_ct = 0
'''
Needs to be an object property because status_handler uses it in a callback context
'''
self.safe_env = {}
self.safe_cred_env = {}
private_data_dir = None
# store a reference to the parent workflow job (if any) so we can include
# it in event data JSON
if self.instance.spawned_by_workflow:
self.parent_workflow_job_id = self.instance.get_workflow_job().id
self.job_created = str(self.instance.created)
try:
self.instance.send_notification_templates("running")
private_data_dir = self.build_private_data_dir(self.instance)
@ -685,7 +474,16 @@ class BaseTask(object):
self.build_extra_vars_file(self.instance, private_data_dir)
args = self.build_args(self.instance, private_data_dir, passwords)
env = self.build_env(self.instance, private_data_dir, private_data_files=private_data_files)
self.safe_env = build_safe_env(env)
self.runner_callback.safe_env = build_safe_env(env)
self.runner_callback.instance = self.instance
# store a reference to the parent workflow job (if any) so we can include
# it in event data JSON
if self.instance.spawned_by_workflow:
self.runner_callback.parent_workflow_job_id = self.instance.get_workflow_job().id
self.runner_callback.job_created = str(self.instance.created)
credentials = self.build_credentials_list(self.instance)
@ -693,7 +491,7 @@ class BaseTask(object):
if credential:
credential.credential_type.inject_credential(credential, env, self.safe_cred_env, args, private_data_dir)
self.safe_env.update(self.safe_cred_env)
self.runner_callback.safe_env.update(self.safe_cred_env)
self.write_args_file(private_data_dir, args)
@ -739,16 +537,14 @@ class BaseTask(object):
if not params[v]:
del params[v]
self.dispatcher = CallbackQueueDispatcher()
self.instance.log_lifecycle("running_playbook")
if isinstance(self.instance, SystemJob):
res = ansible_runner.interface.run(
project_dir=settings.BASE_DIR,
event_handler=self.event_handler,
finished_callback=self.finished_callback,
status_handler=self.status_handler,
cancel_callback=self.cancel_callback,
event_handler=self.runner_callback.event_handler,
finished_callback=self.runner_callback.finished_callback,
status_handler=self.runner_callback.status_handler,
cancel_callback=self.runner_callback.cancel_callback,
**params,
)
else:
@ -779,7 +575,7 @@ class BaseTask(object):
extra_update_fields['result_traceback'] = traceback.format_exc()
logger.exception('%s Exception occurred while running task', self.instance.log_format)
finally:
logger.debug('%s finished running, producing %s events.', self.instance.log_format, self.event_ct)
logger.debug('%s finished running, producing %s events.', self.instance.log_format, self.runner_callback.event_ct)
try:
self.post_run_hook(self.instance, status)
@ -793,7 +589,7 @@ class BaseTask(object):
logger.exception('{} Post run hook errored.'.format(self.instance.log_format))
self.instance = self.update_model(pk)
self.instance = self.update_model(pk, status=status, emitted_events=self.event_ct, **extra_update_fields)
self.instance = self.update_model(pk, status=status, emitted_events=self.runner_callback.event_ct, **extra_update_fields)
try:
self.final_run_hook(self.instance, status, private_data_dir, fact_modification_times)
@ -816,7 +612,6 @@ class RunJob(BaseTask):
model = Job
event_model = JobEvent
event_data_key = 'job_id'
def build_private_data(self, job, private_data_dir):
"""
@ -1215,22 +1010,13 @@ class RunProjectUpdate(BaseTask):
model = ProjectUpdate
event_model = ProjectUpdateEvent
event_data_key = 'project_update_id'
callback_class = RunnerCallbackForProjectUpdate
def __init__(self, *args, job_private_data_dir=None, **kwargs):
super(RunProjectUpdate, self).__init__(*args, **kwargs)
self.playbook_new_revision = None
self.original_branch = None
self.job_private_data_dir = job_private_data_dir
def event_handler(self, event_data):
super(RunProjectUpdate, self).event_handler(event_data)
returned_data = event_data.get('event_data', {})
if returned_data.get('task_action', '') == 'set_fact':
returned_facts = returned_data.get('res', {}).get('ansible_facts', {})
if 'scm_version' in returned_facts:
self.playbook_new_revision = returned_facts['scm_version']
def build_private_data(self, project_update, private_data_dir):
"""
Return SSH private key data needed for this project update.
@ -1631,8 +1417,8 @@ class RunProjectUpdate(BaseTask):
super(RunProjectUpdate, self).post_run_hook(instance, status)
# To avoid hangs, very important to release lock even if errors happen here
try:
if self.playbook_new_revision:
instance.scm_revision = self.playbook_new_revision
if self.runner_callback.playbook_new_revision:
instance.scm_revision = self.runner_callback.playbook_new_revision
instance.save(update_fields=['scm_revision'])
# Roles and collection folders copy to durable cache
@ -1672,8 +1458,8 @@ class RunProjectUpdate(BaseTask):
'failed',
'canceled',
):
if self.playbook_new_revision:
p.scm_revision = self.playbook_new_revision
if self.runner_callback.playbook_new_revision:
p.scm_revision = self.runner_callback.playbook_new_revision
else:
if status == 'successful':
logger.error("{} Could not find scm revision in check".format(instance.log_format))
@ -1709,7 +1495,7 @@ class RunInventoryUpdate(BaseTask):
model = InventoryUpdate
event_model = InventoryUpdateEvent
event_data_key = 'inventory_update_id'
callback_class = RunnerCallbackForInventoryUpdate
def build_private_data(self, inventory_update, private_data_dir):
"""
@ -1932,6 +1718,7 @@ class RunInventoryUpdate(BaseTask):
if status != 'successful':
return # nothing to save, step out of the way to allow error reporting
inventory_update.refresh_from_db()
private_data_dir = inventory_update.job_env['AWX_PRIVATE_DATA_DIR']
expected_output = os.path.join(private_data_dir, 'artifacts', str(inventory_update.id), 'output.json')
with open(expected_output) as f:
@ -1966,13 +1753,13 @@ class RunInventoryUpdate(BaseTask):
options['verbosity'] = inventory_update.verbosity
handler = SpecialInventoryHandler(
self.event_handler,
self.cancel_callback,
self.runner_callback.event_handler,
self.runner_callback.cancel_callback,
verbosity=inventory_update.verbosity,
job_timeout=self.get_instance_timeout(self.instance),
start_time=inventory_update.started,
counter=self.event_ct,
initial_line=self.end_line,
counter=self.runner_callback.event_ct,
initial_line=self.runner_callback.end_line,
)
inv_logger = logging.getLogger('awx.main.commands.inventory_import')
formatter = inv_logger.handlers[0].formatter
@ -2006,7 +1793,7 @@ class RunAdHocCommand(BaseTask):
model = AdHocCommand
event_model = AdHocCommandEvent
event_data_key = 'ad_hoc_command_id'
callback_class = RunnerCallbackForAdHocCommand
def build_private_data(self, ad_hoc_command, private_data_dir):
"""
@ -2164,7 +1951,7 @@ class RunSystemJob(BaseTask):
model = SystemJob
event_model = SystemJobEvent
event_data_key = 'system_job_id'
callback_class = RunnerCallbackForSystemJob
def build_execution_environment_params(self, system_job, private_data_dir):
return {}

View File

@ -411,9 +411,9 @@ class AWXReceptorJob:
streamer='process',
quiet=True,
_input=resultfile,
event_handler=self.task.event_handler,
finished_callback=self.task.finished_callback,
status_handler=self.task.status_handler,
event_handler=self.task.runner_callback.event_handler,
finished_callback=self.task.runner_callback.finished_callback,
status_handler=self.task.runner_callback.status_handler,
**self.runner_params,
)
@ -458,7 +458,7 @@ class AWXReceptorJob:
if processor_future.done():
return processor_future.result()
if self.task.cancel_callback():
if self.task.runner_callback.cancel_callback():
result = namedtuple('result', ['status', 'rc'])
return result('canceled', 1)

View File

@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
import configparser
import json
import os
@ -34,7 +33,7 @@ from awx.main.models import (
)
from awx.main.models.credential import HIDDEN_PASSWORD, ManagedCredentialType
from awx.main import tasks
from awx.main.tasks import jobs, system
from awx.main.utils import encrypt_field, encrypt_value
from awx.main.utils.safe_yaml import SafeLoader
from awx.main.utils.execution_environments import CONTAINER_ROOT, to_host_path
@ -113,12 +112,12 @@ def adhoc_update_model_wrapper(adhoc_job):
def test_send_notifications_not_list():
with pytest.raises(TypeError):
tasks.system.send_notifications(None)
system.send_notifications(None)
def test_send_notifications_job_id(mocker):
with mocker.patch('awx.main.models.UnifiedJob.objects.get'):
tasks.system.send_notifications([], job_id=1)
system.send_notifications([], job_id=1)
assert UnifiedJob.objects.get.called
assert UnifiedJob.objects.get.called_with(id=1)
@ -127,7 +126,7 @@ def test_work_success_callback_missing_job():
task_data = {'type': 'project_update', 'id': 9999}
with mock.patch('django.db.models.query.QuerySet.get') as get_mock:
get_mock.side_effect = ProjectUpdate.DoesNotExist()
assert tasks.system.handle_work_success(task_data) is None
assert system.handle_work_success(task_data) is None
@mock.patch('awx.main.models.UnifiedJob.objects.get')
@ -138,7 +137,7 @@ def test_send_notifications_list(mock_notifications_filter, mock_job_get, mocker
mock_notifications = [mocker.MagicMock(spec=Notification, subject="test", body={'hello': 'world'})]
mock_notifications_filter.return_value = mock_notifications
tasks.system.send_notifications([1, 2], job_id=1)
system.send_notifications([1, 2], job_id=1)
assert Notification.objects.filter.call_count == 1
assert mock_notifications[0].status == "successful"
assert mock_notifications[0].save.called
@ -168,7 +167,7 @@ def test_safe_env_returns_new_copy():
@pytest.mark.parametrize("source,expected", [(None, True), (False, False), (True, True)])
def test_openstack_client_config_generation(mocker, source, expected, private_data_dir):
update = tasks.jobs.RunInventoryUpdate()
update = jobs.RunInventoryUpdate()
credential_type = CredentialType.defaults['openstack']()
inputs = {
'host': 'https://keystone.openstack.example.org',
@ -208,7 +207,7 @@ def test_openstack_client_config_generation(mocker, source, expected, private_da
@pytest.mark.parametrize("source,expected", [(None, True), (False, False), (True, True)])
def test_openstack_client_config_generation_with_project_domain_name(mocker, source, expected, private_data_dir):
update = tasks.jobs.RunInventoryUpdate()
update = jobs.RunInventoryUpdate()
credential_type = CredentialType.defaults['openstack']()
inputs = {
'host': 'https://keystone.openstack.example.org',
@ -250,7 +249,7 @@ def test_openstack_client_config_generation_with_project_domain_name(mocker, sou
@pytest.mark.parametrize("source,expected", [(None, True), (False, False), (True, True)])
def test_openstack_client_config_generation_with_region(mocker, source, expected, private_data_dir):
update = tasks.jobs.RunInventoryUpdate()
update = jobs.RunInventoryUpdate()
credential_type = CredentialType.defaults['openstack']()
inputs = {
'host': 'https://keystone.openstack.example.org',
@ -294,7 +293,7 @@ def test_openstack_client_config_generation_with_region(mocker, source, expected
@pytest.mark.parametrize("source,expected", [(False, False), (True, True)])
def test_openstack_client_config_generation_with_private_source_vars(mocker, source, expected, private_data_dir):
update = tasks.jobs.RunInventoryUpdate()
update = jobs.RunInventoryUpdate()
credential_type = CredentialType.defaults['openstack']()
inputs = {
'host': 'https://keystone.openstack.example.org',
@ -357,7 +356,7 @@ class TestExtraVarSanitation(TestJobExecution):
job.created_by = User(pk=123, username='angry-spud')
job.inventory = Inventory(pk=123, name='example-inv')
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.build_extra_vars_file(job, private_data_dir)
fd = open(os.path.join(private_data_dir, 'env', 'extravars'))
@ -393,7 +392,7 @@ class TestExtraVarSanitation(TestJobExecution):
def test_launchtime_vars_unsafe(self, job, private_data_dir):
job.extra_vars = json.dumps({'msg': self.UNSAFE})
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.build_extra_vars_file(job, private_data_dir)
@ -404,7 +403,7 @@ class TestExtraVarSanitation(TestJobExecution):
def test_nested_launchtime_vars_unsafe(self, job, private_data_dir):
job.extra_vars = json.dumps({'msg': {'a': [self.UNSAFE]}})
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.build_extra_vars_file(job, private_data_dir)
@ -415,7 +414,7 @@ class TestExtraVarSanitation(TestJobExecution):
def test_allowed_jt_extra_vars(self, job, private_data_dir):
job.job_template.extra_vars = job.extra_vars = json.dumps({'msg': self.UNSAFE})
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.build_extra_vars_file(job, private_data_dir)
@ -427,7 +426,7 @@ class TestExtraVarSanitation(TestJobExecution):
def test_nested_allowed_vars(self, job, private_data_dir):
job.extra_vars = json.dumps({'msg': {'a': {'b': [self.UNSAFE]}}})
job.job_template.extra_vars = job.extra_vars
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.build_extra_vars_file(job, private_data_dir)
@ -441,7 +440,7 @@ class TestExtraVarSanitation(TestJobExecution):
# `other_var=SENSITIVE`
job.job_template.extra_vars = json.dumps({'msg': self.UNSAFE})
job.extra_vars = json.dumps({'msg': 'other-value', 'other_var': self.UNSAFE})
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.build_extra_vars_file(job, private_data_dir)
@ -456,7 +455,7 @@ class TestExtraVarSanitation(TestJobExecution):
def test_overwritten_jt_extra_vars(self, job, private_data_dir):
job.job_template.extra_vars = json.dumps({'msg': 'SAFE'})
job.extra_vars = json.dumps({'msg': self.UNSAFE})
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.build_extra_vars_file(job, private_data_dir)
@ -472,7 +471,7 @@ class TestGenericRun:
job.websocket_emit_status = mock.Mock()
job.execution_environment = execution_environment
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.instance = job
task.update_model = mock.Mock(return_value=job)
task.model.objects.get = mock.Mock(return_value=job)
@ -494,7 +493,7 @@ class TestGenericRun:
job.send_notification_templates = mock.Mock()
job.execution_environment = execution_environment
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.instance = job
task.update_model = mock.Mock(wraps=update_model_wrapper)
task.model.objects.get = mock.Mock(return_value=job)
@ -508,45 +507,45 @@ class TestGenericRun:
assert c in task.update_model.call_args_list
def test_event_count(self):
task = tasks.jobs.RunJob()
task.dispatcher = mock.MagicMock()
task.instance = Job()
task.event_ct = 0
task = jobs.RunJob()
task.runner_callback.dispatcher = mock.MagicMock()
task.runner_callback.instance = Job()
task.runner_callback.event_ct = 0
event_data = {}
[task.event_handler(event_data) for i in range(20)]
assert 20 == task.event_ct
[task.runner_callback.event_handler(event_data) for i in range(20)]
assert 20 == task.runner_callback.event_ct
def test_finished_callback_eof(self):
task = tasks.jobs.RunJob()
task.dispatcher = mock.MagicMock()
task.instance = Job(pk=1, id=1)
task.event_ct = 17
task.finished_callback(None)
task.dispatcher.dispatch.assert_called_with({'event': 'EOF', 'final_counter': 17, 'job_id': 1, 'guid': None})
task = jobs.RunJob()
task.runner_callback.dispatcher = mock.MagicMock()
task.runner_callback.instance = Job(pk=1, id=1)
task.runner_callback.event_ct = 17
task.runner_callback.finished_callback(None)
task.runner_callback.dispatcher.dispatch.assert_called_with({'event': 'EOF', 'final_counter': 17, 'job_id': 1, 'guid': None})
def test_save_job_metadata(self, job, update_model_wrapper):
class MockMe:
pass
task = tasks.jobs.RunJob()
task.instance = job
task.safe_env = {'secret_key': 'redacted_value'}
task.update_model = mock.Mock(wraps=update_model_wrapper)
task = jobs.RunJob()
task.runner_callback.instance = job
task.runner_callback.safe_env = {'secret_key': 'redacted_value'}
task.runner_callback.update_model = mock.Mock(wraps=update_model_wrapper)
runner_config = MockMe()
runner_config.command = {'foo': 'bar'}
runner_config.cwd = '/foobar'
runner_config.env = {'switch': 'blade', 'foot': 'ball', 'secret_key': 'secret_value'}
task.status_handler({'status': 'starting'}, runner_config)
task.runner_callback.status_handler({'status': 'starting'}, runner_config)
task.update_model.assert_called_with(
task.runner_callback.update_model.assert_called_with(
1, job_args=json.dumps({'foo': 'bar'}), job_cwd='/foobar', job_env={'switch': 'blade', 'foot': 'ball', 'secret_key': 'redacted_value'}
)
def test_created_by_extra_vars(self):
job = Job(created_by=User(pk=123, username='angry-spud'))
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task._write_extra_vars_file = mock.Mock()
task.build_extra_vars_file(job, None)
@ -563,7 +562,7 @@ class TestGenericRun:
job.extra_vars = json.dumps({'super_secret': encrypt_value('CLASSIFIED', pk=None)})
job.survey_passwords = {'super_secret': '$encrypted$'}
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task._write_extra_vars_file = mock.Mock()
task.build_extra_vars_file(job, None)
@ -576,7 +575,7 @@ class TestGenericRun:
job = Job(project=Project(), inventory=Inventory())
job.execution_environment = execution_environment
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.instance = job
task._write_extra_vars_file = mock.Mock()
@ -595,7 +594,7 @@ class TestAdhocRun(TestJobExecution):
adhoc_job.websocket_emit_status = mock.Mock()
adhoc_job.send_notification_templates = mock.Mock()
task = tasks.jobs.RunAdHocCommand()
task = jobs.RunAdHocCommand()
task.update_model = mock.Mock(wraps=adhoc_update_model_wrapper)
task.model.objects.get = mock.Mock(return_value=adhoc_job)
task.build_inventory = mock.Mock()
@ -619,7 +618,7 @@ class TestAdhocRun(TestJobExecution):
})
#adhoc_job.websocket_emit_status = mock.Mock()
task = tasks.jobs.RunAdHocCommand()
task = jobs.RunAdHocCommand()
#task.update_model = mock.Mock(wraps=adhoc_update_model_wrapper)
#task.build_inventory = mock.Mock(return_value='/tmp/something.inventory')
task._write_extra_vars_file = mock.Mock()
@ -634,7 +633,7 @@ class TestAdhocRun(TestJobExecution):
def test_created_by_extra_vars(self):
adhoc_job = AdHocCommand(created_by=User(pk=123, username='angry-spud'))
task = tasks.jobs.RunAdHocCommand()
task = jobs.RunAdHocCommand()
task._write_extra_vars_file = mock.Mock()
task.build_extra_vars_file(adhoc_job, None)
@ -693,7 +692,7 @@ class TestJobCredentials(TestJobExecution):
}
def test_username_jinja_usage(self, job, private_data_dir):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
ssh = CredentialType.defaults['ssh']()
credential = Credential(pk=1, credential_type=ssh, inputs={'username': '{{ ansible_ssh_pass }}'})
job.credentials.add(credential)
@ -704,7 +703,7 @@ class TestJobCredentials(TestJobExecution):
@pytest.mark.parametrize("flag", ['become_username', 'become_method'])
def test_become_jinja_usage(self, job, private_data_dir, flag):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
ssh = CredentialType.defaults['ssh']()
credential = Credential(pk=1, credential_type=ssh, inputs={'username': 'joe', flag: '{{ ansible_ssh_pass }}'})
job.credentials.add(credential)
@ -715,7 +714,7 @@ class TestJobCredentials(TestJobExecution):
assert 'Jinja variables are not allowed' in str(e.value)
def test_ssh_passwords(self, job, private_data_dir, field, password_name, expected_flag):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
ssh = CredentialType.defaults['ssh']()
credential = Credential(pk=1, credential_type=ssh, inputs={'username': 'bob', field: 'secret'})
credential.inputs[field] = encrypt_field(credential, field)
@ -732,7 +731,7 @@ class TestJobCredentials(TestJobExecution):
assert expected_flag in ' '.join(args)
def test_net_ssh_key_unlock(self, job):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
net = CredentialType.defaults['net']()
credential = Credential(pk=1, credential_type=net, inputs={'ssh_key_unlock': 'secret'})
credential.inputs['ssh_key_unlock'] = encrypt_field(credential, 'ssh_key_unlock')
@ -745,7 +744,7 @@ class TestJobCredentials(TestJobExecution):
assert 'secret' in expect_passwords.values()
def test_net_first_ssh_key_unlock_wins(self, job):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
for i in range(3):
net = CredentialType.defaults['net']()
credential = Credential(pk=i, credential_type=net, inputs={'ssh_key_unlock': 'secret{}'.format(i)})
@ -759,7 +758,7 @@ class TestJobCredentials(TestJobExecution):
assert 'secret0' in expect_passwords.values()
def test_prefer_ssh_over_net_ssh_key_unlock(self, job):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
net = CredentialType.defaults['net']()
net_credential = Credential(pk=1, credential_type=net, inputs={'ssh_key_unlock': 'net_secret'})
net_credential.inputs['ssh_key_unlock'] = encrypt_field(net_credential, 'ssh_key_unlock')
@ -778,7 +777,7 @@ class TestJobCredentials(TestJobExecution):
assert 'ssh_secret' in expect_passwords.values()
def test_vault_password(self, private_data_dir, job):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
vault = CredentialType.defaults['vault']()
credential = Credential(pk=1, credential_type=vault, inputs={'vault_password': 'vault-me'})
credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password')
@ -793,7 +792,7 @@ class TestJobCredentials(TestJobExecution):
assert '--ask-vault-pass' in ' '.join(args)
def test_vault_password_ask(self, private_data_dir, job):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
vault = CredentialType.defaults['vault']()
credential = Credential(pk=1, credential_type=vault, inputs={'vault_password': 'ASK'})
credential.inputs['vault_password'] = encrypt_field(credential, 'vault_password')
@ -808,7 +807,7 @@ class TestJobCredentials(TestJobExecution):
assert '--ask-vault-pass' in ' '.join(args)
def test_multi_vault_password(self, private_data_dir, job):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
vault = CredentialType.defaults['vault']()
for i, label in enumerate(['dev', 'prod', 'dotted.name']):
credential = Credential(pk=i, credential_type=vault, inputs={'vault_password': 'pass@{}'.format(label), 'vault_id': label})
@ -831,7 +830,7 @@ class TestJobCredentials(TestJobExecution):
assert '--vault-id dotted.name@prompt' in ' '.join(args)
def test_multi_vault_id_conflict(self, job):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
vault = CredentialType.defaults['vault']()
for i in range(2):
credential = Credential(pk=i, credential_type=vault, inputs={'vault_password': 'some-pass', 'vault_id': 'conflict'})
@ -844,7 +843,7 @@ class TestJobCredentials(TestJobExecution):
assert 'multiple vault credentials were specified with --vault-id' in str(e.value)
def test_multi_vault_password_ask(self, private_data_dir, job):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
vault = CredentialType.defaults['vault']()
for i, label in enumerate(['dev', 'prod']):
credential = Credential(pk=i, credential_type=vault, inputs={'vault_password': 'ASK', 'vault_id': label})
@ -999,7 +998,7 @@ class TestJobCredentials(TestJobExecution):
assert safe_env['VMWARE_PASSWORD'] == HIDDEN_PASSWORD
def test_openstack_credentials(self, private_data_dir, job):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.instance = job
openstack = CredentialType.defaults['openstack']()
credential = Credential(
@ -1067,7 +1066,7 @@ class TestJobCredentials(TestJobExecution):
],
)
def test_net_credentials(self, authorize, expected_authorize, job, private_data_dir):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.instance = job
net = CredentialType.defaults['net']()
inputs = {'username': 'bob', 'password': 'secret', 'ssh_key_data': self.EXAMPLE_PRIVATE_KEY, 'authorize_password': 'authorizeme'}
@ -1135,7 +1134,7 @@ class TestJobCredentials(TestJobExecution):
assert env['TURBO_BUTTON'] == str(True)
def test_custom_environment_injectors_with_reserved_env_var(self, private_data_dir, job):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.instance = job
some_cloud = CredentialType(
kind='cloud',
@ -1171,7 +1170,7 @@ class TestJobCredentials(TestJobExecution):
assert safe_env['MY_CLOUD_PRIVATE_VAR'] == HIDDEN_PASSWORD
def test_custom_environment_injectors_with_extra_vars(self, private_data_dir, job):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
some_cloud = CredentialType(
kind='cloud',
name='SomeCloud',
@ -1190,7 +1189,7 @@ class TestJobCredentials(TestJobExecution):
assert hasattr(extra_vars["api_token"], '__UNSAFE__')
def test_custom_environment_injectors_with_boolean_extra_vars(self, job, private_data_dir):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
some_cloud = CredentialType(
kind='cloud',
name='SomeCloud',
@ -1209,7 +1208,7 @@ class TestJobCredentials(TestJobExecution):
return ['successful', 0]
def test_custom_environment_injectors_with_complicated_boolean_template(self, job, private_data_dir):
task = tasks.jobs.RunJob()
task = jobs.RunJob()
some_cloud = CredentialType(
kind='cloud',
name='SomeCloud',
@ -1230,7 +1229,7 @@ class TestJobCredentials(TestJobExecution):
"""
extra_vars that contain secret field values should be censored in the DB
"""
task = tasks.jobs.RunJob()
task = jobs.RunJob()
some_cloud = CredentialType(
kind='cloud',
name='SomeCloud',
@ -1335,7 +1334,7 @@ class TestJobCredentials(TestJobExecution):
def test_awx_task_env(self, settings, private_data_dir, job):
settings.AWX_TASK_ENV = {'FOO': 'BAR'}
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.instance = job
env = task.build_env(job, private_data_dir)
@ -1362,7 +1361,7 @@ class TestProjectUpdateGalaxyCredentials(TestJobExecution):
def test_galaxy_credentials_ignore_certs(self, private_data_dir, project_update, ignore):
settings.GALAXY_IGNORE_CERTS = ignore
task = tasks.jobs.RunProjectUpdate()
task = jobs.RunProjectUpdate()
task.instance = project_update
env = task.build_env(project_update, private_data_dir)
if ignore:
@ -1371,7 +1370,7 @@ class TestProjectUpdateGalaxyCredentials(TestJobExecution):
assert 'ANSIBLE_GALAXY_IGNORE' not in env
def test_galaxy_credentials_empty(self, private_data_dir, project_update):
class RunProjectUpdate(tasks.jobs.RunProjectUpdate):
class RunProjectUpdate(jobs.RunProjectUpdate):
__vars__ = {}
def _write_extra_vars_file(self, private_data_dir, extra_vars, *kw):
@ -1390,7 +1389,7 @@ class TestProjectUpdateGalaxyCredentials(TestJobExecution):
assert not k.startswith('ANSIBLE_GALAXY_SERVER')
def test_single_public_galaxy(self, private_data_dir, project_update):
class RunProjectUpdate(tasks.jobs.RunProjectUpdate):
class RunProjectUpdate(jobs.RunProjectUpdate):
__vars__ = {}
def _write_extra_vars_file(self, private_data_dir, extra_vars, *kw):
@ -1439,7 +1438,7 @@ class TestProjectUpdateGalaxyCredentials(TestJobExecution):
)
project_update.project.organization.galaxy_credentials.add(public_galaxy)
project_update.project.organization.galaxy_credentials.add(rh)
task = tasks.jobs.RunProjectUpdate()
task = jobs.RunProjectUpdate()
task.instance = project_update
env = task.build_env(project_update, private_data_dir)
assert sorted([(k, v) for k, v in env.items() if k.startswith('ANSIBLE_GALAXY')]) == [
@ -1481,7 +1480,7 @@ class TestProjectUpdateCredentials(TestJobExecution):
}
def test_username_and_password_auth(self, project_update, scm_type):
task = tasks.jobs.RunProjectUpdate()
task = jobs.RunProjectUpdate()
ssh = CredentialType.defaults['ssh']()
project_update.scm_type = scm_type
project_update.credential = Credential(pk=1, credential_type=ssh, inputs={'username': 'bob', 'password': 'secret'})
@ -1495,7 +1494,7 @@ class TestProjectUpdateCredentials(TestJobExecution):
assert 'secret' in expect_passwords.values()
def test_ssh_key_auth(self, project_update, scm_type):
task = tasks.jobs.RunProjectUpdate()
task = jobs.RunProjectUpdate()
ssh = CredentialType.defaults['ssh']()
project_update.scm_type = scm_type
project_update.credential = Credential(pk=1, credential_type=ssh, inputs={'username': 'bob', 'ssh_key_data': self.EXAMPLE_PRIVATE_KEY})
@ -1509,7 +1508,7 @@ class TestProjectUpdateCredentials(TestJobExecution):
def test_awx_task_env(self, project_update, settings, private_data_dir, scm_type, execution_environment):
project_update.execution_environment = execution_environment
settings.AWX_TASK_ENV = {'FOO': 'BAR'}
task = tasks.jobs.RunProjectUpdate()
task = jobs.RunProjectUpdate()
task.instance = project_update
project_update.scm_type = scm_type
@ -1524,7 +1523,7 @@ class TestInventoryUpdateCredentials(TestJobExecution):
return InventoryUpdate(pk=1, execution_environment=execution_environment, inventory_source=InventorySource(pk=1, inventory=Inventory(pk=1)))
def test_source_without_credential(self, mocker, inventory_update, private_data_dir):
task = tasks.jobs.RunInventoryUpdate()
task = jobs.RunInventoryUpdate()
task.instance = inventory_update
inventory_update.source = 'ec2'
inventory_update.get_cloud_credential = mocker.Mock(return_value=None)
@ -1537,7 +1536,7 @@ class TestInventoryUpdateCredentials(TestJobExecution):
assert 'AWS_SECRET_ACCESS_KEY' not in env
def test_ec2_source(self, private_data_dir, inventory_update, mocker):
task = tasks.jobs.RunInventoryUpdate()
task = jobs.RunInventoryUpdate()
task.instance = inventory_update
aws = CredentialType.defaults['aws']()
inventory_update.source = 'ec2'
@ -1561,7 +1560,7 @@ class TestInventoryUpdateCredentials(TestJobExecution):
assert safe_env['AWS_SECRET_ACCESS_KEY'] == HIDDEN_PASSWORD
def test_vmware_source(self, inventory_update, private_data_dir, mocker):
task = tasks.jobs.RunInventoryUpdate()
task = jobs.RunInventoryUpdate()
task.instance = inventory_update
vmware = CredentialType.defaults['vmware']()
inventory_update.source = 'vmware'
@ -1589,7 +1588,7 @@ class TestInventoryUpdateCredentials(TestJobExecution):
env["VMWARE_VALIDATE_CERTS"] == "False",
def test_azure_rm_source_with_tenant(self, private_data_dir, inventory_update, mocker):
task = tasks.jobs.RunInventoryUpdate()
task = jobs.RunInventoryUpdate()
task.instance = inventory_update
azure_rm = CredentialType.defaults['azure_rm']()
inventory_update.source = 'azure_rm'
@ -1625,7 +1624,7 @@ class TestInventoryUpdateCredentials(TestJobExecution):
assert safe_env['AZURE_SECRET'] == HIDDEN_PASSWORD
def test_azure_rm_source_with_password(self, private_data_dir, inventory_update, mocker):
task = tasks.jobs.RunInventoryUpdate()
task = jobs.RunInventoryUpdate()
task.instance = inventory_update
azure_rm = CredentialType.defaults['azure_rm']()
inventory_update.source = 'azure_rm'
@ -1654,7 +1653,7 @@ class TestInventoryUpdateCredentials(TestJobExecution):
assert safe_env['AZURE_PASSWORD'] == HIDDEN_PASSWORD
def test_gce_source(self, inventory_update, private_data_dir, mocker):
task = tasks.jobs.RunInventoryUpdate()
task = jobs.RunInventoryUpdate()
task.instance = inventory_update
gce = CredentialType.defaults['gce']()
inventory_update.source = 'gce'
@ -1684,7 +1683,7 @@ class TestInventoryUpdateCredentials(TestJobExecution):
assert json_data['project_id'] == 'some-project'
def test_openstack_source(self, inventory_update, private_data_dir, mocker):
task = tasks.jobs.RunInventoryUpdate()
task = jobs.RunInventoryUpdate()
task.instance = inventory_update
openstack = CredentialType.defaults['openstack']()
inventory_update.source = 'openstack'
@ -1724,7 +1723,7 @@ class TestInventoryUpdateCredentials(TestJobExecution):
)
def test_satellite6_source(self, inventory_update, private_data_dir, mocker):
task = tasks.jobs.RunInventoryUpdate()
task = jobs.RunInventoryUpdate()
task.instance = inventory_update
satellite6 = CredentialType.defaults['satellite6']()
inventory_update.source = 'satellite6'
@ -1747,7 +1746,7 @@ class TestInventoryUpdateCredentials(TestJobExecution):
assert safe_env["FOREMAN_PASSWORD"] == HIDDEN_PASSWORD
def test_insights_source(self, inventory_update, private_data_dir, mocker):
task = tasks.jobs.RunInventoryUpdate()
task = jobs.RunInventoryUpdate()
task.instance = inventory_update
insights = CredentialType.defaults['insights']()
inventory_update.source = 'insights'
@ -1776,7 +1775,7 @@ class TestInventoryUpdateCredentials(TestJobExecution):
@pytest.mark.parametrize('verify', [True, False])
def test_tower_source(self, verify, inventory_update, private_data_dir, mocker):
task = tasks.jobs.RunInventoryUpdate()
task = jobs.RunInventoryUpdate()
task.instance = inventory_update
tower = CredentialType.defaults['controller']()
inventory_update.source = 'controller'
@ -1804,7 +1803,7 @@ class TestInventoryUpdateCredentials(TestJobExecution):
assert safe_env['CONTROLLER_PASSWORD'] == HIDDEN_PASSWORD
def test_tower_source_ssl_verify_empty(self, inventory_update, private_data_dir, mocker):
task = tasks.jobs.RunInventoryUpdate()
task = jobs.RunInventoryUpdate()
task.instance = inventory_update
tower = CredentialType.defaults['controller']()
inventory_update.source = 'controller'
@ -1832,7 +1831,7 @@ class TestInventoryUpdateCredentials(TestJobExecution):
assert env['TOWER_VERIFY_SSL'] == 'False'
def test_awx_task_env(self, inventory_update, private_data_dir, settings, mocker):
task = tasks.jobs.RunInventoryUpdate()
task = jobs.RunInventoryUpdate()
task.instance = inventory_update
gce = CredentialType.defaults['gce']()
inventory_update.source = 'gce'
@ -1883,7 +1882,7 @@ def test_aquire_lock_open_fail_logged(logging_getLogger, os_open):
logger = mock.Mock()
logging_getLogger.return_value = logger
ProjectUpdate = tasks.jobs.RunProjectUpdate()
ProjectUpdate = jobs.RunProjectUpdate()
with pytest.raises(OSError):
ProjectUpdate.acquire_lock(instance)
@ -1910,7 +1909,7 @@ def test_aquire_lock_acquisition_fail_logged(fcntl_lockf, logging_getLogger, os_
fcntl_lockf.side_effect = err
ProjectUpdate = tasks.jobs.RunProjectUpdate()
ProjectUpdate = jobs.RunProjectUpdate()
with pytest.raises(IOError):
ProjectUpdate.acquire_lock(instance)
os_close.assert_called_with(3)
@ -1947,7 +1946,7 @@ def test_notification_job_not_finished(logging_getLogger, mocker):
logging_getLogger.return_value = logger
with mocker.patch('awx.main.models.UnifiedJob.objects.get', uj):
tasks.system.handle_success_and_failure_notifications(1)
system.handle_success_and_failure_notifications(1)
assert logger.warn.called_with(f"Failed to even try to send notifications for job '{uj}' due to job not being in finished state.")
@ -1955,7 +1954,7 @@ def test_notification_job_finished(mocker):
uj = mocker.MagicMock(send_notification_templates=mocker.MagicMock(), finished=True)
with mocker.patch('awx.main.models.UnifiedJob.objects.get', mocker.MagicMock(return_value=uj)):
tasks.system.handle_success_and_failure_notifications(1)
system.handle_success_and_failure_notifications(1)
uj.send_notification_templates.assert_called()
@ -1964,7 +1963,7 @@ def test_job_run_no_ee():
proj = Project(pk=1, organization=org)
job = Job(project=proj, organization=org, inventory=Inventory(pk=1))
job.execution_environment = None
task = tasks.jobs.RunJob()
task = jobs.RunJob()
task.instance = job
task.update_model = mock.Mock(return_value=job)
task.model.objects.get = mock.Mock(return_value=job)
@ -1983,7 +1982,7 @@ def test_project_update_no_ee():
proj = Project(pk=1, organization=org)
project_update = ProjectUpdate(pk=1, project=proj, scm_type='git')
project_update.execution_environment = None
task = tasks.jobs.RunProjectUpdate()
task = jobs.RunProjectUpdate()
task.instance = project_update
with pytest.raises(RuntimeError) as e:

View File

@ -0,0 +1,40 @@
from django.db import transaction, DatabaseError
import logging
import time
logger = logging.getLogger('awx.main.tasks.utils')
def update_model(model, pk, _attempt=0, **updates):
"""Reload the model instance from the database and update the
given fields.
"""
try:
with transaction.atomic():
# Retrieve the model instance.
instance = model.objects.get(pk=pk)
# Update the appropriate fields and save the model
# instance, then return the new instance.
if updates:
update_fields = ['modified']
for field, value in updates.items():
setattr(instance, field, value)
update_fields.append(field)
if field == 'status':
update_fields.append('failed')
instance.save(update_fields=update_fields)
return instance
except DatabaseError as e:
# Log out the error to the debug logger.
logger.debug('Database error updating %s, retrying in 5 seconds (retry #%d): %s', model._meta.object_name, _attempt + 1, e)
# Attempt to retry the update, assuming we haven't already
# tried too many times.
if _attempt < 5:
time.sleep(5)
return update_model(model, pk, _attempt=_attempt + 1, **updates)
else:
logger.error('Failed to update %s after %d retries.', model._meta.object_name, _attempt)