AAP-58539 Move to dispatcherd (#16209)

* WIP First pass
* started removing feature flags and adjusting logic
* Add decorator
* moved to dispatcher decorator
* updated as many as I could find
* Keep callback receiver working
* remove any code that is not used by the call back receiver
* add back auto_max_workers
* added back get_auto_max_workers into common utils
* Remove control and hazmat (squash this not done)
* moved status out and deleted control as no longer needed
* removed unused imports
* adjusted test import to pull correct method
* fixed imports and addressed clusternode heartbeat test
* Update function comments
* Add back hazmat for config and remove baseworker
* added back hazmat per @alancoding feedback around config
* removed baseworker completely and refactored it into the callback
  worker
* Fix dispatcher run call and remove dispatch setting
* remove dispatcher mock publish setting
* Adjust heartbeat arg and more formatting
* fixed the call to cluster_node_heartbeat missing binder
* Fix attribute error in server logs
This commit is contained in:
Jake Jackson
2026-01-23 15:49:32 -05:00
committed by GitHub
parent 94d5769f32
commit 36a00ec46b
38 changed files with 294 additions and 2010 deletions

View File

@@ -77,14 +77,13 @@ class PubSub(object):
n = psycopg.connection.Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
yield n
def events(self, yield_timeouts=False):
def events(self):
if not self.conn.autocommit:
raise RuntimeError('Listening for events can only be done in autocommit mode')
while True:
if select.select([self.conn], [], [], self.select_timeout) == NOT_READY:
if yield_timeouts:
yield None
yield None
else:
notification_generator = self.current_notifies(self.conn)
for notification in notification_generator:

View File

@@ -2,7 +2,7 @@ from django.conf import settings
from ansible_base.lib.utils.db import get_pg_notify_params
from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.pool import get_auto_max_workers
from awx.main.utils.common import get_auto_max_workers
def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False) -> dict:

View File

@@ -1,77 +0,0 @@
import logging
import uuid
import json
from django.db import connection
from awx.main.dispatch import get_task_queuename
from awx.main.utils.redis import get_redis_client
from . import pg_bus_conn
logger = logging.getLogger('awx.main.dispatch')
class Control(object):
services = ('dispatcher', 'callback_receiver')
result = None
def __init__(self, service, host=None):
if service not in self.services:
raise RuntimeError('{} must be in {}'.format(service, self.services))
self.service = service
self.queuename = host or get_task_queuename()
def status(self, *args, **kwargs):
r = get_redis_client()
if self.service == 'dispatcher':
stats = r.get(f'awx_{self.service}_statistics') or b''
return stats.decode('utf-8')
else:
workers = []
for key in r.keys('awx_callback_receiver_statistics_*'):
workers.append(r.get(key).decode('utf-8'))
return '\n'.join(workers)
def running(self, *args, **kwargs):
return self.control_with_reply('running', *args, **kwargs)
def cancel(self, task_ids, with_reply=True):
if with_reply:
return self.control_with_reply('cancel', extra_data={'task_ids': task_ids})
else:
self.control({'control': 'cancel', 'task_ids': task_ids, 'reply_to': None}, extra_data={'task_ids': task_ids})
def schedule(self, *args, **kwargs):
return self.control_with_reply('schedule', *args, **kwargs)
@classmethod
def generate_reply_queue_name(cls):
return f"reply_to_{str(uuid.uuid4()).replace('-','_')}"
def control_with_reply(self, command, timeout=5, extra_data=None):
logger.warning('checking {} {} for {}'.format(self.service, command, self.queuename))
reply_queue = Control.generate_reply_queue_name()
self.result = None
if not connection.get_autocommit():
raise RuntimeError('Control-with-reply messages can only be done in autocommit mode')
with pg_bus_conn(select_timeout=timeout) as conn:
conn.listen(reply_queue)
send_data = {'control': command, 'reply_to': reply_queue}
if extra_data:
send_data.update(extra_data)
conn.notify(self.queuename, json.dumps(send_data))
for reply in conn.events(yield_timeouts=True):
if reply is None:
logger.error(f'{self.service} did not reply within {timeout}s')
raise RuntimeError(f"{self.service} did not reply within {timeout}s")
break
return json.loads(reply.payload)
def control(self, msg, **kwargs):
with pg_bus_conn() as conn:
conn.notify(self.queuename, json.dumps(msg))

View File

@@ -18,7 +18,7 @@ django.setup() # noqa
from django.conf import settings
# Preload all periodic tasks so their imports will be in shared memory
for name, options in settings.CELERYBEAT_SCHEDULE.items():
for name, options in settings.DISPATCHER_SCHEDULE.items():
resolve_callable(options['task'])

View File

