AAP-64221 Fix broken cancel logic with dispatcherd (#16247)

* Fix broken cancel logic with dispatcherd

Update tests for UnifiedJob

Update test assertion

* Further simply cancel path
This commit is contained in:
Alan Rominger
2026-01-27 14:39:08 -05:00
committed by GitHub
parent 823b736afe
commit 1128ad5a57
6 changed files with 104 additions and 89 deletions

View File

@@ -10,7 +10,6 @@ import json
import logging import logging
import os import os
import re import re
import socket
import subprocess import subprocess
import tempfile import tempfile
from collections import OrderedDict from collections import OrderedDict
@@ -1488,40 +1487,17 @@ class UnifiedJob(
return 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (self.model_to_str(), self.name, self.id) return 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (self.model_to_str(), self.name, self.id)
return None return None
def fallback_cancel(self):
if not self.celery_task_id:
self.refresh_from_db(fields=['celery_task_id'])
self.cancel_dispatcher_process()
def cancel_dispatcher_process(self): def cancel_dispatcher_process(self):
"""Returns True if dispatcher running this job acknowledged request and sent SIGTERM""" """Returns True if dispatcher running this job acknowledged request and sent SIGTERM"""
if not self.celery_task_id: if not self.celery_task_id:
return False return False
# Special case for task manager (used during workflow job cancellation)
if not connection.get_autocommit():
try:
ctl = get_control_from_settings()
ctl.control('cancel', data={'uuid': self.celery_task_id})
except Exception:
logger.exception("Error sending cancel command to dispatcher")
return True # task manager itself needs to act under assumption that cancel was received
# Standard case with reply
try: try:
timeout = 5 logger.info(f'Sending cancel message to pg_notify channel {self.controller_node} for task {self.celery_task_id}')
ctl = get_control_from_settings(default_publish_channel=self.controller_node)
ctl = get_control_from_settings() ctl.control('cancel', data={'uuid': self.celery_task_id})
results = ctl.control_with_reply('cancel', data={'uuid': self.celery_task_id}, expected_replies=1, timeout=timeout)
# Check if cancel was successful by checking if we got any results
return bool(results and len(results) > 0)
except socket.timeout:
logger.error(f'could not reach dispatcher on {self.controller_node} within {timeout}s')
except Exception: except Exception:
logger.exception("error encountered when checking task status") logger.exception("Error sending cancel command to dispatcher")
return False # whether confirmation was obtained
def cancel(self, job_explanation=None, is_chain=False): def cancel(self, job_explanation=None, is_chain=False):
if self.can_cancel: if self.can_cancel:
@@ -1544,19 +1520,13 @@ class UnifiedJob(
# the job control process will use the cancel_flag to distinguish a shutdown from a cancel # the job control process will use the cancel_flag to distinguish a shutdown from a cancel
self.save(update_fields=cancel_fields) self.save(update_fields=cancel_fields)
controller_notified = False # Be extra sure we have the task id, in case job is transitioning into running right now
if self.celery_task_id: if not self.celery_task_id:
controller_notified = self.cancel_dispatcher_process() self.refresh_from_db(fields=['celery_task_id', 'controller_node'])
# If a SIGTERM signal was sent to the control process, and acked by the dispatcher # send pg_notify message to cancel, will not send until transaction completes
# then we want to let its own cleanup change status, otherwise change status now if self.celery_task_id:
if not controller_notified: self.cancel_dispatcher_process()
if self.status != 'canceled':
self.status = 'canceled'
self.save(update_fields=['status'])
# Avoid race condition where we have stale model from pending state but job has already started,
# its checking signal but not cancel_flag, so re-send signal after updating cancel fields
self.fallback_cancel()
return self.cancel_flag return self.cancel_flag

View File

@@ -785,7 +785,7 @@ class WorkflowJob(UnifiedJob, WorkflowJobOptions, SurveyJobMixin, JobNotificatio
def cancel_dispatcher_process(self): def cancel_dispatcher_process(self):
# WorkflowJobs don't _actually_ run anything in the dispatcher, so # WorkflowJobs don't _actually_ run anything in the dispatcher, so
# there's no point in asking the dispatcher if it knows about this task # there's no point in asking the dispatcher if it knows about this task
return True return
class WorkflowApprovalTemplate(UnifiedJobTemplate, RelatedJobsMixin): class WorkflowApprovalTemplate(UnifiedJobTemplate, RelatedJobsMixin):

View File

@@ -69,7 +69,7 @@ def signal_callback():
def with_signal_handling(f): def with_signal_handling(f):
""" """
Change signal handling to make signal_callback return True in event of SIGTERM or SIGINT. Change signal handling to make signal_callback return True in event of SIGTERM, SIGINT, or SIGUSR1.
""" """
@functools.wraps(f) @functools.wraps(f)

View File

@@ -1,5 +1,6 @@
import itertools import itertools
import pytest import pytest
from uuid import uuid4
# CRUM # CRUM
from crum import impersonate from crum import impersonate
@@ -33,6 +34,64 @@ def test_soft_unique_together(post, project, admin_user):
assert 'combination already exists' in str(r.data) assert 'combination already exists' in str(r.data)
@pytest.mark.django_db
class TestJobCancel:
"""
Coverage for UnifiedJob.cancel, focused on interaction with dispatcherd objects.
Using mocks for the dispatcherd objects, because tests by default use a no-op broker.
"""
def test_cancel_sets_flag_and_clears_start_args(self, mocker):
job = Job.objects.create(status='running', name='foo-job', celery_task_id=str(uuid4()), controller_node='foo', start_args='{"secret": "value"}')
job.websocket_emit_status = mocker.MagicMock()
assert job.can_cancel is True
assert job.cancel_flag is False
job.cancel()
job.refresh_from_db()
assert job.cancel_flag is True
assert job.start_args == ''
def test_cancel_sets_job_explanation(self, mocker):
job = Job.objects.create(status='running', name='foo-job', celery_task_id=str(uuid4()), controller_node='foo')
job.websocket_emit_status = mocker.MagicMock()
job_explanation = 'giggity giggity'
job.cancel(job_explanation=job_explanation)
job.refresh_from_db()
assert job.job_explanation == job_explanation
def test_cancel_sends_control_message(self, mocker):
celery_task_id = str(uuid4())
job = Job.objects.create(status='running', name='foo-job', celery_task_id=celery_task_id, controller_node='foo')
job.websocket_emit_status = mocker.MagicMock()
control = mocker.MagicMock()
get_control = mocker.patch('awx.main.models.unified_jobs.get_control_from_settings', return_value=control)
job.cancel()
get_control.assert_called_once_with(default_publish_channel='foo')
control.control.assert_called_once_with('cancel', data={'uuid': celery_task_id})
def test_cancel_refreshes_task_id_before_sending_control(self, mocker):
job = Job.objects.create(status='pending', name='foo-job', celery_task_id='', controller_node='bar')
job.websocket_emit_status = mocker.MagicMock()
celery_task_id = str(uuid4())
Job.objects.filter(pk=job.pk).update(status='running', celery_task_id=celery_task_id)
control = mocker.MagicMock()
get_control = mocker.patch('awx.main.models.unified_jobs.get_control_from_settings', return_value=control)
refresh_spy = mocker.spy(job, 'refresh_from_db')
job.cancel()
refresh_spy.assert_called_once_with(fields=['celery_task_id', 'controller_node'])
get_control.assert_called_once_with(default_publish_channel='bar')
control.control.assert_called_once_with('cancel', data={'uuid': celery_task_id})
@pytest.mark.django_db @pytest.mark.django_db
class TestCreateUnifiedJob: class TestCreateUnifiedJob:
""" """

View File

@@ -1,4 +1,3 @@
import pytest
from unittest import mock from unittest import mock
from awx.main.models import UnifiedJob, UnifiedJobTemplate, WorkflowJob, WorkflowJobNode, WorkflowApprovalTemplate, Job, User, Project, JobTemplate, Inventory from awx.main.models import UnifiedJob, UnifiedJobTemplate, WorkflowJob, WorkflowJobNode, WorkflowApprovalTemplate, Job, User, Project, JobTemplate, Inventory
@@ -22,52 +21,6 @@ def test_unified_job_workflow_attributes():
assert job.workflow_job_id == 1 assert job.workflow_job_id == 1
def mock_on_commit(f):
f()
@pytest.fixture
def unified_job(mocker):
mocker.patch.object(UnifiedJob, 'can_cancel', return_value=True)
j = UnifiedJob()
j.status = 'pending'
j.cancel_flag = None
j.save = mocker.MagicMock()
j.websocket_emit_status = mocker.MagicMock()
j.fallback_cancel = mocker.MagicMock()
return j
def test_cancel(unified_job):
with mock.patch('awx.main.models.unified_jobs.connection.on_commit', wraps=mock_on_commit):
unified_job.cancel()
assert unified_job.cancel_flag is True
assert unified_job.status == 'canceled'
assert unified_job.job_explanation == ''
# Note: the websocket emit status check is just reflecting the state of the current code.
# Some more thought may want to go into only emitting canceled if/when the job record
# status is changed to canceled. Unlike, currently, where it's emitted unconditionally.
unified_job.websocket_emit_status.assert_called_with("canceled")
assert [(args, kwargs) for args, kwargs in unified_job.save.call_args_list] == [
((), {'update_fields': ['cancel_flag', 'start_args']}),
((), {'update_fields': ['status']}),
]
def test_cancel_job_explanation(unified_job):
job_explanation = 'giggity giggity'
with mock.patch('awx.main.models.unified_jobs.connection.on_commit'):
unified_job.cancel(job_explanation=job_explanation)
assert unified_job.job_explanation == job_explanation
assert [(args, kwargs) for args, kwargs in unified_job.save.call_args_list] == [
((), {'update_fields': ['cancel_flag', 'start_args', 'job_explanation']}),
((), {'update_fields': ['status']}),
]
def test_organization_copy_to_jobs(): def test_organization_copy_to_jobs():
""" """
All unified job types should infer their organization from their template organization All unified job types should infer their organization from their template organization

View File

@@ -12,6 +12,10 @@ def pytest_sigterm():
pytest_sigterm.called_count += 1 pytest_sigterm.called_count += 1
def pytest_sigusr1():
pytest_sigusr1.called_count += 1
def tmp_signals_for_test(func): def tmp_signals_for_test(func):
""" """
When we run our internal signal handlers, it will call the original signal When we run our internal signal handlers, it will call the original signal
@@ -26,13 +30,17 @@ def tmp_signals_for_test(func):
def wrapper(): def wrapper():
original_sigterm = signal.getsignal(signal.SIGTERM) original_sigterm = signal.getsignal(signal.SIGTERM)
original_sigint = signal.getsignal(signal.SIGINT) original_sigint = signal.getsignal(signal.SIGINT)
original_sigusr1 = signal.getsignal(signal.SIGUSR1)
signal.signal(signal.SIGTERM, pytest_sigterm) signal.signal(signal.SIGTERM, pytest_sigterm)
signal.signal(signal.SIGINT, pytest_sigint) signal.signal(signal.SIGINT, pytest_sigint)
signal.signal(signal.SIGUSR1, pytest_sigusr1)
pytest_sigterm.called_count = 0 pytest_sigterm.called_count = 0
pytest_sigint.called_count = 0 pytest_sigint.called_count = 0
pytest_sigusr1.called_count = 0
func() func()
signal.signal(signal.SIGTERM, original_sigterm) signal.signal(signal.SIGTERM, original_sigterm)
signal.signal(signal.SIGINT, original_sigint) signal.signal(signal.SIGINT, original_sigint)
signal.signal(signal.SIGUSR1, original_sigusr1)
return wrapper return wrapper
@@ -58,11 +66,13 @@ def test_outer_inner_signal_handling():
assert signal_callback() is False assert signal_callback() is False
assert pytest_sigterm.called_count == 0 assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0 assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 0
f1() f1()
assert signal_callback() is False assert signal_callback() is False
assert signal.getsignal(signal.SIGTERM) is original_sigterm assert signal.getsignal(signal.SIGTERM) is original_sigterm
assert pytest_sigterm.called_count == 1 assert pytest_sigterm.called_count == 1
assert pytest_sigint.called_count == 0 assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 0
@tmp_signals_for_test @tmp_signals_for_test
@@ -87,8 +97,31 @@ def test_inner_outer_signal_handling():
assert signal_callback() is False assert signal_callback() is False
assert pytest_sigterm.called_count == 0 assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0 assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 0
f1() f1()
assert signal_callback() is False assert signal_callback() is False
assert signal.getsignal(signal.SIGTERM) is original_sigterm assert signal.getsignal(signal.SIGTERM) is original_sigterm
assert pytest_sigterm.called_count == 0 assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 1 assert pytest_sigint.called_count == 1
assert pytest_sigusr1.called_count == 0
@tmp_signals_for_test
def test_sigusr1_signal_handling():
@with_signal_handling
def f1():
assert signal_callback() is False
signal_state.set_signal_flag(for_signal=signal.SIGUSR1)
assert signal_callback()
original_sigusr1 = signal.getsignal(signal.SIGUSR1)
assert signal_callback() is False
assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 0
f1()
assert signal_callback() is False
assert signal.getsignal(signal.SIGUSR1) is original_sigusr1
assert pytest_sigterm.called_count == 0
assert pytest_sigint.called_count == 0
assert pytest_sigusr1.called_count == 1