From 36a00ec46bd47226923298309b3387759f82f581 Mon Sep 17 00:00:00 2001 From: Jake Jackson Date: Fri, 23 Jan 2026 15:49:32 -0500 Subject: [PATCH] AAP-58539 Move to dispatcherd (#16209) * WIP First pass * started removing feature flags and adjusting logic * Add decorator * moved to dispatcher decorator * updated as many as I could find * Keep callback receiver working * remove any code that is not used by the call back receiver * add back auto_max_workers * added back get_auto_max_workers into common utils * Remove control and hazmat (squash this not done) * moved status out and deleted control as no longer needed * removed unused imports * adjusted test import to pull correct method * fixed imports and addressed clusternode heartbeat test * Update function comments * Add back hazmat for config and remove baseworker * added back hazmat per @alancoding feedback around config * removed baseworker completely and refactored it into the callback worker * Fix dispatcher run call and remove dispatch setting * remove dispatcher mock publish setting * Adjust heartbeat arg and more formatting * fixed the call to cluster_node_heartbeat missing binder * Fix attribute error in server logs --- .github/workflows/ci.yml | 2 +- awx/main/analytics/analytics_tasks.py | 6 +- awx/main/dispatch/__init__.py | 5 +- awx/main/dispatch/config.py | 2 +- awx/main/dispatch/control.py | 77 --- awx/main/dispatch/hazmat.py | 2 +- awx/main/dispatch/periodic.py | 146 ----- awx/main/dispatch/pool.py | 532 +----------------- awx/main/dispatch/publish.py | 163 ------ awx/main/dispatch/worker/__init__.py | 3 +- awx/main/dispatch/worker/base.py | 280 +-------- awx/main/dispatch/worker/callback.py | 40 +- awx/main/dispatch/worker/task.py | 171 ++---- .../management/commands/run_cache_clear.py | 6 +- .../commands/run_callback_receiver.py | 11 +- .../management/commands/run_dispatcher.py | 109 +--- .../commands/run_rsyslog_configurer.py | 6 +- awx/main/models/unified_jobs.py | 38 +- awx/main/scheduler/task_manager.py | 24 +- awx/main/scheduler/tasks.py | 10 +- awx/main/tasks/host_indirect.py | 2 +- awx/main/tasks/host_metrics.py | 6 +- awx/main/tasks/jobs.py | 11 +- awx/main/tasks/receptor.py | 8 +- awx/main/tasks/system.py | 117 ++-- awx/main/tests/data/sleep_task.py | 3 +- .../test_feature_flags_api.py | 2 +- awx/main/tests/functional/models/test_ha.py | 2 +- awx/main/tests/functional/test_dispatch.py | 372 +----------- awx/main/tests/functional/test_jobs.py | 2 +- awx/main/tests/settings_for_test.py | 3 - awx/main/tests/unit/settings/test_defaults.py | 17 +- awx/main/tests/unit/test_settings.py | 7 +- awx/main/utils/common.py | 38 ++ awx/main/utils/external_logging.py | 4 +- awx/settings/defaults.py | 51 +- awx/settings/development_defaults.py | 1 - docs/tasks.md | 25 +- 38 files changed, 294 insertions(+), 2010 deletions(-) delete mode 100644 awx/main/dispatch/control.py delete mode 100644 awx/main/dispatch/periodic.py delete mode 100644 awx/main/dispatch/publish.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9f99f95b62..ffe76debcd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -212,7 +212,7 @@ jobs: continue-on-error: true run: | set +e - timeout 54m bash -elc ' + timeout 15m bash -elc ' python -m pip install -r molecule/requirements.txt python -m pip install PyYAML # for awx/tools/scripts/rewrite-awx-operator-requirements.py $(realpath ../awx/tools/scripts/rewrite-awx-operator-requirements.py) molecule/requirements.yml $(realpath ../awx) diff --git a/awx/main/analytics/analytics_tasks.py b/awx/main/analytics/analytics_tasks.py index 3ab6d4bad1..89c2a27de0 100644 --- a/awx/main/analytics/analytics_tasks.py +++ b/awx/main/analytics/analytics_tasks.py @@ -1,15 +1,17 @@ # Python import logging +# Dispatcherd +from dispatcherd.publish import task + # AWX from awx.main.analytics.subsystem_metrics import DispatcherMetrics, CallbackReceiverMetrics -from awx.main.dispatch.publish import task as task_awx from awx.main.dispatch import get_task_queuename logger = logging.getLogger('awx.main.scheduler') -@task_awx(queue=get_task_queuename, timeout=300, on_duplicate='discard') +@task(queue=get_task_queuename, timeout=300, on_duplicate='discard') def send_subsystem_metrics(): DispatcherMetrics().send_metrics() CallbackReceiverMetrics().send_metrics() diff --git a/awx/main/dispatch/__init__.py b/awx/main/dispatch/__init__.py index 97ec6774f2..a2b9a39058 100644 --- a/awx/main/dispatch/__init__.py +++ b/awx/main/dispatch/__init__.py @@ -77,14 +77,13 @@ class PubSub(object): n = psycopg.connection.Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) yield n - def events(self, yield_timeouts=False): + def events(self): if not self.conn.autocommit: raise RuntimeError('Listening for events can only be done in autocommit mode') while True: if select.select([self.conn], [], [], self.select_timeout) == NOT_READY: - if yield_timeouts: - yield None + yield None else: notification_generator = self.current_notifies(self.conn) for notification in notification_generator: diff --git a/awx/main/dispatch/config.py b/awx/main/dispatch/config.py index 9f5773e153..3809c93599 100644 --- a/awx/main/dispatch/config.py +++ b/awx/main/dispatch/config.py @@ -2,7 +2,7 @@ from django.conf import settings from ansible_base.lib.utils.db import get_pg_notify_params from awx.main.dispatch import get_task_queuename -from awx.main.dispatch.pool import get_auto_max_workers +from awx.main.utils.common import get_auto_max_workers def get_dispatcherd_config(for_service: bool = False, mock_publish: bool = False) -> dict: diff --git a/awx/main/dispatch/control.py b/awx/main/dispatch/control.py deleted file mode 100644 index d20c6317ac..0000000000 --- a/awx/main/dispatch/control.py +++ /dev/null @@ -1,77 +0,0 @@ -import logging -import uuid -import json - -from django.db import connection - -from awx.main.dispatch import get_task_queuename -from awx.main.utils.redis import get_redis_client - -from . import pg_bus_conn - -logger = logging.getLogger('awx.main.dispatch') - - -class Control(object): - services = ('dispatcher', 'callback_receiver') - result = None - - def __init__(self, service, host=None): - if service not in self.services: - raise RuntimeError('{} must be in {}'.format(service, self.services)) - self.service = service - self.queuename = host or get_task_queuename() - - def status(self, *args, **kwargs): - r = get_redis_client() - if self.service == 'dispatcher': - stats = r.get(f'awx_{self.service}_statistics') or b'' - return stats.decode('utf-8') - else: - workers = [] - for key in r.keys('awx_callback_receiver_statistics_*'): - workers.append(r.get(key).decode('utf-8')) - return '\n'.join(workers) - - def running(self, *args, **kwargs): - return self.control_with_reply('running', *args, **kwargs) - - def cancel(self, task_ids, with_reply=True): - if with_reply: - return self.control_with_reply('cancel', extra_data={'task_ids': task_ids}) - else: - self.control({'control': 'cancel', 'task_ids': task_ids, 'reply_to': None}, extra_data={'task_ids': task_ids}) - - def schedule(self, *args, **kwargs): - return self.control_with_reply('schedule', *args, **kwargs) - - @classmethod - def generate_reply_queue_name(cls): - return f"reply_to_{str(uuid.uuid4()).replace('-','_')}" - - def control_with_reply(self, command, timeout=5, extra_data=None): - logger.warning('checking {} {} for {}'.format(self.service, command, self.queuename)) - reply_queue = Control.generate_reply_queue_name() - self.result = None - - if not connection.get_autocommit(): - raise RuntimeError('Control-with-reply messages can only be done in autocommit mode') - - with pg_bus_conn(select_timeout=timeout) as conn: - conn.listen(reply_queue) - send_data = {'control': command, 'reply_to': reply_queue} - if extra_data: - send_data.update(extra_data) - conn.notify(self.queuename, json.dumps(send_data)) - - for reply in conn.events(yield_timeouts=True): - if reply is None: - logger.error(f'{self.service} did not reply within {timeout}s') - raise RuntimeError(f"{self.service} did not reply within {timeout}s") - break - - return json.loads(reply.payload) - - def control(self, msg, **kwargs): - with pg_bus_conn() as conn: - conn.notify(self.queuename, json.dumps(msg)) diff --git a/awx/main/dispatch/hazmat.py b/awx/main/dispatch/hazmat.py index b1b30db090..79baefa9d2 100644 --- a/awx/main/dispatch/hazmat.py +++ b/awx/main/dispatch/hazmat.py @@ -18,7 +18,7 @@ django.setup() # noqa from django.conf import settings # Preload all periodic tasks so their imports will be in shared memory -for name, options in settings.CELERYBEAT_SCHEDULE.items(): +for name, options in settings.DISPATCHER_SCHEDULE.items(): resolve_callable(options['task']) diff --git a/awx/main/dispatch/periodic.py b/awx/main/dispatch/periodic.py deleted file mode 100644 index 0d3229cd91..0000000000 --- a/awx/main/dispatch/periodic.py +++ /dev/null @@ -1,146 +0,0 @@ -import logging -import time -import yaml -from datetime import datetime - -logger = logging.getLogger('awx.main.dispatch.periodic') - - -class ScheduledTask: - """ - Class representing schedules, very loosely modeled after python schedule library Job - the idea of this class is to: - - only deal in relative times (time since the scheduler global start) - - only deal in integer math for target runtimes, but float for current relative time - - Missed schedule policy: - Invariant target times are maintained, meaning that if interval=10s offset=0 - and it runs at t=7s, then it calls for next run in 3s. - However, if a complete interval has passed, that is counted as a missed run, - and missed runs are abandoned (no catch-up runs). - """ - - def __init__(self, name: str, data: dict): - # parameters need for schedule computation - self.interval = int(data['schedule'].total_seconds()) - self.offset = 0 # offset relative to start time this schedule begins - self.index = 0 # number of periods of the schedule that has passed - - # parameters that do not affect scheduling logic - self.last_run = None # time of last run, only used for debug - self.completed_runs = 0 # number of times schedule is known to run - self.name = name - self.data = data # used by caller to know what to run - - @property - def next_run(self): - "Time until the next run with t=0 being the global_start of the scheduler class" - return (self.index + 1) * self.interval + self.offset - - def due_to_run(self, relative_time): - return bool(self.next_run <= relative_time) - - def expected_runs(self, relative_time): - return int((relative_time - self.offset) / self.interval) - - def mark_run(self, relative_time): - self.last_run = relative_time - self.completed_runs += 1 - new_index = self.expected_runs(relative_time) - if new_index > self.index + 1: - logger.warning(f'Missed {new_index - self.index - 1} schedules of {self.name}') - self.index = new_index - - def missed_runs(self, relative_time): - "Number of times job was supposed to ran but failed to, only used for debug" - missed_ct = self.expected_runs(relative_time) - self.completed_runs - # if this is currently due to run do not count that as a missed run - if missed_ct and self.due_to_run(relative_time): - missed_ct -= 1 - return missed_ct - - -class Scheduler: - def __init__(self, schedule): - """ - Expects schedule in the form of a dictionary like - { - 'job1': {'schedule': timedelta(seconds=50), 'other': 'stuff'} - } - Only the schedule nearest-second value is used for scheduling, - the rest of the data is for use by the caller to know what to run. - """ - self.jobs = [ScheduledTask(name, data) for name, data in schedule.items()] - min_interval = min(job.interval for job in self.jobs) - num_jobs = len(self.jobs) - - # this is intentionally oppioniated against spammy schedules - # a core goal is to spread out the scheduled tasks (for worker management) - # and high-frequency schedules just do not work with that - if num_jobs > min_interval: - raise RuntimeError(f'Number of schedules ({num_jobs}) is more than the shortest schedule interval ({min_interval} seconds).') - - # even space out jobs over the base interval - for i, job in enumerate(self.jobs): - job.offset = (i * min_interval) // num_jobs - - # internally times are all referenced relative to startup time, add grace period - self.global_start = time.time() + 2.0 - - def get_and_mark_pending(self, reftime=None): - if reftime is None: - reftime = time.time() # mostly for tests - relative_time = reftime - self.global_start - to_run = [] - for job in self.jobs: - if job.due_to_run(relative_time): - to_run.append(job) - logger.debug(f'scheduler found {job.name} to run, {relative_time - job.next_run} seconds after target') - job.mark_run(relative_time) - return to_run - - def time_until_next_run(self, reftime=None): - if reftime is None: - reftime = time.time() # mostly for tests - relative_time = reftime - self.global_start - next_job = min(self.jobs, key=lambda j: j.next_run) - delta = next_job.next_run - relative_time - if delta <= 0.1: - # careful not to give 0 or negative values to the select timeout, which has unclear interpretation - logger.warning(f'Scheduler next run of {next_job.name} is {-delta} seconds in the past') - return 0.1 - elif delta > 20.0: - logger.warning(f'Scheduler next run unexpectedly over 20 seconds in future: {delta}') - return 20.0 - logger.debug(f'Scheduler next run is {next_job.name} in {delta} seconds') - return delta - - def debug(self, *args, **kwargs): - data = dict() - data['title'] = 'Scheduler status' - reftime = time.time() - - now = datetime.fromtimestamp(reftime).strftime('%Y-%m-%d %H:%M:%S UTC') - start_time = datetime.fromtimestamp(self.global_start).strftime('%Y-%m-%d %H:%M:%S UTC') - relative_time = reftime - self.global_start - data['started_time'] = start_time - data['current_time'] = now - data['current_time_relative'] = round(relative_time, 3) - data['total_schedules'] = len(self.jobs) - - data['schedule_list'] = dict( - [ - ( - job.name, - dict( - last_run_seconds_ago=round(relative_time - job.last_run, 3) if job.last_run else None, - next_run_in_seconds=round(job.next_run - relative_time, 3), - offset_in_seconds=job.offset, - completed_runs=job.completed_runs, - missed_runs=job.missed_runs(relative_time), - ), - ) - for job in sorted(self.jobs, key=lambda job: job.interval) - ] - ) - return yaml.safe_dump(data, default_flow_style=False, sort_keys=False) diff --git a/awx/main/dispatch/pool.py b/awx/main/dispatch/pool.py index 802a5d5da6..853c322e8d 100644 --- a/awx/main/dispatch/pool.py +++ b/awx/main/dispatch/pool.py @@ -1,251 +1,50 @@ import logging import os -import random -import signal -import sys import time -import traceback -from datetime import datetime, timezone -from uuid import uuid4 -import json -import collections from multiprocessing import Process -from multiprocessing import Queue as MPQueue -from queue import Full as QueueFull, Empty as QueueEmpty from django.conf import settings -from django.db import connection as django_connection, connections +from django.db import connection as django_connection from django.core.cache import cache as django_cache -from django.utils.timezone import now as tz_now -from django_guid import set_guid -from jinja2 import Template -import psutil -from ansible_base.lib.logging.runtime import log_excess_runtime - -from awx.main.models import UnifiedJob -from awx.main.dispatch import reaper -from awx.main.utils.common import get_mem_effective_capacity, get_corrected_memory, get_corrected_cpu, get_cpu_effective_capacity - -# ansible-runner -from ansible_runner.utils.capacity import get_mem_in_bytes, get_cpu_count - -if 'run_callback_receiver' in sys.argv: - logger = logging.getLogger('awx.main.commands.run_callback_receiver') -else: - logger = logging.getLogger('awx.main.dispatch') - - -RETIRED_SENTINEL_TASK = "[retired]" - - -class NoOpResultQueue(object): - def put(self, item): - pass +logger = logging.getLogger('awx.main.commands.run_callback_receiver') class PoolWorker(object): """ - Used to track a worker child process and its pending and finished messages. + A simple wrapper around a multiprocessing.Process that tracks a worker child process. - This class makes use of two distinct multiprocessing.Queues to track state: - - - self.queue: this is a queue which represents pending messages that should - be handled by this worker process; as new AMQP messages come - in, a pool will put() them into this queue; the child - process that is forked will get() from this queue and handle - received messages in an endless loop - - self.finished: this is a queue which the worker process uses to signal - that it has finished processing a message - - When a message is put() onto this worker, it is tracked in - self.managed_tasks. - - Periodically, the worker will call .calculate_managed_tasks(), which will - cause messages in self.finished to be removed from self.managed_tasks. - - In this way, self.managed_tasks represents a view of the messages assigned - to a specific process. The message at [0] is the least-recently inserted - message, and it represents what the worker is running _right now_ - (self.current_task). - - A worker is "busy" when it has at least one message in self.managed_tasks. - It is "idle" when self.managed_tasks is empty. + The worker process runs the provided target function and tracks its creation time. """ - track_managed_tasks = False - - def __init__(self, queue_size, target, args, **kwargs): - self.messages_sent = 0 - self.messages_finished = 0 - self.managed_tasks = collections.OrderedDict() - self.finished = MPQueue(queue_size) if self.track_managed_tasks else NoOpResultQueue() - self.queue = MPQueue(queue_size) - self.process = Process(target=target, args=(self.queue, self.finished) + args) + def __init__(self, target, args, **kwargs): + self.process = Process(target=target, args=args) self.process.daemon = True self.creation_time = time.monotonic() - self.retiring = False def start(self): self.process.start() - def put(self, body): - if self.retiring: - uuid = body.get('uuid', 'N/A') if isinstance(body, dict) else 'N/A' - logger.info(f"Worker pid:{self.pid} is retiring. Refusing new task {uuid}.") - raise QueueFull("Worker is retiring and not accepting new tasks") # AutoscalePool.write handles QueueFull - uuid = '?' - if isinstance(body, dict): - if not body.get('uuid'): - body['uuid'] = str(uuid4()) - uuid = body['uuid'] - if self.track_managed_tasks: - self.managed_tasks[uuid] = body - self.queue.put(body, block=True, timeout=5) - self.messages_sent += 1 - self.calculate_managed_tasks() - - def quit(self): - """ - Send a special control message to the worker that tells it to exit - gracefully. - """ - self.queue.put('QUIT') - - @property - def age(self): - """Returns the current age of the worker in seconds.""" - return time.monotonic() - self.creation_time - - @property - def pid(self): - return self.process.pid - - @property - def qsize(self): - return self.queue.qsize() - - @property - def alive(self): - return self.process.is_alive() - - @property - def mb(self): - if self.alive: - return '{:0.3f}'.format(psutil.Process(self.pid).memory_info().rss / 1024.0 / 1024.0) - return '0' - - @property - def exitcode(self): - return str(self.process.exitcode) - - def calculate_managed_tasks(self): - if not self.track_managed_tasks: - return - # look to see if any tasks were finished - finished = [] - for _ in range(self.finished.qsize()): - try: - finished.append(self.finished.get(block=False)) - except QueueEmpty: - break # qsize is not always _totally_ up to date - - # if any tasks were finished, removed them from the managed tasks for - # this worker - for uuid in finished: - try: - del self.managed_tasks[uuid] - self.messages_finished += 1 - except KeyError: - # ansible _sometimes_ appears to send events w/ duplicate UUIDs; - # UUIDs for ansible events are *not* actually globally unique - # when this occurs, it's _fine_ to ignore this KeyError because - # the purpose of self.managed_tasks is to just track internal - # state of which events are *currently* being processed. - logger.warning('Event UUID {} appears to be have been duplicated.'.format(uuid)) - if self.retiring: - self.managed_tasks[RETIRED_SENTINEL_TASK] = {'task': RETIRED_SENTINEL_TASK} - - @property - def current_task(self): - if not self.track_managed_tasks: - return None - self.calculate_managed_tasks() - # the task at [0] is the one that's running right now (or is about to - # be running) - if len(self.managed_tasks): - return self.managed_tasks[list(self.managed_tasks.keys())[0]] - - return None - - @property - def orphaned_tasks(self): - if not self.track_managed_tasks: - return [] - orphaned = [] - if not self.alive: - # if this process had a running task that never finished, - # requeue its error callbacks - current_task = self.current_task - if isinstance(current_task, dict): - orphaned.extend(current_task.get('errbacks', [])) - - # if this process has any pending messages requeue them - for _ in range(self.qsize): - try: - message = self.queue.get(block=False) - if message != 'QUIT': - orphaned.append(message) - except QueueEmpty: - break # qsize is not always _totally_ up to date - if len(orphaned): - logger.error('requeuing {} messages from gone worker pid:{}'.format(len(orphaned), self.pid)) - return orphaned - - @property - def busy(self): - self.calculate_managed_tasks() - return len(self.managed_tasks) > 0 - - @property - def idle(self): - return not self.busy - - -class StatefulPoolWorker(PoolWorker): - track_managed_tasks = True - class WorkerPool(object): """ Creates a pool of forked PoolWorkers. - As WorkerPool.write(...) is called (generally, by a kombu consumer - implementation when it receives an AMQP message), messages are passed to - one of the multiprocessing Queues where some work can be done on them. + Each worker process runs the provided target function in an isolated process. + The pool manages spawning, tracking, and stopping worker processes. - class MessagePrinter(awx.main.dispatch.worker.BaseWorker): - - def perform_work(self, body): - print(body) - - pool = WorkerPool(min_workers=4) # spawn four worker processes - pool.init_workers(MessagePrint().work_loop) - pool.write( - 0, # preferred worker 0 - 'Hello, World!' - ) + Example: + pool = WorkerPool(workers_num=4) # spawn four worker processes """ pool_cls = PoolWorker debug_meta = '' - def __init__(self, min_workers=None, queue_size=None): + def __init__(self, workers_num=None): self.name = settings.CLUSTER_HOST_ID self.pid = os.getpid() - self.min_workers = min_workers or settings.JOB_EVENT_WORKERS - self.queue_size = queue_size or settings.JOB_EVENT_MAX_QUEUE_SIZE + self.workers_num = workers_num or settings.JOB_EVENT_WORKERS self.workers = [] def __len__(self): @@ -254,7 +53,7 @@ class WorkerPool(object): def init_workers(self, target, *target_args): self.target = target self.target_args = target_args - for idx in range(self.min_workers): + for idx in range(self.workers_num): self.up() def up(self): @@ -264,320 +63,19 @@ class WorkerPool(object): # for the DB and cache connections (that way lies race conditions) django_connection.close() django_cache.close() - worker = self.pool_cls(self.queue_size, self.target, (idx,) + self.target_args) + worker = self.pool_cls(self.target, (idx,) + self.target_args) self.workers.append(worker) try: worker.start() except Exception: logger.exception('could not fork') else: - logger.debug('scaling up worker pid:{}'.format(worker.pid)) + logger.debug('scaling up worker pid:{}'.format(worker.process.pid)) return idx, worker - def debug(self, *args, **kwargs): - tmpl = Template( - 'Recorded at: {{ dt }} \n' - '{{ pool.name }}[pid:{{ pool.pid }}] workers total={{ workers|length }} {{ meta }} \n' - '{% for w in workers %}' - '. worker[pid:{{ w.pid }}]{% if not w.alive %} GONE exit={{ w.exitcode }}{% endif %}' - ' sent={{ w.messages_sent }}' - ' age={{ "%.0f"|format(w.age) }}s' - ' retiring={{ w.retiring }}' - '{% if w.messages_finished %} finished={{ w.messages_finished }}{% endif %}' - ' qsize={{ w.managed_tasks|length }}' - ' rss={{ w.mb }}MB' - '{% for task in w.managed_tasks.values() %}' - '\n - {% if loop.index0 == 0 %}running {% if "age" in task %}for: {{ "%.1f" % task["age"] }}s {% endif %}{% else %}queued {% endif %}' - '{{ task["uuid"] }} ' - '{% if "task" in task %}' - '{{ task["task"].rsplit(".", 1)[-1] }}' - # don't print kwargs, they often contain launch-time secrets - '(*{{ task.get("args", []) }})' - '{% endif %}' - '{% endfor %}' - '{% if not w.managed_tasks|length %}' - ' [IDLE]' - '{% endif %}' - '\n' - '{% endfor %}' - ) - now = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC') - return tmpl.render(pool=self, workers=self.workers, meta=self.debug_meta, dt=now) - - def write(self, preferred_queue, body): - queue_order = sorted(range(len(self.workers)), key=lambda x: -1 if x == preferred_queue else x) - write_attempt_order = [] - for queue_actual in queue_order: - try: - self.workers[queue_actual].put(body) - return queue_actual - except QueueFull: - pass - except Exception: - tb = traceback.format_exc() - logger.warning("could not write to queue %s" % preferred_queue) - logger.warning("detail: {}".format(tb)) - write_attempt_order.append(preferred_queue) - logger.error("could not write payload to any queue, attempted order: {}".format(write_attempt_order)) - return None - def stop(self, signum): try: for worker in self.workers: os.kill(worker.pid, signum) except Exception: logger.exception('could not kill {}'.format(worker.pid)) - - -def get_auto_max_workers(): - """Method we normally rely on to get max_workers - - Uses almost same logic as Instance.local_health_check - The important thing is to be MORE than Instance.capacity - so that the task-manager does not over-schedule this node - - Ideally we would just use the capacity from the database plus reserve workers, - but this poses some bootstrap problems where OCP task containers - register themselves after startup - """ - # Get memory from ansible-runner - total_memory_gb = get_mem_in_bytes() - - # This may replace memory calculation with a user override - corrected_memory = get_corrected_memory(total_memory_gb) - - # Get same number as max forks based on memory, this function takes memory as bytes - mem_capacity = get_mem_effective_capacity(corrected_memory, is_control_node=True) - - # Follow same process for CPU capacity constraint - cpu_count = get_cpu_count() - corrected_cpu = get_corrected_cpu(cpu_count) - cpu_capacity = get_cpu_effective_capacity(corrected_cpu, is_control_node=True) - - # Here is what is different from health checks, - auto_max = max(mem_capacity, cpu_capacity) - - # add magic number of extra workers to ensure - # we have a few extra workers to run the heartbeat - auto_max += 7 - - return auto_max - - -class AutoscalePool(WorkerPool): - """ - An extended pool implementation that automatically scales workers up and - down based on demand - """ - - pool_cls = StatefulPoolWorker - - def __init__(self, *args, **kwargs): - self.max_workers = kwargs.pop('max_workers', None) - self.max_worker_lifetime_seconds = kwargs.pop( - 'max_worker_lifetime_seconds', getattr(settings, 'WORKER_MAX_LIFETIME_SECONDS', 14400) - ) # Default to 4 hours - super(AutoscalePool, self).__init__(*args, **kwargs) - - if self.max_workers is None: - self.max_workers = get_auto_max_workers() - - # max workers can't be less than min_workers - self.max_workers = max(self.min_workers, self.max_workers) - - # the task manager enforces settings.TASK_MANAGER_TIMEOUT on its own - # but if the task takes longer than the time defined here, we will force it to stop here - self.task_manager_timeout = settings.TASK_MANAGER_TIMEOUT + settings.TASK_MANAGER_TIMEOUT_GRACE_PERIOD - - # initialize some things for subsystem metrics periodic gathering - # the AutoscalePool class does not save these to redis directly, but reports via produce_subsystem_metrics - self.scale_up_ct = 0 - self.worker_count_max = 0 - - # last time we wrote current tasks, to avoid too much log spam - self.last_task_list_log = time.monotonic() - - def produce_subsystem_metrics(self, metrics_object): - metrics_object.set('dispatcher_pool_scale_up_events', self.scale_up_ct) - metrics_object.set('dispatcher_pool_active_task_count', sum(len(w.managed_tasks) for w in self.workers)) - metrics_object.set('dispatcher_pool_max_worker_count', self.worker_count_max) - self.worker_count_max = len(self.workers) - - @property - def should_grow(self): - if len(self.workers) < self.min_workers: - # If we don't have at least min_workers, add more - return True - # If every worker is busy doing something, add more - return all([w.busy for w in self.workers]) - - @property - def full(self): - return len(self.workers) == self.max_workers - - @property - def debug_meta(self): - return 'min={} max={}'.format(self.min_workers, self.max_workers) - - @log_excess_runtime(logger, debug_cutoff=0.05, cutoff=0.2) - def cleanup(self): - """ - Perform some internal account and cleanup. This is run on - every cluster node heartbeat: - - 1. Discover worker processes that exited, and recover messages they - were handling. - 2. Clean up unnecessary, idle workers. - - IMPORTANT: this function is one of the few places in the dispatcher - (aside from setting lookups) where we talk to the database. As such, - if there's an outage, this method _can_ throw various - django.db.utils.Error exceptions. Act accordingly. - """ - orphaned = [] - for w in self.workers[::]: - is_retirement_age = self.max_worker_lifetime_seconds is not None and w.age > self.max_worker_lifetime_seconds - if not w.alive: - # the worker process has exited - # 1. take the task it was running and enqueue the error - # callbacks - # 2. take any pending tasks delivered to its queue and - # send them to another worker - logger.error('worker pid:{} is gone (exit={})'.format(w.pid, w.exitcode)) - if w.current_task: - if w.current_task == {'task': RETIRED_SENTINEL_TASK}: - logger.debug('scaling down worker pid:{} due to worker age: {}'.format(w.pid, w.age)) - self.workers.remove(w) - continue - if w.current_task != 'QUIT': - try: - for j in UnifiedJob.objects.filter(celery_task_id=w.current_task['uuid']): - reaper.reap_job(j, 'failed') - except Exception: - logger.exception('failed to reap job UUID {}'.format(w.current_task['uuid'])) - else: - logger.warning(f'Worker was told to quit but has not, pid={w.pid}') - orphaned.extend(w.orphaned_tasks) - self.workers.remove(w) - - elif w.idle and len(self.workers) > self.min_workers: - # the process has an empty queue (it's idle) and we have - # more processes in the pool than we need (> min) - # send this process a message so it will exit gracefully - # at the next opportunity - logger.debug('scaling down worker pid:{}'.format(w.pid)) - w.quit() - self.workers.remove(w) - - elif w.idle and is_retirement_age: - logger.debug('scaling down worker pid:{} due to worker age: {}'.format(w.pid, w.age)) - w.quit() - self.workers.remove(w) - - elif is_retirement_age and not w.retiring and not w.idle: - logger.info( - f"Worker pid:{w.pid} (age: {w.age:.0f}s) exceeded max lifetime ({self.max_worker_lifetime_seconds:.0f}s). " - "Signaling for graceful retirement." - ) - # Send QUIT signal; worker will finish current task then exit. - w.quit() - # mark as retiring to reject any future tasks that might be assigned in meantime - w.retiring = True - - if w.alive: - # if we discover a task manager invocation that's been running - # too long, reap it (because otherwise it'll just hold the postgres - # advisory lock forever); the goal of this code is to discover - # deadlocks or other serious issues in the task manager that cause - # the task manager to never do more work - current_task = w.current_task - if current_task and isinstance(current_task, dict): - endings = ('tasks.task_manager', 'tasks.dependency_manager', 'tasks.workflow_manager') - current_task_name = current_task.get('task', '') - if current_task_name.endswith(endings): - if 'started' not in current_task: - w.managed_tasks[current_task['uuid']]['started'] = time.time() - age = time.time() - current_task['started'] - w.managed_tasks[current_task['uuid']]['age'] = age - if age > self.task_manager_timeout: - logger.error(f'{current_task_name} has held the advisory lock for {age}, sending SIGUSR1 to {w.pid}') - os.kill(w.pid, signal.SIGUSR1) - - for m in orphaned: - # if all the workers are dead, spawn at least one - if not len(self.workers): - self.up() - idx = random.choice(range(len(self.workers))) - self.write(idx, m) - - def add_bind_kwargs(self, body): - bind_kwargs = body.pop('bind_kwargs', []) - body.setdefault('kwargs', {}) - if 'dispatch_time' in bind_kwargs: - body['kwargs']['dispatch_time'] = tz_now().isoformat() - if 'worker_tasks' in bind_kwargs: - worker_tasks = {} - for worker in self.workers: - worker.calculate_managed_tasks() - worker_tasks[worker.pid] = list(worker.managed_tasks.keys()) - body['kwargs']['worker_tasks'] = worker_tasks - - def up(self): - if self.full: - # if we can't spawn more workers, just toss this message into a - # random worker's backlog - idx = random.choice(range(len(self.workers))) - return idx, self.workers[idx] - else: - self.scale_up_ct += 1 - ret = super(AutoscalePool, self).up() - new_worker_ct = len(self.workers) - if new_worker_ct > self.worker_count_max: - self.worker_count_max = new_worker_ct - return ret - - @staticmethod - def fast_task_serialization(current_task): - try: - return str(current_task.get('task')) + ' - ' + str(sorted(current_task.get('args', []))) + ' - ' + str(sorted(current_task.get('kwargs', {}))) - except Exception: - # just make sure this does not make things worse - return str(current_task) - - def write(self, preferred_queue, body): - if 'guid' in body: - set_guid(body['guid']) - try: - if isinstance(body, dict) and body.get('bind_kwargs'): - self.add_bind_kwargs(body) - if self.should_grow: - self.up() - # we don't care about "preferred queue" round robin distribution, just - # find the first non-busy worker and claim it - workers = self.workers[:] - random.shuffle(workers) - for w in workers: - if not w.busy: - w.put(body) - break - else: - task_name = 'unknown' - if isinstance(body, dict): - task_name = body.get('task') - logger.warning(f'Workers maxed, queuing {task_name}, load: {sum(len(w.managed_tasks) for w in self.workers)} / {len(self.workers)}') - # Once every 10 seconds write out task list for debugging - if time.monotonic() - self.last_task_list_log >= 10.0: - task_counts = {} - for worker in self.workers: - task_slug = self.fast_task_serialization(worker.current_task) - task_counts.setdefault(task_slug, 0) - task_counts[task_slug] += 1 - logger.info(f'Running tasks by count:\n{json.dumps(task_counts, indent=2)}') - self.last_task_list_log = time.monotonic() - return super(AutoscalePool, self).write(preferred_queue, body) - except Exception: - for conn in connections.all(): - # If the database connection has a hiccup, re-establish a new - # connection - conn.close_if_unusable_or_obsolete() - logger.exception('failed to write inbound message') diff --git a/awx/main/dispatch/publish.py b/awx/main/dispatch/publish.py deleted file mode 100644 index 4aef040a88..0000000000 --- a/awx/main/dispatch/publish.py +++ /dev/null @@ -1,163 +0,0 @@ -import inspect -import logging -import json -import time -from uuid import uuid4 - -from dispatcherd.publish import submit_task -from dispatcherd.processors.blocker import Blocker -from dispatcherd.utils import resolve_callable - -from django_guid import get_guid -from django.conf import settings - -from . import pg_bus_conn - -logger = logging.getLogger('awx.main.dispatch') - - -def serialize_task(f): - return '.'.join([f.__module__, f.__name__]) - - -class task: - """ - Used to decorate a function or class so that it can be run asynchronously - via the task dispatcher. Tasks can be simple functions: - - @task() - def add(a, b): - return a + b - - ...or classes that define a `run` method: - - @task() - class Adder: - def run(self, a, b): - return a + b - - # Tasks can be run synchronously... - assert add(1, 1) == 2 - assert Adder().run(1, 1) == 2 - - # ...or published to a queue: - add.apply_async([1, 1]) - Adder.apply_async([1, 1]) - - # Tasks can also define a specific target queue or use the special fan-out queue tower_broadcast: - - @task(queue='slow-tasks') - def snooze(): - time.sleep(10) - - @task(queue='tower_broadcast') - def announce(): - print("Run this everywhere!") - - # The special parameter bind_kwargs tells the main dispatcher process to add certain kwargs - - @task(bind_kwargs=['dispatch_time']) - def print_time(dispatch_time=None): - print(f"Time I was dispatched: {dispatch_time}") - """ - - def __init__(self, queue=None, bind_kwargs=None, timeout=None, on_duplicate=None): - self.queue = queue - self.bind_kwargs = bind_kwargs - self.timeout = timeout - self.on_duplicate = on_duplicate - - def __call__(self, fn=None): - queue = self.queue - bind_kwargs = self.bind_kwargs - timeout = self.timeout - on_duplicate = self.on_duplicate - - class PublisherMixin(object): - queue = None - - @classmethod - def delay(cls, *args, **kwargs): - return cls.apply_async(args, kwargs) - - @classmethod - def get_async_body(cls, args=None, kwargs=None, uuid=None, **kw): - """ - Get the python dict to become JSON data in the pg_notify message - This same message gets passed over the dispatcher IPC queue to workers - If a task is submitted to a multiprocessing pool, skipping pg_notify, this might be used directly - """ - task_id = uuid or str(uuid4()) - args = args or [] - kwargs = kwargs or {} - obj = {'uuid': task_id, 'args': args, 'kwargs': kwargs, 'task': cls.name, 'time_pub': time.time()} - guid = get_guid() - if guid: - obj['guid'] = guid - if bind_kwargs: - obj['bind_kwargs'] = bind_kwargs - obj.update(**kw) - return obj - - @classmethod - def apply_async(cls, args=None, kwargs=None, queue=None, uuid=None, **kw): - try: - from flags.state import flag_enabled - - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - # At this point we have the import string, and submit_task wants the method, so back to that - actual_task = resolve_callable(cls.name) - processor_options = () - if on_duplicate is not None: - processor_options = (Blocker.Params(on_duplicate=on_duplicate),) - return submit_task( - actual_task, - args=args, - kwargs=kwargs, - queue=queue, - uuid=uuid, - timeout=timeout, - processor_options=processor_options, - **kw, - ) - except Exception: - logger.exception(f"[DISPATCHER] Failed to check for alternative dispatcherd implementation for {cls.name}") - # Continue with original implementation if anything fails - pass - - # Original implementation follows - queue = queue or getattr(cls.queue, 'im_func', cls.queue) - if not queue: - msg = f'{cls.name}: Queue value required and may not be None' - logger.error(msg) - raise ValueError(msg) - obj = cls.get_async_body(args=args, kwargs=kwargs, uuid=uuid, **kw) - if callable(queue): - queue = queue() - if not settings.DISPATCHER_MOCK_PUBLISH: - with pg_bus_conn() as conn: - conn.notify(queue, json.dumps(obj)) - return (obj, queue) - - # If the object we're wrapping *is* a class (e.g., RunJob), return - # a *new* class that inherits from the wrapped class *and* BaseTask - # In this way, the new class returned by our decorator is the class - # being decorated *plus* PublisherMixin so cls.apply_async() and - # cls.delay() work - bases = [] - ns = {'name': serialize_task(fn), 'queue': queue} - if inspect.isclass(fn): - bases = list(fn.__bases__) - ns.update(fn.__dict__) - cls = type(fn.__name__, tuple(bases + [PublisherMixin]), ns) - if inspect.isclass(fn): - return cls - - # if the object being decorated is *not* a class (it's a Python - # function), make fn.apply_async and fn.delay proxy through to the - # PublisherMixin we dynamically created above - setattr(fn, 'name', cls.name) - setattr(fn, 'apply_async', cls.apply_async) - setattr(fn, 'delay', cls.delay) - setattr(fn, 'get_async_body', cls.get_async_body) - return fn diff --git a/awx/main/dispatch/worker/__init__.py b/awx/main/dispatch/worker/__init__.py index 6fe8f64608..71203886eb 100644 --- a/awx/main/dispatch/worker/__init__.py +++ b/awx/main/dispatch/worker/__init__.py @@ -1,3 +1,2 @@ -from .base import AWXConsumerRedis, AWXConsumerPG, BaseWorker # noqa +from .base import AWXConsumerRedis # noqa from .callback import CallbackBrokerWorker # noqa -from .task import TaskWorker # noqa diff --git a/awx/main/dispatch/worker/base.py b/awx/main/dispatch/worker/base.py index db89fc923e..133d87c908 100644 --- a/awx/main/dispatch/worker/base.py +++ b/awx/main/dispatch/worker/base.py @@ -4,32 +4,15 @@ import os import logging import signal -import sys -import redis -import json -import psycopg import time -from uuid import UUID -from queue import Empty as QueueEmpty -from datetime import timedelta from django import db -from django.conf import settings -import redis.exceptions -from ansible_base.lib.logging.runtime import log_excess_runtime from awx.main.utils.redis import get_redis_client from awx.main.dispatch.pool import WorkerPool -from awx.main.dispatch.periodic import Scheduler -from awx.main.dispatch import pg_bus_conn -from awx.main.utils.db import set_connection_name -import awx.main.analytics.subsystem_metrics as s_metrics -if 'run_callback_receiver' in sys.argv: - logger = logging.getLogger('awx.main.commands.run_callback_receiver') -else: - logger = logging.getLogger('awx.main.dispatch') +logger = logging.getLogger('awx.main.commands.run_callback_receiver') def signame(sig): @@ -62,85 +45,6 @@ class AWXConsumerBase(object): self.pool.init_workers(self.worker.work_loop) self.redis = get_redis_client() - @property - def listening_on(self): - return f'listening on {self.queues}' - - def control(self, body): - logger.warning(f'Received control signal:\n{body}') - control = body.get('control') - if control in ('status', 'schedule', 'running', 'cancel'): - reply_queue = body['reply_to'] - if control == 'status': - msg = '\n'.join([self.listening_on, self.pool.debug()]) - if control == 'schedule': - msg = self.scheduler.debug() - elif control == 'running': - msg = [] - for worker in self.pool.workers: - worker.calculate_managed_tasks() - msg.extend(worker.managed_tasks.keys()) - elif control == 'cancel': - msg = [] - task_ids = set(body['task_ids']) - for worker in self.pool.workers: - task = worker.current_task - if task and task['uuid'] in task_ids: - logger.warn(f'Sending SIGTERM to task id={task["uuid"]}, task={task.get("task")}, args={task.get("args")}') - os.kill(worker.pid, signal.SIGTERM) - msg.append(task['uuid']) - if task_ids and not msg: - logger.info(f'Could not locate running tasks to cancel with ids={task_ids}') - - if reply_queue is not None: - with pg_bus_conn() as conn: - conn.notify(reply_queue, json.dumps(msg)) - elif control == 'reload': - for worker in self.pool.workers: - worker.quit() - else: - logger.error('unrecognized control message: {}'.format(control)) - - def dispatch_task(self, body): - """This will place the given body into a worker queue to run method decorated as a task""" - if isinstance(body, dict): - body['time_ack'] = time.time() - - if len(self.pool): - if "uuid" in body and body['uuid']: - try: - queue = UUID(body['uuid']).int % len(self.pool) - except Exception: - queue = self.total_messages % len(self.pool) - else: - queue = self.total_messages % len(self.pool) - else: - queue = 0 - self.pool.write(queue, body) - self.total_messages += 1 - - def process_task(self, body): - """Routes the task details in body as either a control task or a task-task""" - if 'control' in body: - try: - return self.control(body) - except Exception: - logger.exception(f"Exception handling control message: {body}") - return - self.dispatch_task(body) - - @log_excess_runtime(logger, debug_cutoff=0.05, cutoff=0.2) - def record_statistics(self): - if time.time() - self.last_stats > 1: # buffer stat recording to once per second - save_data = self.pool.debug() - try: - self.redis.set(f'awx_{self.name}_statistics', save_data) - except redis.exceptions.ConnectionError as exc: - logger.warning(f'Redis connection error saving {self.name} status data:\n{exc}\nmissed data:\n{save_data}') - except Exception: - logger.exception(f"Unknown redis error saving {self.name} status data:\nmissed data:\n{save_data}") - self.last_stats = time.time() - def run(self, *args, **kwargs): signal.signal(signal.SIGINT, self.stop) signal.signal(signal.SIGTERM, self.stop) @@ -150,196 +54,14 @@ class AWXConsumerBase(object): def stop(self, signum, frame): self.should_stop = True logger.warning('received {}, stopping'.format(signame(signum))) - self.worker.on_stop() raise SystemExit() class AWXConsumerRedis(AWXConsumerBase): def run(self, *args, **kwargs): super(AWXConsumerRedis, self).run(*args, **kwargs) - self.worker.on_start() logger.info(f'Callback receiver started with pid={os.getpid()}') db.connection.close() # logs use database, so close connection while True: time.sleep(60) - - -class AWXConsumerPG(AWXConsumerBase): - def __init__(self, *args, schedule=None, **kwargs): - super().__init__(*args, **kwargs) - self.pg_max_wait = getattr(settings, 'DISPATCHER_DB_DOWNTOWN_TOLLERANCE', settings.DISPATCHER_DB_DOWNTIME_TOLERANCE) - # if no successful loops have ran since startup, then we should fail right away - self.pg_is_down = True # set so that we fail if we get database errors on startup - init_time = time.time() - self.pg_down_time = init_time - self.pg_max_wait # allow no grace period - self.last_cleanup = init_time - self.subsystem_metrics = s_metrics.DispatcherMetrics(auto_pipe_execute=False) - self.last_metrics_gather = init_time - self.listen_cumulative_time = 0.0 - if schedule: - schedule = schedule.copy() - else: - schedule = {} - # add control tasks to be ran at regular schedules - # NOTE: if we run out of database connections, it is important to still run cleanup - # so that we scale down workers and free up connections - schedule['pool_cleanup'] = {'control': self.pool.cleanup, 'schedule': timedelta(seconds=60)} - # record subsystem metrics for the dispatcher - schedule['metrics_gather'] = {'control': self.record_metrics, 'schedule': timedelta(seconds=20)} - self.scheduler = Scheduler(schedule) - - @log_excess_runtime(logger, debug_cutoff=0.05, cutoff=0.2) - def record_metrics(self): - current_time = time.time() - self.pool.produce_subsystem_metrics(self.subsystem_metrics) - self.subsystem_metrics.set('dispatcher_availability', self.listen_cumulative_time / (current_time - self.last_metrics_gather)) - try: - self.subsystem_metrics.pipe_execute() - except redis.exceptions.ConnectionError as exc: - logger.warning(f'Redis connection error saving dispatcher metrics, error:\n{exc}') - self.listen_cumulative_time = 0.0 - self.last_metrics_gather = current_time - - def run_periodic_tasks(self): - """ - Run general periodic logic, and return maximum time in seconds before - the next requested run - This may be called more often than that when events are consumed - so this should be very efficient in that - """ - try: - self.record_statistics() # maintains time buffer in method - except Exception as exc: - logger.warning(f'Failed to save dispatcher statistics {exc}') - - # Everything benchmarks to the same original time, so that skews due to - # runtime of the actions, themselves, do not mess up scheduling expectations - reftime = time.time() - - for job in self.scheduler.get_and_mark_pending(reftime=reftime): - if 'control' in job.data: - try: - job.data['control']() - except Exception: - logger.exception(f'Error running control task {job.data}') - elif 'task' in job.data: - body = self.worker.resolve_callable(job.data['task']).get_async_body() - # bypasses pg_notify for scheduled tasks - self.dispatch_task(body) - - if self.pg_is_down: - logger.info('Dispatcher listener connection established') - self.pg_is_down = False - - self.listen_start = time.time() - - return self.scheduler.time_until_next_run(reftime=reftime) - - def run(self, *args, **kwargs): - super(AWXConsumerPG, self).run(*args, **kwargs) - - logger.info(f"Running {self.name}, workers min={self.pool.min_workers} max={self.pool.max_workers}, listening to queues {self.queues}") - init = False - - while True: - try: - with pg_bus_conn(new_connection=True) as conn: - for queue in self.queues: - conn.listen(queue) - if init is False: - self.worker.on_start() - init = True - # run_periodic_tasks run scheduled actions and gives time until next scheduled action - # this is saved to the conn (PubSub) object in order to modify read timeout in-loop - conn.select_timeout = self.run_periodic_tasks() - # this is the main operational loop for awx-manage run_dispatcher - for e in conn.events(yield_timeouts=True): - self.listen_cumulative_time += time.time() - self.listen_start # for metrics - if e is not None: - self.process_task(json.loads(e.payload)) - conn.select_timeout = self.run_periodic_tasks() - if self.should_stop: - return - except psycopg.InterfaceError: - logger.warning("Stale Postgres message bus connection, reconnecting") - continue - except (db.DatabaseError, psycopg.OperationalError): - # If we have attained stady state operation, tolerate short-term database hickups - if not self.pg_is_down: - logger.exception(f"Error consuming new events from postgres, will retry for {self.pg_max_wait} s") - self.pg_down_time = time.time() - self.pg_is_down = True - current_downtime = time.time() - self.pg_down_time - if current_downtime > self.pg_max_wait: - logger.exception(f"Postgres event consumer has not recovered in {current_downtime} s, exiting") - # Sending QUIT to multiprocess queue to signal workers to exit - for worker in self.pool.workers: - try: - worker.quit() - except Exception: - logger.exception(f"Error sending QUIT to worker {worker}") - raise - # Wait for a second before next attempt, but still listen for any shutdown signals - for i in range(10): - if self.should_stop: - return - time.sleep(0.1) - for conn in db.connections.all(): - conn.close_if_unusable_or_obsolete() - except Exception: - # Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata - logger.exception('Encountered unhandled error in dispatcher main loop') - # Sending QUIT to multiprocess queue to signal workers to exit - for worker in self.pool.workers: - try: - worker.quit() - except Exception: - logger.exception(f"Error sending QUIT to worker {worker}") - raise - - -class BaseWorker(object): - def read(self, queue): - return queue.get(block=True, timeout=1) - - def work_loop(self, queue, finished, idx, *args): - ppid = os.getppid() - signal_handler = WorkerSignalHandler() - set_connection_name('worker') # set application_name to distinguish from other dispatcher processes - while not signal_handler.kill_now: - # if the parent PID changes, this process has been orphaned - # via e.g., segfault or sigkill, we should exit too - if os.getppid() != ppid: - break - try: - body = self.read(queue) - if body == 'QUIT': - break - except QueueEmpty: - continue - except Exception: - logger.exception("Exception on worker {}, reconnecting: ".format(idx)) - continue - try: - for conn in db.connections.all(): - # If the database connection has a hiccup during the prior message, close it - # so we can establish a new connection - conn.close_if_unusable_or_obsolete() - self.perform_work(body, *args) - except Exception: - logger.exception(f'Unhandled exception in perform_work in worker pid={os.getpid()}') - finally: - if 'uuid' in body: - uuid = body['uuid'] - finished.put(uuid) - logger.debug('worker exiting gracefully pid:{}'.format(os.getpid())) - - def perform_work(self, body): - raise NotImplementedError() - - def on_start(self): - pass - - def on_stop(self): - pass diff --git a/awx/main/dispatch/worker/callback.py b/awx/main/dispatch/worker/callback.py index 60f01e380e..b056cd5b92 100644 --- a/awx/main/dispatch/worker/callback.py +++ b/awx/main/dispatch/worker/callback.py @@ -4,10 +4,12 @@ import os import signal import time import datetime +from queue import Empty as QueueEmpty from django.conf import settings from django.utils.functional import cached_property from django.utils.timezone import now as tz_now +from django import db from django.db import transaction, connection as django_connection from django_guid import set_guid @@ -16,6 +18,7 @@ import psutil import redis from awx.main.utils.redis import get_redis_client +from awx.main.utils.db import set_connection_name from awx.main.consumers import emit_channel_notification from awx.main.models import JobEvent, AdHocCommandEvent, ProjectUpdateEvent, InventoryUpdateEvent, SystemJobEvent, UnifiedJob from awx.main.constants import ACTIVE_STATES @@ -23,7 +26,7 @@ from awx.main.models.events import emit_event_detail from awx.main.utils.profiling import AWXProfiler from awx.main.tasks.system import events_processed_hook import awx.main.analytics.subsystem_metrics as s_metrics -from .base import BaseWorker +from .base import WorkerSignalHandler logger = logging.getLogger('awx.main.commands.run_callback_receiver') @@ -54,7 +57,7 @@ def job_stats_wrapup(job_identifier, event=None): logger.exception('Worker failed to save stats or emit notifications: Job {}'.format(job_identifier)) -class CallbackBrokerWorker(BaseWorker): +class CallbackBrokerWorker: """ A worker implementation that deserializes callback event data and persists it into the database. @@ -86,7 +89,7 @@ class CallbackBrokerWorker(BaseWorker): """This needs to be obtained after forking, or else it will give the parent process""" return os.getpid() - def read(self, queue): + def read(self): has_redis_error = False try: res = self.redis.blpop(self.queue_name, timeout=1) @@ -149,10 +152,37 @@ class CallbackBrokerWorker(BaseWorker): filepath = self.prof.stop() logger.error(f'profiling is disabled, wrote {filepath}') - def work_loop(self, *args, **kw): + def work_loop(self, idx, *args): if settings.AWX_CALLBACK_PROFILE: signal.signal(signal.SIGUSR1, self.toggle_profiling) - return super(CallbackBrokerWorker, self).work_loop(*args, **kw) + + ppid = os.getppid() + signal_handler = WorkerSignalHandler() + set_connection_name('worker') # set application_name to distinguish from other dispatcher processes + while not signal_handler.kill_now: + # if the parent PID changes, this process has been orphaned + # via e.g., segfault or sigkill, we should exit too + if os.getppid() != ppid: + break + try: + body = self.read() # this is only for the callback, only reading from redis. + if body == 'QUIT': + break + except QueueEmpty: + continue + except Exception: + logger.exception("Exception on worker {}, reconnecting: ".format(idx)) + continue + try: + for conn in db.connections.all(): + # If the database connection has a hiccup during the prior message, close it + # so we can establish a new connection + conn.close_if_unusable_or_obsolete() + self.perform_work(body, *args) + except Exception: + logger.exception(f'Unhandled exception in perform_work in worker pid={os.getpid()}') + + logger.debug('worker exiting gracefully pid:{}'.format(os.getpid())) def flush(self, force=False): now = tz_now() diff --git a/awx/main/dispatch/worker/task.py b/awx/main/dispatch/worker/task.py index 6726aaeae3..eda332c146 100644 --- a/awx/main/dispatch/worker/task.py +++ b/awx/main/dispatch/worker/task.py @@ -1,144 +1,55 @@ import inspect import logging import importlib -import sys -import traceback import time -from kubernetes.config import kube_config - -from django.conf import settings from django_guid import set_guid -from awx.main.tasks.system import dispatch_startup, inform_cluster_of_shutdown - -from .base import BaseWorker - logger = logging.getLogger('awx.main.dispatch') -class TaskWorker(BaseWorker): +def resolve_callable(task): """ - A worker implementation that deserializes task messages and runs native - Python code. - - The code that *builds* these types of messages is found in - `awx.main.dispatch.publish`. + Transform a dotted notation task into an imported, callable function, e.g., + awx.main.tasks.system.delete_inventory + awx.main.tasks.jobs.RunProjectUpdate """ + if not task.startswith('awx.'): + raise ValueError('{} is not a valid awx task'.format(task)) + module, target = task.rsplit('.', 1) + module = importlib.import_module(module) + _call = None + if hasattr(module, target): + _call = getattr(module, target, None) + if not (hasattr(_call, 'apply_async') and hasattr(_call, 'delay')): + raise ValueError('{} is not decorated with @task()'.format(task)) + return _call - @staticmethod - def resolve_callable(task): - """ - Transform a dotted notation task into an imported, callable function, e.g., - awx.main.tasks.system.delete_inventory - awx.main.tasks.jobs.RunProjectUpdate - """ - if not task.startswith('awx.'): - raise ValueError('{} is not a valid awx task'.format(task)) - module, target = task.rsplit('.', 1) - module = importlib.import_module(module) - _call = None - if hasattr(module, target): - _call = getattr(module, target, None) - if not (hasattr(_call, 'apply_async') and hasattr(_call, 'delay')): - raise ValueError('{} is not decorated with @task()'.format(task)) - - return _call - - @staticmethod - def run_callable(body): - """ - Given some AMQP message, import the correct Python code and run it. - """ - task = body['task'] - uuid = body.get('uuid', '') - args = body.get('args', []) - kwargs = body.get('kwargs', {}) - if 'guid' in body: - set_guid(body.pop('guid')) - _call = TaskWorker.resolve_callable(task) - if inspect.isclass(_call): - # the callable is a class, e.g., RunJob; instantiate and - # return its `run()` method - _call = _call().run - - log_extra = '' - logger_method = logger.debug - if ('time_ack' in body) and ('time_pub' in body): - time_publish = body['time_ack'] - body['time_pub'] - time_waiting = time.time() - body['time_ack'] - if time_waiting > 5.0 or time_publish > 5.0: - # If task too a very long time to process, add this information to the log - log_extra = f' took {time_publish:.4f} to ack, {time_waiting:.4f} in local dispatcher' - logger_method = logger.info - # don't print kwargs, they often contain launch-time secrets - logger_method(f'task {uuid} starting {task}(*{args}){log_extra}') - - return _call(*args, **kwargs) - - def perform_work(self, body): - """ - Import and run code for a task e.g., - - body = { - 'args': [8], - 'callbacks': [{ - 'args': [], - 'kwargs': {} - 'task': u'awx.main.tasks.system.handle_work_success' - }], - 'errbacks': [{ - 'args': [], - 'kwargs': {}, - 'task': 'awx.main.tasks.system.handle_work_error' - }], - 'kwargs': {}, - 'task': u'awx.main.tasks.jobs.RunProjectUpdate' - } - """ - settings.__clean_on_fork__() - result = None - try: - result = self.run_callable(body) - except Exception as exc: - result = exc - - try: - if getattr(exc, 'is_awx_task_error', False): - # Error caused by user / tracked in job output - logger.warning("{}".format(exc)) - else: - task = body['task'] - args = body.get('args', []) - kwargs = body.get('kwargs', {}) - logger.exception('Worker failed to run task {}(*{}, **{}'.format(task, args, kwargs)) - except Exception: - # It's fairly critical that this code _not_ raise exceptions on logging - # If you configure external logging in a way that _it_ fails, there's - # not a lot we can do here; sys.stderr.write is a final hail mary - _, _, tb = sys.exc_info() - traceback.print_tb(tb) - - for callback in body.get('errbacks', []) or []: - callback['uuid'] = body['uuid'] - self.perform_work(callback) - finally: - # It's frustrating that we have to do this, but the python k8s - # client leaves behind cacert files in /tmp, so we must clean up - # the tmpdir per-dispatcher process every time a new task comes in - try: - kube_config._cleanup_temp_files() - except Exception: - logger.exception('failed to cleanup k8s client tmp files') - - for callback in body.get('callbacks', []) or []: - callback['uuid'] = body['uuid'] - self.perform_work(callback) - return result - - def on_start(self): - dispatch_startup() - - def on_stop(self): - inform_cluster_of_shutdown() +def run_callable(body): + """ + Given some AMQP message, import the correct Python code and run it. + """ + task = body['task'] + uuid = body.get('uuid', '') + args = body.get('args', []) + kwargs = body.get('kwargs', {}) + if 'guid' in body: + set_guid(body.pop('guid')) + _call = resolve_callable(task) + if inspect.isclass(_call): + # the callable is a class, e.g., RunJob; instantiate and + # return its `run()` method + _call = _call().run + log_extra = '' + logger_method = logger.debug + if ('time_ack' in body) and ('time_pub' in body): + time_publish = body['time_ack'] - body['time_pub'] + time_waiting = time.time() - body['time_ack'] + if time_waiting > 5.0 or time_publish > 5.0: + # If task too a very long time to process, add this information to the log + log_extra = f' took {time_publish:.4f} to ack, {time_waiting:.4f} in local dispatcher' + logger_method = logger.info + # don't print kwargs, they often contain launch-time secrets + logger_method(f'task {uuid} starting {task}(*{args}){log_extra}') + return _call(*args, **kwargs) diff --git a/awx/main/management/commands/run_cache_clear.py b/awx/main/management/commands/run_cache_clear.py index bba9cd8f68..d8f35ed5d5 100644 --- a/awx/main/management/commands/run_cache_clear.py +++ b/awx/main/management/commands/run_cache_clear.py @@ -4,7 +4,7 @@ import json from django.core.management.base import BaseCommand from awx.main.dispatch import pg_bus_conn -from awx.main.dispatch.worker.task import TaskWorker +from awx.main.dispatch.worker.task import run_callable logger = logging.getLogger('awx.main.cache_clear') @@ -21,11 +21,11 @@ class Command(BaseCommand): try: with pg_bus_conn() as conn: conn.listen("tower_settings_change") - for e in conn.events(yield_timeouts=True): + for e in conn.events(): if e is not None: body = json.loads(e.payload) logger.info(f"Cache clear request received. Clearing now, payload: {e.payload}") - TaskWorker.run_callable(body) + run_callable(body) except Exception: # Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata diff --git a/awx/main/management/commands/run_callback_receiver.py b/awx/main/management/commands/run_callback_receiver.py index 8f67909dad..c450d6dc72 100644 --- a/awx/main/management/commands/run_callback_receiver.py +++ b/awx/main/management/commands/run_callback_receiver.py @@ -8,8 +8,8 @@ from django.core.management.base import BaseCommand, CommandError import redis.exceptions from awx.main.analytics.subsystem_metrics import CallbackReceiverMetricsServer -from awx.main.dispatch.control import Control from awx.main.dispatch.worker import AWXConsumerRedis, CallbackBrokerWorker +from awx.main.utils.redis import get_redis_client class Command(BaseCommand): @@ -26,7 +26,7 @@ class Command(BaseCommand): def handle(self, *arg, **options): if options.get('status'): - print(Control('callback_receiver').status()) + print(self.status()) return consumer = None @@ -46,3 +46,10 @@ class Command(BaseCommand): print('Terminating Callback Receiver') if consumer: consumer.stop() + + def status(self, *args, **kwargs): + r = get_redis_client() + workers = [] + for key in r.keys('awx_callback_receiver_statistics_*'): + workers.append(r.get(key).decode('utf-8')) + return '\n'.join(workers) diff --git a/awx/main/management/commands/run_dispatcher.py b/awx/main/management/commands/run_dispatcher.py index 5571b56a0b..599af413fe 100644 --- a/awx/main/management/commands/run_dispatcher.py +++ b/awx/main/management/commands/run_dispatcher.py @@ -5,25 +5,16 @@ import logging.config import yaml import copy -import redis - from django.conf import settings -from django.db import connection from django.core.management.base import BaseCommand, CommandError from django.core.cache import cache as django_cache - -from flags.state import flag_enabled +from django.db import connection from dispatcherd.factories import get_control_from_settings from dispatcherd import run_service from dispatcherd.config import setup as dispatcher_setup -from awx.main.dispatch import get_task_queuename from awx.main.dispatch.config import get_dispatcherd_config -from awx.main.dispatch.control import Control -from awx.main.dispatch.pool import AutoscalePool -from awx.main.dispatch.worker import AWXConsumerPG, TaskWorker -from awx.main.analytics.subsystem_metrics import DispatcherMetricsServer logger = logging.getLogger('awx.main.dispatch') @@ -33,14 +24,7 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument('--status', dest='status', action='store_true', help='print the internal state of any running dispatchers') - parser.add_argument('--schedule', dest='schedule', action='store_true', help='print the current status of schedules being ran by dispatcher') parser.add_argument('--running', dest='running', action='store_true', help='print the UUIDs of any tasked managed by this dispatcher') - parser.add_argument( - '--reload', - dest='reload', - action='store_true', - help=('cause the dispatcher to recycle all of its worker processes; running jobs will run to completion first'), - ) parser.add_argument( '--cancel', dest='cancel', @@ -53,38 +37,17 @@ class Command(BaseCommand): def handle(self, *arg, **options): if options.get('status'): - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - ctl = get_control_from_settings() - running_data = ctl.control_with_reply('status') - if len(running_data) != 1: - raise CommandError('Did not receive expected number of replies') - print(yaml.dump(running_data[0], default_flow_style=False)) - return - else: - print(Control('dispatcher').status()) - return - if options.get('schedule'): - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - print('NOT YET IMPLEMENTED') - return - else: - print(Control('dispatcher').schedule()) + ctl = get_control_from_settings() + running_data = ctl.control_with_reply('status') + if len(running_data) != 1: + raise CommandError('Did not receive expected number of replies') + print(yaml.dump(running_data[0], default_flow_style=False)) return if options.get('running'): - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - ctl = get_control_from_settings() - running_data = ctl.control_with_reply('running') - print(yaml.dump(running_data, default_flow_style=False)) - return - else: - print(Control('dispatcher').running()) - return - if options.get('reload'): - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - print('NOT YET IMPLEMENTED') - return - else: - return Control('dispatcher').control({'control': 'reload'}) + ctl = get_control_from_settings() + running_data = ctl.control_with_reply('running') + print(yaml.dump(running_data, default_flow_style=False)) + return if options.get('cancel'): cancel_str = options.get('cancel') try: @@ -94,44 +57,24 @@ class Command(BaseCommand): if not isinstance(cancel_data, list): cancel_data = [cancel_str] - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - ctl = get_control_from_settings() - results = [] - for task_id in cancel_data: - # For each task UUID, send an individual cancel command - result = ctl.control_with_reply('cancel', data={'uuid': task_id}) - results.append(result) - print(yaml.dump(results, default_flow_style=False)) - return - else: - print(Control('dispatcher').cancel(cancel_data)) - return + ctl = get_control_from_settings() + results = [] + for task_id in cancel_data: + # For each task UUID, send an individual cancel command + result = ctl.control_with_reply('cancel', data={'uuid': task_id}) + results.append(result) + print(yaml.dump(results, default_flow_style=False)) + return - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - self.configure_dispatcher_logging() + self.configure_dispatcher_logging() + # Close the connection, because the pg_notify broker will create new async connection + connection.close() + django_cache.close() + dispatcher_setup(get_dispatcherd_config(for_service=True)) + run_service() - # Close the connection, because the pg_notify broker will create new async connection - connection.close() - django_cache.close() - - dispatcher_setup(get_dispatcherd_config(for_service=True)) - run_service() - else: - consumer = None - - try: - DispatcherMetricsServer().start() - except redis.exceptions.ConnectionError as exc: - raise CommandError(f'Dispatcher could not connect to redis, error: {exc}') - - try: - queues = ['tower_broadcast_all', 'tower_settings_change', get_task_queuename()] - consumer = AWXConsumerPG('dispatcher', TaskWorker(), queues, AutoscalePool(min_workers=4), schedule=settings.CELERYBEAT_SCHEDULE) - consumer.run() - except KeyboardInterrupt: - logger.debug('Terminating Task Dispatcher') - if consumer: - consumer.stop() + dispatcher_setup(get_dispatcherd_config(for_service=True)) + run_service() def configure_dispatcher_logging(self): # Apply special log rule for the parent process diff --git a/awx/main/management/commands/run_rsyslog_configurer.py b/awx/main/management/commands/run_rsyslog_configurer.py index bc68370987..8df5f84331 100644 --- a/awx/main/management/commands/run_rsyslog_configurer.py +++ b/awx/main/management/commands/run_rsyslog_configurer.py @@ -5,7 +5,7 @@ from django.core.management.base import BaseCommand from django.conf import settings from django.core.cache import cache from awx.main.dispatch import pg_bus_conn -from awx.main.dispatch.worker.task import TaskWorker +from awx.main.dispatch.worker.task import run_callable from awx.main.utils.external_logging import reconfigure_rsyslog logger = logging.getLogger('awx.main.rsyslog_configurer') @@ -26,7 +26,7 @@ class Command(BaseCommand): conn.listen("rsyslog_configurer") # reconfigure rsyslog on start up reconfigure_rsyslog() - for e in conn.events(yield_timeouts=True): + for e in conn.events(): if e is not None: logger.info("Change in logging settings found. Restarting rsyslogd") # clear the cache of relevant settings then restart @@ -34,7 +34,7 @@ class Command(BaseCommand): cache.delete_many(setting_keys) settings._awx_conf_memoizedcache.clear() body = json.loads(e.payload) - TaskWorker.run_callable(body) + run_callable(body) except Exception: # Log unanticipated exception in addition to writing to stderr to get timestamps and other metadata logger.exception('Encountered unhandled error in rsyslog_configurer main loop') diff --git a/awx/main/models/unified_jobs.py b/awx/main/models/unified_jobs.py index 3a3ce545a5..350abd908d 100644 --- a/awx/main/models/unified_jobs.py +++ b/awx/main/models/unified_jobs.py @@ -15,6 +15,9 @@ import subprocess import tempfile from collections import OrderedDict +# Dispatcher +from dispatcherd.factories import get_control_from_settings + # Django from django.conf import settings from django.db import models, connection, transaction @@ -24,7 +27,6 @@ from django.utils.translation import gettext_lazy as _ from django.utils.timezone import now from django.utils.encoding import smart_str from django.contrib.contenttypes.models import ContentType -from flags.state import flag_enabled # REST Framework from rest_framework.exceptions import ParseError @@ -39,7 +41,6 @@ from ansible_base.rbac.models import RoleEvaluation # AWX from awx.main.models.base import CommonModelNameNotUnique, PasswordFieldsModel, NotificationFieldsModel from awx.main.dispatch import get_task_queuename -from awx.main.dispatch.control import Control as ControlDispatcher from awx.main.registrar import activity_stream_registrar from awx.main.models.mixins import TaskManagerUnifiedJobMixin, ExecutionEnvironmentMixin from awx.main.models.rbac import to_permissions @@ -1497,43 +1498,30 @@ class UnifiedJob( if not self.celery_task_id: return False - canceled = [] # Special case for task manager (used during workflow job cancellation) if not connection.get_autocommit(): - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - try: - from dispatcherd.factories import get_control_from_settings + try: - ctl = get_control_from_settings() - ctl.control('cancel', data={'uuid': self.celery_task_id}) - except Exception: - logger.exception("Error sending cancel command to new dispatcher") - else: - try: - ControlDispatcher('dispatcher', self.controller_node).cancel([self.celery_task_id], with_reply=False) - except Exception: - logger.exception("Error sending cancel command to legacy dispatcher") + ctl = get_control_from_settings() + ctl.control('cancel', data={'uuid': self.celery_task_id}) + except Exception: + logger.exception("Error sending cancel command to dispatcher") return True # task manager itself needs to act under assumption that cancel was received # Standard case with reply try: timeout = 5 - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - from dispatcherd.factories import get_control_from_settings - ctl = get_control_from_settings() - results = ctl.control_with_reply('cancel', data={'uuid': self.celery_task_id}, expected_replies=1, timeout=timeout) - # Check if cancel was successful by checking if we got any results - return bool(results and len(results) > 0) - else: - # Original implementation - canceled = ControlDispatcher('dispatcher', self.controller_node).cancel([self.celery_task_id]) + ctl = get_control_from_settings() + results = ctl.control_with_reply('cancel', data={'uuid': self.celery_task_id}, expected_replies=1, timeout=timeout) + # Check if cancel was successful by checking if we got any results + return bool(results and len(results) > 0) except socket.timeout: logger.error(f'could not reach dispatcher on {self.controller_node} within {timeout}s') except Exception: logger.exception("error encountered when checking task status") - return bool(self.celery_task_id in canceled) # True or False, whether confirmation was obtained + return False # whether confirmation was obtained def cancel(self, job_explanation=None, is_chain=False): if self.can_cancel: diff --git a/awx/main/scheduler/task_manager.py b/awx/main/scheduler/task_manager.py index 5904c47d57..5fc1c0b51c 100644 --- a/awx/main/scheduler/task_manager.py +++ b/awx/main/scheduler/task_manager.py @@ -19,9 +19,6 @@ from django.utils.timezone import now as tz_now from django.conf import settings from django.contrib.contenttypes.models import ContentType -# django-flags -from flags.state import flag_enabled - from ansible_base.lib.utils.models import get_type_for_model # django-ansible-base @@ -523,19 +520,7 @@ class TaskManager(TaskBase): task.save() task.log_lifecycle("waiting") - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - self.control_nodes_to_notify.add(task.get_queue_name()) - else: - # apply_async does a NOTIFY to the channel dispatcher is listening to - # postgres will treat this as part of the transaction, which is what we want - if task.status != 'failed' and type(task) is not WorkflowJob: - task_cls = task._get_task_class() - task_cls.apply_async( - [task.pk], - opts, - queue=task.get_queue_name(), - uuid=task.celery_task_id, - ) + self.control_nodes_to_notify.add(task.get_queue_name()) # In exception cases, like a job failing pre-start checks, we send the websocket status message. # For jobs going into waiting, we omit this because of performance issues, as it should go to running quickly @@ -729,7 +714,6 @@ class TaskManager(TaskBase): for workflow_approval in self.get_expired_workflow_approvals(): self.timeout_approval_node(workflow_approval) - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - for controller_node in self.control_nodes_to_notify: - logger.info(f'Notifying node {controller_node} of new waiting jobs.') - dispatch_waiting_jobs.apply_async(queue=controller_node) + for controller_node in self.control_nodes_to_notify: + logger.info(f'Notifying node {controller_node} of new waiting jobs.') + dispatch_waiting_jobs.apply_async(queue=controller_node) diff --git a/awx/main/scheduler/tasks.py b/awx/main/scheduler/tasks.py index b2ea8608f7..346bcd0d67 100644 --- a/awx/main/scheduler/tasks.py +++ b/awx/main/scheduler/tasks.py @@ -4,10 +4,12 @@ import logging # Django from django.conf import settings +# Dispatcherd +from dispatcherd.publish import task + # AWX from awx import MODE from awx.main.scheduler import TaskManager, DependencyManager, WorkflowManager -from awx.main.dispatch.publish import task as task_awx from awx.main.dispatch import get_task_queuename logger = logging.getLogger('awx.main.scheduler') @@ -20,16 +22,16 @@ def run_manager(manager, prefix): manager().schedule() -@task_awx(queue=get_task_queuename) +@task(queue=get_task_queuename) def task_manager(): run_manager(TaskManager, "task") -@task_awx(queue=get_task_queuename) +@task(queue=get_task_queuename) def dependency_manager(): run_manager(DependencyManager, "dependency") -@task_awx(queue=get_task_queuename) +@task(queue=get_task_queuename) def workflow_manager(): run_manager(WorkflowManager, "workflow") diff --git a/awx/main/tasks/host_indirect.py b/awx/main/tasks/host_indirect.py index 11f32c248a..e271021962 100644 --- a/awx/main/tasks/host_indirect.py +++ b/awx/main/tasks/host_indirect.py @@ -12,7 +12,7 @@ from django.db import transaction # Django flags from flags.state import flag_enabled -from awx.main.dispatch.publish import task +from dispatcherd.publish import task from awx.main.dispatch import get_task_queuename from awx.main.models.indirect_managed_node_audit import IndirectManagedNodeAudit from awx.main.models.event_query import EventQuery diff --git a/awx/main/tasks/host_metrics.py b/awx/main/tasks/host_metrics.py index c5681f28d5..71572d7af9 100644 --- a/awx/main/tasks/host_metrics.py +++ b/awx/main/tasks/host_metrics.py @@ -6,8 +6,8 @@ from django.conf import settings from django.db.models import Count, F from django.db.models.functions import TruncMonth from django.utils.timezone import now +from dispatcherd.publish import task from awx.main.dispatch import get_task_queuename -from awx.main.dispatch.publish import task as task_awx from awx.main.models.inventory import HostMetric, HostMetricSummaryMonthly from awx.main.tasks.helpers import is_run_threshold_reached from awx.conf.license import get_license @@ -17,7 +17,7 @@ from awx.main.utils.db import bulk_update_sorted_by_id logger = logging.getLogger('awx.main.tasks.host_metrics') -@task_awx(queue=get_task_queuename) +@task(queue=get_task_queuename) def cleanup_host_metrics(): if is_run_threshold_reached(getattr(settings, 'CLEANUP_HOST_METRICS_LAST_TS', None), getattr(settings, 'CLEANUP_HOST_METRICS_INTERVAL', 30) * 86400): logger.info(f"Executing cleanup_host_metrics, last ran at {getattr(settings, 'CLEANUP_HOST_METRICS_LAST_TS', '---')}") @@ -28,7 +28,7 @@ def cleanup_host_metrics(): logger.info("Finished cleanup_host_metrics") -@task_awx(queue=get_task_queuename) +@task(queue=get_task_queuename) def host_metric_summary_monthly(): """Run cleanup host metrics summary monthly task each week""" if is_run_threshold_reached(getattr(settings, 'HOST_METRIC_SUMMARY_TASK_LAST_TS', None), getattr(settings, 'HOST_METRIC_SUMMARY_TASK_INTERVAL', 7) * 86400): diff --git a/awx/main/tasks/jobs.py b/awx/main/tasks/jobs.py index 1cd6205ba9..9e46f35500 100644 --- a/awx/main/tasks/jobs.py +++ b/awx/main/tasks/jobs.py @@ -36,7 +36,6 @@ from dispatcherd.publish import task from dispatcherd.utils import serialize_task # AWX -from awx.main.dispatch.publish import task as task_awx from awx.main.dispatch import get_task_queuename from awx.main.constants import ( PRIVILEGE_ESCALATION_METHODS, @@ -851,7 +850,7 @@ class SourceControlMixin(BaseTask): self.release_lock(project) -@task_awx(queue=get_task_queuename) +@task(queue=get_task_queuename) class RunJob(SourceControlMixin, BaseTask): """ Run a job using ansible-playbook. @@ -1174,7 +1173,7 @@ class RunJob(SourceControlMixin, BaseTask): update_inventory_computed_fields.delay(inventory.id) -@task_awx(queue=get_task_queuename) +@task(queue=get_task_queuename) class RunProjectUpdate(BaseTask): model = ProjectUpdate event_model = ProjectUpdateEvent @@ -1513,7 +1512,7 @@ class RunProjectUpdate(BaseTask): return [] -@task_awx(queue=get_task_queuename) +@task(queue=get_task_queuename) class RunInventoryUpdate(SourceControlMixin, BaseTask): model = InventoryUpdate event_model = InventoryUpdateEvent @@ -1776,7 +1775,7 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask): raise PostRunError('Error occured while saving inventory data, see traceback or server logs', status='error', tb=traceback.format_exc()) -@task_awx(queue=get_task_queuename) +@task(queue=get_task_queuename) class RunAdHocCommand(BaseTask): """ Run an ad hoc command using ansible. @@ -1929,7 +1928,7 @@ class RunAdHocCommand(BaseTask): return d -@task_awx(queue=get_task_queuename) +@task(queue=get_task_queuename) class RunSystemJob(BaseTask): model = SystemJob event_model = SystemJobEvent diff --git a/awx/main/tasks/receptor.py b/awx/main/tasks/receptor.py index ad62b315be..e1ccf4d7c4 100644 --- a/awx/main/tasks/receptor.py +++ b/awx/main/tasks/receptor.py @@ -20,6 +20,9 @@ import ansible_runner # django-ansible-base from ansible_base.lib.utils.db import advisory_lock +# Dispatcherd +from dispatcherd.publish import task + # AWX from awx.main.utils.execution_environments import get_default_pod_spec from awx.main.exceptions import ReceptorNodeNotFound @@ -32,7 +35,6 @@ from awx.main.constants import MAX_ISOLATED_PATH_COLON_DELIMITER from awx.main.tasks.signals import signal_state, signal_callback, SignalExit from awx.main.models import Instance, InstanceLink, UnifiedJob, ReceptorAddress from awx.main.dispatch import get_task_queuename -from awx.main.dispatch.publish import task as task_awx # Receptorctl from receptorctl.socket_interface import ReceptorControl @@ -852,7 +854,7 @@ def reload_receptor(): raise RuntimeError("Receptor reload failed") -@task_awx(on_duplicate='queue_one') +@task(on_duplicate='queue_one') def write_receptor_config(): """ This task runs async on each control node, K8S only. @@ -875,7 +877,7 @@ def write_receptor_config(): reload_receptor() -@task_awx(queue=get_task_queuename, on_duplicate='discard') +@task(queue=get_task_queuename, on_duplicate='discard') def remove_deprovisioned_node(hostname): InstanceLink.objects.filter(source__hostname=hostname).update(link_state=InstanceLink.States.REMOVING) InstanceLink.objects.filter(target__instance__hostname=hostname).update(link_state=InstanceLink.States.REMOVING) diff --git a/awx/main/tasks/system.py b/awx/main/tasks/system.py index 6d9656346a..905a2a235c 100644 --- a/awx/main/tasks/system.py +++ b/awx/main/tasks/system.py @@ -9,12 +9,12 @@ import shutil import time from collections import namedtuple from contextlib import redirect_stdout -from datetime import datetime from packaging.version import Version from io import StringIO # dispatcherd from dispatcherd.factories import get_control_from_settings +from dispatcherd.publish import task # Runner import ansible_runner.cleanup @@ -56,7 +56,6 @@ from awx.main.analytics.subsystem_metrics import DispatcherMetrics from awx.main.constants import ACTIVE_STATES, ERROR_STATES from awx.main.consumers import emit_channel_notification from awx.main.dispatch import get_task_queuename, reaper -from awx.main.dispatch.publish import task as task_awx from awx.main.models import ( Instance, InstanceGroup, @@ -74,7 +73,6 @@ from awx.main.tasks.host_indirect import save_indirect_host_entries from awx.main.tasks.receptor import administrative_workunit_reaper, get_receptor_ctl, worker_cleanup, worker_info, write_receptor_config from awx.main.utils.common import ignore_inventory_computed_fields, ignore_inventory_group_removal from awx.main.utils.reload import stop_local_services -from dispatcherd.publish import task logger = logging.getLogger('awx.main.tasks.system') @@ -95,7 +93,10 @@ def _run_dispatch_startup_common(): # TODO: Enable this on VM installs if settings.IS_K8S: - write_receptor_config() + try: + write_receptor_config() + except Exception: + logger.exception("Failed to write receptor config, skipping.") try: convert_jsonfields() @@ -125,20 +126,12 @@ def _run_dispatch_startup_common(): # no-op. # apply_cluster_membership_policies() - cluster_node_heartbeat() + cluster_node_heartbeat(None) reaper.startup_reaping() m = DispatcherMetrics() m.reset_values() -def _legacy_dispatch_startup(): - """ - Legacy branch for startup: simply performs reaping of waiting jobs with a zero grace period. - """ - logger.debug("Legacy dispatcher: calling reaper.reap_waiting with grace_period=0") - reaper.reap_waiting(grace_period=0) - - def _dispatcherd_dispatch_startup(): """ New dispatcherd branch for startup: uses the control API to re-submit waiting jobs. @@ -153,21 +146,16 @@ def dispatch_startup(): """ System initialization at startup. First, execute the common logic. - Then, if FEATURE_DISPATCHERD_ENABLED is enabled, re-submit waiting jobs via the control API; - otherwise, fall back to legacy reaping of waiting jobs. + Then, re-submit waiting jobs via the control API. """ _run_dispatch_startup_common() - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - _dispatcherd_dispatch_startup() - else: - _legacy_dispatch_startup() + _dispatcherd_dispatch_startup() def inform_cluster_of_shutdown(): """ Clean system shutdown that marks the current instance offline. - In legacy mode, it also reaps waiting jobs. - In dispatcherd mode, it relies on dispatcherd's built-in cleanup. + Relies on dispatcherd's built-in cleanup. """ try: inst = Instance.objects.get(hostname=settings.CLUSTER_HOST_ID) @@ -176,18 +164,11 @@ def inform_cluster_of_shutdown(): logger.exception("Cluster host not found: %s", settings.CLUSTER_HOST_ID) return - if flag_enabled('FEATURE_DISPATCHERD_ENABLED'): - logger.debug("Dispatcherd mode: no extra reaping required for instance %s", inst.hostname) - else: - try: - logger.debug("Legacy mode: reaping waiting jobs for instance %s", inst.hostname) - reaper.reap_waiting(inst, grace_period=0) - except Exception: - logger.exception("Failed to reap waiting jobs for %s", inst.hostname) + logger.debug("No extra reaping required for instance %s", inst.hostname) logger.warning("Normal shutdown processed for instance %s; instance removed from capacity pool.", inst.hostname) -@task_awx(queue=get_task_queuename, timeout=3600 * 5) +@task(queue=get_task_queuename, timeout=3600 * 5) def migrate_jsonfield(table, pkfield, columns): batchsize = 10000 with advisory_lock(f'json_migration_{table}', wait=False) as acquired: @@ -233,7 +214,7 @@ def migrate_jsonfield(table, pkfield, columns): logger.warning(f"Migration of {table} to jsonb is finished.") -@task_awx(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one') +@task(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one') def apply_cluster_membership_policies(): from awx.main.signals import disable_activity_stream @@ -345,7 +326,7 @@ def apply_cluster_membership_policies(): logger.debug('Cluster policy computation finished in {} seconds'.format(time.time() - started_compute)) -@task_awx(queue='tower_settings_change', timeout=600) +@task(queue='tower_settings_change', timeout=600) def clear_setting_cache(setting_keys): # log that cache is being cleared logger.info(f"clear_setting_cache of keys {setting_keys}") @@ -363,7 +344,7 @@ def clear_setting_cache(setting_keys): ctl.control('set_log_level', data={'level': settings.LOG_AGGREGATOR_LEVEL}) -@task_awx(queue='tower_broadcast_all', timeout=600) +@task(queue='tower_broadcast_all', timeout=600) def delete_project_files(project_path): # TODO: possibly implement some retry logic lock_file = project_path + '.lock' @@ -381,7 +362,7 @@ def delete_project_files(project_path): logger.exception('Could not remove lock file {}'.format(lock_file)) -@task_awx(queue='tower_broadcast_all') +@task(queue='tower_broadcast_all') def profile_sql(threshold=1, minutes=1): if threshold <= 0: cache.delete('awx-profile-sql-threshold') @@ -391,7 +372,7 @@ def profile_sql(threshold=1, minutes=1): logger.error('SQL QUERIES >={}s ENABLED FOR {} MINUTE(S)'.format(threshold, minutes)) -@task_awx(queue=get_task_queuename, timeout=1800) +@task(queue=get_task_queuename, timeout=1800) def send_notifications(notification_list, job_id=None): if not isinstance(notification_list, list): raise TypeError("notification_list should be of type list") @@ -436,13 +417,13 @@ def events_processed_hook(unified_job): save_indirect_host_entries.delay(unified_job.id) -@task_awx(queue=get_task_queuename, timeout=3600 * 5, on_duplicate='discard') +@task(queue=get_task_queuename, timeout=3600 * 5, on_duplicate='discard') def gather_analytics(): if is_run_threshold_reached(getattr(settings, 'AUTOMATION_ANALYTICS_LAST_GATHER', None), settings.AUTOMATION_ANALYTICS_GATHER_INTERVAL): analytics.gather() -@task_awx(queue=get_task_queuename, timeout=600, on_duplicate='queue_one') +@task(queue=get_task_queuename, timeout=600, on_duplicate='queue_one') def purge_old_stdout_files(): nowtime = time.time() for f in os.listdir(settings.JOBOUTPUT_ROOT): @@ -504,18 +485,18 @@ class CleanupImagesAndFiles: cls.run_remote(this_inst, **kwargs) -@task_awx(queue='tower_broadcast_all', timeout=3600) +@task(queue='tower_broadcast_all', timeout=3600) def handle_removed_image(remove_images=None): """Special broadcast invocation of this method to handle case of deleted EE""" CleanupImagesAndFiles.run(remove_images=remove_images, file_pattern='') -@task_awx(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one') +@task(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one') def cleanup_images_and_files(): CleanupImagesAndFiles.run(image_prune=True) -@task_awx(queue=get_task_queuename, timeout=600, on_duplicate='queue_one') +@task(queue=get_task_queuename, timeout=600, on_duplicate='queue_one') def execution_node_health_check(node): if node == '': logger.warning('Remote health check incorrectly called with blank string') @@ -640,44 +621,13 @@ def inspect_execution_and_hop_nodes(instance_list): execution_node_health_check.apply_async([hostname]) -@task_awx(queue=get_task_queuename, bind_kwargs=['dispatch_time', 'worker_tasks']) -def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None): - """ - Original implementation for AWX dispatcher. - Uses worker_tasks from bind_kwargs to track running tasks. - """ - # Run common instance management logic - this_inst, instance_list, lost_instances = _heartbeat_instance_management() - if this_inst is None: - return # Early return case from instance management - - # Check versions - _heartbeat_check_versions(this_inst, instance_list) - - # Handle lost instances - _heartbeat_handle_lost_instances(lost_instances, this_inst) - - # Run local reaper - original implementation using worker_tasks - if worker_tasks is not None: - active_task_ids = [] - for task_list in worker_tasks.values(): - active_task_ids.extend(task_list) - - # Convert dispatch_time to datetime - ref_time = datetime.fromisoformat(dispatch_time) if dispatch_time else now() - - reaper.reap(instance=this_inst, excluded_uuids=active_task_ids, ref_time=ref_time) - - if max(len(task_list) for task_list in worker_tasks.values()) <= 1: - reaper.reap_waiting(instance=this_inst, excluded_uuids=active_task_ids, ref_time=ref_time) - - @task(queue=get_task_queuename, bind=True) -def adispatch_cluster_node_heartbeat(binder): +def cluster_node_heartbeat(binder): """ Dispatcherd implementation. Uses Control API to get running tasks. """ + # Run common instance management logic this_inst, instance_list, lost_instances = _heartbeat_instance_management() if this_inst is None: @@ -690,6 +640,9 @@ def adispatch_cluster_node_heartbeat(binder): _heartbeat_handle_lost_instances(lost_instances, this_inst) # Get running tasks using dispatcherd API + if binder is None: + logger.debug("Heartbeat finished in startup.") + return active_task_ids = _get_active_task_ids_from_dispatcherd(binder) if active_task_ids is None: logger.warning("No active task IDs retrieved from dispatcherd, skipping reaper") @@ -839,7 +792,7 @@ def _heartbeat_handle_lost_instances(lost_instances, this_inst): logger.exception('No SQL state available. Error marking {} as lost'.format(other_inst.hostname)) -@task_awx(queue=get_task_queuename, timeout=1800, on_duplicate='queue_one') +@task(queue=get_task_queuename, timeout=1800, on_duplicate='queue_one') def awx_receptor_workunit_reaper(): """ When an AWX job is launched via receptor, files such as status, stdin, and stdout are created @@ -885,7 +838,7 @@ def awx_receptor_workunit_reaper(): administrative_workunit_reaper(receptor_work_list) -@task_awx(queue=get_task_queuename, timeout=1800, on_duplicate='queue_one') +@task(queue=get_task_queuename, timeout=1800, on_duplicate='queue_one') def awx_k8s_reaper(): if not settings.RECEPTOR_RELEASE_WORK: return @@ -908,7 +861,7 @@ def awx_k8s_reaper(): logger.exception("Failed to delete orphaned pod {} from {}".format(job.log_format, group)) -@task_awx(queue=get_task_queuename, timeout=3600 * 5, on_duplicate='discard') +@task(queue=get_task_queuename, timeout=3600 * 5, on_duplicate='discard') def awx_periodic_scheduler(): lock_session_timeout_milliseconds = settings.TASK_MANAGER_LOCK_TIMEOUT * 1000 with advisory_lock('awx_periodic_scheduler_lock', lock_session_timeout_milliseconds=lock_session_timeout_milliseconds, wait=False) as acquired: @@ -965,7 +918,7 @@ def awx_periodic_scheduler(): emit_channel_notification('schedules-changed', dict(id=schedule.id, group_name="schedules")) -@task_awx(queue=get_task_queuename, timeout=3600) +@task(queue=get_task_queuename, timeout=3600) def handle_failure_notifications(task_ids): """A task-ified version of the method that sends notifications.""" found_task_ids = set() @@ -980,7 +933,7 @@ def handle_failure_notifications(task_ids): logger.warning(f'Could not send notifications for {deleted_tasks} because they were not found in the database') -@task_awx(queue=get_task_queuename, timeout=3600 * 5) +@task(queue=get_task_queuename, timeout=3600 * 5) def update_inventory_computed_fields(inventory_id): """ Signal handler and wrapper around inventory.update_computed_fields to @@ -1030,7 +983,7 @@ def update_smart_memberships_for_inventory(smart_inventory): return False -@task_awx(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one') +@task(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one') def update_host_smart_inventory_memberships(): smart_inventories = Inventory.objects.filter(kind='smart', host_filter__isnull=False, pending_deletion=False) changed_inventories = set([]) @@ -1046,7 +999,7 @@ def update_host_smart_inventory_memberships(): smart_inventory.update_computed_fields() -@task_awx(queue=get_task_queuename, timeout=3600 * 5) +@task(queue=get_task_queuename, timeout=3600 * 5) def delete_inventory(inventory_id, user_id, retries=5): # Delete inventory as user if user_id is None: @@ -1108,7 +1061,7 @@ def _reconstruct_relationships(copy_mapping): new_obj.save() -@task_awx(queue=get_task_queuename, timeout=600) +@task(queue=get_task_queuename, timeout=600) def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, permission_check_func=None): logger.debug('Deep copy {} from {} to {}.'.format(model_name, obj_pk, new_obj_pk)) @@ -1163,7 +1116,7 @@ def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, p update_inventory_computed_fields.delay(new_obj.id) -@task_awx(queue=get_task_queuename, timeout=3600, on_duplicate='discard') +@task(queue=get_task_queuename, timeout=3600, on_duplicate='discard') def periodic_resource_sync(): if not getattr(settings, 'RESOURCE_SERVER', None): logger.debug("Skipping periodic resource_sync, RESOURCE_SERVER not configured") diff --git a/awx/main/tests/data/sleep_task.py b/awx/main/tests/data/sleep_task.py index 1293db56dc..8582a73c79 100644 --- a/awx/main/tests/data/sleep_task.py +++ b/awx/main/tests/data/sleep_task.py @@ -6,14 +6,13 @@ from dispatcherd.publish import task from django.db import connection from awx.main.dispatch import get_task_queuename -from awx.main.dispatch.publish import task as old_task from ansible_base.lib.utils.db import advisory_lock logger = logging.getLogger(__name__) -@old_task(queue=get_task_queuename) +@task(queue=get_task_queuename) def sleep_task(seconds=10, log=False): if log: logger.info('starting sleep_task') diff --git a/awx/main/tests/functional/dab_feature_flags/test_feature_flags_api.py b/awx/main/tests/functional/dab_feature_flags/test_feature_flags_api.py index 8007ff7b84..fb483000fb 100644 --- a/awx/main/tests/functional/dab_feature_flags/test_feature_flags_api.py +++ b/awx/main/tests/functional/dab_feature_flags/test_feature_flags_api.py @@ -21,7 +21,7 @@ def test_feature_flags_list_endpoint_override(get, flag_val): bob = User.objects.create(username='bob', password='test_user', is_superuser=True) AAPFlag.objects.all().delete() - flag_name = "FEATURE_DISPATCHERD_ENABLED" + flag_name = "FEATURE_INDIRECT_NODE_COUNTING_ENABLED" setattr(settings, flag_name, flag_val) seed_feature_flags() url = "/api/v2/feature_flags/states/" diff --git a/awx/main/tests/functional/models/test_ha.py b/awx/main/tests/functional/models/test_ha.py index bf8c5309c7..c8ee9dc0a7 100644 --- a/awx/main/tests/functional/models/test_ha.py +++ b/awx/main/tests/functional/models/test_ha.py @@ -3,7 +3,7 @@ import pytest # AWX from awx.main.ha import is_ha_environment from awx.main.models.ha import Instance -from awx.main.dispatch.pool import get_auto_max_workers +from awx.main.utils.common import get_auto_max_workers # Django from django.test.utils import override_settings diff --git a/awx/main/tests/functional/test_dispatch.py b/awx/main/tests/functional/test_dispatch.py index 382f858c28..44fad8b733 100644 --- a/awx/main/tests/functional/test_dispatch.py +++ b/awx/main/tests/functional/test_dispatch.py @@ -1,20 +1,11 @@ import datetime -import multiprocessing -import random -import signal -import time -import yaml from unittest import mock -from flags.state import disable_flag, enable_flag from django.utils.timezone import now as tz_now import pytest from awx.main.models import Job, WorkflowJob, Instance from awx.main.dispatch import reaper -from awx.main.dispatch.pool import StatefulPoolWorker, WorkerPool, AutoscalePool -from awx.main.dispatch.publish import task -from awx.main.dispatch.worker import BaseWorker, TaskWorker -from awx.main.dispatch.periodic import Scheduler +from dispatcherd.publish import task ''' Prevent logger. calls from triggering database operations @@ -57,294 +48,6 @@ def multiply(a, b): return a * b -class SimpleWorker(BaseWorker): - def perform_work(self, body, *args): - pass - - -class ResultWriter(BaseWorker): - def perform_work(self, body, result_queue): - result_queue.put(body + '!!!') - - -class SlowResultWriter(BaseWorker): - def perform_work(self, body, result_queue): - time.sleep(3) - super(SlowResultWriter, self).perform_work(body, result_queue) - - -@pytest.mark.usefixtures("disable_database_settings") -class TestPoolWorker: - def setup_method(self, test_method): - self.worker = StatefulPoolWorker(1000, self.tick, tuple()) - - def tick(self): - self.worker.finished.put(self.worker.queue.get()['uuid']) - time.sleep(0.5) - - def test_qsize(self): - assert self.worker.qsize == 0 - for i in range(3): - self.worker.put({'task': 'abc123'}) - assert self.worker.qsize == 3 - - def test_put(self): - assert len(self.worker.managed_tasks) == 0 - assert self.worker.messages_finished == 0 - self.worker.put({'task': 'abc123'}) - - assert len(self.worker.managed_tasks) == 1 - assert self.worker.messages_sent == 1 - - def test_managed_tasks(self): - self.worker.put({'task': 'abc123'}) - self.worker.calculate_managed_tasks() - assert len(self.worker.managed_tasks) == 1 - - self.tick() - self.worker.calculate_managed_tasks() - assert len(self.worker.managed_tasks) == 0 - - def test_current_task(self): - self.worker.put({'task': 'abc123'}) - assert self.worker.current_task['task'] == 'abc123' - - def test_quit(self): - self.worker.quit() - assert self.worker.queue.get() == 'QUIT' - - def test_idle_busy(self): - assert self.worker.idle is True - assert self.worker.busy is False - self.worker.put({'task': 'abc123'}) - assert self.worker.busy is True - assert self.worker.idle is False - - -@pytest.mark.django_db -class TestWorkerPool: - def setup_method(self, test_method): - self.pool = WorkerPool(min_workers=3) - - def teardown_method(self, test_method): - self.pool.stop(signal.SIGTERM) - - def test_worker(self): - self.pool.init_workers(SimpleWorker().work_loop) - assert len(self.pool) == 3 - for worker in self.pool.workers: - assert worker.messages_sent == 0 - assert worker.alive is True - - def test_single_task(self): - self.pool.init_workers(SimpleWorker().work_loop) - self.pool.write(0, 'xyz') - assert self.pool.workers[0].messages_sent == 1 # worker at index 0 handled one task - assert self.pool.workers[1].messages_sent == 0 - assert self.pool.workers[2].messages_sent == 0 - - def test_queue_preference(self): - self.pool.init_workers(SimpleWorker().work_loop) - self.pool.write(2, 'xyz') - assert self.pool.workers[0].messages_sent == 0 - assert self.pool.workers[1].messages_sent == 0 - assert self.pool.workers[2].messages_sent == 1 # worker at index 2 handled one task - - def test_worker_processing(self): - result_queue = multiprocessing.Queue() - self.pool.init_workers(ResultWriter().work_loop, result_queue) - for i in range(10): - self.pool.write(random.choice(range(len(self.pool))), 'Hello, Worker {}'.format(i)) - all_messages = [result_queue.get(timeout=1) for i in range(10)] - all_messages.sort() - assert all_messages == ['Hello, Worker {}!!!'.format(i) for i in range(10)] - - total_handled = sum([worker.messages_sent for worker in self.pool.workers]) - assert total_handled == 10 - - -@pytest.mark.django_db -class TestAutoScaling: - def setup_method(self, test_method): - self.pool = AutoscalePool(min_workers=2, max_workers=10) - - def teardown_method(self, test_method): - self.pool.stop(signal.SIGTERM) - - def test_scale_up(self): - result_queue = multiprocessing.Queue() - self.pool.init_workers(SlowResultWriter().work_loop, result_queue) - - # start with two workers, write an event to each worker and make it busy - assert len(self.pool) == 2 - for i, w in enumerate(self.pool.workers): - w.put('Hello, Worker {}'.format(0)) - assert len(self.pool) == 2 - - # wait for the subprocesses to start working on their tasks and be marked busy - time.sleep(1) - assert self.pool.should_grow - - # write a third message, expect a new worker to spawn because all - # workers are busy - self.pool.write(0, 'Hello, Worker {}'.format(2)) - assert len(self.pool) == 3 - - def test_scale_down(self): - self.pool.init_workers(ResultWriter().work_loop, multiprocessing.Queue()) - - # start with two workers, and scale up to 10 workers - assert len(self.pool) == 2 - for i in range(8): - self.pool.up() - assert len(self.pool) == 10 - - # cleanup should scale down to 8 workers - self.pool.cleanup() - assert len(self.pool) == 2 - - def test_max_scale_up(self): - self.pool.init_workers(ResultWriter().work_loop, multiprocessing.Queue()) - - assert len(self.pool) == 2 - for i in range(25): - self.pool.up() - assert self.pool.max_workers == 10 - assert self.pool.full is True - assert len(self.pool) == 10 - - def test_equal_worker_distribution(self): - # if all workers are busy, spawn new workers *before* adding messages - # to an existing queue - self.pool.init_workers(SlowResultWriter().work_loop, multiprocessing.Queue) - - # start with two workers, write an event to each worker and make it busy - assert len(self.pool) == 2 - for i in range(10): - self.pool.write(0, 'Hello, World!') - assert len(self.pool) == 10 - for w in self.pool.workers: - assert w.busy - assert len(w.managed_tasks) == 1 - - # the queue is full at 10, the _next_ write should put the message into - # a worker's backlog - assert len(self.pool) == 10 - for w in self.pool.workers: - assert w.messages_sent == 1 - self.pool.write(0, 'Hello, World!') - assert len(self.pool) == 10 - assert self.pool.workers[0].messages_sent == 2 - - @pytest.mark.timeout(20) - def test_lost_worker_autoscale(self): - # if a worker exits, it should be replaced automatically up to min_workers - self.pool.init_workers(ResultWriter().work_loop, multiprocessing.Queue()) - - # start with two workers, kill one of them - assert len(self.pool) == 2 - assert not self.pool.should_grow - alive_pid = self.pool.workers[1].pid - self.pool.workers[0].process.kill() - self.pool.workers[0].process.join() # waits for process to full terminate - - # clean up and the dead worker - self.pool.cleanup() - assert len(self.pool) == 1 - assert self.pool.workers[0].pid == alive_pid - - # the next queue write should replace the lost worker - self.pool.write(0, 'Hello, Worker') - assert len(self.pool) == 2 - - -@pytest.mark.usefixtures("disable_database_settings") -class TestTaskDispatcher: - @property - def tm(self): - return TaskWorker() - - def test_function_dispatch(self): - result = self.tm.perform_work({'task': 'awx.main.tests.functional.test_dispatch.add', 'args': [2, 2]}) - assert result == 4 - - def test_function_dispatch_must_be_decorated(self): - result = self.tm.perform_work({'task': 'awx.main.tests.functional.test_dispatch.restricted', 'args': [2, 2]}) - assert isinstance(result, ValueError) - assert str(result) == 'awx.main.tests.functional.test_dispatch.restricted is not decorated with @task()' # noqa - - def test_method_dispatch(self): - result = self.tm.perform_work({'task': 'awx.main.tests.functional.test_dispatch.Adder', 'args': [2, 2]}) - assert result == 4 - - def test_method_dispatch_must_be_decorated(self): - result = self.tm.perform_work({'task': 'awx.main.tests.functional.test_dispatch.Restricted', 'args': [2, 2]}) - assert isinstance(result, ValueError) - assert str(result) == 'awx.main.tests.functional.test_dispatch.Restricted is not decorated with @task()' # noqa - - def test_python_function_cannot_be_imported(self): - result = self.tm.perform_work( - { - 'task': 'os.system', - 'args': ['ls'], - } - ) - assert isinstance(result, ValueError) - assert str(result) == 'os.system is not a valid awx task' # noqa - - def test_undefined_function_cannot_be_imported(self): - result = self.tm.perform_work({'task': 'awx.foo.bar'}) - assert isinstance(result, ModuleNotFoundError) - assert str(result) == "No module named 'awx.foo'" # noqa - - -@pytest.mark.django_db -class TestTaskPublisher: - @pytest.fixture(autouse=True) - def _disable_dispatcherd(self): - flag_name = "FEATURE_DISPATCHERD_ENABLED" - disable_flag(flag_name) - yield - enable_flag(flag_name) - - def test_function_callable(self): - assert add(2, 2) == 4 - - def test_method_callable(self): - assert Adder().run(2, 2) == 4 - - def test_function_apply_async(self): - message, queue = add.apply_async([2, 2], queue='foobar') - assert message['args'] == [2, 2] - assert message['kwargs'] == {} - assert message['task'] == 'awx.main.tests.functional.test_dispatch.add' - assert queue == 'foobar' - - def test_method_apply_async(self): - message, queue = Adder.apply_async([2, 2], queue='foobar') - assert message['args'] == [2, 2] - assert message['kwargs'] == {} - assert message['task'] == 'awx.main.tests.functional.test_dispatch.Adder' - assert queue == 'foobar' - - def test_apply_async_queue_required(self): - with pytest.raises(ValueError) as e: - message, queue = add.apply_async([2, 2]) - assert "awx.main.tests.functional.test_dispatch.add: Queue value required and may not be None" == e.value.args[0] - - def test_queue_defined_in_task_decorator(self): - message, queue = multiply.apply_async([2, 2]) - assert queue == 'hard-math' - - def test_queue_overridden_from_task_decorator(self): - message, queue = multiply.apply_async([2, 2], queue='not-so-hard') - assert queue == 'not-so-hard' - - def test_apply_with_callable_queuename(self): - message, queue = add.apply_async([2, 2], queue=lambda: 'called') - assert queue == 'called' - - yesterday = tz_now() - datetime.timedelta(days=1) minute = tz_now() - datetime.timedelta(seconds=120) now = tz_now() @@ -448,76 +151,3 @@ class TestJobReaper(object): assert job.started > ref_time assert job.status == 'running' assert job.job_explanation == '' - - -@pytest.mark.django_db -class TestScheduler: - def test_too_many_schedules_freak_out(self): - with pytest.raises(RuntimeError): - Scheduler({'job1': {'schedule': datetime.timedelta(seconds=1)}, 'job2': {'schedule': datetime.timedelta(seconds=1)}}) - - def test_spread_out(self): - scheduler = Scheduler( - { - 'job1': {'schedule': datetime.timedelta(seconds=16)}, - 'job2': {'schedule': datetime.timedelta(seconds=16)}, - 'job3': {'schedule': datetime.timedelta(seconds=16)}, - 'job4': {'schedule': datetime.timedelta(seconds=16)}, - } - ) - assert [job.offset for job in scheduler.jobs] == [0, 4, 8, 12] - - def test_missed_schedule(self, mocker): - scheduler = Scheduler({'job1': {'schedule': datetime.timedelta(seconds=10)}}) - assert scheduler.jobs[0].missed_runs(time.time() - scheduler.global_start) == 0 - mocker.patch('awx.main.dispatch.periodic.time.time', return_value=scheduler.global_start + 50) - scheduler.get_and_mark_pending() - assert scheduler.jobs[0].missed_runs(50) > 1 - - def test_advance_schedule(self, mocker): - scheduler = Scheduler( - { - 'job1': {'schedule': datetime.timedelta(seconds=30)}, - 'joba': {'schedule': datetime.timedelta(seconds=20)}, - 'jobb': {'schedule': datetime.timedelta(seconds=20)}, - } - ) - for job in scheduler.jobs: - # HACK: the offsets automatically added make this a hard test to write... so remove offsets - job.offset = 0.0 - mocker.patch('awx.main.dispatch.periodic.time.time', return_value=scheduler.global_start + 29) - to_run = scheduler.get_and_mark_pending() - assert set(job.name for job in to_run) == set(['joba', 'jobb']) - mocker.patch('awx.main.dispatch.periodic.time.time', return_value=scheduler.global_start + 39) - to_run = scheduler.get_and_mark_pending() - assert len(to_run) == 1 - assert to_run[0].name == 'job1' - - @staticmethod - def get_job(scheduler, name): - for job in scheduler.jobs: - if job.name == name: - return job - - def test_scheduler_debug(self, mocker): - scheduler = Scheduler( - { - 'joba': {'schedule': datetime.timedelta(seconds=20)}, - 'jobb': {'schedule': datetime.timedelta(seconds=50)}, - 'jobc': {'schedule': datetime.timedelta(seconds=500)}, - 'jobd': {'schedule': datetime.timedelta(seconds=20)}, - } - ) - rel_time = 119.9 # slightly under the 6th 20-second bin, to avoid offset problems - current_time = scheduler.global_start + rel_time - mocker.patch('awx.main.dispatch.periodic.time.time', return_value=current_time - 1.0e-8) - self.get_job(scheduler, 'jobb').mark_run(rel_time) - self.get_job(scheduler, 'jobd').mark_run(rel_time - 20.0) - - output = scheduler.debug() - data = yaml.safe_load(output) - assert data['schedule_list']['jobc']['last_run_seconds_ago'] is None - assert data['schedule_list']['joba']['missed_runs'] == 4 - assert data['schedule_list']['jobd']['missed_runs'] == 3 - assert data['schedule_list']['jobd']['completed_runs'] == 1 - assert data['schedule_list']['jobb']['next_run_in_seconds'] > 25.0 diff --git a/awx/main/tests/functional/test_jobs.py b/awx/main/tests/functional/test_jobs.py index e71420f737..7d4a0ed5b7 100644 --- a/awx/main/tests/functional/test_jobs.py +++ b/awx/main/tests/functional/test_jobs.py @@ -50,7 +50,7 @@ def test_job_capacity_and_with_inactive_node(): i.save() with override_settings(CLUSTER_HOST_ID=i.hostname): with mock.patch.object(redis.client.Redis, 'ping', lambda self: True): - cluster_node_heartbeat() + cluster_node_heartbeat(None) i = Instance.objects.get(id=i.id) assert i.capacity == 0 diff --git a/awx/main/tests/settings_for_test.py b/awx/main/tests/settings_for_test.py index b7d5cdf023..5634494c33 100644 --- a/awx/main/tests/settings_for_test.py +++ b/awx/main/tests/settings_for_test.py @@ -7,9 +7,6 @@ from awx.settings.development import * # NOQA # Some things make decisions based on settings.SETTINGS_MODULE, so this is done for that SETTINGS_MODULE = 'awx.settings.development' -# Turn off task submission, because sqlite3 does not have pg_notify -DISPATCHER_MOCK_PUBLISH = True - # Use SQLite for unit tests instead of PostgreSQL. If the lines below are # commented out, Django will create the test_awx-dev database in PostgreSQL to # run unit tests. diff --git a/awx/main/tests/unit/settings/test_defaults.py b/awx/main/tests/unit/settings/test_defaults.py index a7f5eeeca8..10cb5561a7 100644 --- a/awx/main/tests/unit/settings/test_defaults.py +++ b/awx/main/tests/unit/settings/test_defaults.py @@ -1,20 +1,19 @@ import pytest from django.conf import settings -from datetime import timedelta @pytest.mark.parametrize( - "job_name,function_path", + "task_name", [ - ('tower_scheduler', 'awx.main.tasks.system.awx_periodic_scheduler'), + 'awx.main.tasks.system.awx_periodic_scheduler', ], ) -def test_CELERYBEAT_SCHEDULE(mocker, job_name, function_path): - assert job_name in settings.CELERYBEAT_SCHEDULE - assert 'schedule' in settings.CELERYBEAT_SCHEDULE[job_name] - assert type(settings.CELERYBEAT_SCHEDULE[job_name]['schedule']) is timedelta - assert settings.CELERYBEAT_SCHEDULE[job_name]['task'] == function_path +def test_DISPATCHER_SCHEDULE(mocker, task_name): + assert task_name in settings.DISPATCHER_SCHEDULE + assert 'schedule' in settings.DISPATCHER_SCHEDULE[task_name] + assert type(settings.DISPATCHER_SCHEDULE[task_name]['schedule']) in (int, float) + assert settings.DISPATCHER_SCHEDULE[task_name]['task'] == task_name # Ensures that the function exists - mocker.patch(function_path) + mocker.patch(task_name) diff --git a/awx/main/tests/unit/test_settings.py b/awx/main/tests/unit/test_settings.py index 42ad771b1a..ee517d6a87 100644 --- a/awx/main/tests/unit/test_settings.py +++ b/awx/main/tests/unit/test_settings.py @@ -8,9 +8,7 @@ LOCAL_SETTINGS = ( 'CACHES', 'DEBUG', 'NAMED_URL_GRAPH', - 'DISPATCHER_MOCK_PUBLISH', # Platform flags are managed by the platform flags system and have environment-specific defaults - 'FEATURE_DISPATCHERD_ENABLED', 'FEATURE_INDIRECT_NODE_COUNTING_ENABLED', ) @@ -87,12 +85,9 @@ def test_development_defaults_feature_flags(monkeypatch): spec.loader.exec_module(development_defaults) # Also import through the development settings to ensure both paths are tested - from awx.settings.development import FEATURE_INDIRECT_NODE_COUNTING_ENABLED, FEATURE_DISPATCHERD_ENABLED + from awx.settings.development import FEATURE_INDIRECT_NODE_COUNTING_ENABLED # Verify the feature flags are set correctly in both the module and settings assert hasattr(development_defaults, 'FEATURE_INDIRECT_NODE_COUNTING_ENABLED') assert development_defaults.FEATURE_INDIRECT_NODE_COUNTING_ENABLED is True - assert hasattr(development_defaults, 'FEATURE_DISPATCHERD_ENABLED') - assert development_defaults.FEATURE_DISPATCHERD_ENABLED is True assert FEATURE_INDIRECT_NODE_COUNTING_ENABLED is True - assert FEATURE_DISPATCHERD_ENABLED is True diff --git a/awx/main/utils/common.py b/awx/main/utils/common.py index 2f45bb7c8f..4365887be2 100644 --- a/awx/main/utils/common.py +++ b/awx/main/utils/common.py @@ -43,6 +43,9 @@ from django.apps import apps # AWX from awx.conf.license import get_license +# ansible-runner +from ansible_runner.utils.capacity import get_mem_in_bytes, get_cpu_count + logger = logging.getLogger('awx.main.utils') __all__ = [ @@ -1220,3 +1223,38 @@ def unified_job_class_to_event_table_name(job_class): def load_all_entry_points_for(entry_point_subsections: list[str], /) -> dict[str, EntryPoint]: return {ep.name: ep for entry_point_category in entry_point_subsections for ep in entry_points(group=f'awx_plugins.{entry_point_category}')} + + +def get_auto_max_workers(): + """Method we normally rely on to get max_workers + + Uses almost same logic as Instance.local_health_check + The important thing is to be MORE than Instance.capacity + so that the task-manager does not over-schedule this node + + Ideally we would just use the capacity from the database plus reserve workers, + but this poses some bootstrap problems where OCP task containers + register themselves after startup + """ + # Get memory from ansible-runner + total_memory_gb = get_mem_in_bytes() + + # This may replace memory calculation with a user override + corrected_memory = get_corrected_memory(total_memory_gb) + + # Get same number as max forks based on memory, this function takes memory as bytes + mem_capacity = get_mem_effective_capacity(corrected_memory, is_control_node=True) + + # Follow same process for CPU capacity constraint + cpu_count = get_cpu_count() + corrected_cpu = get_corrected_cpu(cpu_count) + cpu_capacity = get_cpu_effective_capacity(corrected_cpu, is_control_node=True) + + # Here is what is different from health checks, + auto_max = max(mem_capacity, cpu_capacity) + + # add magic number of extra workers to ensure + # we have a few extra workers to run the heartbeat + auto_max += 7 + + return auto_max diff --git a/awx/main/utils/external_logging.py b/awx/main/utils/external_logging.py index 21aa104a15..81983b85e6 100644 --- a/awx/main/utils/external_logging.py +++ b/awx/main/utils/external_logging.py @@ -4,9 +4,9 @@ import tempfile import urllib.parse as urlparse from django.conf import settings +from dispatcherd.publish import task from awx.main.utils.reload import supervisor_service_command -from awx.main.dispatch.publish import task as task_awx def construct_rsyslog_conf_template(settings=settings): @@ -139,7 +139,7 @@ def construct_rsyslog_conf_template(settings=settings): return tmpl -@task_awx(queue='rsyslog_configurer', timeout=600, on_duplicate='queue_one') +@task(queue='rsyslog_configurer', timeout=600, on_duplicate='queue_one') def reconfigure_rsyslog(): tmpl = construct_rsyslog_conf_template() # Write config to a temp file then move it to preserve atomicity diff --git a/awx/settings/defaults.py b/awx/settings/defaults.py index 9a2bc5204d..63810aca8b 100644 --- a/awx/settings/defaults.py +++ b/awx/settings/defaults.py @@ -7,7 +7,6 @@ import os import re # noqa import tempfile import socket -from datetime import timedelta DEBUG = True SQL_DEBUG = DEBUG @@ -416,49 +415,34 @@ EXECUTION_NODE_REMEDIATION_CHECKS = 60 * 30 # once every 30 minutes check if an # Amount of time dispatcher will try to reconnect to database for jobs and consuming new work DISPATCHER_DB_DOWNTIME_TOLERANCE = 40 -# If you set this, nothing will ever be sent to pg_notify -# this is not practical to use, although periodic schedules may still run slugish but functional tasks -# sqlite3 based tests will use this -DISPATCHER_MOCK_PUBLISH = False - BROKER_URL = 'unix:///var/run/redis/redis.sock' REDIS_RETRY_COUNT = 3 # Number of retries for Redis connection errors REDIS_BACKOFF_CAP = 1.0 # Maximum backoff delay in seconds for Redis retries REDIS_BACKOFF_BASE = 0.5 # Base for exponential backoff calculation for Redis retries -CELERYBEAT_SCHEDULE = { - 'tower_scheduler': {'task': 'awx.main.tasks.system.awx_periodic_scheduler', 'schedule': timedelta(seconds=30), 'options': {'expires': 20}}, - 'cluster_heartbeat': { + +DISPATCHER_SCHEDULE = { + 'awx.main.tasks.system.awx_periodic_scheduler': {'task': 'awx.main.tasks.system.awx_periodic_scheduler', 'schedule': 30, 'options': {'expires': 20}}, + 'awx.main.tasks.system.cluster_node_heartbeat': { 'task': 'awx.main.tasks.system.cluster_node_heartbeat', - 'schedule': timedelta(seconds=CLUSTER_NODE_HEARTBEAT_PERIOD), + 'schedule': CLUSTER_NODE_HEARTBEAT_PERIOD, 'options': {'expires': 50}, }, - 'gather_analytics': {'task': 'awx.main.tasks.system.gather_analytics', 'schedule': timedelta(minutes=5)}, - 'task_manager': {'task': 'awx.main.scheduler.tasks.task_manager', 'schedule': timedelta(seconds=20), 'options': {'expires': 20}}, - 'dependency_manager': {'task': 'awx.main.scheduler.tasks.dependency_manager', 'schedule': timedelta(seconds=20), 'options': {'expires': 20}}, - 'k8s_reaper': {'task': 'awx.main.tasks.system.awx_k8s_reaper', 'schedule': timedelta(seconds=60), 'options': {'expires': 50}}, - 'receptor_reaper': {'task': 'awx.main.tasks.system.awx_receptor_workunit_reaper', 'schedule': timedelta(seconds=60)}, - 'send_subsystem_metrics': {'task': 'awx.main.analytics.analytics_tasks.send_subsystem_metrics', 'schedule': timedelta(seconds=20)}, - 'cleanup_images': {'task': 'awx.main.tasks.system.cleanup_images_and_files', 'schedule': timedelta(hours=3)}, - 'cleanup_host_metrics': {'task': 'awx.main.tasks.host_metrics.cleanup_host_metrics', 'schedule': timedelta(hours=3, minutes=30)}, - 'host_metric_summary_monthly': {'task': 'awx.main.tasks.host_metrics.host_metric_summary_monthly', 'schedule': timedelta(hours=4)}, - 'periodic_resource_sync': {'task': 'awx.main.tasks.system.periodic_resource_sync', 'schedule': timedelta(minutes=15)}, - 'cleanup_and_save_indirect_host_entries_fallback': { + 'awx.main.tasks.system.gather_analytics': {'task': 'awx.main.tasks.system.gather_analytics', 'schedule': 300}, + 'awx.main.scheduler.tasks.task_manager': {'task': 'awx.main.scheduler.tasks.task_manager', 'schedule': 20, 'options': {'expires': 20}}, + 'awx.main.scheduler.tasks.dependency_manager': {'task': 'awx.main.scheduler.tasks.dependency_manager', 'schedule': 20, 'options': {'expires': 20}}, + 'awx.main.tasks.system.awx_k8s_reaper': {'task': 'awx.main.tasks.system.awx_k8s_reaper', 'schedule': 60, 'options': {'expires': 50}}, + 'awx.main.tasks.system.awx_receptor_workunit_reaper': {'task': 'awx.main.tasks.system.awx_receptor_workunit_reaper', 'schedule': 60}, + 'awx.main.analytics.analytics_tasks.send_subsystem_metrics': {'task': 'awx.main.analytics.analytics_tasks.send_subsystem_metrics', 'schedule': 20}, + 'awx.main.tasks.system.cleanup_images_and_files': {'task': 'awx.main.tasks.system.cleanup_images_and_files', 'schedule': 10800}, + 'awx.main.tasks.host_metrics.cleanup_host_metrics': {'task': 'awx.main.tasks.host_metrics.cleanup_host_metrics', 'schedule': 12600}, + 'awx.main.tasks.host_metrics.host_metric_summary_monthly': {'task': 'awx.main.tasks.host_metrics.host_metric_summary_monthly', 'schedule': 14400}, + 'awx.main.tasks.system.periodic_resource_sync': {'task': 'awx.main.tasks.system.periodic_resource_sync', 'schedule': 900}, + 'awx.main.tasks.host_indirect.cleanup_and_save_indirect_host_entries_fallback': { 'task': 'awx.main.tasks.host_indirect.cleanup_and_save_indirect_host_entries_fallback', - 'schedule': timedelta(minutes=60), + 'schedule': 3600, }, } -DISPATCHER_SCHEDULE = {} -for options in CELERYBEAT_SCHEDULE.values(): - new_options = options.copy() - task_name = options['task'] - # Handle the only one exception case of the heartbeat which has a new implementation - if task_name == 'awx.main.tasks.system.cluster_node_heartbeat': - task_name = 'awx.main.tasks.system.adispatch_cluster_node_heartbeat' - new_options['task'] = task_name - new_options['schedule'] = options['schedule'].total_seconds() - DISPATCHER_SCHEDULE[task_name] = new_options - # Django Caching Configuration DJANGO_REDIS_IGNORE_EXCEPTIONS = True CACHES = {'default': {'BACKEND': 'awx.main.cache.AWXRedisCache', 'LOCATION': 'unix:///var/run/redis/redis.sock?db=1'}} @@ -1149,7 +1133,6 @@ OPA_REQUEST_RETRIES = 2 # The number of retry attempts for connecting to the OP # feature flags FEATURE_INDIRECT_NODE_COUNTING_ENABLED = False -FEATURE_DISPATCHERD_ENABLED = False # Dispatcher worker lifetime. If set to None, workers will never be retired # based on age. Note workers will finish their last task before retiring if diff --git a/awx/settings/development_defaults.py b/awx/settings/development_defaults.py index c7f43e880e..49cb1a68b6 100644 --- a/awx/settings/development_defaults.py +++ b/awx/settings/development_defaults.py @@ -69,4 +69,3 @@ AWX_DISABLE_TASK_MANAGERS = False # ======================!!!!!!! FOR DEVELOPMENT ONLY !!!!!!!================================= FEATURE_INDIRECT_NODE_COUNTING_ENABLED = True -FEATURE_DISPATCHERD_ENABLED = True diff --git a/docs/tasks.md b/docs/tasks.md index aa91b90339..6c3b7e3c71 100644 --- a/docs/tasks.md +++ b/docs/tasks.md @@ -20,19 +20,18 @@ In this document, we will go into a bit of detail about how and when AWX runs Py - Every node in an AWX cluster runs a periodic task that serves as a heartbeat and capacity check -Transition to dispatcherd Library ---------------------------------- +dispatcherd Library +------------------- -The task system logic is being split out into a new library: +The task system logic has been split out into a separate library: https://github.com/ansible/dispatcherd -Currently AWX is in a transitionary period where this is put behind a feature flag. -The difference can be seen in how the task decorator is imported. +AWX now uses dispatcherd directly for all task management. Tasks are decorated using: - - old `from awx.main.dispatch.publish import task` - - transition `from awx.main.dispatch.publish import task as task_awx` - - new `from dispatcherd.publish import task` +```python +from dispatcherd.publish import task +``` Tasks, Queues and Workers @@ -74,7 +73,7 @@ Defining and Running Tasks Tasks are defined in AWX's source code, and generally live in the `awx.main.tasks` module. Tasks can be defined as simple functions: - from awx.main.dispatch.publish import task as task_awx + from dispatcherd.publish import task @task() def add(a, b): @@ -145,14 +144,6 @@ This outputs running and queued task UUIDs handled by a specific dispatcher ['eb3b0a83-86da-413d-902a-16d7530a6b25', 'f447266a-23da-42b4-8025-fe379d2db96f'] ``` -Additionally, you can tell the local running dispatcher to recycle all of the -workers in its pool. It will wait for any running jobs to finish and exit when -work has completed, spinning up replacement workers. - -``` -awx-manage run_dispatcher --reload -``` - * * * In the following sections, we will go further into the details regarding AWX tasks. They are all decorated by `@task()` in [awx/awx/main/tasks.py](https://github.com/ansible/awx/blob/devel/awx/main/tasks.py)