diff --git a/awx/api/views/__init__.py b/awx/api/views/__init__.py index 0d2d0fdd21..b0ddd0cde1 100644 --- a/awx/api/views/__init__.py +++ b/awx/api/views/__init__.py @@ -87,7 +87,6 @@ from awx.api.renderers import * # noqa from awx.api.serializers import * # noqa from awx.api.metadata import RoleMetadata, JobTypeMetadata from awx.main.constants import ACTIVE_STATES -from awx.main.scheduler.tasks import run_job_complete from awx.api.views.mixin import ( ActivityStreamEnforcementMixin, SystemTrackingEnforcementMixin, @@ -3262,8 +3261,7 @@ class WorkflowJobCancel(WorkflowsEnforcementMixin, RetrieveAPIView): obj = self.get_object() if obj.can_cancel: obj.cancel() - #TODO: Figure out whether an immediate schedule is needed. - run_job_complete.delay(obj.id) + schedule_task_manager() return Response(status=status.HTTP_202_ACCEPTED) else: return self.http_method_not_allowed(request, *args, **kwargs) diff --git a/awx/main/models/unified_jobs.py b/awx/main/models/unified_jobs.py index ba9d51258e..23a2190aba 100644 --- a/awx/main/models/unified_jobs.py +++ b/awx/main/models/unified_jobs.py @@ -43,7 +43,7 @@ from awx.main.utils import ( copy_model_by_class, copy_m2m_relationships, get_type_for_model, parse_yaml_or_json, getattr_dne ) -from awx.main.utils import polymorphic +from awx.main.utils import polymorphic, schedule_task_manager from awx.main.constants import ACTIVE_STATES, CAN_CANCEL from awx.main.redact import UriCleaner, REPLACE_STR from awx.main.consumers import emit_channel_notification @@ -1251,8 +1251,7 @@ class UnifiedJob(PolymorphicModel, PasswordFieldsModel, CommonModelNameNotUnique self.update_fields(start_args=json.dumps(kwargs), status='pending') self.websocket_emit_status("pending") - from awx.main.scheduler.tasks import run_job_launch - connection.on_commit(lambda: run_job_launch.delay(self.id)) + schedule_task_manager() # Each type of unified job has a different Task class; get the # appropirate one. diff --git a/awx/main/scheduler/task_manager.py b/awx/main/scheduler/task_manager.py index b527610a26..dd6e620747 100644 --- a/awx/main/scheduler/task_manager.py +++ b/awx/main/scheduler/task_manager.py @@ -30,7 +30,7 @@ from awx.main.models import ( ) from awx.main.scheduler.dag_workflow import WorkflowDAG from awx.main.utils.pglock import advisory_lock -from awx.main.utils import get_type_for_model +from awx.main.utils import get_type_for_model, task_manager_bulk_reschedule, schedule_task_manager from awx.main.signals import disable_activity_stream from awx.main.scheduler.dependency_graph import DependencyGraph from awx.main.utils import decrypt_field @@ -161,6 +161,7 @@ class TaskManager(): result = [] for workflow_job in workflow_jobs: dag = WorkflowDAG(workflow_job) + status_changed = False if workflow_job.cancel_flag: logger.debug('Canceling spawned jobs of %s due to cancel flag.', workflow_job.log_format) cancel_finished = dag.cancel_node_jobs() @@ -169,7 +170,7 @@ class TaskManager(): workflow_job.status = 'canceled' workflow_job.start_args = '' # blank field to remove encrypted passwords workflow_job.save(update_fields=['status', 'start_args']) - workflow_job.websocket_emit_status(workflow_job.status) + status_changed = True else: is_done, has_failed = dag.is_workflow_done() if not is_done: @@ -181,7 +182,11 @@ class TaskManager(): workflow_job.status = new_status workflow_job.start_args = '' # blank field to remove encrypted passwords workflow_job.save(update_fields=['status', 'start_args']) + status_changed = True + if status_changed: workflow_job.websocket_emit_status(workflow_job.status) + if workflow_job.spawned_by_workflow: + schedule_task_manager() return result def get_dependent_jobs_for_inv_and_proj_update(self, job_obj): @@ -221,6 +226,7 @@ class TaskManager(): if type(task) is WorkflowJob: task.status = 'running' logger.info('Transitioning %s to running status.', task.log_format) + schedule_task_manager() elif not task.supports_isolation() and rampart_group.controller_id: # non-Ansible jobs on isolated instances run on controller task.instance_group = rampart_group.controller @@ -556,7 +562,8 @@ class TaskManager(): return logger.debug("Starting Scheduler") - finished_wfjs = self._schedule() + with task_manager_bulk_reschedule(): + finished_wfjs = self._schedule() # Operations whose queries rely on modifications made during the atomic scheduling session for wfj in WorkflowJob.objects.filter(id__in=finished_wfjs): diff --git a/awx/main/scheduler/tasks.py b/awx/main/scheduler/tasks.py index cef2d52c60..c0d3dd842e 100644 --- a/awx/main/scheduler/tasks.py +++ b/awx/main/scheduler/tasks.py @@ -9,16 +9,6 @@ from awx.main.dispatch.publish import task logger = logging.getLogger('awx.main.scheduler') -@task() -def run_job_launch(job_id): - TaskManager().schedule() - - -@task() -def run_job_complete(job_id): - TaskManager().schedule() - - @task() def run_task_manager(): logger.debug("Running Tower task manager.") diff --git a/awx/main/tasks.py b/awx/main/tasks.py index 3b9d9bace4..08a6c54548 100644 --- a/awx/main/tasks.py +++ b/awx/main/tasks.py @@ -56,7 +56,7 @@ from awx.main.dispatch import get_local_queuename, reaper from awx.main.utils import (get_ansible_version, get_ssh_version, decrypt_field, update_scm_url, check_proot_installed, build_proot_temp_dir, get_licenser, wrap_args_with_proot, OutputEventFilter, OutputVerboseFilter, ignore_inventory_computed_fields, - ignore_inventory_group_removal, extract_ansible_vars) + ignore_inventory_group_removal, extract_ansible_vars, schedule_task_manager) from awx.main.utils.safe_yaml import safe_dump, sanitize_jinja from awx.main.utils.reload import stop_local_services from awx.main.utils.pglock import advisory_lock @@ -493,8 +493,7 @@ def handle_work_success(task_actual): if not instance: return - from awx.main.scheduler.tasks import run_job_complete - run_job_complete.delay(instance.id) + schedule_task_manager() @task() @@ -533,8 +532,7 @@ def handle_work_error(task_id, *args, **kwargs): # what the job complete message handler does then we may want to send a # completion event for each job here. if first_instance: - from awx.main.scheduler.tasks import run_job_complete - run_job_complete.delay(first_instance.id) + schedule_task_manager() pass diff --git a/awx/main/tests/functional/task_management/test_scheduler.py b/awx/main/tests/functional/task_management/test_scheduler.py index 55b64e3a38..299c47a7c6 100644 --- a/awx/main/tests/functional/task_management/test_scheduler.py +++ b/awx/main/tests/functional/task_management/test_scheduler.py @@ -5,6 +5,7 @@ from datetime import timedelta from awx.main.scheduler import TaskManager from awx.main.utils import encrypt_field +from awx.main.models import WorkflowJobTemplate, JobTemplate @pytest.mark.django_db @@ -21,6 +22,95 @@ def test_single_job_scheduler_launch(default_instance_group, job_template_factor TaskManager.start_task.assert_called_once_with(j, default_instance_group, [], instance) +@pytest.mark.django_db +class TestJobLifeCycle: + + def run_tm(self, tm, expect_channel=None, expect_schedule=None, expect_commit=None): + """Test helper method that takes parameters to assert against + expect_channel - list of expected websocket emit channel message calls + expect_schedule - list of expected calls to reschedule itself + expect_commit - list of expected on_commit calls + If any of these are None, then the assertion is not made. + """ + if expect_schedule and len(expect_schedule) > 1: + raise RuntimeError('Task manager should reschedule itself one time, at most.') + with mock.patch('awx.main.models.unified_jobs.UnifiedJob.websocket_emit_status') as mock_channel: + with mock.patch('awx.main.utils.common._schedule_task_manager') as tm_sch: + # Job are ultimately submitted in on_commit hook, but this will not + # actually run, because it waits until outer transaction, which is the test + # itself in this case + with mock.patch('django.db.connection.on_commit') as mock_commit: + tm.schedule() + if expect_channel is not None: + assert mock_channel.mock_calls == expect_channel + if expect_schedule is not None: + assert tm_sch.mock_calls == expect_schedule + if expect_commit is not None: + assert mock_commit.mock_calls == expect_commit + + def test_task_manager_workflow_rescheduling(self, job_template_factory, inventory, project, default_instance_group): + jt = JobTemplate.objects.create( + allow_simultaneous=True, + inventory=inventory, + project=project, + playbook='helloworld.yml' + ) + wfjt = WorkflowJobTemplate.objects.create(name='foo') + for i in range(2): + wfjt.workflow_nodes.create( + unified_job_template=jt + ) + wj = wfjt.create_unified_job() + assert wj.workflow_nodes.count() == 2 + wj.signal_start() + tm = TaskManager() + + # Transitions workflow job to running + # needs to re-schedule so it spawns jobs next round + self.run_tm(tm, [mock.call('running')], [mock.call()]) + + # Spawns jobs + # needs re-schedule to submit jobs next round + self.run_tm(tm, [mock.call('pending'), mock.call('pending')], [mock.call()]) + + assert jt.jobs.count() == 2 # task manager spawned jobs + + # Submits jobs + # intermission - jobs will run and reschedule TM when finished + self.run_tm(tm, [mock.call('waiting'), mock.call('waiting')], []) + + # I am the job runner + for job in jt.jobs.all(): + job.status = 'successful' + job.save() + + # Finishes workflow + # no further action is necessary, so rescheduling should not happen + self.run_tm(tm, [mock.call('successful')], []) + + def test_task_manager_workflow_workflow_rescheduling(self): + wfjts = [WorkflowJobTemplate.objects.create(name='foo')] + for i in range(5): + wfjt = WorkflowJobTemplate.objects.create(name='foo{}'.format(i)) + wfjts[-1].workflow_nodes.create( + unified_job_template=wfjt + ) + wfjts.append(wfjt) + + wj = wfjts[0].create_unified_job() + wj.signal_start() + tm = TaskManager() + + while wfjts[0].status != 'successful': + wfjts[1].refresh_from_db() + if wfjts[1].status == 'successful': + # final run, no more work to do + self.run_tm(tm, expect_schedule=[]) + else: + self.run_tm(tm, expect_schedule=[mock.call()]) + wfjts[0].refresh_from_db() + + @pytest.mark.django_db def test_single_jt_multi_job_launch_blocks_last(default_instance_group, job_template_factory, mocker): instance = default_instance_group.instances.all()[0] diff --git a/awx/main/utils/common.py b/awx/main/utils/common.py index 69864ba7ae..8452c4d15b 100644 --- a/awx/main/utils/common.py +++ b/awx/main/utils/common.py @@ -49,7 +49,8 @@ __all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore', 'extract_ansible_vars', 'get_search_fields', 'get_system_task_capacity', 'get_cpu_capacity', 'get_mem_capacity', 'wrap_args_with_proot', 'build_proot_temp_dir', 'check_proot_installed', 'model_to_dict', 'model_instance_diff', 'timestamp_apiformat', 'parse_yaml_or_json', 'RequireDebugTrueOrTest', - 'has_model_field_prefetched', 'set_environ', 'IllegalArgumentError', 'get_custom_venv_choices', 'get_external_account'] + 'has_model_field_prefetched', 'set_environ', 'IllegalArgumentError', 'get_custom_venv_choices', 'get_external_account', + 'task_manager_bulk_reschedule', 'schedule_task_manager'] def get_object_or_400(klass, *args, **kwargs): @@ -727,6 +728,7 @@ def get_system_task_capacity(scale=Decimal(1.0), cpu_capacity=None, mem_capacity _inventory_updates = threading.local() +_task_manager = threading.local() @contextlib.contextmanager @@ -742,6 +744,37 @@ def ignore_inventory_computed_fields(): _inventory_updates.is_updating = previous_value +def _schedule_task_manager(): + from awx.main.scheduler.tasks import run_task_manager + from django.db import connection + # runs right away if not in transaction + connection.on_commit(lambda: run_task_manager.delay()) + + +@contextlib.contextmanager +def task_manager_bulk_reschedule(): + """Context manager to avoid submitting task multiple times. + """ + try: + previous_flag = getattr(_task_manager, 'bulk_reschedule', False) + previous_value = getattr(_task_manager, 'needs_scheduling', False) + _task_manager.bulk_reschedule = True + _task_manager.needs_scheduling = False + yield + finally: + _task_manager.bulk_reschedule = previous_flag + if _task_manager.needs_scheduling: + _schedule_task_manager() + _task_manager.needs_scheduling = previous_value + + +def schedule_task_manager(): + if getattr(_task_manager, 'bulk_reschedule', False): + _task_manager.needs_scheduling = True + return + _schedule_task_manager() + + @contextlib.contextmanager def ignore_inventory_group_removal(): '''