@@ -1,146 +0,0 @@
import logging
import time
import yaml
from datetime import datetime
logger = logging.getLogger('awx.main.dispatch.periodic')
class ScheduledTask:
"""
Class representing schedules, very loosely modeled after python schedule library Job
the idea of this class is to:
- only deal in relative times (time since the scheduler global start)
- only deal in integer math for target runtimes, but float for current relative time
Missed schedule policy:
Invariant target times are maintained, meaning that if interval=10s offset=0
and it runs at t=7s, then it calls for next run in 3s.
However, if a complete interval has passed, that is counted as a missed run,
and missed runs are abandoned (no catch-up runs).
"""
def __init__(self, name: str, data: dict):
# parameters need for schedule computation
self.interval = int(data['schedule'].total_seconds())
self.offset = 0 # offset relative to start time this schedule begins
self.index = 0 # number of periods of the schedule that has passed
# parameters that do not affect scheduling logic
self.last_run = None # time of last run, only used for debug
self.completed_runs = 0 # number of times schedule is known to run
self.name = name
self.data = data # used by caller to know what to run
@property
def next_run(self):
"Time until the next run with t=0 being the global_start of the scheduler class"
return (self.index + 1) * self.interval + self.offset
def due_to_run(self, relative_time):
return bool(self.next_run <= relative_time)
def expected_runs(self, relative_time):
return int((relative_time - self.offset) / self.interval)
def mark_run(self, relative_time):
self.last_run = relative_time
self.completed_runs += 1
new_index = self.expected_runs(relative_time)
if new_index > self.index + 1:
logger.warning(f'Missed {new_index - self.index - 1} schedules of {self.name}')
self.index = new_index
def missed_runs(self, relative_time):
"Number of times job was supposed to ran but failed to, only used for debug"
missed_ct = self.expected_runs(relative_time) - self.completed_runs
# if this is currently due to run do not count that as a missed run
if missed_ct and self.due_to_run(relative_time):
missed_ct -= 1
return missed_ct
class Scheduler:
def __init__(self, schedule):
"""
Expects schedule in the form of a dictionary like
{
'job1': {'schedule': timedelta(seconds=50), 'other': 'stuff'}
}
Only the schedule nearest-second value is used for scheduling,
the rest of the data is for use by the caller to know what to run.
"""
self.jobs = [ScheduledTask(name, data) for name, data in schedule.items()]
min_interval = min(job.interval for job in self.jobs)
num_jobs = len(self.jobs)
# this is intentionally oppioniated against spammy schedules
# a core goal is to spread out the scheduled tasks (for worker management)
# and high-frequency schedules just do not work with that
if num_jobs > min_interval:
raise RuntimeError(f'Number of schedules ({num_jobs}) is more than the shortest schedule interval ({min_interval} seconds).')
# even space out jobs over the base interval
for i, job in enumerate(self.jobs):
job.offset = (i * min_interval) // num_jobs
# internally times are all referenced relative to startup time, add grace period
self.global_start = time.time() + 2.0
def get_and_mark_pending(self, reftime=None):
if reftime is None:
reftime = time.time() # mostly for tests
relative_time = reftime - self.global_start
to_run = []
for job in self.jobs:
if job.due_to_run(relative_time):
to_run.append(job)
logger.debug(f'scheduler found {job.name} to run, {relative_time - job.next_run} seconds after target')
job.mark_run(relative_time)
return to_run
def time_until_next_run(self, reftime=None):
if reftime is None:
reftime = time.time() # mostly for tests
relative_time = reftime - self.global_start
next_job = min(self.jobs, key=lambda j: j.next_run)
delta = next_job.next_run - relative_time
if delta <= 0.1:
# careful not to give 0 or negative values to the select timeout, which has unclear interpretation
logger.warning(f'Scheduler next run of {next_job.name} is {-delta} seconds in the past')
return 0.1
elif delta > 20.0:
logger.warning(f'Scheduler next run unexpectedly over 20 seconds in future: {delta}')
return 20.0
logger.debug(f'Scheduler next run is {next_job.name} in {delta} seconds')
return delta
def debug(self, *args, **kwargs):
data = dict()
data['title'] = 'Scheduler status'
reftime = time.time()
now = datetime.fromtimestamp(reftime).strftime('%Y-%m-%d %H:%M:%S UTC')
start_time = datetime.fromtimestamp(self.global_start).strftime('%Y-%m-%d %H:%M:%S UTC')
relative_time = reftime - self.global_start
data['started_time'] = start_time
data['current_time'] = now
data['current_time_relative'] = round(relative_time, 3)
data['total_schedules'] = len(self.jobs)
data['schedule_list'] = dict(
[
(
job.name,
dict(
last_run_seconds_ago=round(relative_time - job.last_run, 3) if job.last_run else None,
next_run_in_seconds=round(job.next_run - relative_time, 3),
offset_in_seconds=job.offset,
completed_runs=job.completed_runs,
missed_runs=job.missed_runs(relative_time),
),
)
for job in sorted(self.jobs, key=lambda job: job.interval)
]
)
return yaml.safe_dump(data, default_flow_style=False, sort_keys=False)

View File

