AAP-43117 Additional dispatcher removal simplifications and waiting reaper updates (#16243)

* Additional dispatcher removal simplifications and waiting repear updates

* Fix double call and logging message

* Implement bugbot comment, should reap running on lost instances

* Add test case for new pending behavior
This commit is contained in:
Alan Rominger
2026-01-26 13:55:37 -05:00
committed by GitHub
parent 12a7229ee9
commit f80bbc57d8
11 changed files with 63 additions and 130 deletions

View File

@@ -30,7 +30,7 @@ def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False
}, },
"main_kwargs": {"node_id": settings.CLUSTER_HOST_ID}, "main_kwargs": {"node_id": settings.CLUSTER_HOST_ID},
"process_manager_cls": "ForkServerManager", "process_manager_cls": "ForkServerManager",
"process_manager_kwargs": {"preload_modules": ['awx.main.dispatch.hazmat']}, "process_manager_kwargs": {"preload_modules": ['awx.main.dispatch.prefork']},
}, },
"brokers": {}, "brokers": {},
"publish": {}, "publish": {},

View File

@@ -1,6 +1,4 @@
import logging import logging
import os
import time
from multiprocessing import Process from multiprocessing import Process
@@ -15,13 +13,12 @@ class PoolWorker(object):
""" """
A simple wrapper around a multiprocessing.Process that tracks a worker child process. A simple wrapper around a multiprocessing.Process that tracks a worker child process.
The worker process runs the provided target function and tracks its creation time. The worker process runs the provided target function.
""" """
def __init__(self, target, args, **kwargs): def __init__(self, target, args):
self.process = Process(target=target, args=args) self.process = Process(target=target, args=args)
self.process.daemon = True self.process.daemon = True
self.creation_time = time.monotonic()
def start(self): def start(self):
self.process.start() self.process.start()
@@ -38,44 +35,20 @@ class WorkerPool(object):
pool = WorkerPool(workers_num=4) # spawn four worker processes pool = WorkerPool(workers_num=4) # spawn four worker processes
""" """
pool_cls = PoolWorker
debug_meta = ''
def __init__(self, workers_num=None): def __init__(self, workers_num=None):
self.name = settings.CLUSTER_HOST_ID
self.pid = os.getpid()
self.workers_num = workers_num or settings.JOB_EVENT_WORKERS self.workers_num = workers_num or settings.JOB_EVENT_WORKERS
self.workers = []
def __len__(self): def init_workers(self, target):
return len(self.workers)
def init_workers(self, target, *target_args):
self.target = target
self.target_args = target_args
for idx in range(self.workers_num): for idx in range(self.workers_num):
self.up() # It's important to close these because we're _about_ to fork, and we
# don't want the forked processes to inherit the open sockets
def up(self): # for the DB and cache connections (that way lies race conditions)
idx = len(self.workers) django_connection.close()
# It's important to close these because we're _about_ to fork, and we django_cache.close()
# don't want the forked processes to inherit the open sockets worker = PoolWorker(target, (idx,))
# for the DB and cache connections (that way lies race conditions) try:
django_connection.close() worker.start()
django_cache.close() except Exception:
worker = self.pool_cls(self.target, (idx,) + self.target_args) logger.exception('could not fork')
self.workers.append(worker) else:
try: logger.debug('scaling up worker pid:{}'.format(worker.process.pid))
worker.start()
except Exception:
logger.exception('could not fork')
else:
logger.debug('scaling up worker pid:{}'.format(worker.process.pid))
return idx, worker
def stop(self, signum):
try:
for worker in self.workers:
os.kill(worker.pid, signum)
except Exception:
logger.exception('could not kill {}'.format(worker.pid))

View File

