More chainsaw work

* fixed imports and addressed clusternode heartbeat test
* took a chainsaw to task.py as well
This commit is contained in:
thedoubl3j
2026-01-15 19:55:54 -05:00
parent 3b896a00a9
commit cddee29f23
9 changed files with 90 additions and 154 deletions

View File

@@ -77,14 +77,13 @@ class PubSub(object):
n = psycopg.connection.Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) n = psycopg.connection.Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
yield n yield n
def events(self, yield_timeouts=False): def events(self):
if not self.conn.autocommit: if not self.conn.autocommit:
raise RuntimeError('Listening for events can only be done in autocommit mode') raise RuntimeError('Listening for events can only be done in autocommit mode')
while True: while True:
if select.select([self.conn], [], [], self.select_timeout) == NOT_READY: if select.select([self.conn], [], [], self.select_timeout) == NOT_READY:
if yield_timeouts: yield None
yield None
else: else:
notification_generator = self.current_notifies(self.conn) notification_generator = self.current_notifies(self.conn)
for notification in notification_generator: for notification in notification_generator:

View File

@@ -3,7 +3,6 @@ import os
import time import time
from multiprocessing import Process from multiprocessing import Process
from multiprocessing import Queue as MPQueue
from django.conf import settings from django.conf import settings
from django.db import connection as django_connection from django.db import connection as django_connection
@@ -58,11 +57,6 @@ class WorkerPool(object):
implementation when it receives an AMQP message), messages are passed to implementation when it receives an AMQP message), messages are passed to
one of the multiprocessing Queues where some work can be done on them. one of the multiprocessing Queues where some work can be done on them.
class MessagePrinter(awx.main.dispatch.worker.BaseWorker):
def perform_work(self, body):
print(body)
pool = WorkerPool(min_workers=4) # spawn four worker processes pool = WorkerPool(min_workers=4) # spawn four worker processes
pool.init_workers(MessagePrint().work_loop) pool.init_workers(MessagePrint().work_loop)
pool.write( pool.write(

View File

@@ -1,3 +1,3 @@
from .base import AWXConsumerRedis, BaseWorker # noqa from .base import AWXConsumerRedis # noqa
from .callback import CallbackBrokerWorker # noqa from .callback import CallbackBrokerWorker # noqa
from .task import TaskWorker # noqa from .task import TaskWorker # noqa

View File

@@ -5,7 +5,6 @@ import os
import logging import logging
import signal import signal
import sys import sys
import redis
import time import time
from queue import Empty as QueueEmpty from queue import Empty as QueueEmpty
@@ -61,14 +60,12 @@ class AWXConsumerBase(object):
def stop(self, signum, frame): def stop(self, signum, frame):
self.should_stop = True self.should_stop = True
logger.warning('received {}, stopping'.format(signame(signum))) logger.warning('received {}, stopping'.format(signame(signum)))
self.worker.on_stop()
raise SystemExit() raise SystemExit()
class AWXConsumerRedis(AWXConsumerBase): class AWXConsumerRedis(AWXConsumerBase):
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
super(AWXConsumerRedis, self).run(*args, **kwargs) super(AWXConsumerRedis, self).run(*args, **kwargs)
self.worker.on_start()
logger.info(f'Callback receiver started with pid={os.getpid()}') logger.info(f'Callback receiver started with pid={os.getpid()}')
db.connection.close() # logs use database, so close connection db.connection.close() # logs use database, so close connection
@@ -90,7 +87,7 @@ class BaseWorker(object):
if os.getppid() != ppid: if os.getppid() != ppid:
break break
try: try:
body = self.read() body = self.read() # this is only for the callback, only reading from redis.
if body == 'QUIT': if body == 'QUIT':
break break
except QueueEmpty: except QueueEmpty:

View File

@@ -4,10 +4,12 @@ import os
import signal import signal
import time import time
import datetime import datetime
from queue import Empty as QueueEmpty
from django.conf import settings from django.conf import settings
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.timezone import now as tz_now from django.utils.timezone import now as tz_now
from django import db
from django.db import transaction, connection as django_connection from django.db import transaction, connection as django_connection
from django_guid import set_guid from django_guid import set_guid
@@ -16,6 +18,7 @@ import psutil
import redis import redis
from awx.main.utils.redis import get_redis_client from awx.main.utils.redis import get_redis_client
from awx.main.utils.db import set_connection_name
from awx.main.consumers import emit_channel_notification from awx.main.consumers import emit_channel_notification
from awx.main.models import JobEvent, AdHocCommandEvent, ProjectUpdateEvent, InventoryUpdateEvent, SystemJobEvent, UnifiedJob from awx.main.models import JobEvent, AdHocCommandEvent, ProjectUpdateEvent, InventoryUpdateEvent, SystemJobEvent, UnifiedJob
from awx.main.constants import ACTIVE_STATES from awx.main.constants import ACTIVE_STATES
@@ -23,7 +26,7 @@ from awx.main.models.events import emit_event_detail
from awx.main.utils.profiling import AWXProfiler from awx.main.utils.profiling import AWXProfiler
from awx.main.tasks.system import events_processed_hook from awx.main.tasks.system import events_processed_hook
import awx.main.analytics.subsystem_metrics as s_metrics import awx.main.analytics.subsystem_metrics as s_metrics
from .base import BaseWorker from .base import BaseWorker, WorkerSignalHandler
logger = logging.getLogger('awx.main.commands.run_callback_receiver') logger = logging.getLogger('awx.main.commands.run_callback_receiver')
@@ -81,6 +84,35 @@ class CallbackBrokerWorker(BaseWorker):
for key in self.redis.keys('awx_callback_receiver_statistics_*'): for key in self.redis.keys('awx_callback_receiver_statistics_*'):
self.redis.delete(key) self.redis.delete(key)
def work_loop(self, idx, *args):
ppid = os.getppid()
signal_handler = WorkerSignalHandler()
set_connection_name('worker') # set application_name to distinguish from other dispatcher processes
while not signal_handler.kill_now:
# if the parent PID changes, this process has been orphaned
# via e.g., segfault or sigkill, we should exit too
if os.getppid() != ppid:
break
try:
body = self.read() # this is only for the callback, only reading from redis.
if body == 'QUIT':
break
except QueueEmpty:
continue
except Exception:
logger.exception("Exception on worker {}, reconnecting: ".format(idx))
continue
try:
for conn in db.connections.all():
# If the database connection has a hiccup during the prior message, close it
# so we can establish a new connection
conn.close_if_unusable_or_obsolete()
self.perform_work(body, *args)
except Exception:
logger.exception(f'Unhandled exception in perform_work in worker pid={os.getpid()}')
logger.debug('worker exiting gracefully pid:{}'.format(os.getpid()))
@cached_property @cached_property
def pid(self): def pid(self):
"""This needs to be obtained after forking, or else it will give the parent process""" """This needs to be obtained after forking, or else it will give the parent process"""

View File

@@ -1,144 +1,56 @@
import inspect import inspect
import logging import logging
import importlib import importlib
import sys
import traceback
import time import time
from kubernetes.config import kube_config
from django.conf import settings
from django_guid import set_guid from django_guid import set_guid
from awx.main.tasks.system import dispatch_startup, inform_cluster_of_shutdown
from .base import BaseWorker
logger = logging.getLogger('awx.main.dispatch') logger = logging.getLogger('awx.main.dispatch')
class TaskWorker(BaseWorker): def resolve_callable(task):
""" """
A worker implementation that deserializes task messages and runs native Transform a dotted notation task into an imported, callable function, e.g.,
Python code. awx.main.tasks.system.delete_inventory
awx.main.tasks.jobs.RunProjectUpdate
The code that *builds* these types of messages is found in
`awx.main.dispatch.publish`.
""" """
if not task.startswith('awx.'):
raise ValueError('{} is not a valid awx task'.format(task))
module, target = task.rsplit('.', 1)
module = importlib.import_module(module)
_call = None
if hasattr(module, target):
_call = getattr(module, target, None)
if not (hasattr(_call, 'apply_async') and hasattr(_call, 'delay')):
raise ValueError('{} is not decorated with @task()'.format(task))
return _call
@staticmethod
def resolve_callable(task):
"""
Transform a dotted notation task into an imported, callable function, e.g.,
awx.main.tasks.system.delete_inventory def run_callable(body):
awx.main.tasks.jobs.RunProjectUpdate """
""" Given some AMQP message, import the correct Python code and run it.
if not task.startswith('awx.'): """
raise ValueError('{} is not a valid awx task'.format(task)) task = body['task']
module, target = task.rsplit('.', 1) uuid = body.get('uuid', '<unknown>')
module = importlib.import_module(module) args = body.get('args', [])
_call = None kwargs = body.get('kwargs', {})
if hasattr(module, target): if 'guid' in body:
_call = getattr(module, target, None) set_guid(body.pop('guid'))
if not (hasattr(_call, 'apply_async') and hasattr(_call, 'delay')): _call = resolve_callable(task)
raise ValueError('{} is not decorated with @task()'.format(task)) if inspect.isclass(_call):
# the callable is a class, e.g., RunJob; instantiate and
return _call # return its `run()` method
_call = _call().run
@staticmethod log_extra = ''
def run_callable(body): logger_method = logger.debug
""" if ('time_ack' in body) and ('time_pub' in body):
Given some AMQP message, import the correct Python code and run it. time_publish = body['time_ack'] - body['time_pub']
""" time_waiting = time.time() - body['time_ack']
task = body['task'] if time_waiting > 5.0 or time_publish > 5.0:
uuid = body.get('uuid', '<unknown>') # If task too a very long time to process, add this information to the log
args = body.get('args', []) log_extra = f' took {time_publish:.4f} to ack, {time_waiting:.4f} in local dispatcher'
kwargs = body.get('kwargs', {}) logger_method = logger.info
if 'guid' in body: # don't print kwargs, they often contain launch-time secrets
set_guid(body.pop('guid')) logger_method(f'task {uuid} starting {task}(*{args}){log_extra}')
_call = TaskWorker.resolve_callable(task) return _call(*args, **kwargs)
if inspect.isclass(_call):
# the callable is a class, e.g., RunJob; instantiate and
# return its `run()` method
_call = _call().run
log_extra = ''
logger_method = logger.debug
if ('time_ack' in body) and ('time_pub' in body):
time_publish = body['time_ack'] - body['time_pub']
time_waiting = time.time() - body['time_ack']
if time_waiting > 5.0 or time_publish > 5.0:
# If task too a very long time to process, add this information to the log
log_extra = f' took {time_publish:.4f} to ack, {time_waiting:.4f} in local dispatcher'
logger_method = logger.info
# don't print kwargs, they often contain launch-time secrets
logger_method(f'task {uuid} starting {task}(*{args}){log_extra}')
return _call(*args, **kwargs)
def perform_work(self, body):
"""
Import and run code for a task e.g.,
body = {
'args': [8],
'callbacks': [{
'args': [],
'kwargs': {}
'task': u'awx.main.tasks.system.handle_work_success'
}],
'errbacks': [{
'args': [],
'kwargs': {},
'task': 'awx.main.tasks.system.handle_work_error'
}],
'kwargs': {},
'task': u'awx.main.tasks.jobs.RunProjectUpdate'
}
"""
settings.__clean_on_fork__()
result = None
try:
result = self.run_callable(body)
except Exception as exc:
result = exc
try:
if getattr(exc, 'is_awx_task_error', False):
# Error caused by user / tracked in job output
logger.warning("{}".format(exc))
else:
task = body['task']
args = body.get('args', [])
kwargs = body.get('kwargs', {})
logger.exception('Worker failed to run task {}(*{}, **{}'.format(task, args, kwargs))
except Exception:
# It's fairly critical that this code _not_ raise exceptions on logging
# If you configure external logging in a way that _it_ fails, there's
# not a lot we can do here; sys.stderr.write is a final hail mary
_, _, tb = sys.exc_info()
traceback.print_tb(tb)
for callback in body.get('errbacks', []) or []:
callback['uuid'] = body['uuid']
self.perform_work(callback)
finally:
# It's frustrating that we have to do this, but the python k8s
# client leaves behind cacert files in /tmp, so we must clean up
# the tmpdir per-dispatcher process every time a new task comes in
try:
kube_config._cleanup_temp_files()
except Exception:
logger.exception('failed to cleanup k8s client tmp files')
for callback in body.get('callbacks', []) or []:
callback['uuid'] = body['uuid']
self.perform_work(callback)
return result
def on_start(self):
dispatch_startup()
def on_stop(self):
inform_cluster_of_shutdown()

View File

@@ -4,7 +4,7 @@ import json
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from awx.main.dispatch import pg_bus_conn from awx.main.dispatch import pg_bus_conn
from awx.main.dispatch.worker.task import TaskWorker from awx.main.dispatch.worker.task import run_callable
logger = logging.getLogger('awx.main.cache_clear') logger = logging.getLogger('awx.main.cache_clear')
@@ -21,11 +21,11 @@ class Command(BaseCommand):
try: try:
with pg_bus_conn() as conn: with pg_bus_conn() as conn:
conn.listen("tower_settings_change") conn.listen("tower_settings_change")
for e in conn.events(yield_timeouts=True): for e in conn.events():
if e is not None: if e is not None:
body = json.loads(e.payload) body = json.loads(e.payload)
logger.info(f"Cache clear request received. Clearing now, payload: {e.payload}") logger.info(f"Cache clear request received. Clearing now, payload: {e.payload}")
TaskWorker.run_callable(body) run_callable(body)
except Exception: except Exception:
# Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata # Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata

View File

@@ -5,7 +5,7 @@ from django.core.management.base import BaseCommand
from django.conf import settings from django.conf import settings
from django.core.cache import cache from django.core.cache import cache
from awx.main.dispatch import pg_bus_conn from awx.main.dispatch import pg_bus_conn
from awx.main.dispatch.worker.task import TaskWorker from awx.main.dispatch.worker.task import run_callable
from awx.main.utils.external_logging import reconfigure_rsyslog from awx.main.utils.external_logging import reconfigure_rsyslog
logger = logging.getLogger('awx.main.rsyslog_configurer') logger = logging.getLogger('awx.main.rsyslog_configurer')
@@ -26,7 +26,7 @@ class Command(BaseCommand):
conn.listen("rsyslog_configurer") conn.listen("rsyslog_configurer")
# reconfigure rsyslog on start up # reconfigure rsyslog on start up
reconfigure_rsyslog() reconfigure_rsyslog()
for e in conn.events(yield_timeouts=True): for e in conn.events():
if e is not None: if e is not None:
logger.info("Change in logging settings found. Restarting rsyslogd") logger.info("Change in logging settings found. Restarting rsyslogd")
# clear the cache of relevant settings then restart # clear the cache of relevant settings then restart
@@ -34,7 +34,7 @@ class Command(BaseCommand):
cache.delete_many(setting_keys) cache.delete_many(setting_keys)
settings._awx_conf_memoizedcache.clear() settings._awx_conf_memoizedcache.clear()
body = json.loads(e.payload) body = json.loads(e.payload)
TaskWorker.run_callable(body) run_callable(body)
except Exception: except Exception:
# Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata # Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata
logger.exception('Encountered unhandled error in rsyslog_configurer main loop') logger.exception('Encountered unhandled error in rsyslog_configurer main loop')

View File

@@ -14,6 +14,7 @@ from io import StringIO
# dispatcherd # dispatcherd
from dispatcherd.factories import get_control_from_settings from dispatcherd.factories import get_control_from_settings
from dispatcherd.publish import task
# Runner # Runner
import ansible_runner.cleanup import ansible_runner.cleanup
@@ -46,9 +47,6 @@ from django.utils.translation import gettext_noop
from flags.state import flag_enabled from flags.state import flag_enabled
from rest_framework.exceptions import PermissionDenied from rest_framework.exceptions import PermissionDenied
# Dispatcherd
from dispatcherd.publish import task
# AWX # AWX
from awx import __version__ as awx_application_version from awx import __version__ as awx_application_version
from awx.conf import settings_registry from awx.conf import settings_registry
@@ -125,7 +123,7 @@ def _run_dispatch_startup_common():
# no-op. # no-op.
# #
apply_cluster_membership_policies() apply_cluster_membership_policies()
cluster_node_heartbeat() cluster_node_heartbeat(None)
reaper.startup_reaping() reaper.startup_reaping()
m = DispatcherMetrics() m = DispatcherMetrics()
m.reset_values() m.reset_values()
@@ -626,6 +624,7 @@ def cluster_node_heartbeat(binder):
Dispatcherd implementation. Dispatcherd implementation.
Uses Control API to get running tasks. Uses Control API to get running tasks.
""" """
# Run common instance management logic # Run common instance management logic
this_inst, instance_list, lost_instances = _heartbeat_instance_management() this_inst, instance_list, lost_instances = _heartbeat_instance_management()
if this_inst is None: if this_inst is None:
@@ -638,6 +637,9 @@ def cluster_node_heartbeat(binder):
_heartbeat_handle_lost_instances(lost_instances, this_inst) _heartbeat_handle_lost_instances(lost_instances, this_inst)
# Get running tasks using dispatcherd API # Get running tasks using dispatcherd API
if binder is None:
logger.debug("Heartbeat finished in startup.")
return
active_task_ids = _get_active_task_ids_from_dispatcherd(binder) active_task_ids = _get_active_task_ids_from_dispatcherd(binder)
if active_task_ids is None: if active_task_ids is None:
logger.warning("No active task IDs retrieved from dispatcherd, skipping reaper") logger.warning("No active task IDs retrieved from dispatcherd, skipping reaper")