From 94b34b801cdbefd0ce001b8dfcc2a6184dcc3060 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Mon, 19 Dec 2022 11:19:38 -0500 Subject: [PATCH] Avoid unbounded kwargs by fetching subtasks inside handle_work_error Update tests to new handle_work_error call pattern Handle blame correctly with multiple serial deps add new test case corresponding to this scenario --- awx/main/scheduler/task_manager.py | 21 +++---- awx/main/tasks/system.py | 63 +++++++++---------- .../task_management/test_rampart_groups.py | 12 ++-- .../task_management/test_scheduler.py | 28 ++++----- awx/main/tests/functional/test_tasks.py | 18 +++++- 5 files changed, 72 insertions(+), 70 deletions(-) diff --git a/awx/main/scheduler/task_manager.py b/awx/main/scheduler/task_manager.py index 3610ecb89a..880d1e9e71 100644 --- a/awx/main/scheduler/task_manager.py +++ b/awx/main/scheduler/task_manager.py @@ -507,7 +507,7 @@ class TaskManager(TaskBase): return None @timeit - def start_task(self, task, instance_group, dependent_tasks=None, instance=None): + def start_task(self, task, instance_group, instance=None): # Just like for process_running_tasks, add the job to the dependency graph and # ask the TaskManagerInstanceGroups object to update consumed capacity on all # implicated instances and container groups. @@ -524,14 +524,6 @@ class TaskManager(TaskBase): ScheduleTaskManager().schedule() from awx.main.tasks.system import handle_work_error, handle_work_success - dependent_tasks = dependent_tasks or [] - - task_actual = { - 'type': get_type_for_model(type(task)), - 'id': task.id, - } - dependencies = [{'type': get_type_for_model(type(t)), 'id': t.id} for t in dependent_tasks] - task.status = 'waiting' (start_status, opts) = task.pre_start() @@ -563,6 +555,7 @@ class TaskManager(TaskBase): # apply_async does a NOTIFY to the channel dispatcher is listening to # postgres will treat this as part of the transaction, which is what we want if task.status != 'failed' and type(task) is not WorkflowJob: + task_actual = {'type': get_type_for_model(type(task)), 'id': task.id} task_cls = task._get_task_class() task_cls.apply_async( [task.pk], @@ -570,7 +563,7 @@ class TaskManager(TaskBase): queue=task.get_queue_name(), uuid=task.celery_task_id, callbacks=[{'task': handle_work_success.name, 'kwargs': {'task_actual': task_actual}}], - errbacks=[{'task': handle_work_error.name, 'args': [task.celery_task_id], 'kwargs': {'subtasks': [task_actual] + dependencies}}], + errbacks=[{'task': handle_work_error.name, 'kwargs': {'task_actual': task_actual}}], ) # In exception cases, like a job failing pre-start checks, we send the websocket status message @@ -609,7 +602,7 @@ class TaskManager(TaskBase): if isinstance(task, WorkflowJob): # Previously we were tracking allow_simultaneous blocking both here and in DependencyGraph. # Double check that using just the DependencyGraph works for Workflows and Sliced Jobs. - self.start_task(task, None, task.get_jobs_fail_chain(), None) + self.start_task(task, None, None) continue found_acceptable_queue = False @@ -637,7 +630,7 @@ class TaskManager(TaskBase): execution_instance = self.tm_models.instances[control_instance.hostname].obj task.log_lifecycle("controller_node_chosen") task.log_lifecycle("execution_node_chosen") - self.start_task(task, self.controlplane_ig, task.get_jobs_fail_chain(), execution_instance) + self.start_task(task, self.controlplane_ig, execution_instance) found_acceptable_queue = True continue @@ -645,7 +638,7 @@ class TaskManager(TaskBase): if not self.tm_models.instance_groups[instance_group.name].has_remaining_capacity(task): continue if instance_group.is_container_group: - self.start_task(task, instance_group, task.get_jobs_fail_chain(), None) + self.start_task(task, instance_group, None) found_acceptable_queue = True break @@ -670,7 +663,7 @@ class TaskManager(TaskBase): ) ) execution_instance = self.tm_models.instances[execution_instance.hostname].obj - self.start_task(task, instance_group, task.get_jobs_fail_chain(), execution_instance) + self.start_task(task, instance_group, execution_instance) found_acceptable_queue = True break else: diff --git a/awx/main/tasks/system.py b/awx/main/tasks/system.py index ee3293beae..482c168af2 100644 --- a/awx/main/tasks/system.py +++ b/awx/main/tasks/system.py @@ -52,6 +52,7 @@ from awx.main.constants import ACTIVE_STATES from awx.main.dispatch.publish import task from awx.main.dispatch import get_local_queuename, reaper from awx.main.utils.common import ( + get_type_for_model, ignore_inventory_computed_fields, ignore_inventory_group_removal, ScheduleWorkflowManager, @@ -720,45 +721,43 @@ def handle_work_success(task_actual): @task(queue=get_local_queuename) -def handle_work_error(task_id, *args, **kwargs): - subtasks = kwargs.get('subtasks', None) - logger.debug('Executing error task id %s, subtasks: %s' % (task_id, str(subtasks))) - first_instance = None - first_instance_type = '' - if subtasks is not None: - for each_task in subtasks: - try: - instance = UnifiedJob.get_instance_by_type(each_task['type'], each_task['id']) - if not instance: - # Unknown task type - logger.warning("Unknown task type: {}".format(each_task['type'])) - continue - except ObjectDoesNotExist: - logger.warning('Missing {} `{}` in error callback.'.format(each_task['type'], each_task['id'])) - continue +def handle_work_error(task_actual): + try: + instance = UnifiedJob.get_instance_by_type(task_actual['type'], task_actual['id']) + except ObjectDoesNotExist: + logger.warning('Missing {} `{}` in error callback.'.format(task_actual['type'], task_actual['id'])) + return + if not instance: + return - if first_instance is None: - first_instance = instance - first_instance_type = each_task['type'] + subtasks = instance.get_jobs_fail_chain() # reverse of dependent_jobs mostly + logger.debug(f'Executing error task id {task_actual["id"]}, subtasks: {[subtask.id for subtask in subtasks]}') - if instance.celery_task_id != task_id and not instance.cancel_flag and not instance.status in ('successful', 'failed'): - instance.status = 'failed' - instance.failed = True - if not instance.job_explanation: - instance.job_explanation = 'Previous Task Failed: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % ( - first_instance_type, - first_instance.name, - first_instance.id, - ) - instance.save() - instance.websocket_emit_status("failed") + deps_of_deps = {} + + for subtask in subtasks: + if subtask.celery_task_id != instance.celery_task_id and not subtask.cancel_flag and not subtask.status in ('successful', 'failed'): + # If there are multiple in the dependency chain, A->B->C, and this was called for A, blame B for clarity + blame_job = deps_of_deps.get(subtask.id, instance) + subtask.status = 'failed' + subtask.failed = True + if not subtask.job_explanation: + subtask.job_explanation = 'Previous Task Failed: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % ( + get_type_for_model(type(blame_job)), + blame_job.name, + blame_job.id, + ) + subtask.save() + subtask.websocket_emit_status("failed") + + for sub_subtask in subtask.get_jobs_fail_chain(): + deps_of_deps[sub_subtask.id] = subtask # We only send 1 job complete message since all the job completion message # handling does is trigger the scheduler. If we extend the functionality of # what the job complete message handler does then we may want to send a # completion event for each job here. - if first_instance: - schedule_manager_success_or_error(first_instance) + schedule_manager_success_or_error(instance) @task(queue=get_local_queuename) diff --git a/awx/main/tests/functional/task_management/test_rampart_groups.py b/awx/main/tests/functional/task_management/test_rampart_groups.py index 6bed591147..48ea9edb08 100644 --- a/awx/main/tests/functional/task_management/test_rampart_groups.py +++ b/awx/main/tests/functional/task_management/test_rampart_groups.py @@ -23,7 +23,7 @@ def test_multi_group_basic_job_launch(instance_factory, controlplane_instance_gr mock_task_impact.return_value = 500 with mocker.patch("awx.main.scheduler.TaskManager.start_task"): TaskManager().schedule() - TaskManager.start_task.assert_has_calls([mock.call(j1, ig1, [], i1), mock.call(j2, ig2, [], i2)]) + TaskManager.start_task.assert_has_calls([mock.call(j1, ig1, i1), mock.call(j2, ig2, i2)]) @pytest.mark.django_db @@ -54,7 +54,7 @@ def test_multi_group_with_shared_dependency(instance_factory, controlplane_insta DependencyManager().schedule() TaskManager().schedule() pu = p.project_updates.first() - TaskManager.start_task.assert_called_once_with(pu, controlplane_instance_group, [j1, j2], controlplane_instance_group.instances.all()[0]) + TaskManager.start_task.assert_called_once_with(pu, controlplane_instance_group, controlplane_instance_group.instances.all()[0]) pu.finished = pu.created + timedelta(seconds=1) pu.status = "successful" pu.save() @@ -62,8 +62,8 @@ def test_multi_group_with_shared_dependency(instance_factory, controlplane_insta DependencyManager().schedule() TaskManager().schedule() - TaskManager.start_task.assert_any_call(j1, ig1, [], i1) - TaskManager.start_task.assert_any_call(j2, ig2, [], i2) + TaskManager.start_task.assert_any_call(j1, ig1, i1) + TaskManager.start_task.assert_any_call(j2, ig2, i2) assert TaskManager.start_task.call_count == 2 @@ -75,7 +75,7 @@ def test_workflow_job_no_instancegroup(workflow_job_template_factory, controlpla wfj.save() with mocker.patch("awx.main.scheduler.TaskManager.start_task"): TaskManager().schedule() - TaskManager.start_task.assert_called_once_with(wfj, None, [], None) + TaskManager.start_task.assert_called_once_with(wfj, None, None) assert wfj.instance_group is None @@ -150,7 +150,7 @@ def test_failover_group_run(instance_factory, controlplane_instance_group, mocke mock_task_impact.return_value = 500 with mock.patch.object(TaskManager, "start_task", wraps=tm.start_task) as mock_job: tm.schedule() - mock_job.assert_has_calls([mock.call(j1, ig1, [], i1), mock.call(j1_1, ig2, [], i2)]) + mock_job.assert_has_calls([mock.call(j1, ig1, i1), mock.call(j1_1, ig2, i2)]) assert mock_job.call_count == 2 diff --git a/awx/main/tests/functional/task_management/test_scheduler.py b/awx/main/tests/functional/task_management/test_scheduler.py index f362841033..42d144d5cc 100644 --- a/awx/main/tests/functional/task_management/test_scheduler.py +++ b/awx/main/tests/functional/task_management/test_scheduler.py @@ -18,7 +18,7 @@ def test_single_job_scheduler_launch(hybrid_instance, controlplane_instance_grou j = create_job(objects.job_template) with mocker.patch("awx.main.scheduler.TaskManager.start_task"): TaskManager().schedule() - TaskManager.start_task.assert_called_once_with(j, controlplane_instance_group, [], instance) + TaskManager.start_task.assert_called_once_with(j, controlplane_instance_group, instance) @pytest.mark.django_db @@ -240,12 +240,12 @@ def test_multi_jt_capacity_blocking(hybrid_instance, job_template_factory, mocke mock_task_impact.return_value = 505 with mock.patch.object(TaskManager, "start_task", wraps=tm.start_task) as mock_job: tm.schedule() - mock_job.assert_called_once_with(j1, controlplane_instance_group, [], instance) + mock_job.assert_called_once_with(j1, controlplane_instance_group, instance) j1.status = "successful" j1.save() with mock.patch.object(TaskManager, "start_task", wraps=tm.start_task) as mock_job: tm.schedule() - mock_job.assert_called_once_with(j2, controlplane_instance_group, [], instance) + mock_job.assert_called_once_with(j2, controlplane_instance_group, instance) @pytest.mark.django_db @@ -337,12 +337,12 @@ def test_single_job_dependencies_project_launch(controlplane_instance_group, job pu = [x for x in p.project_updates.all()] assert len(pu) == 1 TaskManager().schedule() - TaskManager.start_task.assert_called_once_with(pu[0], controlplane_instance_group, [j], instance) + TaskManager.start_task.assert_called_once_with(pu[0], controlplane_instance_group, instance) pu[0].status = "successful" pu[0].save() with mock.patch("awx.main.scheduler.TaskManager.start_task"): TaskManager().schedule() - TaskManager.start_task.assert_called_once_with(j, controlplane_instance_group, [], instance) + TaskManager.start_task.assert_called_once_with(j, controlplane_instance_group, instance) @pytest.mark.django_db @@ -365,12 +365,12 @@ def test_single_job_dependencies_inventory_update_launch(controlplane_instance_g iu = [x for x in ii.inventory_updates.all()] assert len(iu) == 1 TaskManager().schedule() - TaskManager.start_task.assert_called_once_with(iu[0], controlplane_instance_group, [j], instance) + TaskManager.start_task.assert_called_once_with(iu[0], controlplane_instance_group, instance) iu[0].status = "successful" iu[0].save() with mock.patch("awx.main.scheduler.TaskManager.start_task"): TaskManager().schedule() - TaskManager.start_task.assert_called_once_with(j, controlplane_instance_group, [], instance) + TaskManager.start_task.assert_called_once_with(j, controlplane_instance_group, instance) @pytest.mark.django_db @@ -412,7 +412,7 @@ def test_job_dependency_with_already_updated(controlplane_instance_group, job_te mock_iu.assert_not_called() with mock.patch("awx.main.scheduler.TaskManager.start_task"): TaskManager().schedule() - TaskManager.start_task.assert_called_once_with(j, controlplane_instance_group, [], instance) + TaskManager.start_task.assert_called_once_with(j, controlplane_instance_group, instance) @pytest.mark.django_db @@ -442,9 +442,7 @@ def test_shared_dependencies_launch(controlplane_instance_group, job_template_fa TaskManager().schedule() pu = p.project_updates.first() iu = ii.inventory_updates.first() - TaskManager.start_task.assert_has_calls( - [mock.call(iu, controlplane_instance_group, [j1, j2], instance), mock.call(pu, controlplane_instance_group, [j1, j2], instance)] - ) + TaskManager.start_task.assert_has_calls([mock.call(iu, controlplane_instance_group, instance), mock.call(pu, controlplane_instance_group, instance)]) pu.status = "successful" pu.finished = pu.created + timedelta(seconds=1) pu.save() @@ -453,9 +451,7 @@ def test_shared_dependencies_launch(controlplane_instance_group, job_template_fa iu.save() with mock.patch("awx.main.scheduler.TaskManager.start_task"): TaskManager().schedule() - TaskManager.start_task.assert_has_calls( - [mock.call(j1, controlplane_instance_group, [], instance), mock.call(j2, controlplane_instance_group, [], instance)] - ) + TaskManager.start_task.assert_has_calls([mock.call(j1, controlplane_instance_group, instance), mock.call(j2, controlplane_instance_group, instance)]) pu = [x for x in p.project_updates.all()] iu = [x for x in ii.inventory_updates.all()] assert len(pu) == 1 @@ -479,7 +475,7 @@ def test_job_not_blocking_project_update(controlplane_instance_group, job_templa project_update.status = "pending" project_update.save() TaskManager().schedule() - TaskManager.start_task.assert_called_once_with(project_update, controlplane_instance_group, [], instance) + TaskManager.start_task.assert_called_once_with(project_update, controlplane_instance_group, instance) @pytest.mark.django_db @@ -503,7 +499,7 @@ def test_job_not_blocking_inventory_update(controlplane_instance_group, job_temp DependencyManager().schedule() TaskManager().schedule() - TaskManager.start_task.assert_called_once_with(inventory_update, controlplane_instance_group, [], instance) + TaskManager.start_task.assert_called_once_with(inventory_update, controlplane_instance_group, instance) @pytest.mark.django_db diff --git a/awx/main/tests/functional/test_tasks.py b/awx/main/tests/functional/test_tasks.py index 6de551cf9f..8abe5579eb 100644 --- a/awx/main/tests/functional/test_tasks.py +++ b/awx/main/tests/functional/test_tasks.py @@ -5,8 +5,8 @@ import tempfile import shutil from awx.main.tasks.jobs import RunJob -from awx.main.tasks.system import execution_node_health_check, _cleanup_images_and_files -from awx.main.models import Instance, Job +from awx.main.tasks.system import execution_node_health_check, _cleanup_images_and_files, handle_work_error +from awx.main.models import Instance, Job, InventoryUpdate, ProjectUpdate @pytest.fixture @@ -74,3 +74,17 @@ def test_does_not_run_reaped_job(mocker, mock_me): job.refresh_from_db() assert job.status == 'failed' mock_run.assert_not_called() + + +@pytest.mark.django_db +def test_handle_work_error_nested(project, inventory_source): + pu = ProjectUpdate.objects.create(status='failed', project=project, celery_task_id='1234') + iu = InventoryUpdate.objects.create(status='pending', inventory_source=inventory_source, source='scm') + job = Job.objects.create(status='pending') + iu.dependent_jobs.add(pu) + job.dependent_jobs.add(pu, iu) + handle_work_error({'type': 'project_update', 'id': pu.id}) + iu.refresh_from_db() + job.refresh_from_db() + assert iu.job_explanation == f'Previous Task Failed: {{"job_type": "project_update", "job_name": "", "job_id": "{pu.id}"}}' + assert job.job_explanation == f'Previous Task Failed: {{"job_type": "inventory_update", "job_name": "", "job_id": "{iu.id}"}}'