mirror of
https://github.com/ansible/awx.git
synced 2026-02-02 01:58:09 -03:30
AAP-64221 Fix broken cancel logic with dispatcherd (#16247)
* Fix broken cancel logic with dispatcherd Update tests for UnifiedJob Update test assertion * Further simply cancel path
This commit is contained in:
@@ -10,7 +10,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
@@ -1488,40 +1487,17 @@ class UnifiedJob(
|
||||
return 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (self.model_to_str(), self.name, self.id)
|
||||
return None
|
||||
|
||||
def fallback_cancel(self):
|
||||
if not self.celery_task_id:
|
||||
self.refresh_from_db(fields=['celery_task_id'])
|
||||
self.cancel_dispatcher_process()
|
||||
|
||||
def cancel_dispatcher_process(self):
|
||||
"""Returns True if dispatcher running this job acknowledged request and sent SIGTERM"""
|
||||
if not self.celery_task_id:
|
||||
return False
|
||||
|
||||
# Special case for task manager (used during workflow job cancellation)
|
||||
if not connection.get_autocommit():
|
||||
try:
|
||||
|
||||
ctl = get_control_from_settings()
|
||||
ctl.control('cancel', data={'uuid': self.celery_task_id})
|
||||
except Exception:
|
||||
logger.exception("Error sending cancel command to dispatcher")
|
||||
return True # task manager itself needs to act under assumption that cancel was received
|
||||
|
||||
# Standard case with reply
|
||||
try:
|
||||
timeout = 5
|
||||
|
||||
ctl = get_control_from_settings()
|
||||
results = ctl.control_with_reply('cancel', data={'uuid': self.celery_task_id}, expected_replies=1, timeout=timeout)
|
||||
# Check if cancel was successful by checking if we got any results
|
||||
return bool(results and len(results) > 0)
|
||||
except socket.timeout:
|
||||
logger.error(f'could not reach dispatcher on {self.controller_node} within {timeout}s')
|
||||
logger.info(f'Sending cancel message to pg_notify channel {self.controller_node} for task {self.celery_task_id}')
|
||||
ctl = get_control_from_settings(default_publish_channel=self.controller_node)
|
||||
ctl.control('cancel', data={'uuid': self.celery_task_id})
|
||||
except Exception:
|
||||
logger.exception("error encountered when checking task status")
|
||||
|
||||
return False # whether confirmation was obtained
|
||||
logger.exception("Error sending cancel command to dispatcher")
|
||||
|
||||
def cancel(self, job_explanation=None, is_chain=False):
|
||||
if self.can_cancel:
|
||||
@@ -1544,19 +1520,13 @@ class UnifiedJob(
|
||||
# the job control process will use the cancel_flag to distinguish a shutdown from a cancel
|
||||
self.save(update_fields=cancel_fields)
|
||||
|
||||
controller_notified = False
|
||||
if self.celery_task_id:
|
||||
controller_notified = self.cancel_dispatcher_process()
|
||||
# Be extra sure we have the task id, in case job is transitioning into running right now
|
||||
if not self.celery_task_id:
|
||||
self.refresh_from_db(fields=['celery_task_id', 'controller_node'])
|
||||
|
||||
# If a SIGTERM signal was sent to the control process, and acked by the dispatcher
|
||||
# then we want to let its own cleanup change status, otherwise change status now
|
||||
if not controller_notified:
|
||||
if self.status != 'canceled':
|
||||
self.status = 'canceled'
|
||||
self.save(update_fields=['status'])
|
||||
# Avoid race condition where we have stale model from pending state but job has already started,
|
||||
# its checking signal but not cancel_flag, so re-send signal after updating cancel fields
|
||||
self.fallback_cancel()
|
||||
# send pg_notify message to cancel, will not send until transaction completes
|
||||
if self.celery_task_id:
|
||||
self.cancel_dispatcher_process()
|
||||
|
||||
return self.cancel_flag
|
||||
|
||||
|
||||
@@ -785,7 +785,7 @@ class WorkflowJob(UnifiedJob, WorkflowJobOptions, SurveyJobMixin, JobNotificatio
|
||||
def cancel_dispatcher_process(self):
|
||||
# WorkflowJobs don't _actually_ run anything in the dispatcher, so
|
||||
# there's no point in asking the dispatcher if it knows about this task
|
||||
return True
|
||||
return
|
||||
|
||||
|
||||
class WorkflowApprovalTemplate(UnifiedJobTemplate, RelatedJobsMixin):
|
||||
|
||||
@@ -69,7 +69,7 @@ def signal_callback():
|
||||
|
||||
def with_signal_handling(f):
|
||||
"""
|
||||
Change signal handling to make signal_callback return True in event of SIGTERM or SIGINT.
|
||||
Change signal handling to make signal_callback return True in event of SIGTERM, SIGINT, or SIGUSR1.
|
||||
"""
|
||||
|
||||
@functools.wraps(f)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import pytest
|
||||
from uuid import uuid4
|
||||
|
||||
# CRUM
|
||||
from crum import impersonate
|
||||
@@ -33,6 +34,64 @@ def test_soft_unique_together(post, project, admin_user):
|
||||
assert 'combination already exists' in str(r.data)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestJobCancel:
|
||||
"""
|
||||
Coverage for UnifiedJob.cancel, focused on interaction with dispatcherd objects.
|
||||
Using mocks for the dispatcherd objects, because tests by default use a no-op broker.
|
||||
"""
|
||||
|
||||
def test_cancel_sets_flag_and_clears_start_args(self, mocker):
|
||||
job = Job.objects.create(status='running', name='foo-job', celery_task_id=str(uuid4()), controller_node='foo', start_args='{"secret": "value"}')
|
||||
job.websocket_emit_status = mocker.MagicMock()
|
||||
|
||||
assert job.can_cancel is True
|
||||
assert job.cancel_flag is False
|
||||
|
||||
job.cancel()
|
||||
job.refresh_from_db()
|
||||
|
||||
assert job.cancel_flag is True
|
||||
assert job.start_args == ''
|
||||
|
||||
def test_cancel_sets_job_explanation(self, mocker):
|
||||
job = Job.objects.create(status='running', name='foo-job', celery_task_id=str(uuid4()), controller_node='foo')
|
||||
job.websocket_emit_status = mocker.MagicMock()
|
||||
job_explanation = 'giggity giggity'
|
||||
|
||||
job.cancel(job_explanation=job_explanation)
|
||||
job.refresh_from_db()
|
||||
|
||||
assert job.job_explanation == job_explanation
|
||||
|
||||
def test_cancel_sends_control_message(self, mocker):
|
||||
celery_task_id = str(uuid4())
|
||||
job = Job.objects.create(status='running', name='foo-job', celery_task_id=celery_task_id, controller_node='foo')
|
||||
job.websocket_emit_status = mocker.MagicMock()
|
||||
control = mocker.MagicMock()
|
||||
get_control = mocker.patch('awx.main.models.unified_jobs.get_control_from_settings', return_value=control)
|
||||
|
||||
job.cancel()
|
||||
|
||||
get_control.assert_called_once_with(default_publish_channel='foo')
|
||||
control.control.assert_called_once_with('cancel', data={'uuid': celery_task_id})
|
||||
|
||||
def test_cancel_refreshes_task_id_before_sending_control(self, mocker):
|
||||
job = Job.objects.create(status='pending', name='foo-job', celery_task_id='', controller_node='bar')
|
||||
job.websocket_emit_status = mocker.MagicMock()
|
||||
celery_task_id = str(uuid4())
|
||||
Job.objects.filter(pk=job.pk).update(status='running', celery_task_id=celery_task_id)
|
||||
control = mocker.MagicMock()
|
||||
get_control = mocker.patch('awx.main.models.unified_jobs.get_control_from_settings', return_value=control)
|
||||
refresh_spy = mocker.spy(job, 'refresh_from_db')
|
||||
|
||||
job.cancel()
|
||||
|
||||
refresh_spy.assert_called_once_with(fields=['celery_task_id', 'controller_node'])
|
||||
get_control.assert_called_once_with(default_publish_channel='bar')
|
||||
control.control.assert_called_once_with('cancel', data={'uuid': celery_task_id})
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestCreateUnifiedJob:
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from awx.main.models import UnifiedJob, UnifiedJobTemplate, WorkflowJob, WorkflowJobNode, WorkflowApprovalTemplate, Job, User, Project, JobTemplate, Inventory
|
||||
@@ -22,52 +21,6 @@ def test_unified_job_workflow_attributes():
|
||||
assert job.workflow_job_id == 1
|
||||
|
||||
|
||||
def mock_on_commit(f):
|
||||
f()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unified_job(mocker):
|
||||
mocker.patch.object(UnifiedJob, 'can_cancel', return_value=True)
|
||||
j = UnifiedJob()
|
||||
j.status = 'pending'
|
||||
j.cancel_flag = None
|
||||
j.save = mocker.MagicMock()
|
||||
j.websocket_emit_status = mocker.MagicMock()
|
||||
j.fallback_cancel = mocker.MagicMock()
|
||||
return j
|
||||
|
||||
|
||||
def test_cancel(unified_job):
|
||||
with mock.patch('awx.main.models.unified_jobs.connection.on_commit', wraps=mock_on_commit):
|
||||
unified_job.cancel()
|
||||
|
||||
assert unified_job.cancel_flag is True
|
||||
assert unified_job.status == 'canceled'
|
||||
assert unified_job.job_explanation == ''
|
||||
# Note: the websocket emit status check is just reflecting the state of the current code.
|
||||
# Some more thought may want to go into only emitting canceled if/when the job record
|
||||
# status is changed to canceled. Unlike, currently, where it's emitted unconditionally.
|
||||
unified_job.websocket_emit_status.assert_called_with("canceled")
|
||||
assert [(args, kwargs) for args, kwargs in unified_job.save.call_args_list] == [
|
||||
((), {'update_fields': ['cancel_flag', 'start_args']}),
|
||||
((), {'update_fields': ['status']}),
|
||||
]
|
||||
|
||||
|
||||
def test_cancel_job_explanation(unified_job):
|
||||
job_explanation = 'giggity giggity'
|
||||
|
||||
with mock.patch('awx.main.models.unified_jobs.connection.on_commit'):
|
||||
unified_job.cancel(job_explanation=job_explanation)
|
||||
|
||||
assert unified_job.job_explanation == job_explanation
|
||||
assert [(args, kwargs) for args, kwargs in unified_job.save.call_args_list] == [
|
||||
((), {'update_fields': ['cancel_flag', 'start_args', 'job_explanation']}),
|
||||
((), {'update_fields': ['status']}),
|
||||
]
|
||||
|
||||
|
||||
def test_organization_copy_to_jobs():
|
||||
"""
|
||||
All unified job types should infer their organization from their template organization
|
||||
|
||||
@@ -12,6 +12,10 @@ def pytest_sigterm():
|
||||
pytest_sigterm.called_count += 1
|
||||
|
||||
|
||||
def pytest_sigusr1():
|
||||
pytest_sigusr1.called_count += 1
|
||||
|
||||
|
||||
def tmp_signals_for_test(func):
|
||||
"""
|
||||
When we run our internal signal handlers, it will call the original signal
|
||||
@@ -26,13 +30,17 @@ def tmp_signals_for_test(func):
|
||||
def wrapper():
|
||||
original_sigterm = signal.getsignal(signal.SIGTERM)
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
original_sigusr1 = signal.getsignal(signal.SIGUSR1)
|
||||
signal.signal(signal.SIGTERM, pytest_sigterm)
|
||||
signal.signal(signal.SIGINT, pytest_sigint)
|
||||
signal.signal(signal.SIGUSR1, pytest_sigusr1)
|
||||
pytest_sigterm.called_count = 0
|
||||
pytest_sigint.called_count = 0
|
||||
pytest_sigusr1.called_count = 0
|
||||
func()
|
||||
signal.signal(signal.SIGTERM, original_sigterm)
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
signal.signal(signal.SIGUSR1, original_sigusr1)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -58,11 +66,13 @@ def test_outer_inner_signal_handling():
|
||||
assert signal_callback() is False
|
||||
assert pytest_sigterm.called_count == 0
|
||||
assert pytest_sigint.called_count == 0
|
||||
assert pytest_sigusr1.called_count == 0
|
||||
f1()
|
||||
assert signal_callback() is False
|
||||
assert signal.getsignal(signal.SIGTERM) is original_sigterm
|
||||
assert pytest_sigterm.called_count == 1
|
||||
assert pytest_sigint.called_count == 0
|
||||
assert pytest_sigusr1.called_count == 0
|
||||
|
||||
|
||||
@tmp_signals_for_test
|
||||
@@ -87,8 +97,31 @@ def test_inner_outer_signal_handling():
|
||||
assert signal_callback() is False
|
||||
assert pytest_sigterm.called_count == 0
|
||||
assert pytest_sigint.called_count == 0
|
||||
assert pytest_sigusr1.called_count == 0
|
||||
f1()
|
||||
assert signal_callback() is False
|
||||
assert signal.getsignal(signal.SIGTERM) is original_sigterm
|
||||
assert pytest_sigterm.called_count == 0
|
||||
assert pytest_sigint.called_count == 1
|
||||
assert pytest_sigusr1.called_count == 0
|
||||
|
||||
|
||||
@tmp_signals_for_test
|
||||
def test_sigusr1_signal_handling():
|
||||
@with_signal_handling
|
||||
def f1():
|
||||
assert signal_callback() is False
|
||||
signal_state.set_signal_flag(for_signal=signal.SIGUSR1)
|
||||
assert signal_callback()
|
||||
|
||||
original_sigusr1 = signal.getsignal(signal.SIGUSR1)
|
||||
assert signal_callback() is False
|
||||
assert pytest_sigterm.called_count == 0
|
||||
assert pytest_sigint.called_count == 0
|
||||
assert pytest_sigusr1.called_count == 0
|
||||
f1()
|
||||
assert signal_callback() is False
|
||||
assert signal.getsignal(signal.SIGUSR1) is original_sigusr1
|
||||
assert pytest_sigterm.called_count == 0
|
||||
assert pytest_sigint.called_count == 0
|
||||
assert pytest_sigusr1.called_count == 1
|
||||
|
||||
Reference in New Issue
Block a user