From cddee29f2313c27cddfb6bf2cfb4b049aa168e88 Mon Sep 17 00:00:00 2001 From: thedoubl3j Date: Thu, 15 Jan 2026 19:55:54 -0500 Subject: [PATCH] More chainsaw work * fixed imports and addressed clusternode heartbeat test * took a chainsaw to task.py as well --- awx/main/dispatch/__init__.py | 5 +- awx/main/dispatch/pool.py | 6 - awx/main/dispatch/worker/__init__.py | 2 +- awx/main/dispatch/worker/base.py | 5 +- awx/main/dispatch/worker/callback.py | 34 +++- awx/main/dispatch/worker/task.py | 170 +++++------------- .../management/commands/run_cache_clear.py | 6 +- .../commands/run_rsyslog_configurer.py | 6 +- awx/main/tasks/system.py | 10 +- 9 files changed, 90 insertions(+), 154 deletions(-) diff --git a/awx/main/dispatch/__init__.py b/awx/main/dispatch/__init__.py index 97ec6774f2..a2b9a39058 100644 --- a/awx/main/dispatch/__init__.py +++ b/awx/main/dispatch/__init__.py @@ -77,14 +77,13 @@ class PubSub(object): n = psycopg.connection.Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) yield n - def events(self, yield_timeouts=False): + def events(self): if not self.conn.autocommit: raise RuntimeError('Listening for events can only be done in autocommit mode') while True: if select.select([self.conn], [], [], self.select_timeout) == NOT_READY: - if yield_timeouts: - yield None + yield None else: notification_generator = self.current_notifies(self.conn) for notification in notification_generator: diff --git a/awx/main/dispatch/pool.py b/awx/main/dispatch/pool.py index d873bef5bf..b34d90d7d2 100644 --- a/awx/main/dispatch/pool.py +++ b/awx/main/dispatch/pool.py @@ -3,7 +3,6 @@ import os import time from multiprocessing import Process -from multiprocessing import Queue as MPQueue from django.conf import settings from django.db import connection as django_connection @@ -58,11 +57,6 @@ class WorkerPool(object): implementation when it receives an AMQP message), messages are passed to one of the multiprocessing Queues where some work can be done on them. - class MessagePrinter(awx.main.dispatch.worker.BaseWorker): - - def perform_work(self, body): - print(body) - pool = WorkerPool(min_workers=4) # spawn four worker processes pool.init_workers(MessagePrint().work_loop) pool.write( diff --git a/awx/main/dispatch/worker/__init__.py b/awx/main/dispatch/worker/__init__.py index 3d04184ef5..1ea197bf9d 100644 --- a/awx/main/dispatch/worker/__init__.py +++ b/awx/main/dispatch/worker/__init__.py @@ -1,3 +1,3 @@ -from .base import AWXConsumerRedis, BaseWorker # noqa +from .base import AWXConsumerRedis # noqa from .callback import CallbackBrokerWorker # noqa from .task import TaskWorker # noqa diff --git a/awx/main/dispatch/worker/base.py b/awx/main/dispatch/worker/base.py index 477e5b5d3f..cb59adee98 100644 --- a/awx/main/dispatch/worker/base.py +++ b/awx/main/dispatch/worker/base.py @@ -5,7 +5,6 @@ import os import logging import signal import sys -import redis import time from queue import Empty as QueueEmpty @@ -61,14 +60,12 @@ class AWXConsumerBase(object): def stop(self, signum, frame): self.should_stop = True logger.warning('received {}, stopping'.format(signame(signum))) - self.worker.on_stop() raise SystemExit() class AWXConsumerRedis(AWXConsumerBase): def run(self, *args, **kwargs): super(AWXConsumerRedis, self).run(*args, **kwargs) - self.worker.on_start() logger.info(f'Callback receiver started with pid={os.getpid()}') db.connection.close() # logs use database, so close connection @@ -90,7 +87,7 @@ class BaseWorker(object): if os.getppid() != ppid: break try: - body = self.read() + body = self.read() # this is only for the callback, only reading from redis. if body == 'QUIT': break except QueueEmpty: diff --git a/awx/main/dispatch/worker/callback.py b/awx/main/dispatch/worker/callback.py index 503d978cac..37ce7aad39 100644 --- a/awx/main/dispatch/worker/callback.py +++ b/awx/main/dispatch/worker/callback.py @@ -4,10 +4,12 @@ import os import signal import time import datetime +from queue import Empty as QueueEmpty from django.conf import settings from django.utils.functional import cached_property from django.utils.timezone import now as tz_now +from django import db from django.db import transaction, connection as django_connection from django_guid import set_guid @@ -16,6 +18,7 @@ import psutil import redis from awx.main.utils.redis import get_redis_client +from awx.main.utils.db import set_connection_name from awx.main.consumers import emit_channel_notification from awx.main.models import JobEvent, AdHocCommandEvent, ProjectUpdateEvent, InventoryUpdateEvent, SystemJobEvent, UnifiedJob from awx.main.constants import ACTIVE_STATES @@ -23,7 +26,7 @@ from awx.main.models.events import emit_event_detail from awx.main.utils.profiling import AWXProfiler from awx.main.tasks.system import events_processed_hook import awx.main.analytics.subsystem_metrics as s_metrics -from .base import BaseWorker +from .base import BaseWorker, WorkerSignalHandler logger = logging.getLogger('awx.main.commands.run_callback_receiver') @@ -81,6 +84,35 @@ class CallbackBrokerWorker(BaseWorker): for key in self.redis.keys('awx_callback_receiver_statistics_*'): self.redis.delete(key) + def work_loop(self, idx, *args): + ppid = os.getppid() + signal_handler = WorkerSignalHandler() + set_connection_name('worker') # set application_name to distinguish from other dispatcher processes + while not signal_handler.kill_now: + # if the parent PID changes, this process has been orphaned + # via e.g., segfault or sigkill, we should exit too + if os.getppid() != ppid: + break + try: + body = self.read() # this is only for the callback, only reading from redis. + if body == 'QUIT': + break + except QueueEmpty: + continue + except Exception: + logger.exception("Exception on worker {}, reconnecting: ".format(idx)) + continue + try: + for conn in db.connections.all(): + # If the database connection has a hiccup during the prior message, close it + # so we can establish a new connection + conn.close_if_unusable_or_obsolete() + self.perform_work(body, *args) + except Exception: + logger.exception(f'Unhandled exception in perform_work in worker pid={os.getpid()}') + + logger.debug('worker exiting gracefully pid:{}'.format(os.getpid())) + @cached_property def pid(self): """This needs to be obtained after forking, or else it will give the parent process""" diff --git a/awx/main/dispatch/worker/task.py b/awx/main/dispatch/worker/task.py index 6726aaeae3..c9375804c2 100644 --- a/awx/main/dispatch/worker/task.py +++ b/awx/main/dispatch/worker/task.py @@ -1,144 +1,56 @@ import inspect import logging import importlib -import sys -import traceback import time -from kubernetes.config import kube_config - -from django.conf import settings from django_guid import set_guid -from awx.main.tasks.system import dispatch_startup, inform_cluster_of_shutdown - -from .base import BaseWorker logger = logging.getLogger('awx.main.dispatch') -class TaskWorker(BaseWorker): +def resolve_callable(task): """ - A worker implementation that deserializes task messages and runs native - Python code. - - The code that *builds* these types of messages is found in - `awx.main.dispatch.publish`. + Transform a dotted notation task into an imported, callable function, e.g., + awx.main.tasks.system.delete_inventory + awx.main.tasks.jobs.RunProjectUpdate """ + if not task.startswith('awx.'): + raise ValueError('{} is not a valid awx task'.format(task)) + module, target = task.rsplit('.', 1) + module = importlib.import_module(module) + _call = None + if hasattr(module, target): + _call = getattr(module, target, None) + if not (hasattr(_call, 'apply_async') and hasattr(_call, 'delay')): + raise ValueError('{} is not decorated with @task()'.format(task)) + return _call - @staticmethod - def resolve_callable(task): - """ - Transform a dotted notation task into an imported, callable function, e.g., - awx.main.tasks.system.delete_inventory - awx.main.tasks.jobs.RunProjectUpdate - """ - if not task.startswith('awx.'): - raise ValueError('{} is not a valid awx task'.format(task)) - module, target = task.rsplit('.', 1) - module = importlib.import_module(module) - _call = None - if hasattr(module, target): - _call = getattr(module, target, None) - if not (hasattr(_call, 'apply_async') and hasattr(_call, 'delay')): - raise ValueError('{} is not decorated with @task()'.format(task)) - - return _call - - @staticmethod - def run_callable(body): - """ - Given some AMQP message, import the correct Python code and run it. - """ - task = body['task'] - uuid = body.get('uuid', '') - args = body.get('args', []) - kwargs = body.get('kwargs', {}) - if 'guid' in body: - set_guid(body.pop('guid')) - _call = TaskWorker.resolve_callable(task) - if inspect.isclass(_call): - # the callable is a class, e.g., RunJob; instantiate and - # return its `run()` method - _call = _call().run - - log_extra = '' - logger_method = logger.debug - if ('time_ack' in body) and ('time_pub' in body): - time_publish = body['time_ack'] - body['time_pub'] - time_waiting = time.time() - body['time_ack'] - if time_waiting > 5.0 or time_publish > 5.0: - # If task too a very long time to process, add this information to the log - log_extra = f' took {time_publish:.4f} to ack, {time_waiting:.4f} in local dispatcher' - logger_method = logger.info - # don't print kwargs, they often contain launch-time secrets - logger_method(f'task {uuid} starting {task}(*{args}){log_extra}') - - return _call(*args, **kwargs) - - def perform_work(self, body): - """ - Import and run code for a task e.g., - - body = { - 'args': [8], - 'callbacks': [{ - 'args': [], - 'kwargs': {} - 'task': u'awx.main.tasks.system.handle_work_success' - }], - 'errbacks': [{ - 'args': [], - 'kwargs': {}, - 'task': 'awx.main.tasks.system.handle_work_error' - }], - 'kwargs': {}, - 'task': u'awx.main.tasks.jobs.RunProjectUpdate' - } - """ - settings.__clean_on_fork__() - result = None - try: - result = self.run_callable(body) - except Exception as exc: - result = exc - - try: - if getattr(exc, 'is_awx_task_error', False): - # Error caused by user / tracked in job output - logger.warning("{}".format(exc)) - else: - task = body['task'] - args = body.get('args', []) - kwargs = body.get('kwargs', {}) - logger.exception('Worker failed to run task {}(*{}, **{}'.format(task, args, kwargs)) - except Exception: - # It's fairly critical that this code _not_ raise exceptions on logging - # If you configure external logging in a way that _it_ fails, there's - # not a lot we can do here; sys.stderr.write is a final hail mary - _, _, tb = sys.exc_info() - traceback.print_tb(tb) - - for callback in body.get('errbacks', []) or []: - callback['uuid'] = body['uuid'] - self.perform_work(callback) - finally: - # It's frustrating that we have to do this, but the python k8s - # client leaves behind cacert files in /tmp, so we must clean up - # the tmpdir per-dispatcher process every time a new task comes in - try: - kube_config._cleanup_temp_files() - except Exception: - logger.exception('failed to cleanup k8s client tmp files') - - for callback in body.get('callbacks', []) or []: - callback['uuid'] = body['uuid'] - self.perform_work(callback) - return result - - def on_start(self): - dispatch_startup() - - def on_stop(self): - inform_cluster_of_shutdown() +def run_callable(body): + """ + Given some AMQP message, import the correct Python code and run it. + """ + task = body['task'] + uuid = body.get('uuid', '') + args = body.get('args', []) + kwargs = body.get('kwargs', {}) + if 'guid' in body: + set_guid(body.pop('guid')) + _call = resolve_callable(task) + if inspect.isclass(_call): + # the callable is a class, e.g., RunJob; instantiate and + # return its `run()` method + _call = _call().run + log_extra = '' + logger_method = logger.debug + if ('time_ack' in body) and ('time_pub' in body): + time_publish = body['time_ack'] - body['time_pub'] + time_waiting = time.time() - body['time_ack'] + if time_waiting > 5.0 or time_publish > 5.0: + # If task too a very long time to process, add this information to the log + log_extra = f' took {time_publish:.4f} to ack, {time_waiting:.4f} in local dispatcher' + logger_method = logger.info + # don't print kwargs, they often contain launch-time secrets + logger_method(f'task {uuid} starting {task}(*{args}){log_extra}') + return _call(*args, **kwargs) diff --git a/awx/main/management/commands/run_cache_clear.py b/awx/main/management/commands/run_cache_clear.py index bba9cd8f68..d8f35ed5d5 100644 --- a/awx/main/management/commands/run_cache_clear.py +++ b/awx/main/management/commands/run_cache_clear.py @@ -4,7 +4,7 @@ import json from django.core.management.base import BaseCommand from awx.main.dispatch import pg_bus_conn -from awx.main.dispatch.worker.task import TaskWorker +from awx.main.dispatch.worker.task import run_callable logger = logging.getLogger('awx.main.cache_clear') @@ -21,11 +21,11 @@ class Command(BaseCommand): try: with pg_bus_conn() as conn: conn.listen("tower_settings_change") - for e in conn.events(yield_timeouts=True): + for e in conn.events(): if e is not None: body = json.loads(e.payload) logger.info(f"Cache clear request received. Clearing now, payload: {e.payload}") - TaskWorker.run_callable(body) + run_callable(body) except Exception: # Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata diff --git a/awx/main/management/commands/run_rsyslog_configurer.py b/awx/main/management/commands/run_rsyslog_configurer.py index bc68370987..8df5f84331 100644 --- a/awx/main/management/commands/run_rsyslog_configurer.py +++ b/awx/main/management/commands/run_rsyslog_configurer.py @@ -5,7 +5,7 @@ from django.core.management.base import BaseCommand from django.conf import settings from django.core.cache import cache from awx.main.dispatch import pg_bus_conn -from awx.main.dispatch.worker.task import TaskWorker +from awx.main.dispatch.worker.task import run_callable from awx.main.utils.external_logging import reconfigure_rsyslog logger = logging.getLogger('awx.main.rsyslog_configurer') @@ -26,7 +26,7 @@ class Command(BaseCommand): conn.listen("rsyslog_configurer") # reconfigure rsyslog on start up reconfigure_rsyslog() - for e in conn.events(yield_timeouts=True): + for e in conn.events(): if e is not None: logger.info("Change in logging settings found. Restarting rsyslogd") # clear the cache of relevant settings then restart @@ -34,7 +34,7 @@ class Command(BaseCommand): cache.delete_many(setting_keys) settings._awx_conf_memoizedcache.clear() body = json.loads(e.payload) - TaskWorker.run_callable(body) + run_callable(body) except Exception: # Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata logger.exception('Encountered unhandled error in rsyslog_configurer main loop') diff --git a/awx/main/tasks/system.py b/awx/main/tasks/system.py index b3fe21d103..e218d130d9 100644 --- a/awx/main/tasks/system.py +++ b/awx/main/tasks/system.py @@ -14,6 +14,7 @@ from io import StringIO # dispatcherd from dispatcherd.factories import get_control_from_settings +from dispatcherd.publish import task # Runner import ansible_runner.cleanup @@ -46,9 +47,6 @@ from django.utils.translation import gettext_noop from flags.state import flag_enabled from rest_framework.exceptions import PermissionDenied -# Dispatcherd -from dispatcherd.publish import task - # AWX from awx import __version__ as awx_application_version from awx.conf import settings_registry @@ -125,7 +123,7 @@ def _run_dispatch_startup_common(): # no-op. # apply_cluster_membership_policies() - cluster_node_heartbeat() + cluster_node_heartbeat(None) reaper.startup_reaping() m = DispatcherMetrics() m.reset_values() @@ -626,6 +624,7 @@ def cluster_node_heartbeat(binder): Dispatcherd implementation. Uses Control API to get running tasks. """ + # Run common instance management logic this_inst, instance_list, lost_instances = _heartbeat_instance_management() if this_inst is None: @@ -638,6 +637,9 @@ def cluster_node_heartbeat(binder): _heartbeat_handle_lost_instances(lost_instances, this_inst) # Get running tasks using dispatcherd API + if binder is None: + logger.debug("Heartbeat finished in startup.") + return active_task_ids = _get_active_task_ids_from_dispatcherd(binder) if active_task_ids is None: logger.warning("No active task IDs retrieved from dispatcherd, skipping reaper")