diff --git a/awx/main/models/unified_jobs.py b/awx/main/models/unified_jobs.py index 350abd908d..d5fb23056d 100644 --- a/awx/main/models/unified_jobs.py +++ b/awx/main/models/unified_jobs.py @@ -10,7 +10,6 @@ import json import logging import os import re -import socket import subprocess import tempfile from collections import OrderedDict @@ -1488,40 +1487,17 @@ class UnifiedJob( return 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (self.model_to_str(), self.name, self.id) return None - def fallback_cancel(self): - if not self.celery_task_id: - self.refresh_from_db(fields=['celery_task_id']) - self.cancel_dispatcher_process() - def cancel_dispatcher_process(self): """Returns True if dispatcher running this job acknowledged request and sent SIGTERM""" if not self.celery_task_id: return False - # Special case for task manager (used during workflow job cancellation) - if not connection.get_autocommit(): - try: - - ctl = get_control_from_settings() - ctl.control('cancel', data={'uuid': self.celery_task_id}) - except Exception: - logger.exception("Error sending cancel command to dispatcher") - return True # task manager itself needs to act under assumption that cancel was received - - # Standard case with reply try: - timeout = 5 - - ctl = get_control_from_settings() - results = ctl.control_with_reply('cancel', data={'uuid': self.celery_task_id}, expected_replies=1, timeout=timeout) - # Check if cancel was successful by checking if we got any results - return bool(results and len(results) > 0) - except socket.timeout: - logger.error(f'could not reach dispatcher on {self.controller_node} within {timeout}s') + logger.info(f'Sending cancel message to pg_notify channel {self.controller_node} for task {self.celery_task_id}') + ctl = get_control_from_settings(default_publish_channel=self.controller_node) + ctl.control('cancel', data={'uuid': self.celery_task_id}) except Exception: - logger.exception("error encountered when checking task status") - - return False # whether confirmation was obtained + logger.exception("Error sending cancel command to dispatcher") def cancel(self, job_explanation=None, is_chain=False): if self.can_cancel: @@ -1544,19 +1520,13 @@ class UnifiedJob( # the job control process will use the cancel_flag to distinguish a shutdown from a cancel self.save(update_fields=cancel_fields) - controller_notified = False - if self.celery_task_id: - controller_notified = self.cancel_dispatcher_process() + # Be extra sure we have the task id, in case job is transitioning into running right now + if not self.celery_task_id: + self.refresh_from_db(fields=['celery_task_id', 'controller_node']) - # If a SIGTERM signal was sent to the control process, and acked by the dispatcher - # then we want to let its own cleanup change status, otherwise change status now - if not controller_notified: - if self.status != 'canceled': - self.status = 'canceled' - self.save(update_fields=['status']) - # Avoid race condition where we have stale model from pending state but job has already started, - # its checking signal but not cancel_flag, so re-send signal after updating cancel fields - self.fallback_cancel() + # send pg_notify message to cancel, will not send until transaction completes + if self.celery_task_id: + self.cancel_dispatcher_process() return self.cancel_flag diff --git a/awx/main/models/workflow.py b/awx/main/models/workflow.py index 136f49f86a..1753dff26d 100644 --- a/awx/main/models/workflow.py +++ b/awx/main/models/workflow.py @@ -785,7 +785,7 @@ class WorkflowJob(UnifiedJob, WorkflowJobOptions, SurveyJobMixin, JobNotificatio def cancel_dispatcher_process(self): # WorkflowJobs don't _actually_ run anything in the dispatcher, so # there's no point in asking the dispatcher if it knows about this task - return True + return class WorkflowApprovalTemplate(UnifiedJobTemplate, RelatedJobsMixin): diff --git a/awx/main/tasks/signals.py b/awx/main/tasks/signals.py index 7bcb57e0e5..a1607bfc99 100644 --- a/awx/main/tasks/signals.py +++ b/awx/main/tasks/signals.py @@ -69,7 +69,7 @@ def signal_callback(): def with_signal_handling(f): """ - Change signal handling to make signal_callback return True in event of SIGTERM or SIGINT. + Change signal handling to make signal_callback return True in event of SIGTERM, SIGINT, or SIGUSR1. """ @functools.wraps(f) diff --git a/awx/main/tests/functional/models/test_unified_job.py b/awx/main/tests/functional/models/test_unified_job.py index 5e37fc985c..0618085cef 100644 --- a/awx/main/tests/functional/models/test_unified_job.py +++ b/awx/main/tests/functional/models/test_unified_job.py @@ -1,5 +1,6 @@ import itertools import pytest +from uuid import uuid4 # CRUM from crum import impersonate @@ -33,6 +34,64 @@ def test_soft_unique_together(post, project, admin_user): assert 'combination already exists' in str(r.data) +@pytest.mark.django_db +class TestJobCancel: + """ + Coverage for UnifiedJob.cancel, focused on interaction with dispatcherd objects. + Using mocks for the dispatcherd objects, because tests by default use a no-op broker. + """ + + def test_cancel_sets_flag_and_clears_start_args(self, mocker): + job = Job.objects.create(status='running', name='foo-job', celery_task_id=str(uuid4()), controller_node='foo', start_args='{"secret": "value"}') + job.websocket_emit_status = mocker.MagicMock() + + assert job.can_cancel is True + assert job.cancel_flag is False + + job.cancel() + job.refresh_from_db() + + assert job.cancel_flag is True + assert job.start_args == '' + + def test_cancel_sets_job_explanation(self, mocker): + job = Job.objects.create(status='running', name='foo-job', celery_task_id=str(uuid4()), controller_node='foo') + job.websocket_emit_status = mocker.MagicMock() + job_explanation = 'giggity giggity' + + job.cancel(job_explanation=job_explanation) + job.refresh_from_db() + + assert job.job_explanation == job_explanation + + def test_cancel_sends_control_message(self, mocker): + celery_task_id = str(uuid4()) + job = Job.objects.create(status='running', name='foo-job', celery_task_id=celery_task_id, controller_node='foo') + job.websocket_emit_status = mocker.MagicMock() + control = mocker.MagicMock() + get_control = mocker.patch('awx.main.models.unified_jobs.get_control_from_settings', return_value=control) + + job.cancel() + + get_control.assert_called_once_with(default_publish_channel='foo') + control.control.assert_called_once_with('cancel', data={'uuid': celery_task_id}) + + def test_cancel_refreshes_task_id_before_sending_control(self, mocker): + job = Job.objects.create(status='pending', name='foo-job', celery_task_id='', controller_node='bar') + job.websocket_emit_status = mocker.MagicMock() + celery_task_id = str(uuid4()) + Job.objects.filter(pk=job.pk).update(status='running', celery_task_id=celery_task_id) + control = mocker.MagicMock() + get_control = mocker.patch('awx.main.models.unified_jobs.get_control_from_settings', return_value=control) + refresh_spy = mocker.spy(job, 'refresh_from_db') + + job.cancel() + + refresh_spy.assert_called_once_with(fields=['celery_task_id', 'controller_node']) + get_control.assert_called_once_with(default_publish_channel='bar') + control.control.assert_called_once_with('cancel', data={'uuid': celery_task_id}) + + @pytest.mark.django_db class TestCreateUnifiedJob: """ diff --git a/awx/main/tests/unit/models/test_unified_job_unit.py b/awx/main/tests/unit/models/test_unified_job_unit.py index b6080f55f7..2fa8807dff 100644 --- a/awx/main/tests/unit/models/test_unified_job_unit.py +++ b/awx/main/tests/unit/models/test_unified_job_unit.py @@ -1,4 +1,3 @@ -import pytest from unittest import mock from awx.main.models import UnifiedJob, UnifiedJobTemplate, WorkflowJob, WorkflowJobNode, WorkflowApprovalTemplate, Job, User, Project, JobTemplate, Inventory @@ -22,52 +21,6 @@ def test_unified_job_workflow_attributes(): assert job.workflow_job_id == 1 -def mock_on_commit(f): - f() - - -@pytest.fixture -def unified_job(mocker): - mocker.patch.object(UnifiedJob, 'can_cancel', return_value=True) - j = UnifiedJob() - j.status = 'pending' - j.cancel_flag = None - j.save = mocker.MagicMock() - j.websocket_emit_status = mocker.MagicMock() - j.fallback_cancel = mocker.MagicMock() - return j - - -def test_cancel(unified_job): - with mock.patch('awx.main.models.unified_jobs.connection.on_commit', wraps=mock_on_commit): - unified_job.cancel() - - assert unified_job.cancel_flag is True - assert unified_job.status == 'canceled' - assert unified_job.job_explanation == '' - # Note: the websocket emit status check is just reflecting the state of the current code. - # Some more thought may want to go into only emitting canceled if/when the job record - # status is changed to canceled. Unlike, currently, where it's emitted unconditionally. - unified_job.websocket_emit_status.assert_called_with("canceled") - assert [(args, kwargs) for args, kwargs in unified_job.save.call_args_list] == [ - ((), {'update_fields': ['cancel_flag', 'start_args']}), - ((), {'update_fields': ['status']}), - ] - - -def test_cancel_job_explanation(unified_job): - job_explanation = 'giggity giggity' - - with mock.patch('awx.main.models.unified_jobs.connection.on_commit'): - unified_job.cancel(job_explanation=job_explanation) - - assert unified_job.job_explanation == job_explanation - assert [(args, kwargs) for args, kwargs in unified_job.save.call_args_list] == [ - ((), {'update_fields': ['cancel_flag', 'start_args', 'job_explanation']}), - ((), {'update_fields': ['status']}), - ] - - def test_organization_copy_to_jobs(): """ All unified job types should infer their organization from their template organization diff --git a/awx/main/tests/unit/tasks/test_signals.py b/awx/main/tests/unit/tasks/test_signals.py index 2a63b30d38..f089ea749d 100644 --- a/awx/main/tests/unit/tasks/test_signals.py +++ b/awx/main/tests/unit/tasks/test_signals.py @@ -12,6 +12,10 @@ def pytest_sigterm(): pytest_sigterm.called_count += 1 +def pytest_sigusr1(): + pytest_sigusr1.called_count += 1 + + def tmp_signals_for_test(func): """ When we run our internal signal handlers, it will call the original signal @@ -26,13 +30,17 @@ def tmp_signals_for_test(func): def wrapper(): original_sigterm = signal.getsignal(signal.SIGTERM) original_sigint = signal.getsignal(signal.SIGINT) + original_sigusr1 = signal.getsignal(signal.SIGUSR1) signal.signal(signal.SIGTERM, pytest_sigterm) signal.signal(signal.SIGINT, pytest_sigint) + signal.signal(signal.SIGUSR1, pytest_sigusr1) pytest_sigterm.called_count = 0 pytest_sigint.called_count = 0 + pytest_sigusr1.called_count = 0 func() signal.signal(signal.SIGTERM, original_sigterm) signal.signal(signal.SIGINT, original_sigint) + signal.signal(signal.SIGUSR1, original_sigusr1) return wrapper @@ -58,11 +66,13 @@ def test_outer_inner_signal_handling(): assert signal_callback() is False assert pytest_sigterm.called_count == 0 assert pytest_sigint.called_count == 0 + assert pytest_sigusr1.called_count == 0 f1() assert signal_callback() is False assert signal.getsignal(signal.SIGTERM) is original_sigterm assert pytest_sigterm.called_count == 1 assert pytest_sigint.called_count == 0 + assert pytest_sigusr1.called_count == 0 @tmp_signals_for_test @@ -87,8 +97,31 @@ def test_inner_outer_signal_handling(): assert signal_callback() is False assert pytest_sigterm.called_count == 0 assert pytest_sigint.called_count == 0 + assert pytest_sigusr1.called_count == 0 f1() assert signal_callback() is False assert signal.getsignal(signal.SIGTERM) is original_sigterm assert pytest_sigterm.called_count == 0 assert pytest_sigint.called_count == 1 + assert pytest_sigusr1.called_count == 0 + + +@tmp_signals_for_test +def test_sigusr1_signal_handling(): + @with_signal_handling + def f1(): + assert signal_callback() is False + signal_state.set_signal_flag(for_signal=signal.SIGUSR1) + assert signal_callback() + + original_sigusr1 = signal.getsignal(signal.SIGUSR1) + assert signal_callback() is False + assert pytest_sigterm.called_count == 0 + assert pytest_sigint.called_count == 0 + assert pytest_sigusr1.called_count == 0 + f1() + assert signal_callback() is False + assert signal.getsignal(signal.SIGUSR1) is original_sigusr1 + assert pytest_sigterm.called_count == 0 + assert pytest_sigint.called_count == 0 + assert pytest_sigusr1.called_count == 1