@@ -1,251 +1,50 @@
import logging
import os
import random
import signal
import sys
import time
import traceback
from datetime import datetime, timezone
from uuid import uuid4
import json
import collections
from multiprocessing import Process
from multiprocessing import Queue as MPQueue
from queue import Full as QueueFull, Empty as QueueEmpty
from django.conf import settings
from django.db import connection as django_connection, connections
from django.db import connection as django_connection
from django.core.cache import cache as django_cache
from django.utils.timezone import now as tz_now
from django_guid import set_guid
from jinja2 import Template
import psutil
from ansible_base.lib.logging.runtime import log_excess_runtime
from awx.main.models import UnifiedJob
from awx.main.dispatch import reaper
from awx.main.utils.common import get_mem_effective_capacity, get_corrected_memory, get_corrected_cpu, get_cpu_effective_capacity
# ansible-runner
from ansible_runner.utils.capacity import get_mem_in_bytes, get_cpu_count
if 'run_callback_receiver' in sys.argv:
logger = logging.getLogger('awx.main.commands.run_callback_receiver')
else:
logger = logging.getLogger('awx.main.dispatch')
RETIRED_SENTINEL_TASK = "[retired]"
class NoOpResultQueue(object):
def put(self, item):
pass
logger = logging.getLogger('awx.main.commands.run_callback_receiver')
class PoolWorker(object):
"""
Used to track a worker child process and its pending and finished messages.
A simple wrapper around a multiprocessing.Process that tracks a worker child process.
This class makes use of two distinct multiprocessing.Queues to track state:
- self.queue: this is a queue which represents pending messages that should
be handled by this worker process; as new AMQP messages come
in, a pool will put() them into this queue; the child
process that is forked will get() from this queue and handle
received messages in an endless loop
- self.finished: this is a queue which the worker process uses to signal
that it has finished processing a message
When a message is put() onto this worker, it is tracked in
self.managed_tasks.
Periodically, the worker will call .calculate_managed_tasks(), which will
cause messages in self.finished to be removed from self.managed_tasks.
In this way, self.managed_tasks represents a view of the messages assigned
to a specific process. The message at [0] is the least-recently inserted
message, and it represents what the worker is running _right now_
(self.current_task).
A worker is "busy" when it has at least one message in self.managed_tasks.
It is "idle" when self.managed_tasks is empty.
The worker process runs the provided target function and tracks its creation time.
"""
track_managed_tasks = False
def __init__(self, queue_size, target, args, **kwargs):
self.messages_sent = 0
self.messages_finished = 0
self.managed_tasks = collections.OrderedDict()
self.finished = MPQueue(queue_size) if self.track_managed_tasks else NoOpResultQueue()
self.queue = MPQueue(queue_size)
self.process = Process(target=target, args=(self.queue, self.finished) + args)
def __init__(self, target, args, **kwargs):
self.process = Process(target=target, args=args)
self.process.daemon = True
self.creation_time = time.monotonic()
self.retiring = False
def start(self):
self.process.start()
def put(self, body):
if self.retiring:
uuid = body.get('uuid', 'N/A') if isinstance(body, dict) else 'N/A'
logger.info(f"Worker pid:{self.pid} is retiring. Refusing new task {uuid}.")
raise QueueFull("Worker is retiring and not accepting new tasks") # AutoscalePool.write handles QueueFull
uuid = '?'
if isinstance(body, dict):
if not body.get('uuid'):
body['uuid'] = str(uuid4())
uuid = body['uuid']
if self.track_managed_tasks:
self.managed_tasks[uuid] = body
self.queue.put(body, block=True, timeout=5)
self.messages_sent += 1
self.calculate_managed_tasks()
def quit(self):
"""
Send a special control message to the worker that tells it to exit
gracefully.
"""
self.queue.put('QUIT')
@property
def age(self):
"""Returns the current age of the worker in seconds."""
return time.monotonic() - self.creation_time
@property
def pid(self):
return self.process.pid
@property
def qsize(self):
return self.queue.qsize()
@property
def alive(self):
return self.process.is_alive()
@property
def mb(self):
if self.alive:
return '{:0.3f}'.format(psutil.Process(self.pid).memory_info().rss / 1024.0 / 1024.0)
return '0'
@property
def exitcode(self):
return str(self.process.exitcode)
def calculate_managed_tasks(self):
if not self.track_managed_tasks:
return
# look to see if any tasks were finished
finished = []
for _ in range(self.finished.qsize()):
try:
finished.append(self.finished.get(block=False))
except QueueEmpty:
break # qsize is not always _totally_ up to date
# if any tasks were finished, removed them from the managed tasks for
# this worker
for uuid in finished:
try:
del self.managed_tasks[uuid]
self.messages_finished += 1
except KeyError:
# ansible _sometimes_ appears to send events w/ duplicate UUIDs;
# UUIDs for ansible events are *not* actually globally unique
# when this occurs, it's _fine_ to ignore this KeyError because
# the purpose of self.managed_tasks is to just track internal
# state of which events are *currently* being processed.
logger.warning('Event UUID {} appears to be have been duplicated.'.format(uuid))
if self.retiring:
self.managed_tasks[RETIRED_SENTINEL_TASK] = {'task': RETIRED_SENTINEL_TASK}
@property
def current_task(self):
if not self.track_managed_tasks:
return None
self.calculate_managed_tasks()
# the task at [0] is the one that's running right now (or is about to
# be running)
if len(self.managed_tasks):
return self.managed_tasks[list(self.managed_tasks.keys())[0]]
return None
@property
def orphaned_tasks(self):
if not self.track_managed_tasks:
return []
orphaned = []
if not self.alive:
# if this process had a running task that never finished,
# requeue its error callbacks
current_task = self.current_task
if isinstance(current_task, dict):
orphaned.extend(current_task.get('errbacks', []))
# if this process has any pending messages requeue them
for _ in range(self.qsize):
try:
message = self.queue.get(block=False)
if message != 'QUIT':
orphaned.append(message)
except QueueEmpty:
break # qsize is not always _totally_ up to date
if len(orphaned):
logger.error('requeuing {} messages from gone worker pid:{}'.format(len(orphaned), self.pid))
return orphaned
@property
def busy(self):
self.calculate_managed_tasks()
return len(self.managed_tasks) > 0
@property
def idle(self):
return not self.busy
class StatefulPoolWorker(PoolWorker):
track_managed_tasks = True
class WorkerPool(object):
"""
Creates a pool of forked PoolWorkers.
As WorkerPool.write(...) is called (generally, by a kombu consumer
implementation when it receives an AMQP message), messages are passed to
one of the multiprocessing Queues where some work can be done on them.
Each worker process runs the provided target function in an isolated process.
The pool manages spawning, tracking, and stopping worker processes.
class MessagePrinter(awx.main.dispatch.worker.BaseWorker):
def perform_work(self, body):
print(body)
pool = WorkerPool(min_workers=4) # spawn four worker processes
pool.init_workers(MessagePrint().work_loop)
pool.write(
0, # preferred worker 0
'Hello, World!'
)
Example:
pool = WorkerPool(workers_num=4) # spawn four worker processes
"""
pool_cls = PoolWorker
debug_meta = ''
def __init__(self, min_workers=None, queue_size=None):
def __init__(self, workers_num=None):
self.name = settings.CLUSTER_HOST_ID
self.pid = os.getpid()
self.min_workers = min_workers or settings.JOB_EVENT_WORKERS
self.queue_size = queue_size or settings.JOB_EVENT_MAX_QUEUE_SIZE
self.workers_num = workers_num or settings.JOB_EVENT_WORKERS
self.workers = []
def __len__(self):
@@ -254,7 +53,7 @@ class WorkerPool(object):
def init_workers(self, target, *target_args):
self.target = target
self.target_args = target_args
for idx in range(self.min_workers):
for idx in range(self.workers_num):
self.up()
def up(self):
@@ -264,320 +63,19 @@ class WorkerPool(object):
# for the DB and cache connections (that way lies race conditions)
django_connection.close()
django_cache.close()
worker = self.pool_cls(self.queue_size, self.target, (idx,) + self.target_args)
worker = self.pool_cls(self.target, (idx,) + self.target_args)
self.workers.append(worker)
try:
worker.start()
except Exception:
logger.exception('could not fork')
else:
logger.debug('scaling up worker pid:{}'.format(worker.pid))
logger.debug('scaling up worker pid:{}'.format(worker.process.pid))
return idx, worker
def debug(self, *args, **kwargs):
tmpl = Template(
'Recorded at: {{ dt }} \n'
'{{ pool.name }}[pid:{{ pool.pid }}] workers total={{ workers|length }} {{ meta }} \n'
'{% for w in workers %}'
'. worker[pid:{{ w.pid }}]{% if not w.alive %} GONE exit={{ w.exitcode }}{% endif %}'
' sent={{ w.messages_sent }}'
' age={{ "%.0f"|format(w.age) }}s'
' retiring={{ w.retiring }}'
'{% if w.messages_finished %} finished={{ w.messages_finished }}{% endif %}'
' qsize={{ w.managed_tasks|length }}'
' rss={{ w.mb }}MB'
'{% for task in w.managed_tasks.values() %}'
'\n - {% if loop.index0 == 0 %}running {% if "age" in task %}for: {{ "%.1f" % task["age"] }}s {% endif %}{% else %}queued {% endif %}'
'{{ task["uuid"] }} '
'{% if "task" in task %}'
'{{ task["task"].rsplit(".", 1)[-1] }}'
# don't print kwargs, they often contain launch-time secrets
'(*{{ task.get("args", []) }})'
'{% endif %}'
'{% endfor %}'
'{% if not w.managed_tasks|length %}'
' [IDLE]'
'{% endif %}'
'\n'
'{% endfor %}'
)
now = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')
return tmpl.render(pool=self, workers=self.workers, meta=self.debug_meta, dt=now)
def write(self, preferred_queue, body):
queue_order = sorted(range(len(self.workers)), key=lambda x: -1 if x == preferred_queue else x)
write_attempt_order = []
for queue_actual in queue_order:
try:
self.workers[queue_actual].put(body)
return queue_actual
except QueueFull:
pass
except Exception:
tb = traceback.format_exc()
logger.warning("could not write to queue %s" % preferred_queue)
logger.warning("detail: {}".format(tb))
write_attempt_order.append(preferred_queue)
logger.error("could not write payload to any queue, attempted order: {}".format(write_attempt_order))
return None
def stop(self, signum):
try:
for worker in self.workers:
os.kill(worker.pid, signum)
except Exception:
logger.exception('could not kill {}'.format(worker.pid))
def get_auto_max_workers():
"""Method we normally rely on to get max_workers
Uses almost same logic as Instance.local_health_check
The important thing is to be MORE than Instance.capacity
so that the task-manager does not over-schedule this node
Ideally we would just use the capacity from the database plus reserve workers,
but this poses some bootstrap problems where OCP task containers
register themselves after startup
"""
# Get memory from ansible-runner
total_memory_gb = get_mem_in_bytes()
# This may replace memory calculation with a user override
corrected_memory = get_corrected_memory(total_memory_gb)
# Get same number as max forks based on memory, this function takes memory as bytes
mem_capacity = get_mem_effective_capacity(corrected_memory, is_control_node=True)
# Follow same process for CPU capacity constraint
cpu_count = get_cpu_count()
corrected_cpu = get_corrected_cpu(cpu_count)
cpu_capacity = get_cpu_effective_capacity(corrected_cpu, is_control_node=True)
# Here is what is different from health checks,
auto_max = max(mem_capacity, cpu_capacity)
# add magic number of extra workers to ensure
# we have a few extra workers to run the heartbeat
auto_max += 7
return auto_max
class AutoscalePool(WorkerPool):
"""
An extended pool implementation that automatically scales workers up and
down based on demand
"""
pool_cls = StatefulPoolWorker
def __init__(self, *args, **kwargs):
self.max_workers = kwargs.pop('max_workers', None)
self.max_worker_lifetime_seconds = kwargs.pop(
'max_worker_lifetime_seconds', getattr(settings, 'WORKER_MAX_LIFETIME_SECONDS', 14400)
) # Default to 4 hours
super(AutoscalePool, self).__init__(*args, **kwargs)
if self.max_workers is None:
self.max_workers = get_auto_max_workers()
# max workers can't be less than min_workers
self.max_workers = max(self.min_workers, self.max_workers)
# the task manager enforces settings.TASK_MANAGER_TIMEOUT on its own
# but if the task takes longer than the time defined here, we will force it to stop here
self.task_manager_timeout = settings.TASK_MANAGER_TIMEOUT + settings.TASK_MANAGER_TIMEOUT_GRACE_PERIOD
# initialize some things for subsystem metrics periodic gathering
# the AutoscalePool class does not save these to redis directly, but reports via produce_subsystem_metrics
self.scale_up_ct = 0
self.worker_count_max = 0
# last time we wrote current tasks, to avoid too much log spam
self.last_task_list_log = time.monotonic()
def produce_subsystem_metrics(self, metrics_object):
metrics_object.set('dispatcher_pool_scale_up_events', self.scale_up_ct)
metrics_object.set('dispatcher_pool_active_task_count', sum(len(w.managed_tasks) for w in self.workers))
metrics_object.set('dispatcher_pool_max_worker_count', self.worker_count_max)
self.worker_count_max = len(self.workers)
@property
def should_grow(self):
if len(self.workers) < self.min_workers:
# If we don't have at least min_workers, add more
return True
# If every worker is busy doing something, add more
return all([w.busy for w in self.workers])
@property
def full(self):
return len(self.workers) == self.max_workers
@property
def debug_meta(self):
return 'min={} max={}'.format(self.min_workers, self.max_workers)
@log_excess_runtime(logger, debug_cutoff=0.05, cutoff=0.2)
def cleanup(self):
"""
Perform some internal account and cleanup. This is run on
every cluster node heartbeat:
1. Discover worker processes that exited, and recover messages they
were handling.
2. Clean up unnecessary, idle workers.
IMPORTANT: this function is one of the few places in the dispatcher
(aside from setting lookups) where we talk to the database. As such,
if there's an outage, this method _can_ throw various
django.db.utils.Error exceptions. Act accordingly.
"""
orphaned = []
for w in self.workers[::]:
is_retirement_age = self.max_worker_lifetime_seconds is not None and w.age > self.max_worker_lifetime_seconds
if not w.alive:
# the worker process has exited
# 1. take the task it was running and enqueue the error
# callbacks
# 2. take any pending tasks delivered to its queue and
# send them to another worker
logger.error('worker pid:{} is gone (exit={})'.format(w.pid, w.exitcode))
if w.current_task:
if w.current_task == {'task': RETIRED_SENTINEL_TASK}:
logger.debug('scaling down worker pid:{} due to worker age: {}'.format(w.pid, w.age))
self.workers.remove(w)
continue
if w.current_task != 'QUIT':
try:
for j in UnifiedJob.objects.filter(celery_task_id=w.current_task['uuid']):
reaper.reap_job(j, 'failed')
except Exception:
logger.exception('failed to reap job UUID {}'.format(w.current_task['uuid']))
else:
logger.warning(f'Worker was told to quit but has not, pid={w.pid}')
orphaned.extend(w.orphaned_tasks)
self.workers.remove(w)
elif w.idle and len(self.workers) > self.min_workers:
# the process has an empty queue (it's idle) and we have
# more processes in the pool than we need (> min)
# send this process a message so it will exit gracefully
# at the next opportunity
logger.debug('scaling down worker pid:{}'.format(w.pid))
w.quit()
self.workers.remove(w)
elif w.idle and is_retirement_age:
logger.debug('scaling down worker pid:{} due to worker age: {}'.format(w.pid, w.age))
w.quit()
self.workers.remove(w)
elif is_retirement_age and not w.retiring and not w.idle:
logger.info(
f"Worker pid:{w.pid} (age: {w.age:.0f}s) exceeded max lifetime ({self.max_worker_lifetime_seconds:.0f}s). "
"Signaling for graceful retirement."
)
# Send QUIT signal; worker will finish current task then exit.
w.quit()
# mark as retiring to reject any future tasks that might be assigned in meantime
w.retiring = True
if w.alive:
# if we discover a task manager invocation that's been running
# too long, reap it (because otherwise it'll just hold the postgres
# advisory lock forever); the goal of this code is to discover
# deadlocks or other serious issues in the task manager that cause
# the task manager to never do more work
current_task = w.current_task
if current_task and isinstance(current_task, dict):
endings = ('tasks.task_manager', 'tasks.dependency_manager', 'tasks.workflow_manager')
current_task_name = current_task.get('task', '')
if current_task_name.endswith(endings):
if 'started' not in current_task:
w.managed_tasks[current_task['uuid']]['started'] = time.time()
age = time.time() - current_task['started']
w.managed_tasks[current_task['uuid']]['age'] = age
if age > self.task_manager_timeout:
logger.error(f'{current_task_name} has held the advisory lock for {age}, sending SIGUSR1 to {w.pid}')
os.kill(w.pid, signal.SIGUSR1)
for m in orphaned:
# if all the workers are dead, spawn at least one
if not len(self.workers):
self.up()
idx = random.choice(range(len(self.workers)))
self.write(idx, m)
def add_bind_kwargs(self, body):
bind_kwargs = body.pop('bind_kwargs', [])
body.setdefault('kwargs', {})
if 'dispatch_time' in bind_kwargs:
body['kwargs']['dispatch_time'] = tz_now().isoformat()
if 'worker_tasks' in bind_kwargs:
worker_tasks = {}
for worker in self.workers:
worker.calculate_managed_tasks()
worker_tasks[worker.pid] = list(worker.managed_tasks.keys())
body['kwargs']['worker_tasks'] = worker_tasks
def up(self):
if self.full:
# if we can't spawn more workers, just toss this message into a
# random worker's backlog
idx = random.choice(range(len(self.workers)))
return idx, self.workers[idx]
else:
self.scale_up_ct += 1
ret = super(AutoscalePool, self).up()
new_worker_ct = len(self.workers)
if new_worker_ct > self.worker_count_max:
self.worker_count_max = new_worker_ct
return ret
@staticmethod
def fast_task_serialization(current_task):
try:
return str(current_task.get('task')) + ' - ' + str(sorted(current_task.get('args', []))) + ' - ' + str(sorted(current_task.get('kwargs', {})))
except Exception:
# just make sure this does not make things worse
return str(current_task)
def write(self, preferred_queue, body):
if 'guid' in body:
set_guid(body['guid'])
try:
if isinstance(body, dict) and body.get('bind_kwargs'):
self.add_bind_kwargs(body)
if self.should_grow:
self.up()
# we don't care about "preferred queue" round robin distribution, just
# find the first non-busy worker and claim it
workers = self.workers[:]
random.shuffle(workers)
for w in workers:
if not w.busy:
w.put(body)
break
else:
task_name = 'unknown'
if isinstance(body, dict):
task_name = body.get('task')
logger.warning(f'Workers maxed, queuing {task_name}, load: {sum(len(w.managed_tasks) for w in self.workers)} / {len(self.workers)}')
# Once every 10 seconds write out task list for debugging
if time.monotonic() - self.last_task_list_log >= 10.0:
task_counts = {}
for worker in self.workers:
task_slug = self.fast_task_serialization(worker.current_task)
task_counts.setdefault(task_slug, 0)
task_counts[task_slug] += 1
logger.info(f'Running tasks by count:\n{json.dumps(task_counts, indent=2)}')
self.last_task_list_log = time.monotonic()
return super(AutoscalePool, self).write(preferred_queue, body)
except Exception:
for conn in connections.all():
# If the database connection has a hiccup, re-establish a new
# connection
conn.close_if_unusable_or_obsolete()
logger.exception('failed to write inbound message')