@@ -1,9 +1,6 @@
from datetime import timedelta
import logging import logging
from django.db.models import Q from django.db.models import Q
from django.conf import settings
from django.utils.timezone import now as tz_now
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from awx.main.models import Instance, UnifiedJob, WorkflowJob from awx.main.models import Instance, UnifiedJob, WorkflowJob
@@ -50,26 +47,6 @@ def reap_job(j, status, job_explanation=None):
logger.error(f'{j.log_format} is no longer {status_before}; reaping') logger.error(f'{j.log_format} is no longer {status_before}; reaping')
def reap_waiting(instance=None, status='failed', job_explanation=None, grace_period=None, excluded_uuids=None, ref_time=None):
"""
Reap all jobs in waiting for this instance.
"""
if grace_period is None:
grace_period = settings.JOB_WAITING_GRACE_PERIOD + settings.TASK_MANAGER_TIMEOUT
if instance is None:
hostname = Instance.objects.my_hostname()
else:
hostname = instance.hostname
if ref_time is None:
ref_time = tz_now()
jobs = UnifiedJob.objects.filter(status='waiting', modified__lte=ref_time - timedelta(seconds=grace_period), controller_node=hostname)
if excluded_uuids:
jobs = jobs.exclude(celery_task_id__in=excluded_uuids)
for j in jobs:
reap_job(j, status, job_explanation=job_explanation)
def reap(instance=None, status='failed', job_explanation=None, excluded_uuids=None, ref_time=None): def reap(instance=None, status='failed', job_explanation=None, excluded_uuids=None, ref_time=None):
""" """
Reap all jobs in running for this instance. Reap all jobs in running for this instance.

View File

@@ -19,49 +19,24 @@ def signame(sig):
return dict((k, v) for v, k in signal.__dict__.items() if v.startswith('SIG') and not v.startswith('SIG_'))[sig] return dict((k, v) for v, k in signal.__dict__.items() if v.startswith('SIG') and not v.startswith('SIG_'))[sig]
class WorkerSignalHandler: class AWXConsumerRedis(object):
def __init__(self):
self.kill_now = False
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, self.exit_gracefully)
def exit_gracefully(self, *args, **kwargs):
self.kill_now = True
class AWXConsumerBase(object):
last_stats = time.time()
def __init__(self, name, worker, queues=[], pool=None):
self.should_stop = False
def __init__(self, name, worker):
self.name = name self.name = name
self.total_messages = 0 self.pool = WorkerPool()
self.queues = queues self.pool.init_workers(worker.work_loop)
self.worker = worker
self.pool = pool
if pool is None:
self.pool = WorkerPool()
self.pool.init_workers(self.worker.work_loop)
self.redis = get_redis_client() self.redis = get_redis_client()
def run(self, *args, **kwargs): def run(self):
signal.signal(signal.SIGINT, self.stop) signal.signal(signal.SIGINT, self.stop)
signal.signal(signal.SIGTERM, self.stop) signal.signal(signal.SIGTERM, self.stop)
# Child should implement other things here
def stop(self, signum, frame):
self.should_stop = True
logger.warning('received {}, stopping'.format(signame(signum)))
raise SystemExit()
class AWXConsumerRedis(AWXConsumerBase):
def run(self, *args, **kwargs):
super(AWXConsumerRedis, self).run(*args, **kwargs)
logger.info(f'Callback receiver started with pid={os.getpid()}') logger.info(f'Callback receiver started with pid={os.getpid()}')
db.connection.close() # logs use database, so close connection db.connection.close() # logs use database, so close connection
while True: while True:
time.sleep(60) time.sleep(60)
def stop(self, signum, frame):
logger.warning('received {}, stopping'.format(signame(signum)))
raise SystemExit()

View File

@@ -26,7 +26,6 @@ from awx.main.models.events import emit_event_detail
from awx.main.utils.profiling import AWXProfiler from awx.main.utils.profiling import AWXProfiler
from awx.main.tasks.system import events_processed_hook from awx.main.tasks.system import events_processed_hook
import awx.main.analytics.subsystem_metrics as s_metrics import awx.main.analytics.subsystem_metrics as s_metrics
from .base import WorkerSignalHandler
logger = logging.getLogger('awx.main.commands.run_callback_receiver') logger = logging.getLogger('awx.main.commands.run_callback_receiver')
@@ -57,6 +56,16 @@ def job_stats_wrapup(job_identifier, event=None):
logger.exception('Worker failed to save stats or emit notifications: Job {}'.format(job_identifier)) logger.exception('Worker failed to save stats or emit notifications: Job {}'.format(job_identifier))
class WorkerSignalHandler:
def __init__(self):
self.kill_now = False
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, self.exit_gracefully)
def exit_gracefully(self, *args, **kwargs):
self.kill_now = True
class CallbackBrokerWorker: class CallbackBrokerWorker:
""" """
A worker implementation that deserializes callback event data and persists A worker implementation that deserializes callback event data and persists

