From fd671ecc9d49845f1e0fa09ec178d1d6518061fe Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Thu, 30 Jun 2022 13:20:08 -0400 Subject: [PATCH] Give specific messages if job was killed due to SIGTERM or SIGKILL (#12435) * Reap jobs on dispatcher startup to increase clarity, replace existing reaping logic * Exit jobs if receiving SIGTERM signal * Fix unwanted reaping on shutdown, let subprocess close out * Add some sanity tests for signal module * Add a log for an unhandled dispatcher error * Refine wording of error messages Co-authored-by: Elijah DeLee --- awx/main/dispatch/reaper.py | 21 +++++++ awx/main/dispatch/worker/base.py | 9 ++- .../management/commands/run_dispatcher.py | 2 +- awx/main/tasks/callback.py | 9 ++- awx/main/tasks/jobs.py | 9 ++- awx/main/tasks/signals.py | 63 +++++++++++++++++++ awx/main/tasks/system.py | 4 -- awx/main/tests/unit/tasks/test_signals.py | 50 +++++++++++++++ awx/main/utils/update_model.py | 7 ++- 9 files changed, 164 insertions(+), 10 deletions(-) create mode 100644 awx/main/tasks/signals.py create mode 100644 awx/main/tests/unit/tasks/test_signals.py diff --git a/awx/main/dispatch/reaper.py b/awx/main/dispatch/reaper.py index 1a1fb2a40e..a86664b80c 100644 --- a/awx/main/dispatch/reaper.py +++ b/awx/main/dispatch/reaper.py @@ -10,6 +10,27 @@ from awx.main.models import Instance, UnifiedJob, WorkflowJob logger = logging.getLogger('awx.main.dispatch') +def startup_reaping(): + """ + If this particular instance is starting, then we know that any running jobs are invalid + so we will reap those jobs as a special action here + """ + me = Instance.objects.me() + jobs = UnifiedJob.objects.filter(status='running', controller_node=me.hostname) + job_ids = [] + for j in jobs: + job_ids.append(j.id) + j.status = 'failed' + j.start_args = '' + j.job_explanation += 'Task was marked as running at system start up. The system must have not shut down properly, so it has been marked as failed.' + j.save(update_fields=['status', 'start_args', 'job_explanation']) + if hasattr(j, 'send_notification_templates'): + j.send_notification_templates('failed') + j.websocket_emit_status('failed') + if job_ids: + logger.error(f'Unified jobs {job_ids} were reaped on dispatch startup') + + def reap_job(j, status): if UnifiedJob.objects.get(id=j.id).status not in ('running', 'waiting'): # just in case, don't reap jobs that aren't running diff --git a/awx/main/dispatch/worker/base.py b/awx/main/dispatch/worker/base.py index 34443b70b2..46418828b6 100644 --- a/awx/main/dispatch/worker/base.py +++ b/awx/main/dispatch/worker/base.py @@ -169,8 +169,9 @@ class AWXConsumerPG(AWXConsumerBase): logger.exception(f"Error consuming new events from postgres, will retry for {self.pg_max_wait} s") self.pg_down_time = time.time() self.pg_is_down = True - if time.time() - self.pg_down_time > self.pg_max_wait: - logger.warning(f"Postgres event consumer has not recovered in {self.pg_max_wait} s, exiting") + current_downtime = time.time() - self.pg_down_time + if current_downtime > self.pg_max_wait: + logger.exception(f"Postgres event consumer has not recovered in {current_downtime} s, exiting") raise # Wait for a second before next attempt, but still listen for any shutdown signals for i in range(10): @@ -179,6 +180,10 @@ class AWXConsumerPG(AWXConsumerBase): time.sleep(0.1) for conn in db.connections.all(): conn.close_if_unusable_or_obsolete() + except Exception: + # Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata + logger.exception('Encountered unhandled error in dispatcher main loop') + raise class BaseWorker(object): diff --git a/awx/main/management/commands/run_dispatcher.py b/awx/main/management/commands/run_dispatcher.py index bafe27cdaf..e4d17f2aed 100644 --- a/awx/main/management/commands/run_dispatcher.py +++ b/awx/main/management/commands/run_dispatcher.py @@ -53,7 +53,7 @@ class Command(BaseCommand): # (like the node heartbeat) periodic.run_continuously() - reaper.reap() + reaper.startup_reaping() consumer = None try: diff --git a/awx/main/tasks/callback.py b/awx/main/tasks/callback.py index fa37055ac2..a4a02421a0 100644 --- a/awx/main/tasks/callback.py +++ b/awx/main/tasks/callback.py @@ -16,6 +16,7 @@ from awx.main.redact import UriCleaner from awx.main.constants import MINIMAL_EVENTS, ANSIBLE_RUNNER_NEEDS_UPDATE_MESSAGE from awx.main.utils.update_model import update_model from awx.main.queue import CallbackQueueDispatcher +from awx.main.tasks.signals import signal_callback logger = logging.getLogger('awx.main.tasks.callback') @@ -179,7 +180,13 @@ class RunnerCallback: Ansible runner callback to tell the job when/if it is canceled """ unified_job_id = self.instance.pk - self.instance = self.update_model(unified_job_id) + if signal_callback(): + return True + try: + self.instance = self.update_model(unified_job_id) + except Exception: + logger.exception(f'Encountered error during cancel check for {unified_job_id}, canceling now') + return True if not self.instance: logger.error('unified job {} was deleted while running, canceling'.format(unified_job_id)) return True diff --git a/awx/main/tasks/jobs.py b/awx/main/tasks/jobs.py index 9c58e1eebe..63a5681666 100644 --- a/awx/main/tasks/jobs.py +++ b/awx/main/tasks/jobs.py @@ -62,6 +62,7 @@ from awx.main.tasks.callback import ( RunnerCallbackForProjectUpdate, RunnerCallbackForSystemJob, ) +from awx.main.tasks.signals import with_signal_handling, signal_callback from awx.main.tasks.receptor import AWXReceptorJob from awx.main.exceptions import AwxTaskError, PostRunError, ReceptorNodeNotFound from awx.main.utils.ansible import read_ansible_config @@ -392,6 +393,7 @@ class BaseTask(object): instance.save(update_fields=['ansible_version']) @with_path_cleanup + @with_signal_handling def run(self, pk, **kwargs): """ Run the job/task and capture its output. @@ -423,7 +425,7 @@ class BaseTask(object): private_data_dir = self.build_private_data_dir(self.instance) self.pre_run_hook(self.instance, private_data_dir) self.instance.log_lifecycle("preparing_playbook") - if self.instance.cancel_flag: + if self.instance.cancel_flag or signal_callback(): self.instance = self.update_model(self.instance.pk, status='canceled') if self.instance.status != 'running': # Stop the task chain and prevent starting the job if it has @@ -545,6 +547,11 @@ class BaseTask(object): self.runner_callback.delay_update(skip_if_already_set=True, job_explanation=f"Job terminated due to {status}") if status == 'timeout': status = 'failed' + elif status == 'canceled': + self.instance = self.update_model(pk) + if (getattr(self.instance, 'cancel_flag', False) is False) and signal_callback(): + self.runner_callback.delay_update(job_explanation="Task was canceled due to receiving a shutdown signal.") + status = 'failed' except ReceptorNodeNotFound as exc: self.runner_callback.delay_update(job_explanation=str(exc)) except Exception: diff --git a/awx/main/tasks/signals.py b/awx/main/tasks/signals.py new file mode 100644 index 0000000000..6f0c69ca4c --- /dev/null +++ b/awx/main/tasks/signals.py @@ -0,0 +1,63 @@ +import signal +import functools +import logging + + +logger = logging.getLogger('awx.main.tasks.signals') + + +__all__ = ['with_signal_handling', 'signal_callback'] + + +class SignalState: + def reset(self): + self.sigterm_flag = False + self.is_active = False + self.original_sigterm = None + self.original_sigint = None + + def __init__(self): + self.reset() + + def set_flag(self, *args): + """Method to pass into the python signal.signal method to receive signals""" + self.sigterm_flag = True + + def connect_signals(self): + self.original_sigterm = signal.getsignal(signal.SIGTERM) + self.original_sigint = signal.getsignal(signal.SIGINT) + signal.signal(signal.SIGTERM, self.set_flag) + signal.signal(signal.SIGINT, self.set_flag) + self.is_active = True + + def restore_signals(self): + signal.signal(signal.SIGTERM, self.original_sigterm) + signal.signal(signal.SIGINT, self.original_sigint) + self.reset() + + +signal_state = SignalState() + + +def signal_callback(): + return signal_state.sigterm_flag + + +def with_signal_handling(f): + """ + Change signal handling to make signal_callback return True in event of SIGTERM or SIGINT. + """ + + @functools.wraps(f) + def _wrapped(*args, **kwargs): + try: + this_is_outermost_caller = False + if not signal_state.is_active: + signal_state.connect_signals() + this_is_outermost_caller = True + return f(*args, **kwargs) + finally: + if this_is_outermost_caller: + signal_state.restore_signals() + + return _wrapped diff --git a/awx/main/tasks/system.py b/awx/main/tasks/system.py index 541415f2b8..b828326339 100644 --- a/awx/main/tasks/system.py +++ b/awx/main/tasks/system.py @@ -114,10 +114,6 @@ def inform_cluster_of_shutdown(): try: this_inst = Instance.objects.get(hostname=settings.CLUSTER_HOST_ID) this_inst.mark_offline(update_last_seen=True, errors=_('Instance received normal shutdown signal')) - try: - reaper.reap(this_inst) - except Exception: - logger.exception('failed to reap jobs for {}'.format(this_inst.hostname)) logger.warning('Normal shutdown signal for instance {}, ' 'removed self from capacity pool.'.format(this_inst.hostname)) except Exception: logger.exception('Encountered problem with normal shutdown signal.') diff --git a/awx/main/tests/unit/tasks/test_signals.py b/awx/main/tests/unit/tasks/test_signals.py new file mode 100644 index 0000000000..a435b8a660 --- /dev/null +++ b/awx/main/tests/unit/tasks/test_signals.py @@ -0,0 +1,50 @@ +import signal + +from awx.main.tasks.signals import signal_state, signal_callback, with_signal_handling + + +def test_outer_inner_signal_handling(): + """ + Even if the flag is set in the outer context, its value should persist in the inner context + """ + + @with_signal_handling + def f2(): + assert signal_callback() + + @with_signal_handling + def f1(): + assert signal_callback() is False + signal_state.set_flag() + assert signal_callback() + f2() + + original_sigterm = signal.getsignal(signal.SIGTERM) + assert signal_callback() is False + f1() + assert signal_callback() is False + assert signal.getsignal(signal.SIGTERM) is original_sigterm + + +def test_inner_outer_signal_handling(): + """ + Even if the flag is set in the inner context, its value should persist in the outer context + """ + + @with_signal_handling + def f2(): + assert signal_callback() is False + signal_state.set_flag() + assert signal_callback() + + @with_signal_handling + def f1(): + assert signal_callback() is False + f2() + assert signal_callback() + + original_sigterm = signal.getsignal(signal.SIGTERM) + assert signal_callback() is False + f1() + assert signal_callback() is False + assert signal.getsignal(signal.SIGTERM) is original_sigterm diff --git a/awx/main/utils/update_model.py b/awx/main/utils/update_model.py index 7d03b3964b..80d930e2c5 100644 --- a/awx/main/utils/update_model.py +++ b/awx/main/utils/update_model.py @@ -3,6 +3,8 @@ from django.db import transaction, DatabaseError, InterfaceError import logging import time +from awx.main.tasks.signals import signal_callback + logger = logging.getLogger('awx.main.tasks.utils') @@ -37,7 +39,10 @@ def update_model(model, pk, _attempt=0, _max_attempts=5, select_for_update=False # Attempt to retry the update, assuming we haven't already # tried too many times. if _attempt < _max_attempts: - time.sleep(5) + for i in range(5): + time.sleep(1) + if signal_callback(): + raise RuntimeError(f'Could not fetch {pk} because of receiving abort signal') return update_model(model, pk, _attempt=_attempt + 1, _max_attempts=_max_attempts, **updates) else: logger.error('Failed to update %s after %d retries.', model._meta.object_name, _attempt)