View File

@@ -1,163 +0,0 @@
import inspect
import logging
import json
import time
from uuid import uuid4
from dispatcherd.publish import submit_task
from dispatcherd.processors.blocker import Blocker
from dispatcherd.utils import resolve_callable
from django_guid import get_guid
from django.conf import settings
from . import pg_bus_conn
logger = logging.getLogger('awx.main.dispatch')
def serialize_task(f):
return '.'.join([f.__module__, f.__name__])
class task:
"""
Used to decorate a function or class so that it can be run asynchronously
via the task dispatcher. Tasks can be simple functions:
@task()
def add(a, b):
return a + b
...or classes that define a `run` method:
@task()
class Adder:
def run(self, a, b):
return a + b
# Tasks can be run synchronously...
assert add(1, 1) == 2
assert Adder().run(1, 1) == 2
# ...or published to a queue:
add.apply_async([1, 1])
Adder.apply_async([1, 1])
# Tasks can also define a specific target queue or use the special fan-out queue tower_broadcast:
@task(queue='slow-tasks')
def snooze():
time.sleep(10)
@task(queue='tower_broadcast')
def announce():
print("Run this everywhere!")
# The special parameter bind_kwargs tells the main dispatcher process to add certain kwargs
@task(bind_kwargs=['dispatch_time'])
def print_time(dispatch_time=None):
print(f"Time I was dispatched: {dispatch_time}")
"""
def __init__(self, queue=None, bind_kwargs=None, timeout=None, on_duplicate=None):
self.queue = queue
self.bind_kwargs = bind_kwargs
self.timeout = timeout
self.on_duplicate = on_duplicate
def __call__(self, fn=None):
queue = self.queue
bind_kwargs = self.bind_kwargs
timeout = self.timeout
on_duplicate = self.on_duplicate
class PublisherMixin(object):
queue = None
@classmethod
def delay(cls, *args, **kwargs):
return cls.apply_async(args, kwargs)
@classmethod
def get_async_body(cls, args=None, kwargs=None, uuid=None, **kw):
"""
Get the python dict to become JSON data in the pg_notify message
This same message gets passed over the dispatcher IPC queue to workers
If a task is submitted to a multiprocessing pool, skipping pg_notify, this might be used directly
"""
task_id = uuid or str(uuid4())
args = args or []
kwargs = kwargs or {}
obj = {'uuid': task_id, 'args': args, 'kwargs': kwargs, 'task': cls.name, 'time_pub': time.time()}
guid = get_guid()
if guid:
obj['guid'] = guid
if bind_kwargs:
obj['bind_kwargs'] = bind_kwargs
obj.update(**kw)
return obj
@classmethod
def apply_async(cls, args=None, kwargs=None, queue=None, uuid=None, **kw):
try:
from flags.state import flag_enabled
if flag_enabled('FEATURE_DISPATCHERD_ENABLED'):
# At this point we have the import string, and submit_task wants the method, so back to that
actual_task = resolve_callable(cls.name)
processor_options = ()
if on_duplicate is not None:
processor_options = (Blocker.Params(on_duplicate=on_duplicate),)
return submit_task(
actual_task,
args=args,
kwargs=kwargs,
queue=queue,
uuid=uuid,
timeout=timeout,
processor_options=processor_options,
**kw,
)
except Exception:
logger.exception(f"[DISPATCHER] Failed to check for alternative dispatcherd implementation for {cls.name}")
# Continue with original implementation if anything fails
pass
# Original implementation follows
queue = queue or getattr(cls.queue, 'im_func', cls.queue)
if not queue:
msg = f'{cls.name}: Queue value required and may not be None'
logger.error(msg)
raise ValueError(msg)
obj = cls.get_async_body(args=args, kwargs=kwargs, uuid=uuid, **kw)
if callable(queue):
queue = queue()
if not settings.DISPATCHER_MOCK_PUBLISH:
with pg_bus_conn() as conn:
conn.notify(queue, json.dumps(obj))
return (obj, queue)
# If the object we're wrapping *is* a class (e.g., RunJob), return
# a *new* class that inherits from the wrapped class *and* BaseTask
# In this way, the new class returned by our decorator is the class
# being decorated *plus* PublisherMixin so cls.apply_async() and
# cls.delay() work
bases = []
ns = {'name': serialize_task(fn), 'queue': queue}
if inspect.isclass(fn):
bases = list(fn.__bases__)
ns.update(fn.__dict__)
cls = type(fn.__name__, tuple(bases + [PublisherMixin]), ns)
if inspect.isclass(fn):
return cls
# if the object being decorated is *not* a class (it's a Python
# function), make fn.apply_async and fn.delay proxy through to the
# PublisherMixin we dynamically created above
setattr(fn, 'name', cls.name)
setattr(fn, 'apply_async', cls.apply_async)
setattr(fn, 'delay', cls.delay)
setattr(fn, 'get_async_body', cls.get_async_body)
return fn