View File

@@ -1,4 +1,3 @@
import inspect
import logging import logging
import importlib import importlib
import time import time
@@ -37,18 +36,13 @@ def run_callable(body):
if 'guid' in body: if 'guid' in body:
set_guid(body.pop('guid')) set_guid(body.pop('guid'))
_call = resolve_callable(task) _call = resolve_callable(task)
if inspect.isclass(_call):
# the callable is a class, e.g., RunJob; instantiate and
# return its `run()` method
_call = _call().run
log_extra = '' log_extra = ''
logger_method = logger.debug logger_method = logger.debug
if ('time_ack' in body) and ('time_pub' in body): if 'time_pub' in body:
time_publish = body['time_ack'] - body['time_pub'] time_publish = time.time() - body['time_pub']
time_waiting = time.time() - body['time_ack'] if time_publish > 5.0:
if time_waiting > 5.0 or time_publish > 5.0:
# If task too a very long time to process, add this information to the log # If task too a very long time to process, add this information to the log
log_extra = f' took {time_publish:.4f} to ack, {time_waiting:.4f} in local dispatcher' log_extra = f' took {time_publish:.4f} to send message'
logger_method = logger.info logger_method = logger.info
# don't print kwargs, they often contain launch-time secrets # don't print kwargs, they often contain launch-time secrets
logger_method(f'task {uuid} starting {task}(*{args}){log_extra}') logger_method(f'task {uuid} starting {task}(*{args}){log_extra}')

View File

@@ -3,7 +3,6 @@
import redis import redis
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
import redis.exceptions import redis.exceptions
@@ -36,11 +35,7 @@ class Command(BaseCommand):
raise CommandError(f'Callback receiver could not connect to redis, error: {exc}') raise CommandError(f'Callback receiver could not connect to redis, error: {exc}')
try: try:
consumer = AWXConsumerRedis( consumer = AWXConsumerRedis('callback_receiver', CallbackBrokerWorker())
'callback_receiver',
CallbackBrokerWorker(),
queues=[getattr(settings, 'CALLBACK_QUEUE', '')],
)
consumer.run() consumer.run()
except KeyboardInterrupt: except KeyboardInterrupt:
print('Terminating Callback Receiver') print('Terminating Callback Receiver')

View File

@@ -73,17 +73,16 @@ class Command(BaseCommand):
dispatcher_setup(get_dispatcherd_config(for_service=True)) dispatcher_setup(get_dispatcherd_config(for_service=True))
run_service() run_service()
dispatcher_setup(get_dispatcherd_config(for_service=True))
run_service()
def configure_dispatcher_logging(self): def configure_dispatcher_logging(self):
# Apply special log rule for the parent process # Apply special log rule for the parent process
special_logging = copy.deepcopy(settings.LOGGING) special_logging = copy.deepcopy(settings.LOGGING)
changed_handlers = []
for handler_name, handler_config in special_logging.get('handlers', {}).items(): for handler_name, handler_config in special_logging.get('handlers', {}).items():
filters = handler_config.get('filters', []) filters = handler_config.get('filters', [])
if 'dynamic_level_filter' in filters: if 'dynamic_level_filter' in filters:
handler_config['filters'] = [flt for flt in filters if flt != 'dynamic_level_filter'] handler_config['filters'] = [flt for flt in filters if flt != 'dynamic_level_filter']
logger.info(f'Dispatcherd main process replaced log level filter for {handler_name} handler') changed_handlers.append(handler_name)
logger.info(f'Dispatcherd main process replaced log level filter for handlers: {changed_handlers}')
# Apply the custom logging level here, before the asyncio code starts # Apply the custom logging level here, before the asyncio code starts
special_logging.setdefault('loggers', {}).setdefault('dispatcherd', {}) special_logging.setdefault('loggers', {}).setdefault('dispatcherd', {})

View File

