diff --git a/awx/main/dispatch/pool.py b/awx/main/dispatch/pool.py index ce6c297861..527b4b52fb 100644 --- a/awx/main/dispatch/pool.py +++ b/awx/main/dispatch/pool.py @@ -16,6 +16,7 @@ from queue import Full as QueueFull, Empty as QueueEmpty from django.conf import settings from django.db import connection as django_connection, connections from django.core.cache import cache as django_cache +from django.utils.timezone import now as tz_now from django_guid import set_guid from jinja2 import Template import psutil @@ -377,8 +378,6 @@ class AutoscalePool(WorkerPool): 1. Discover worker processes that exited, and recover messages they were handling. 2. Clean up unnecessary, idle workers. - 3. Check to see if the database says this node is running any tasks - that aren't actually running. If so, reap them. IMPORTANT: this function is one of the few places in the dispatcher (aside from setting lookups) where we talk to the database. As such, @@ -437,18 +436,17 @@ class AutoscalePool(WorkerPool): idx = random.choice(range(len(self.workers))) self.write(idx, m) - # if the database says a job is running or queued on this node, but it's *not*, - # then reap it - running_uuids = [] - for worker in self.workers: - worker.calculate_managed_tasks() - running_uuids.extend(list(worker.managed_tasks.keys())) - - # if we are not in the dangerous situation of queue backup then clear old waiting jobs - if self.workers and max(len(w.managed_tasks) for w in self.workers) <= 1: - reaper.reap_waiting(excluded_uuids=running_uuids) - - reaper.reap(excluded_uuids=running_uuids) + def add_bind_kwargs(self, body): + bind_kwargs = body.pop('bind_kwargs', []) + body.setdefault('kwargs', {}) + if 'dispatch_time' in bind_kwargs: + body['kwargs']['dispatch_time'] = tz_now().isoformat() + if 'worker_tasks' in bind_kwargs: + worker_tasks = {} + for worker in self.workers: + worker.calculate_managed_tasks() + worker_tasks[worker.pid] = list(worker.managed_tasks.keys()) + body['kwargs']['worker_tasks'] = worker_tasks def up(self): if self.full: @@ -463,6 +461,8 @@ class AutoscalePool(WorkerPool): if 'guid' in body: set_guid(body['guid']) try: + if isinstance(body, dict) and body.get('bind_kwargs'): + self.add_bind_kwargs(body) # when the cluster heartbeat occurs, clean up internally if isinstance(body, dict) and 'cluster_node_heartbeat' in body['task']: self.cleanup() diff --git a/awx/main/dispatch/publish.py b/awx/main/dispatch/publish.py index dd19c1338c..bc496496d5 100644 --- a/awx/main/dispatch/publish.py +++ b/awx/main/dispatch/publish.py @@ -50,13 +50,21 @@ class task: @task(queue='tower_broadcast') def announce(): print("Run this everywhere!") + + # The special parameter bind_kwargs tells the main dispatcher process to add certain kwargs + + @task(bind_kwargs=['dispatch_time']) + def print_time(dispatch_time=None): + print(f"Time I was dispatched: {dispatch_time}") """ - def __init__(self, queue=None): + def __init__(self, queue=None, bind_kwargs=None): self.queue = queue + self.bind_kwargs = bind_kwargs def __call__(self, fn=None): queue = self.queue + bind_kwargs = self.bind_kwargs class PublisherMixin(object): @@ -80,6 +88,8 @@ class task: guid = get_guid() if guid: obj['guid'] = guid + if bind_kwargs: + obj['bind_kwargs'] = bind_kwargs obj.update(**kw) if callable(queue): queue = queue() diff --git a/awx/main/dispatch/reaper.py b/awx/main/dispatch/reaper.py index 7a0ae1b884..4248eac3f6 100644 --- a/awx/main/dispatch/reaper.py +++ b/awx/main/dispatch/reaper.py @@ -55,7 +55,7 @@ def reap_job(j, status, job_explanation=None): logger.error(f'{j.log_format} is no longer {status_before}; reaping') -def reap_waiting(instance=None, status='failed', job_explanation=None, grace_period=None, excluded_uuids=None): +def reap_waiting(instance=None, status='failed', job_explanation=None, grace_period=None, excluded_uuids=None, ref_time=None): """ Reap all jobs in waiting for this instance. """ @@ -69,8 +69,9 @@ def reap_waiting(instance=None, status='failed', job_explanation=None, grace_per except RuntimeError as e: logger.warning(f'Local instance is not registered, not running reaper: {e}') return - now = tz_now() - jobs = UnifiedJob.objects.filter(status='waiting', modified__lte=now - timedelta(seconds=grace_period), controller_node=me.hostname) + if ref_time is None: + ref_time = tz_now() + jobs = UnifiedJob.objects.filter(status='waiting', modified__lte=ref_time - timedelta(seconds=grace_period), controller_node=me.hostname) if excluded_uuids: jobs = jobs.exclude(celery_task_id__in=excluded_uuids) for j in jobs: diff --git a/awx/main/tasks/system.py b/awx/main/tasks/system.py index e36c502400..d4f067115e 100644 --- a/awx/main/tasks/system.py +++ b/awx/main/tasks/system.py @@ -10,6 +10,7 @@ from contextlib import redirect_stdout import shutil import time from distutils.version import LooseVersion as Version +from datetime import datetime # Django from django.conf import settings @@ -482,8 +483,8 @@ def inspect_execution_nodes(instance_list): execution_node_health_check.apply_async([hostname]) -@task(queue=get_local_queuename) -def cluster_node_heartbeat(): +@task(queue=get_local_queuename, bind_kwargs=['dispatch_time', 'worker_tasks']) +def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None): logger.debug("Cluster node heartbeat task.") nowtime = now() instance_list = list(Instance.objects.all()) @@ -562,6 +563,15 @@ def cluster_node_heartbeat(): else: logger.exception('Error marking {} as lost'.format(other_inst.hostname)) + # Run local reaper + if worker_tasks is not None: + active_task_ids = [] + for task_list in worker_tasks.values(): + active_task_ids.extend(task_list) + reaper.reap(instance=this_inst, excluded_uuids=active_task_ids) + if max(len(task_list) for task_list in worker_tasks.values()) <= 1: + reaper.reap_waiting(instance=this_inst, excluded_uuids=active_task_ids, ref_time=datetime.fromisoformat(dispatch_time)) + @task(queue=get_local_queuename) def awx_receptor_workunit_reaper(): diff --git a/awx/main/tests/functional/test_dispatch.py b/awx/main/tests/functional/test_dispatch.py index 4b65726b84..f3c9afe58b 100644 --- a/awx/main/tests/functional/test_dispatch.py +++ b/awx/main/tests/functional/test_dispatch.py @@ -199,10 +199,7 @@ class TestAutoScaling: assert len(self.pool) == 10 # cleanup should scale down to 8 workers - with mock.patch('awx.main.dispatch.reaper.reap') as reap: - with mock.patch('awx.main.dispatch.reaper.reap_waiting') as reap: - self.pool.cleanup() - reap.assert_called() + self.pool.cleanup() assert len(self.pool) == 2 def test_max_scale_up(self): @@ -250,10 +247,7 @@ class TestAutoScaling: time.sleep(1) # wait a moment for sigterm # clean up and the dead worker - with mock.patch('awx.main.dispatch.reaper.reap') as reap: - with mock.patch('awx.main.dispatch.reaper.reap_waiting') as reap: - self.pool.cleanup() - reap.assert_called() + self.pool.cleanup() assert len(self.pool) == 1 assert self.pool.workers[0].pid == alive_pid