View File

@@ -1,3 +1,2 @@
from .base import AWXConsumerRedis, AWXConsumerPG, BaseWorker # noqa
from .base import AWXConsumerRedis # noqa
from .callback import CallbackBrokerWorker # noqa
from .task import TaskWorker # noqa

View File

@@ -4,32 +4,15 @@
import os
import logging
import signal
import sys
import redis
import json
import psycopg
import time
from uuid import UUID
from queue import Empty as QueueEmpty
from datetime import timedelta
from django import db
from django.conf import settings
import redis.exceptions
from ansible_base.lib.logging.runtime import log_excess_runtime
from awx.main.utils.redis import get_redis_client
from awx.main.dispatch.pool import WorkerPool
from awx.main.dispatch.periodic import Scheduler
from awx.main.dispatch import pg_bus_conn
from awx.main.utils.db import set_connection_name
import awx.main.analytics.subsystem_metrics as s_metrics
if 'run_callback_receiver' in sys.argv:
logger = logging.getLogger('awx.main.commands.run_callback_receiver')
else:
logger = logging.getLogger('awx.main.dispatch')
logger = logging.getLogger('awx.main.commands.run_callback_receiver')
def signame(sig):
@@ -62,85 +45,6 @@ class AWXConsumerBase(object):
self.pool.init_workers(self.worker.work_loop)
self.redis = get_redis_client()
@property
def listening_on(self):
return f'listening on {self.queues}'
def control(self, body):
logger.warning(f'Received control signal:\n{body}')
control = body.get('control')
if control in ('status', 'schedule', 'running', 'cancel'):
reply_queue = body['reply_to']
if control == 'status':
msg = '\n'.join([self.listening_on, self.pool.debug()])
if control == 'schedule':
msg = self.scheduler.debug()
elif control == 'running':
msg = []
for worker in self.pool.workers:
worker.calculate_managed_tasks()
msg.extend(worker.managed_tasks.keys())
elif control == 'cancel':
msg = []
task_ids = set(body['task_ids'])
for worker in self.pool.workers:
task = worker.current_task
if task and task['uuid'] in task_ids:
logger.warn(f'Sending SIGTERM to task id={task["uuid"]}, task={task.get("task")}, args={task.get("args")}')
os.kill(worker.pid, signal.SIGTERM)
msg.append(task['uuid'])
if task_ids and not msg:
logger.info(f'Could not locate running tasks to cancel with ids={task_ids}')
if reply_queue is not None:
with pg_bus_conn() as conn:
conn.notify(reply_queue, json.dumps(msg))
elif control == 'reload':
for worker in self.pool.workers:
worker.quit()
else:
logger.error('unrecognized control message: {}'.format(control))
def dispatch_task(self, body):
"""This will place the given body into a worker queue to run method decorated as a task"""
if isinstance(body, dict):
body['time_ack'] = time.time()
if len(self.pool):
if "uuid" in body and body['uuid']:
try:
queue = UUID(body['uuid']).int % len(self.pool)
except Exception:
queue = self.total_messages % len(self.pool)
else:
queue = self.total_messages % len(self.pool)
else:
queue = 0
self.pool.write(queue, body)
self.total_messages += 1
def process_task(self, body):
"""Routes the task details in body as either a control task or a task-task"""
if 'control' in body:
try:
return self.control(body)
except Exception:
logger.exception(f"Exception handling control message: {body}")
return
self.dispatch_task(body)
@log_excess_runtime(logger, debug_cutoff=0.05, cutoff=0.2)
def record_statistics(self):
if time.time() - self.last_stats > 1: # buffer stat recording to once per second
save_data = self.pool.debug()
try:
self.redis.set(f'awx_{self.name}_statistics', save_data)
except redis.exceptions.ConnectionError as exc:
logger.warning(f'Redis connection error saving {self.name} status data:\n{exc}\nmissed data:\n{save_data}')
except Exception:
logger.exception(f"Unknown redis error saving {self.name} status data:\nmissed data:\n{save_data}")
self.last_stats = time.time()
def run(self, *args, **kwargs):
signal.signal(signal.SIGINT, self.stop)
signal.signal(signal.SIGTERM, self.stop)
@@ -150,196 +54,14 @@ class AWXConsumerBase(object):
def stop(self, signum, frame):
self.should_stop = True
logger.warning('received {}, stopping'.format(signame(signum)))
self.worker.on_stop()
raise SystemExit()
class AWXConsumerRedis(AWXConsumerBase):
def run(self, *args, **kwargs):
super(AWXConsumerRedis, self).run(*args, **kwargs)
self.worker.on_start()
logger.info(f'Callback receiver started with pid={os.getpid()}')
db.connection.close() # logs use database, so close connection
while True:
time.sleep(60)
class AWXConsumerPG(AWXConsumerBase):
def __init__(self, *args, schedule=None, **kwargs):
super().__init__(*args, **kwargs)
self.pg_max_wait = getattr(settings, 'DISPATCHER_DB_DOWNTOWN_TOLLERANCE', settings.DISPATCHER_DB_DOWNTIME_TOLERANCE)
# if no successful loops have ran since startup, then we should fail right away
self.pg_is_down = True # set so that we fail if we get database errors on startup
init_time = time.time()
self.pg_down_time = init_time - self.pg_max_wait # allow no grace period
self.last_cleanup = init_time
self.subsystem_metrics = s_metrics.DispatcherMetrics(auto_pipe_execute=False)
self.last_metrics_gather = init_time
self.listen_cumulative_time = 0.0
if schedule:
schedule = schedule.copy()
else:
schedule = {}
# add control tasks to be ran at regular schedules
# NOTE: if we run out of database connections, it is important to still run cleanup
# so that we scale down workers and free up connections
schedule['pool_cleanup'] = {'control': self.pool.cleanup, 'schedule': timedelta(seconds=60)}
# record subsystem metrics for the dispatcher
schedule['metrics_gather'] = {'control': self.record_metrics, 'schedule': timedelta(seconds=20)}
self.scheduler = Scheduler(schedule)
@log_excess_runtime(logger, debug_cutoff=0.05, cutoff=0.2)
def record_metrics(self):
current_time = time.time()
self.pool.produce_subsystem_metrics(self.subsystem_metrics)
self.subsystem_metrics.set('dispatcher_availability', self.listen_cumulative_time / (current_time - self.last_metrics_gather))
try:
self.subsystem_metrics.pipe_execute()
except redis.exceptions.ConnectionError as exc:
logger.warning(f'Redis connection error saving dispatcher metrics, error:\n{exc}')
self.listen_cumulative_time = 0.0
self.last_metrics_gather = current_time
def run_periodic_tasks(self):
"""
Run general periodic logic, and return maximum time in seconds before
the next requested run
This may be called more often than that when events are consumed
so this should be very efficient in that
"""
try:
self.record_statistics() # maintains time buffer in method
except Exception as exc:
logger.warning(f'Failed to save dispatcher statistics {exc}')
# Everything benchmarks to the same original time, so that skews due to
# runtime of the actions, themselves, do not mess up scheduling expectations
reftime = time.time()
for job in self.scheduler.get_and_mark_pending(reftime=reftime):
if 'control' in job.data:
try:
job.data['control']()
except Exception:
logger.exception(f'Error running control task {job.data}')
elif 'task' in job.data:
body = self.worker.resolve_callable(job.data['task']).get_async_body()
# bypasses pg_notify for scheduled tasks
self.dispatch_task(body)
if self.pg_is_down:
logger.info('Dispatcher listener connection established')
self.pg_is_down = False
self.listen_start = time.time()
return self.scheduler.time_until_next_run(reftime=reftime)
def run(self, *args, **kwargs):
super(AWXConsumerPG, self).run(*args, **kwargs)
logger.info(f"Running {self.name}, workers min={self.pool.min_workers} max={self.pool.max_workers}, listening to queues {self.queues}")
init = False
while True:
try:
with pg_bus_conn(new_connection=True) as conn:
for queue in self.queues:
conn.listen(queue)
if init is False:
self.worker.on_start()
init = True
# run_periodic_tasks run scheduled actions and gives time until next scheduled action
# this is saved to the conn (PubSub) object in order to modify read timeout in-loop
conn.select_timeout = self.run_periodic_tasks()
# this is the main operational loop for awx-manage run_dispatcher
for e in conn.events(yield_timeouts=True):
self.listen_cumulative_time += time.time() - self.listen_start # for metrics
if e is not None:
self.process_task(json.loads(e.payload))
conn.select_timeout = self.run_periodic_tasks()
if self.should_stop:
return
except psycopg.InterfaceError:
logger.warning("Stale Postgres message bus connection, reconnecting")
continue
except (db.DatabaseError, psycopg.OperationalError):
# If we have attained stady state operation, tolerate short-term database hickups
if not self.pg_is_down:
logger.exception(f"Error consuming new events from postgres, will retry for {self.pg_max_wait} s")
self.pg_down_time = time.time()
self.pg_is_down = True
current_downtime = time.time() - self.pg_down_time
if current_downtime > self.pg_max_wait:
logger.exception(f"Postgres event consumer has not recovered in {current_downtime} s, exiting")
# Sending QUIT to multiprocess queue to signal workers to exit
for worker in self.pool.workers:
try:
worker.quit()
except Exception:
logger.exception(f"Error sending QUIT to worker {worker}")
raise
# Wait for a second before next attempt, but still listen for any shutdown signals
for i in range(10):
if self.should_stop:
return
time.sleep(0.1)
for conn in db.connections.all():
conn.close_if_unusable_or_obsolete()
except Exception:
# Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata
logger.exception('Encountered unhandled error in dispatcher main loop')
# Sending QUIT to multiprocess queue to signal workers to exit
for worker in self.pool.workers:
try:
worker.quit()
except Exception:
logger.exception(f"Error sending QUIT to worker {worker}")
raise
class BaseWorker(object):
def read(self, queue):
return queue.get(block=True, timeout=1)
def work_loop(self, queue, finished, idx, *args):
ppid = os.getppid()
signal_handler = WorkerSignalHandler()
set_connection_name('worker') # set application_name to distinguish from other dispatcher processes
while not signal_handler.kill_now:
# if the parent PID changes, this process has been orphaned
# via e.g., segfault or sigkill, we should exit too
if os.getppid() != ppid:
break
try:
body = self.read(queue)
if body == 'QUIT':
break
except QueueEmpty:
continue
except Exception:
logger.exception("Exception on worker {}, reconnecting: ".format(idx))
continue
try:
for conn in db.connections.all():
# If the database connection has a hiccup during the prior message, close it
# so we can establish a new connection
conn.close_if_unusable_or_obsolete()
self.perform_work(body, *args)
except Exception:
logger.exception(f'Unhandled exception in perform_work in worker pid={os.getpid()}')
finally:
if 'uuid' in body:
uuid = body['uuid']
finished.put(uuid)
logger.debug('worker exiting gracefully pid:{}'.format(os.getpid()))
def perform_work(self, body):
raise NotImplementedError()
def on_start(self):
pass
def on_stop(self):
pass