@@ -760,14 +760,16 @@ def _heartbeat_check_versions(this_inst, instance_list):
def _heartbeat_handle_lost_instances(lost_instances, this_inst): def _heartbeat_handle_lost_instances(lost_instances, this_inst):
"""Handle lost instances by reaping their jobs and marking them offline.""" """Handle lost instances by reaping their running jobs and marking them offline."""
for other_inst in lost_instances: for other_inst in lost_instances:
try: try:
# Any jobs marked as running will be marked as error
explanation = "Job reaped due to instance shutdown" explanation = "Job reaped due to instance shutdown"
reaper.reap(other_inst, job_explanation=explanation) reaper.reap(other_inst, job_explanation=explanation)
reaper.reap_waiting(other_inst, grace_period=0, job_explanation=explanation) # Any jobs that were waiting to be processed by this node will be handed back to task manager
UnifiedJob.objects.filter(status='waiting', controller_node=other_inst.hostname).update(status='pending', controller_node='', execution_node='')
except Exception: except Exception:
logger.exception('failed to reap jobs for {}'.format(other_inst.hostname)) logger.exception('failed to re-process jobs for lost instance {}'.format(other_inst.hostname))
try: try:
if settings.AWX_AUTO_DEPROVISION_INSTANCES and other_inst.node_type == "control": if settings.AWX_AUTO_DEPROVISION_INSTANCES and other_inst.node_type == "control":
deprovision_hostname = other_inst.hostname deprovision_hostname = other_inst.hostname

View File

@@ -5,6 +5,7 @@ import pytest
from awx.main.models import Job, WorkflowJob, Instance from awx.main.models import Job, WorkflowJob, Instance
from awx.main.dispatch import reaper from awx.main.dispatch import reaper
from awx.main.tasks import system
from dispatcherd.publish import task from dispatcherd.publish import task
''' '''
@@ -61,11 +62,6 @@ class TestJobReaper(object):
('running', '', '', None, False), # running, not assigned to the instance ('running', '', '', None, False), # running, not assigned to the instance
('running', 'awx', '', None, True), # running, has the instance as its execution_node ('running', 'awx', '', None, True), # running, has the instance as its execution_node
('running', '', 'awx', None, True), # running, has the instance as its controller_node ('running', '', 'awx', None, True), # running, has the instance as its controller_node
('waiting', '', '', None, False), # waiting, not assigned to the instance
('waiting', 'awx', '', None, False), # waiting, was edited less than a minute ago
('waiting', '', 'awx', None, False), # waiting, was edited less than a minute ago
('waiting', 'awx', '', yesterday, False), # waiting, managed by another node, ignore
('waiting', '', 'awx', yesterday, True), # waiting, assigned to the controller_node, stale
], ],
) )
def test_should_reap(self, status, fail, execution_node, controller_node, modified): def test_should_reap(self, status, fail, execution_node, controller_node, modified):
@@ -83,7 +79,6 @@ class TestJobReaper(object):
# (because .save() overwrites it to _now_) # (because .save() overwrites it to _now_)
Job.objects.filter(id=j.id).update(modified=modified) Job.objects.filter(id=j.id).update(modified=modified)
reaper.reap(i) reaper.reap(i)
reaper.reap_waiting(i)
job = Job.objects.first() job = Job.objects.first()
if fail: if fail:
assert job.status == 'failed' assert job.status == 'failed'
@@ -92,6 +87,20 @@ class TestJobReaper(object):
else: else:
assert job.status == status assert job.status == status
def test_waiting_job_sent_back_to_pending(self):
this_inst = Instance(hostname='awx')
this_inst.save()
lost_inst = Instance(hostname='lost', node_type=Instance.Types.EXECUTION, node_state=Instance.States.UNAVAILABLE)
lost_inst.save()
job = Job.objects.create(status='waiting', controller_node=lost_inst.hostname, execution_node='lost')
system._heartbeat_handle_lost_instances([lost_inst], this_inst)
job.refresh_from_db()
assert job.status == 'pending'
assert job.controller_node == ''
assert job.execution_node == ''
@pytest.mark.parametrize( @pytest.mark.parametrize(
'excluded_uuids, fail, started', 'excluded_uuids, fail, started',
[ [