Merge pull request #12582 from AlanCoding/clean_and_forget

Move reaper logic into worker, avoiding bottlenecks
This commit is contained in:
Alan Rominger
2022-08-17 18:53:47 -04:00
committed by GitHub
5 changed files with 43 additions and 28 deletions

View File

@@ -16,6 +16,7 @@ from queue import Full as QueueFull, Empty as QueueEmpty
from django.conf import settings from django.conf import settings
from django.db import connection as django_connection, connections from django.db import connection as django_connection, connections
from django.core.cache import cache as django_cache from django.core.cache import cache as django_cache
from django.utils.timezone import now as tz_now
from django_guid import set_guid from django_guid import set_guid
from jinja2 import Template from jinja2 import Template
import psutil import psutil
@@ -377,8 +378,6 @@ class AutoscalePool(WorkerPool):
1. Discover worker processes that exited, and recover messages they 1. Discover worker processes that exited, and recover messages they
were handling. were handling.
2. Clean up unnecessary, idle workers. 2. Clean up unnecessary, idle workers.
3. Check to see if the database says this node is running any tasks
that aren't actually running. If so, reap them.
IMPORTANT: this function is one of the few places in the dispatcher IMPORTANT: this function is one of the few places in the dispatcher
(aside from setting lookups) where we talk to the database. As such, (aside from setting lookups) where we talk to the database. As such,
@@ -437,18 +436,17 @@ class AutoscalePool(WorkerPool):
idx = random.choice(range(len(self.workers))) idx = random.choice(range(len(self.workers)))
self.write(idx, m) self.write(idx, m)
# if the database says a job is running or queued on this node, but it's *not*, def add_bind_kwargs(self, body):
# then reap it bind_kwargs = body.pop('bind_kwargs', [])
running_uuids = [] body.setdefault('kwargs', {})
for worker in self.workers: if 'dispatch_time' in bind_kwargs:
worker.calculate_managed_tasks() body['kwargs']['dispatch_time'] = tz_now().isoformat()
running_uuids.extend(list(worker.managed_tasks.keys())) if 'worker_tasks' in bind_kwargs:
worker_tasks = {}
# if we are not in the dangerous situation of queue backup then clear old waiting jobs for worker in self.workers:
if self.workers and max(len(w.managed_tasks) for w in self.workers) <= 1: worker.calculate_managed_tasks()
reaper.reap_waiting(excluded_uuids=running_uuids) worker_tasks[worker.pid] = list(worker.managed_tasks.keys())
body['kwargs']['worker_tasks'] = worker_tasks
reaper.reap(excluded_uuids=running_uuids)
def up(self): def up(self):
if self.full: if self.full:
@@ -463,6 +461,8 @@ class AutoscalePool(WorkerPool):
if 'guid' in body: if 'guid' in body:
set_guid(body['guid']) set_guid(body['guid'])
try: try:
if isinstance(body, dict) and body.get('bind_kwargs'):
self.add_bind_kwargs(body)
# when the cluster heartbeat occurs, clean up internally # when the cluster heartbeat occurs, clean up internally
if isinstance(body, dict) and 'cluster_node_heartbeat' in body['task']: if isinstance(body, dict) and 'cluster_node_heartbeat' in body['task']:
self.cleanup() self.cleanup()

View File

@@ -50,13 +50,21 @@ class task:
@task(queue='tower_broadcast') @task(queue='tower_broadcast')
def announce(): def announce():
print("Run this everywhere!") print("Run this everywhere!")
# The special parameter bind_kwargs tells the main dispatcher process to add certain kwargs
@task(bind_kwargs=['dispatch_time'])
def print_time(dispatch_time=None):
print(f"Time I was dispatched: {dispatch_time}")
""" """
def __init__(self, queue=None): def __init__(self, queue=None, bind_kwargs=None):
self.queue = queue self.queue = queue
self.bind_kwargs = bind_kwargs
def __call__(self, fn=None): def __call__(self, fn=None):
queue = self.queue queue = self.queue
bind_kwargs = self.bind_kwargs
class PublisherMixin(object): class PublisherMixin(object):
@@ -80,6 +88,8 @@ class task:
guid = get_guid() guid = get_guid()
if guid: if guid:
obj['guid'] = guid obj['guid'] = guid
if bind_kwargs:
obj['bind_kwargs'] = bind_kwargs
obj.update(**kw) obj.update(**kw)
if callable(queue): if callable(queue):
queue = queue() queue = queue()

