diff --git a/awx/main/scheduler/__init__.py b/awx/main/scheduler/__init__.py index b81e672036..f16633cc49 100644 --- a/awx/main/scheduler/__init__.py +++ b/awx/main/scheduler/__init__.py @@ -69,14 +69,43 @@ class TaskManager(): ''' Tasks that are running and SHOULD have a celery task. + { + 'execution_node': [j1, j2,...], + 'execution_node': [j3], + ... + } ''' def get_running_tasks(self): + execution_nodes = {} now = tz_now() - return list(UnifiedJob.objects.filter(Q(status='running') | + jobs = list(UnifiedJob.objects.filter(Q(status='running') | (Q(status='waiting', modified__lte=now - timedelta(seconds=60))))) + for j in jobs: + if j.execution_node in execution_nodes: + execution_nodes[j.execution_node].append(j) + elif j.execution_node not in execution_nodes: + execution_nodes[j.execution_node] = [j] + return execution_nodes ''' Tasks that are currently running in celery + + Transform: + { + "celery@ec2-54-204-222-62.compute-1.amazonaws.com": [], + "celery@ec2-54-163-144-168.compute-1.amazonaws.com": [{ + ... + "id": "5238466a-f8c7-43b3-9180-5b78e9da8304", + ... + }] + } + + to: + { + "celery@ec2-54-204-222-62.compute-1.amazonaws.com": [ + "5238466a-f8c7-43b3-9180-5b78e9da8304", + ] + } ''' def get_active_tasks(self): inspector = inspect() @@ -86,15 +115,26 @@ class TaskManager(): logger.warn("Ignoring celery task inspector") active_task_queues = None - active_tasks = set() + queues = None + if active_task_queues is not None: + queues = {} for queue in active_task_queues: + active_tasks = set() map(lambda at: active_tasks.add(at['id']), active_task_queues[queue]) + + # queue is of the form celery@myhost.com + queue_name = queue.split('@') + if len(queue_name) > 1: + queue_name = queue_name[1] + else: + queue_name = queue_name[0] + queues[queue_name] = active_tasks else: if not hasattr(settings, 'CELERY_UNIT_TEST'): return (None, None) - return (active_task_queues, active_tasks) + return (active_task_queues, queues) def get_latest_project_update_tasks(self, all_sorted_tasks): project_ids = Set() @@ -380,32 +420,38 @@ class TaskManager(): logger.debug("Failing inconsistent running jobs.") celery_task_start_time = tz_now() - active_task_queues, active_tasks = self.get_active_tasks() + active_task_queues, active_queues = self.get_active_tasks() cache.set('last_celery_task_cleanup', tz_now()) - if active_tasks is None: + if active_queues is None: logger.error('Failed to retrieve active tasks from celery') return None - all_running_sorted_tasks = self.get_running_tasks() - for task in all_running_sorted_tasks: - - if (task.celery_task_id not in active_tasks and not hasattr(settings, 'IGNORE_CELERY_INSPECTOR')): - # TODO: try catch the getting of the job. The job COULD have been deleted - if isinstance(task, WorkflowJob): - continue - if task.modified > celery_task_start_time: - continue - task.status = 'failed' - task.job_explanation += ' '.join(( - 'Task was marked as running in Tower but was not present in', - 'Celery, so it has been marked as failed.', - )) - task.save() - awx_tasks._send_notification_templates(task, 'failed') - task.websocket_emit_status('failed') - logger.error("%s appears orphaned... marking as failed", task.log_format) - + ''' + Only consider failing tasks on instances for which we obtained a task + list from celery for. + ''' + execution_nodes_jobs = self.get_running_tasks() + for node, node_jobs in execution_nodes_jobs.iteritems(): + if node not in active_queues: + continue + active_tasks = active_queues[node] + for task in node_jobs: + if (task.celery_task_id not in active_tasks and not hasattr(settings, 'IGNORE_CELERY_INSPECTOR')): + # TODO: try catch the getting of the job. The job COULD have been deleted + if isinstance(task, WorkflowJob): + continue + if task.modified > celery_task_start_time: + continue + task.status = 'failed' + task.job_explanation += ' '.join(( + 'Task was marked as running in Tower but was not present in', + 'Celery, so it has been marked as failed.', + )) + task.save() + awx_tasks._send_notification_templates(task, 'failed') + task.websocket_emit_status('failed') + logger.error("Task %s appears orphaned... marking as failed" % task) def calculate_capacity_used(self, tasks): for rampart_group in self.graph: diff --git a/awx/main/tests/functional/task_management/test_scheduler.py b/awx/main/tests/functional/task_management/test_scheduler.py index 7fe41ca781..a9fd694bce 100644 --- a/awx/main/tests/functional/task_management/test_scheduler.py +++ b/awx/main/tests/functional/task_management/test_scheduler.py @@ -11,46 +11,6 @@ from awx.main.models import ( ) -@pytest.fixture -def all_jobs(mocker): - now = tz_now() - j1 = Job.objects.create(status='pending') - j2 = Job.objects.create(status='waiting', celery_task_id='considered_j2') - j3 = Job.objects.create(status='waiting', celery_task_id='considered_j3') - j3.modified = now - timedelta(seconds=60) - j3.save(update_fields=['modified']) - j4 = Job.objects.create(status='running', celery_task_id='considered_j4') - j5 = Job.objects.create(status='waiting', celery_task_id='reapable_j5') - j5.modified = now - timedelta(seconds=60) - j5.save(update_fields=['modified']) - - js = [j1, j2, j3, j4, j5] - for j in js: - j.save = mocker.Mock(wraps=j.save) - j.websocket_emit_status = mocker.Mock() - return js - - -@pytest.fixture -def considered_jobs(all_jobs): - return all_jobs[2:4] + [all_jobs[4]] - - -@pytest.fixture -def reapable_jobs(all_jobs): - return [all_jobs[4]] - - -@pytest.fixture -def unconsidered_jobs(all_jobs): - return all_jobs[0:1] - - -@pytest.fixture -def active_tasks(): - return ([], ['considered_j2', 'considered_j3', 'considered_j4',]) - - @pytest.mark.django_db def test_single_job_scheduler_launch(default_instance_group, job_template_factory, mocker): objects = job_template_factory('jt', organization='org1', project='proj', @@ -258,41 +218,88 @@ def test_cleanup_interval(): assert cache.get('last_celery_task_cleanup') == last_cleanup -@pytest.mark.django_db -@mock.patch('awx.main.tasks._send_notification_templates') -@mock.patch.object(TaskManager, 'get_active_tasks', lambda self: [[], []]) -def test_cleanup_inconsistent_task(notify, active_tasks, considered_jobs, reapable_jobs, mocker): - tm = TaskManager() +class TestReaper(): + @pytest.fixture + def all_jobs(self, mocker): + now = tz_now() - tm.get_running_tasks = mocker.Mock(return_value=considered_jobs) - tm.get_active_tasks = mocker.Mock(return_value=active_tasks) - - tm.cleanup_inconsistent_celery_tasks() - - for j in considered_jobs: - if j not in reapable_jobs: - j.save.assert_not_called() + j1 = Job.objects.create(status='pending', execution_node='host1') + j2 = Job.objects.create(status='waiting', celery_task_id='considered_j2', execution_node='host1') + j3 = Job.objects.create(status='waiting', celery_task_id='considered_j3', execution_node='host1') + j3.modified = now - timedelta(seconds=60) + j3.save(update_fields=['modified']) + j4 = Job.objects.create(status='running', celery_task_id='considered_j4', execution_node='host1') + j5 = Job.objects.create(status='waiting', celery_task_id='reapable_j5', execution_node='host2') + j5.modified = now - timedelta(seconds=60) + j5.save(update_fields=['modified']) + j6 = Job.objects.create(status='waiting', celery_task_id='host2_j6', execution_node='host2_split') + j6.modified = now - timedelta(seconds=60) + j6.save(update_fields=['modified']) + j7 = Job.objects.create(status='running', celery_task_id='host2_j6', execution_node='host2_split') - for reaped_job in reapable_jobs: - notify.assert_called_once_with(reaped_job, 'failed') - reaped_job.websocket_emit_status.assert_called_once_with('failed') - assert reaped_job.status == 'failed' - assert reaped_job.job_explanation == ( - 'Task was marked as running in Tower but was not present in Celery, so it has been marked as failed.' - ) + js = [j1, j2, j3, j4, j5, j6, j7] + for j in js: + j.save = mocker.Mock(wraps=j.save) + j.websocket_emit_status = mocker.Mock() + return js + + @pytest.fixture + def considered_jobs(self, all_jobs): + return all_jobs[2:4] + [all_jobs[4]] + + @pytest.fixture + def reapable_jobs(self, all_jobs): + return [all_jobs[4]] + + @pytest.fixture + def unconsidered_jobs(self, all_jobs): + return all_jobs[0:1] + all_jobs[5:7] + + @pytest.fixture + def active_tasks(self): + return ([], { + 'host1': ['considered_j2', 'considered_j3', 'considered_j4',], + 'host2_split': ['host2_j6', 'host2_j7'], + }) + + @pytest.mark.django_db + @mock.patch('awx.main.tasks._send_notification_templates') + @mock.patch.object(TaskManager, 'get_active_tasks', lambda self: ([], [])) + def test_cleanup_inconsistent_task(self, notify, active_tasks, considered_jobs, reapable_jobs, mocker): + tm = TaskManager() + + #tm.get_running_tasks = mocker.Mock(return_value=considered_jobs) + tm.get_active_tasks = mocker.Mock(return_value=active_tasks) + + tm.cleanup_inconsistent_celery_tasks() + + for j in considered_jobs: + if j not in reapable_jobs: + j.save.assert_not_called() + + for reaped_job in reapable_jobs: + notify.assert_called_once_with(reaped_job, 'failed') + reaped_job.websocket_emit_status.assert_called_once_with('failed') + assert reaped_job.status == 'failed' + assert reaped_job.job_explanation == ( + 'Task was marked as running in Tower but was not present in Celery, so it has been marked as failed.' + ) -@pytest.mark.django_db -def test_get_running_tasks(considered_jobs, reapable_jobs, unconsidered_jobs): - tm = TaskManager() + @pytest.mark.django_db + def test_get_running_tasks(self, all_jobs): + tm = TaskManager() - # Ensure the query grabs the expected jobs - rt = tm.get_running_tasks() - for j in considered_jobs: - assert j in rt - for j in reapable_jobs: - assert j in rt - for j in unconsidered_jobs: - assert j in unconsidered_jobs + # Ensure the query grabs the expected jobs + execution_nodes_jobs = tm.get_running_tasks() + assert 'host1' in execution_nodes_jobs + assert 'host2_split' in execution_nodes_jobs + assert all_jobs[1] in execution_nodes_jobs['host1'] + assert all_jobs[2] in execution_nodes_jobs['host1'] + assert all_jobs[3] in execution_nodes_jobs['host1'] + assert all_jobs[4] in execution_nodes_jobs['host1'] + assert all_jobs[5] in execution_nodes_jobs['host2_split'] + assert all_jobs[6] in execution_nodes_jobs['host2_split'] +