View File

@@ -4,10 +4,12 @@ import os
import signal
import time
import datetime
from queue import Empty as QueueEmpty
from django.conf import settings
from django.utils.functional import cached_property
from django.utils.timezone import now as tz_now
from django import db
from django.db import transaction, connection as django_connection
from django_guid import set_guid
@@ -16,6 +18,7 @@ import psutil
import redis
from awx.main.utils.redis import get_redis_client
from awx.main.utils.db import set_connection_name
from awx.main.consumers import emit_channel_notification
from awx.main.models import JobEvent, AdHocCommandEvent, ProjectUpdateEvent, InventoryUpdateEvent, SystemJobEvent, UnifiedJob
from awx.main.constants import ACTIVE_STATES
@@ -23,7 +26,7 @@ from awx.main.models.events import emit_event_detail
from awx.main.utils.profiling import AWXProfiler
from awx.main.tasks.system import events_processed_hook
import awx.main.analytics.subsystem_metrics as s_metrics
from .base import BaseWorker
from .base import WorkerSignalHandler
logger = logging.getLogger('awx.main.commands.run_callback_receiver')
@@ -54,7 +57,7 @@ def job_stats_wrapup(job_identifier, event=None):
logger.exception('Worker failed to save stats or emit notifications: Job {}'.format(job_identifier))
class CallbackBrokerWorker(BaseWorker):
class CallbackBrokerWorker:
"""
A worker implementation that deserializes callback event data and persists
it into the database.
@@ -86,7 +89,7 @@ class CallbackBrokerWorker(BaseWorker):
"""This needs to be obtained after forking, or else it will give the parent process"""
return os.getpid()
def read(self, queue):
def read(self):
has_redis_error = False
try:
res = self.redis.blpop(self.queue_name, timeout=1)
@@ -149,10 +152,37 @@ class CallbackBrokerWorker(BaseWorker):
filepath = self.prof.stop()
logger.error(f'profiling is disabled, wrote {filepath}')
def work_loop(self, *args, **kw):
def work_loop(self, idx, *args):
if settings.AWX_CALLBACK_PROFILE:
signal.signal(signal.SIGUSR1, self.toggle_profiling)
return super(CallbackBrokerWorker, self).work_loop(*args, **kw)
ppid = os.getppid()
signal_handler = WorkerSignalHandler()
set_connection_name('worker') # set application_name to distinguish from other dispatcher processes
while not signal_handler.kill_now:
# if the parent PID changes, this process has been orphaned
# via e.g., segfault or sigkill, we should exit too
if os.getppid() != ppid:
break
try:
body = self.read() # this is only for the callback, only reading from redis.
if body == 'QUIT':
break
except QueueEmpty:
continue
except Exception:
logger.exception("Exception on worker {}, reconnecting: ".format(idx))
continue
try:
for conn in db.connections.all():
# If the database connection has a hiccup during the prior message, close it
# so we can establish a new connection
conn.close_if_unusable_or_obsolete()
self.perform_work(body, *args)
except Exception:
logger.exception(f'Unhandled exception in perform_work in worker pid={os.getpid()}')
logger.debug('worker exiting gracefully pid:{}'.format(os.getpid()))
def flush(self, force=False):
now = tz_now()

View File

@@ -1,144 +1,55 @@
import inspect
import logging
import importlib
import sys
import traceback
import time
from kubernetes.config import kube_config
from django.conf import settings
from django_guid import set_guid
from awx.main.tasks.system import dispatch_startup, inform_cluster_of_shutdown
from .base import BaseWorker
logger = logging.getLogger('awx.main.dispatch')
class TaskWorker(BaseWorker):
def resolve_callable(task):
"""
A worker implementation that deserializes task messages and runs native
Python code.
The code that *builds* these types of messages is found in
`awx.main.dispatch.publish`.
Transform a dotted notation task into an imported, callable function, e.g.,
awx.main.tasks.system.delete_inventory
awx.main.tasks.jobs.RunProjectUpdate
"""
if not task.startswith('awx.'):
raise ValueError('{} is not a valid awx task'.format(task))
module, target = task.rsplit('.', 1)
module = importlib.import_module(module)
_call = None
if hasattr(module, target):
_call = getattr(module, target, None)
if not (hasattr(_call, 'apply_async') and hasattr(_call, 'delay')):
raise ValueError('{} is not decorated with @task()'.format(task))
return _call
@staticmethod
def resolve_callable(task):
"""
Transform a dotted notation task into an imported, callable function, e.g.,
awx.main.tasks.system.delete_inventory
awx.main.tasks.jobs.RunProjectUpdate
"""
if not task.startswith('awx.'):
raise ValueError('{} is not a valid awx task'.format(task))
module, target = task.rsplit('.', 1)
module = importlib.import_module(module)
_call = None
if hasattr(module, target):
_call = getattr(module, target, None)
if not (hasattr(_call, 'apply_async') and hasattr(_call, 'delay')):
raise ValueError('{} is not decorated with @task()'.format(task))
return _call
@staticmethod
def run_callable(body):
"""
Given some AMQP message, import the correct Python code and run it.
"""
task = body['task']
uuid = body.get('uuid', '<unknown>')
args = body.get('args', [])
kwargs = body.get('kwargs', {})
if 'guid' in body:
set_guid(body.pop('guid'))
_call = TaskWorker.resolve_callable(task)
if inspect.isclass(_call):
# the callable is a class, e.g., RunJob; instantiate and
# return its `run()` method
_call = _call().run
log_extra = ''
logger_method = logger.debug
if ('time_ack' in body) and ('time_pub' in body):
time_publish = body['time_ack'] - body['time_pub']
time_waiting = time.time() - body['time_ack']
if time_waiting > 5.0 or time_publish > 5.0:
# If task too a very long time to process, add this information to the log
log_extra = f' took {time_publish:.4f} to ack, {time_waiting:.4f} in local dispatcher'
logger_method = logger.info
# don't print kwargs, they often contain launch-time secrets
logger_method(f'task {uuid} starting {task}(*{args}){log_extra}')
return _call(*args, **kwargs)
def perform_work(self, body):
"""
Import and run code for a task e.g.,
body = {
'args': [8],
'callbacks': [{
'args': [],
'kwargs': {}
'task': u'awx.main.tasks.system.handle_work_success'
}],
'errbacks': [{
'args': [],
'kwargs': {},
'task': 'awx.main.tasks.system.handle_work_error'
}],
'kwargs': {},
'task': u'awx.main.tasks.jobs.RunProjectUpdate'
}
"""
settings.__clean_on_fork__()
result = None
try:
result = self.run_callable(body)
except Exception as exc:
result = exc
try:
if getattr(exc, 'is_awx_task_error', False):
# Error caused by user / tracked in job output
logger.warning("{}".format(exc))
else:
task = body['task']
args = body.get('args', [])
kwargs = body.get('kwargs', {})
logger.exception('Worker failed to run task {}(*{}, **{}'.format(task, args, kwargs))
except Exception:
# It's fairly critical that this code _not_ raise exceptions on logging
# If you configure external logging in a way that _it_ fails, there's
# not a lot we can do here; sys.stderr.write is a final hail mary
_, _, tb = sys.exc_info()
traceback.print_tb(tb)
for callback in body.get('errbacks', []) or []:
callback['uuid'] = body['uuid']
self.perform_work(callback)
finally:
# It's frustrating that we have to do this, but the python k8s
# client leaves behind cacert files in /tmp, so we must clean up
# the tmpdir per-dispatcher process every time a new task comes in
try:
kube_config._cleanup_temp_files()
except Exception:
logger.exception('failed to cleanup k8s client tmp files')
for callback in body.get('callbacks', []) or []:
callback['uuid'] = body['uuid']
self.perform_work(callback)
return result
def on_start(self):
dispatch_startup()
def on_stop(self):
inform_cluster_of_shutdown()
def run_callable(body):
"""
Given some AMQP message, import the correct Python code and run it.
"""
task = body['task']
uuid = body.get('uuid', '<unknown>')
args = body.get('args', [])
kwargs = body.get('kwargs', {})
if 'guid' in body:
set_guid(body.pop('guid'))
_call = resolve_callable(task)
if inspect.isclass(_call):
# the callable is a class, e.g., RunJob; instantiate and
# return its `run()` method
_call = _call().run
log_extra = ''
logger_method = logger.debug
if ('time_ack' in body) and ('time_pub' in body):
time_publish = body['time_ack'] - body['time_pub']
time_waiting = time.time() - body['time_ack']
if time_waiting > 5.0 or time_publish > 5.0:
# If task too a very long time to process, add this information to the log
log_extra = f' took {time_publish:.4f} to ack, {time_waiting:.4f} in local dispatcher'
logger_method = logger.info
# don't print kwargs, they often contain launch-time secrets
logger_method(f'task {uuid} starting {task}(*{args}){log_extra}')
return _call(*args, **kwargs)