View File

@@ -55,7 +55,7 @@ def reap_job(j, status, job_explanation=None):
logger.error(f'{j.log_format} is no longer {status_before}; reaping') logger.error(f'{j.log_format} is no longer {status_before}; reaping')
def reap_waiting(instance=None, status='failed', job_explanation=None, grace_period=None, excluded_uuids=None): def reap_waiting(instance=None, status='failed', job_explanation=None, grace_period=None, excluded_uuids=None, ref_time=None):
""" """
Reap all jobs in waiting for this instance. Reap all jobs in waiting for this instance.
""" """
@@ -69,8 +69,9 @@ def reap_waiting(instance=None, status='failed', job_explanation=None, grace_per
except RuntimeError as e: except RuntimeError as e:
logger.warning(f'Local instance is not registered, not running reaper: {e}') logger.warning(f'Local instance is not registered, not running reaper: {e}')
return return
now = tz_now() if ref_time is None:
jobs = UnifiedJob.objects.filter(status='waiting', modified__lte=now - timedelta(seconds=grace_period), controller_node=me.hostname) ref_time = tz_now()
jobs = UnifiedJob.objects.filter(status='waiting', modified__lte=ref_time - timedelta(seconds=grace_period), controller_node=me.hostname)
if excluded_uuids: if excluded_uuids:
jobs = jobs.exclude(celery_task_id__in=excluded_uuids) jobs = jobs.exclude(celery_task_id__in=excluded_uuids)
for j in jobs: for j in jobs:

View File

@@ -10,6 +10,7 @@ from contextlib import redirect_stdout
import shutil import shutil
import time import time
from distutils.version import LooseVersion as Version from distutils.version import LooseVersion as Version
from datetime import datetime
# Django # Django
from django.conf import settings from django.conf import settings
@@ -482,8 +483,8 @@ def inspect_execution_nodes(instance_list):
execution_node_health_check.apply_async([hostname]) execution_node_health_check.apply_async([hostname])
@task(queue=get_local_queuename) @task(queue=get_local_queuename, bind_kwargs=['dispatch_time', 'worker_tasks'])
def cluster_node_heartbeat(): def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None):
logger.debug("Cluster node heartbeat task.") logger.debug("Cluster node heartbeat task.")
nowtime = now() nowtime = now()
instance_list = list(Instance.objects.all()) instance_list = list(Instance.objects.all())
@@ -562,6 +563,15 @@ def cluster_node_heartbeat():
else: else:
logger.exception('Error marking {} as lost'.format(other_inst.hostname)) logger.exception('Error marking {} as lost'.format(other_inst.hostname))
# Run local reaper
if worker_tasks is not None:
active_task_ids = []
for task_list in worker_tasks.values():
active_task_ids.extend(task_list)
reaper.reap(instance=this_inst, excluded_uuids=active_task_ids)
if max(len(task_list) for task_list in worker_tasks.values()) <= 1:
reaper.reap_waiting(instance=this_inst, excluded_uuids=active_task_ids, ref_time=datetime.fromisoformat(dispatch_time))
@task(queue=get_local_queuename) @task(queue=get_local_queuename)
def awx_receptor_workunit_reaper(): def awx_receptor_workunit_reaper():

View File

@@ -199,10 +199,7 @@ class TestAutoScaling:
assert len(self.pool) == 10 assert len(self.pool) == 10
# cleanup should scale down to 8 workers # cleanup should scale down to 8 workers
with mock.patch('awx.main.dispatch.reaper.reap') as reap: self.pool.cleanup()
with mock.patch('awx.main.dispatch.reaper.reap_waiting') as reap:
self.pool.cleanup()
reap.assert_called()
assert len(self.pool) == 2 assert len(self.pool) == 2
def test_max_scale_up(self): def test_max_scale_up(self):
@@ -250,10 +247,7 @@ class TestAutoScaling:
time.sleep(1) # wait a moment for sigterm time.sleep(1) # wait a moment for sigterm
# clean up and the dead worker # clean up and the dead worker
with mock.patch('awx.main.dispatch.reaper.reap') as reap: self.pool.cleanup()
with mock.patch('awx.main.dispatch.reaper.reap_waiting') as reap:
self.pool.cleanup()
reap.assert_called()
assert len(self.pool) == 1 assert len(self.pool) == 1
assert self.pool.workers[0].pid == alive_pid assert self.pool.workers[0].pid == alive_pid