Add decorator

* moved to dispatcher decorator
* updated as many as I could find
This commit is contained in:
thedoubl3j
2025-12-19 15:35:57 -05:00
parent e55578b64e
commit f9f4bf2d1a
11 changed files with 59 additions and 193 deletions

View File

@@ -1,15 +1,17 @@
# Python # Python
import logging import logging
# Dispatcherd
from dispatcherd.publish import task
# AWX # AWX
from awx.main.analytics.subsystem_metrics import DispatcherMetrics, CallbackReceiverMetrics from awx.main.analytics.subsystem_metrics import DispatcherMetrics, CallbackReceiverMetrics
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
logger = logging.getLogger('awx.main.scheduler') logger = logging.getLogger('awx.main.scheduler')
@task_awx(queue=get_task_queuename, timeout=300, on_duplicate='discard') @task(queue=get_task_queuename, timeout=300, on_duplicate='discard')
def send_subsystem_metrics(): def send_subsystem_metrics():
DispatcherMetrics().send_metrics() DispatcherMetrics().send_metrics()
CallbackReceiverMetrics().send_metrics() CallbackReceiverMetrics().send_metrics()

View File

@@ -1,137 +0,0 @@
import inspect
import logging
import time
from uuid import uuid4
from dispatcherd.publish import submit_task
from dispatcherd.processors.blocker import Blocker
from dispatcherd.utils import resolve_callable
from django_guid import get_guid
logger = logging.getLogger('awx.main.dispatch')
def serialize_task(f):
return '.'.join([f.__module__, f.__name__])
class task:
"""
Used to decorate a function or class so that it can be run asynchronously
via the task dispatcher. Tasks can be simple functions:
@task()
def add(a, b):
return a + b
...or classes that define a `run` method:
@task()
class Adder:
def run(self, a, b):
return a + b
# Tasks can be run synchronously...
assert add(1, 1) == 2
assert Adder().run(1, 1) == 2
# ...or published to a queue:
add.apply_async([1, 1])
Adder.apply_async([1, 1])
# Tasks can also define a specific target queue or use the special fan-out queue tower_broadcast:
@task(queue='slow-tasks')
def snooze():
time.sleep(10)
@task(queue='tower_broadcast')
def announce():
print("Run this everywhere!")
# The special parameter bind_kwargs tells the main dispatcher process to add certain kwargs
@task(bind_kwargs=['dispatch_time'])
def print_time(dispatch_time=None):
print(f"Time I was dispatched: {dispatch_time}")
"""
def __init__(self, queue=None, bind_kwargs=None, timeout=None, on_duplicate=None):
self.queue = queue
self.bind_kwargs = bind_kwargs
self.timeout = timeout
self.on_duplicate = on_duplicate
def __call__(self, fn=None):
queue = self.queue
bind_kwargs = self.bind_kwargs
timeout = self.timeout
on_duplicate = self.on_duplicate
class PublisherMixin(object):
queue = None
@classmethod
def delay(cls, *args, **kwargs):
return cls.apply_async(args, kwargs)
@classmethod
def get_async_body(cls, args=None, kwargs=None, uuid=None, **kw):
"""
Get the python dict to become JSON data in the pg_notify message
This same message gets passed over the dispatcher IPC queue to workers
If a task is submitted to a multiprocessing pool, skipping pg_notify, this might be used directly
"""
task_id = uuid or str(uuid4())
args = args or []
kwargs = kwargs or {}
obj = {'uuid': task_id, 'args': args, 'kwargs': kwargs, 'task': cls.name, 'time_pub': time.time()}
guid = get_guid()
if guid:
obj['guid'] = guid
if bind_kwargs:
obj['bind_kwargs'] = bind_kwargs
obj.update(**kw)
return obj
@classmethod
def apply_async(cls, args=None, kwargs=None, queue=None, uuid=None, **kw):
# At this point we have the import string, and submit_task wants the method, so back to that
actual_task = resolve_callable(cls.name)
processor_options = ()
if on_duplicate is not None:
processor_options = (Blocker.Params(on_duplicate=on_duplicate),)
return submit_task(
actual_task,
args=args,
kwargs=kwargs,
queue=queue,
uuid=uuid,
timeout=timeout,
processor_options=processor_options,
**kw,
)
# If the object we're wrapping *is* a class (e.g., RunJob), return
# a *new* class that inherits from the wrapped class *and* BaseTask
# In this way, the new class returned by our decorator is the class
# being decorated *plus* PublisherMixin so cls.apply_async() and
# cls.delay() work
bases = []
ns = {'name': serialize_task(fn), 'queue': queue}
if inspect.isclass(fn):
bases = list(fn.__bases__)
ns.update(fn.__dict__)
cls = type(fn.__name__, tuple(bases + [PublisherMixin]), ns)
if inspect.isclass(fn):
return cls
# if the object being decorated is *not* a class (it's a Python
# function), make fn.apply_async and fn.delay proxy through to the
# PublisherMixin we dynamically created above
setattr(fn, 'name', cls.name)
setattr(fn, 'apply_async', cls.apply_async)
setattr(fn, 'delay', cls.delay)
setattr(fn, 'get_async_body', cls.get_async_body)
return fn

View File

@@ -4,10 +4,12 @@ import logging
# Django # Django
from django.conf import settings from django.conf import settings
# Dispatcherd
from dispatcherd.publish import task
# AWX # AWX
from awx import MODE from awx import MODE
from awx.main.scheduler import TaskManager, DependencyManager, WorkflowManager from awx.main.scheduler import TaskManager, DependencyManager, WorkflowManager
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
logger = logging.getLogger('awx.main.scheduler') logger = logging.getLogger('awx.main.scheduler')
@@ -20,16 +22,16 @@ def run_manager(manager, prefix):
manager().schedule() manager().schedule()
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
def task_manager(): def task_manager():
run_manager(TaskManager, "task") run_manager(TaskManager, "task")
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
def dependency_manager(): def dependency_manager():
run_manager(DependencyManager, "dependency") run_manager(DependencyManager, "dependency")
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
def workflow_manager(): def workflow_manager():
run_manager(WorkflowManager, "workflow") run_manager(WorkflowManager, "workflow")

View File

@@ -12,7 +12,7 @@ from django.db import transaction
# Django flags # Django flags
from flags.state import flag_enabled from flags.state import flag_enabled
from awx.main.dispatch.publish import task from dispatcherd.publish import task
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.models.indirect_managed_node_audit import IndirectManagedNodeAudit from awx.main.models.indirect_managed_node_audit import IndirectManagedNodeAudit
from awx.main.models.event_query import EventQuery from awx.main.models.event_query import EventQuery

View File

@@ -6,8 +6,8 @@ from django.conf import settings
from django.db.models import Count, F from django.db.models import Count, F
from django.db.models.functions import TruncMonth from django.db.models.functions import TruncMonth
from django.utils.timezone import now from django.utils.timezone import now
from dispatcherd.publish import task
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.publish import task as task_awx
from awx.main.models.inventory import HostMetric, HostMetricSummaryMonthly from awx.main.models.inventory import HostMetric, HostMetricSummaryMonthly
from awx.main.tasks.helpers import is_run_threshold_reached from awx.main.tasks.helpers import is_run_threshold_reached
from awx.conf.license import get_license from awx.conf.license import get_license
@@ -17,7 +17,7 @@ from awx.main.utils.db import bulk_update_sorted_by_id
logger = logging.getLogger('awx.main.tasks.host_metrics') logger = logging.getLogger('awx.main.tasks.host_metrics')
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
def cleanup_host_metrics(): def cleanup_host_metrics():
if is_run_threshold_reached(getattr(settings, 'CLEANUP_HOST_METRICS_LAST_TS', None), getattr(settings, 'CLEANUP_HOST_METRICS_INTERVAL', 30) * 86400): if is_run_threshold_reached(getattr(settings, 'CLEANUP_HOST_METRICS_LAST_TS', None), getattr(settings, 'CLEANUP_HOST_METRICS_INTERVAL', 30) * 86400):
logger.info(f"Executing cleanup_host_metrics, last ran at {getattr(settings, 'CLEANUP_HOST_METRICS_LAST_TS', '---')}") logger.info(f"Executing cleanup_host_metrics, last ran at {getattr(settings, 'CLEANUP_HOST_METRICS_LAST_TS', '---')}")
@@ -28,7 +28,7 @@ def cleanup_host_metrics():
logger.info("Finished cleanup_host_metrics") logger.info("Finished cleanup_host_metrics")
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
def host_metric_summary_monthly(): def host_metric_summary_monthly():
"""Run cleanup host metrics summary monthly task each week""" """Run cleanup host metrics summary monthly task each week"""
if is_run_threshold_reached(getattr(settings, 'HOST_METRIC_SUMMARY_TASK_LAST_TS', None), getattr(settings, 'HOST_METRIC_SUMMARY_TASK_INTERVAL', 7) * 86400): if is_run_threshold_reached(getattr(settings, 'HOST_METRIC_SUMMARY_TASK_LAST_TS', None), getattr(settings, 'HOST_METRIC_SUMMARY_TASK_INTERVAL', 7) * 86400):

View File

@@ -36,7 +36,6 @@ from dispatcherd.publish import task
from dispatcherd.utils import serialize_task from dispatcherd.utils import serialize_task
# AWX # AWX
from awx.main.dispatch.publish import task as task_awx
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.constants import ( from awx.main.constants import (
PRIVILEGE_ESCALATION_METHODS, PRIVILEGE_ESCALATION_METHODS,
@@ -851,7 +850,7 @@ class SourceControlMixin(BaseTask):
self.release_lock(project) self.release_lock(project)
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
class RunJob(SourceControlMixin, BaseTask): class RunJob(SourceControlMixin, BaseTask):
""" """
Run a job using ansible-playbook. Run a job using ansible-playbook.
@@ -1174,7 +1173,7 @@ class RunJob(SourceControlMixin, BaseTask):
update_inventory_computed_fields.delay(inventory.id) update_inventory_computed_fields.delay(inventory.id)
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
class RunProjectUpdate(BaseTask): class RunProjectUpdate(BaseTask):
model = ProjectUpdate model = ProjectUpdate
event_model = ProjectUpdateEvent event_model = ProjectUpdateEvent
@@ -1513,7 +1512,7 @@ class RunProjectUpdate(BaseTask):
return [] return []
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
class RunInventoryUpdate(SourceControlMixin, BaseTask): class RunInventoryUpdate(SourceControlMixin, BaseTask):
model = InventoryUpdate model = InventoryUpdate
event_model = InventoryUpdateEvent event_model = InventoryUpdateEvent
@@ -1776,7 +1775,7 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask):
raise PostRunError('Error occured while saving inventory data, see traceback or server logs', status='error', tb=traceback.format_exc()) raise PostRunError('Error occured while saving inventory data, see traceback or server logs', status='error', tb=traceback.format_exc())
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
class RunAdHocCommand(BaseTask): class RunAdHocCommand(BaseTask):
""" """
Run an ad hoc command using ansible. Run an ad hoc command using ansible.
@@ -1929,7 +1928,7 @@ class RunAdHocCommand(BaseTask):
return d return d
@task_awx(queue=get_task_queuename) @task(queue=get_task_queuename)
class RunSystemJob(BaseTask): class RunSystemJob(BaseTask):
model = SystemJob model = SystemJob
event_model = SystemJobEvent event_model = SystemJobEvent

View File

@@ -20,6 +20,9 @@ import ansible_runner
# django-ansible-base # django-ansible-base
from ansible_base.lib.utils.db import advisory_lock from ansible_base.lib.utils.db import advisory_lock
# Dispatcherd
from dispatcherd.publish import task
# AWX # AWX
from awx.main.utils.execution_environments import get_default_pod_spec from awx.main.utils.execution_environments import get_default_pod_spec
from awx.main.exceptions import ReceptorNodeNotFound from awx.main.exceptions import ReceptorNodeNotFound
@@ -32,7 +35,6 @@ from awx.main.constants import MAX_ISOLATED_PATH_COLON_DELIMITER
from awx.main.tasks.signals import signal_state, signal_callback, SignalExit from awx.main.tasks.signals import signal_state, signal_callback, SignalExit
from awx.main.models import Instance, InstanceLink, UnifiedJob, ReceptorAddress from awx.main.models import Instance, InstanceLink, UnifiedJob, ReceptorAddress
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.publish import task as task_awx
# Receptorctl # Receptorctl
from receptorctl.socket_interface import ReceptorControl from receptorctl.socket_interface import ReceptorControl
@@ -852,7 +854,7 @@ def reload_receptor():
raise RuntimeError("Receptor reload failed") raise RuntimeError("Receptor reload failed")
@task_awx(on_duplicate='queue_one') @task(on_duplicate='queue_one')
def write_receptor_config(): def write_receptor_config():
""" """
This task runs async on each control node, K8S only. This task runs async on each control node, K8S only.
@@ -875,7 +877,7 @@ def write_receptor_config():
reload_receptor() reload_receptor()
@task_awx(queue=get_task_queuename, on_duplicate='discard') @task(queue=get_task_queuename, on_duplicate='discard')
def remove_deprovisioned_node(hostname): def remove_deprovisioned_node(hostname):
InstanceLink.objects.filter(source__hostname=hostname).update(link_state=InstanceLink.States.REMOVING) InstanceLink.objects.filter(source__hostname=hostname).update(link_state=InstanceLink.States.REMOVING)
InstanceLink.objects.filter(target__instance__hostname=hostname).update(link_state=InstanceLink.States.REMOVING) InstanceLink.objects.filter(target__instance__hostname=hostname).update(link_state=InstanceLink.States.REMOVING)

View File

@@ -47,6 +47,9 @@ from django.utils.translation import gettext_noop
from flags.state import flag_enabled from flags.state import flag_enabled
from rest_framework.exceptions import PermissionDenied from rest_framework.exceptions import PermissionDenied
# Dispatcherd
from dispatcherd.publish import task
# AWX # AWX
from awx import __version__ as awx_application_version from awx import __version__ as awx_application_version
from awx.conf import settings_registry from awx.conf import settings_registry
@@ -56,7 +59,6 @@ from awx.main.analytics.subsystem_metrics import DispatcherMetrics
from awx.main.constants import ACTIVE_STATES, ERROR_STATES from awx.main.constants import ACTIVE_STATES, ERROR_STATES
from awx.main.consumers import emit_channel_notification from awx.main.consumers import emit_channel_notification
from awx.main.dispatch import get_task_queuename, reaper from awx.main.dispatch import get_task_queuename, reaper
from awx.main.dispatch.publish import task as task_awx
from awx.main.models import ( from awx.main.models import (
Instance, Instance,
InstanceGroup, InstanceGroup,
@@ -131,8 +133,6 @@ def _run_dispatch_startup_common():
m.reset_values() m.reset_values()
def _dispatcherd_dispatch_startup(): def _dispatcherd_dispatch_startup():
""" """
New dispatcherd branch for startup: uses the control API to re-submit waiting jobs. New dispatcherd branch for startup: uses the control API to re-submit waiting jobs.
@@ -169,7 +169,7 @@ def inform_cluster_of_shutdown():
logger.warning("Normal shutdown processed for instance %s; instance removed from capacity pool.", inst.hostname) logger.warning("Normal shutdown processed for instance %s; instance removed from capacity pool.", inst.hostname)
@task_awx(queue=get_task_queuename, timeout=3600 * 5) @task(queue=get_task_queuename, timeout=3600 * 5)
def migrate_jsonfield(table, pkfield, columns): def migrate_jsonfield(table, pkfield, columns):
batchsize = 10000 batchsize = 10000
with advisory_lock(f'json_migration_{table}', wait=False) as acquired: with advisory_lock(f'json_migration_{table}', wait=False) as acquired:
@@ -215,7 +215,7 @@ def migrate_jsonfield(table, pkfield, columns):
logger.warning(f"Migration of {table} to jsonb is finished.") logger.warning(f"Migration of {table} to jsonb is finished.")
@task_awx(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one')
def apply_cluster_membership_policies(): def apply_cluster_membership_policies():
from awx.main.signals import disable_activity_stream from awx.main.signals import disable_activity_stream
@@ -327,7 +327,7 @@ def apply_cluster_membership_policies():
logger.debug('Cluster policy computation finished in {} seconds'.format(time.time() - started_compute)) logger.debug('Cluster policy computation finished in {} seconds'.format(time.time() - started_compute))
@task_awx(queue='tower_settings_change', timeout=600) @task(queue='tower_settings_change', timeout=600)
def clear_setting_cache(setting_keys): def clear_setting_cache(setting_keys):
# log that cache is being cleared # log that cache is being cleared
logger.info(f"clear_setting_cache of keys {setting_keys}") logger.info(f"clear_setting_cache of keys {setting_keys}")
@@ -345,7 +345,7 @@ def clear_setting_cache(setting_keys):
ctl.control('set_log_level', data={'level': settings.LOG_AGGREGATOR_LEVEL}) ctl.control('set_log_level', data={'level': settings.LOG_AGGREGATOR_LEVEL})
@task_awx(queue='tower_broadcast_all', timeout=600) @task(queue='tower_broadcast_all', timeout=600)
def delete_project_files(project_path): def delete_project_files(project_path):
# TODO: possibly implement some retry logic # TODO: possibly implement some retry logic
lock_file = project_path + '.lock' lock_file = project_path + '.lock'
@@ -363,7 +363,7 @@ def delete_project_files(project_path):
logger.exception('Could not remove lock file {}'.format(lock_file)) logger.exception('Could not remove lock file {}'.format(lock_file))
@task_awx(queue='tower_broadcast_all') @task(queue='tower_broadcast_all')
def profile_sql(threshold=1, minutes=1): def profile_sql(threshold=1, minutes=1):
if threshold <= 0: if threshold <= 0:
cache.delete('awx-profile-sql-threshold') cache.delete('awx-profile-sql-threshold')
@@ -373,7 +373,7 @@ def profile_sql(threshold=1, minutes=1):
logger.error('SQL QUERIES >={}s ENABLED FOR {} MINUTE(S)'.format(threshold, minutes)) logger.error('SQL QUERIES >={}s ENABLED FOR {} MINUTE(S)'.format(threshold, minutes))
@task_awx(queue=get_task_queuename, timeout=1800) @task(queue=get_task_queuename, timeout=1800)
def send_notifications(notification_list, job_id=None): def send_notifications(notification_list, job_id=None):
if not isinstance(notification_list, list): if not isinstance(notification_list, list):
raise TypeError("notification_list should be of type list") raise TypeError("notification_list should be of type list")
@@ -418,13 +418,13 @@ def events_processed_hook(unified_job):
save_indirect_host_entries.delay(unified_job.id) save_indirect_host_entries.delay(unified_job.id)
@task_awx(queue=get_task_queuename, timeout=3600 * 5, on_duplicate='discard') @task(queue=get_task_queuename, timeout=3600 * 5, on_duplicate='discard')
def gather_analytics(): def gather_analytics():
if is_run_threshold_reached(getattr(settings, 'AUTOMATION_ANALYTICS_LAST_GATHER', None), settings.AUTOMATION_ANALYTICS_GATHER_INTERVAL): if is_run_threshold_reached(getattr(settings, 'AUTOMATION_ANALYTICS_LAST_GATHER', None), settings.AUTOMATION_ANALYTICS_GATHER_INTERVAL):
analytics.gather() analytics.gather()
@task_awx(queue=get_task_queuename, timeout=600, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=600, on_duplicate='queue_one')
def purge_old_stdout_files(): def purge_old_stdout_files():
nowtime = time.time() nowtime = time.time()
for f in os.listdir(settings.JOBOUTPUT_ROOT): for f in os.listdir(settings.JOBOUTPUT_ROOT):
@@ -486,18 +486,18 @@ class CleanupImagesAndFiles:
cls.run_remote(this_inst, **kwargs) cls.run_remote(this_inst, **kwargs)
@task_awx(queue='tower_broadcast_all', timeout=3600) @task(queue='tower_broadcast_all', timeout=3600)
def handle_removed_image(remove_images=None): def handle_removed_image(remove_images=None):
"""Special broadcast invocation of this method to handle case of deleted EE""" """Special broadcast invocation of this method to handle case of deleted EE"""
CleanupImagesAndFiles.run(remove_images=remove_images, file_pattern='') CleanupImagesAndFiles.run(remove_images=remove_images, file_pattern='')
@task_awx(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one')
def cleanup_images_and_files(): def cleanup_images_and_files():
CleanupImagesAndFiles.run(image_prune=True) CleanupImagesAndFiles.run(image_prune=True)
@task_awx(queue=get_task_queuename, timeout=600, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=600, on_duplicate='queue_one')
def execution_node_health_check(node): def execution_node_health_check(node):
if node == '': if node == '':
logger.warning('Remote health check incorrectly called with blank string') logger.warning('Remote health check incorrectly called with blank string')
@@ -622,7 +622,7 @@ def inspect_execution_and_hop_nodes(instance_list):
execution_node_health_check.apply_async([hostname]) execution_node_health_check.apply_async([hostname])
@task_awx(queue=get_task_queuename, bind_kwargs=['dispatch_time', 'worker_tasks']) @task(queue=get_task_queuename, bind_kwargs=['dispatch_time', 'worker_tasks'])
def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None): def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None):
""" """
Original implementation for AWX dispatcher. Original implementation for AWX dispatcher.
@@ -821,7 +821,7 @@ def _heartbeat_handle_lost_instances(lost_instances, this_inst):
logger.exception('No SQL state available. Error marking {} as lost'.format(other_inst.hostname)) logger.exception('No SQL state available. Error marking {} as lost'.format(other_inst.hostname))
@task_awx(queue=get_task_queuename, timeout=1800, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=1800, on_duplicate='queue_one')
def awx_receptor_workunit_reaper(): def awx_receptor_workunit_reaper():
""" """
When an AWX job is launched via receptor, files such as status, stdin, and stdout are created When an AWX job is launched via receptor, files such as status, stdin, and stdout are created
@@ -867,7 +867,7 @@ def awx_receptor_workunit_reaper():
administrative_workunit_reaper(receptor_work_list) administrative_workunit_reaper(receptor_work_list)
@task_awx(queue=get_task_queuename, timeout=1800, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=1800, on_duplicate='queue_one')
def awx_k8s_reaper(): def awx_k8s_reaper():
if not settings.RECEPTOR_RELEASE_WORK: if not settings.RECEPTOR_RELEASE_WORK:
return return
@@ -890,7 +890,7 @@ def awx_k8s_reaper():
logger.exception("Failed to delete orphaned pod {} from {}".format(job.log_format, group)) logger.exception("Failed to delete orphaned pod {} from {}".format(job.log_format, group))
@task_awx(queue=get_task_queuename, timeout=3600 * 5, on_duplicate='discard') @task(queue=get_task_queuename, timeout=3600 * 5, on_duplicate='discard')
def awx_periodic_scheduler(): def awx_periodic_scheduler():
lock_session_timeout_milliseconds = settings.TASK_MANAGER_LOCK_TIMEOUT * 1000 lock_session_timeout_milliseconds = settings.TASK_MANAGER_LOCK_TIMEOUT * 1000
with advisory_lock('awx_periodic_scheduler_lock', lock_session_timeout_milliseconds=lock_session_timeout_milliseconds, wait=False) as acquired: with advisory_lock('awx_periodic_scheduler_lock', lock_session_timeout_milliseconds=lock_session_timeout_milliseconds, wait=False) as acquired:
@@ -947,7 +947,7 @@ def awx_periodic_scheduler():
emit_channel_notification('schedules-changed', dict(id=schedule.id, group_name="schedules")) emit_channel_notification('schedules-changed', dict(id=schedule.id, group_name="schedules"))
@task_awx(queue=get_task_queuename, timeout=3600) @task(queue=get_task_queuename, timeout=3600)
def handle_failure_notifications(task_ids): def handle_failure_notifications(task_ids):
"""A task-ified version of the method that sends notifications.""" """A task-ified version of the method that sends notifications."""
found_task_ids = set() found_task_ids = set()
@@ -962,7 +962,7 @@ def handle_failure_notifications(task_ids):
logger.warning(f'Could not send notifications for {deleted_tasks} because they were not found in the database') logger.warning(f'Could not send notifications for {deleted_tasks} because they were not found in the database')
@task_awx(queue=get_task_queuename, timeout=3600 * 5) @task(queue=get_task_queuename, timeout=3600 * 5)
def update_inventory_computed_fields(inventory_id): def update_inventory_computed_fields(inventory_id):
""" """
Signal handler and wrapper around inventory.update_computed_fields to Signal handler and wrapper around inventory.update_computed_fields to
@@ -1012,7 +1012,7 @@ def update_smart_memberships_for_inventory(smart_inventory):
return False return False
@task_awx(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one') @task(queue=get_task_queuename, timeout=3600, on_duplicate='queue_one')
def update_host_smart_inventory_memberships(): def update_host_smart_inventory_memberships():
smart_inventories = Inventory.objects.filter(kind='smart', host_filter__isnull=False, pending_deletion=False) smart_inventories = Inventory.objects.filter(kind='smart', host_filter__isnull=False, pending_deletion=False)
changed_inventories = set([]) changed_inventories = set([])
@@ -1028,7 +1028,7 @@ def update_host_smart_inventory_memberships():
smart_inventory.update_computed_fields() smart_inventory.update_computed_fields()
@task_awx(queue=get_task_queuename, timeout=3600 * 5) @task(queue=get_task_queuename, timeout=3600 * 5)
def delete_inventory(inventory_id, user_id, retries=5): def delete_inventory(inventory_id, user_id, retries=5):
# Delete inventory as user # Delete inventory as user
if user_id is None: if user_id is None:
@@ -1090,7 +1090,7 @@ def _reconstruct_relationships(copy_mapping):
new_obj.save() new_obj.save()
@task_awx(queue=get_task_queuename, timeout=600) @task(queue=get_task_queuename, timeout=600)
def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, permission_check_func=None): def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, permission_check_func=None):
logger.debug('Deep copy {} from {} to {}.'.format(model_name, obj_pk, new_obj_pk)) logger.debug('Deep copy {} from {} to {}.'.format(model_name, obj_pk, new_obj_pk))
@@ -1145,7 +1145,7 @@ def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, p
update_inventory_computed_fields.delay(new_obj.id) update_inventory_computed_fields.delay(new_obj.id)
@task_awx(queue=get_task_queuename, timeout=3600, on_duplicate='discard') @task(queue=get_task_queuename, timeout=3600, on_duplicate='discard')
def periodic_resource_sync(): def periodic_resource_sync():
if not getattr(settings, 'RESOURCE_SERVER', None): if not getattr(settings, 'RESOURCE_SERVER', None):
logger.debug("Skipping periodic resource_sync, RESOURCE_SERVER not configured") logger.debug("Skipping periodic resource_sync, RESOURCE_SERVER not configured")

View File

@@ -6,14 +6,13 @@ from dispatcherd.publish import task
from django.db import connection from django.db import connection
from awx.main.dispatch import get_task_queuename from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.publish import task as old_task
from ansible_base.lib.utils.db import advisory_lock from ansible_base.lib.utils.db import advisory_lock
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@old_task(queue=get_task_queuename) @task(queue=get_task_queuename)
def sleep_task(seconds=10, log=False): def sleep_task(seconds=10, log=False):
if log: if log:
logger.info('starting sleep_task') logger.info('starting sleep_task')

View File

@@ -4,9 +4,9 @@ import tempfile
import urllib.parse as urlparse import urllib.parse as urlparse
from django.conf import settings from django.conf import settings
from dispatcherd.publish import task
from awx.main.utils.reload import supervisor_service_command from awx.main.utils.reload import supervisor_service_command
from awx.main.dispatch.publish import task as task_awx
def construct_rsyslog_conf_template(settings=settings): def construct_rsyslog_conf_template(settings=settings):
@@ -139,7 +139,7 @@ def construct_rsyslog_conf_template(settings=settings):
return tmpl return tmpl
@task_awx(queue='rsyslog_configurer', timeout=600, on_duplicate='queue_one') @task(queue='rsyslog_configurer', timeout=600, on_duplicate='queue_one')
def reconfigure_rsyslog(): def reconfigure_rsyslog():
tmpl = construct_rsyslog_conf_template() tmpl = construct_rsyslog_conf_template()
# Write config to a temp file then move it to preserve atomicity # Write config to a temp file then move it to preserve atomicity

View File

@@ -20,19 +20,18 @@ In this document, we will go into a bit of detail about how and when AWX runs Py
- Every node in an AWX cluster runs a periodic task that serves as - Every node in an AWX cluster runs a periodic task that serves as
a heartbeat and capacity check a heartbeat and capacity check
Transition to dispatcherd Library dispatcherd Library
--------------------------------- -------------------
The task system logic is being split out into a new library: The task system logic has been split out into a separate library:
https://github.com/ansible/dispatcherd https://github.com/ansible/dispatcherd
Currently AWX is in a transitionary period where this is put behind a feature flag. AWX now uses dispatcherd directly for all task management. Tasks are decorated using:
The difference can be seen in how the task decorator is imported.
- old `from awx.main.dispatch.publish import task` ```python
- transition `from awx.main.dispatch.publish import task as task_awx` from dispatcherd.publish import task
- new `from dispatcherd.publish import task` ```
Tasks, Queues and Workers Tasks, Queues and Workers
@@ -74,7 +73,7 @@ Defining and Running Tasks
Tasks are defined in AWX's source code, and generally live in the Tasks are defined in AWX's source code, and generally live in the
`awx.main.tasks` module. Tasks can be defined as simple functions: `awx.main.tasks` module. Tasks can be defined as simple functions:
from awx.main.dispatch.publish import task as task_awx from dispatcherd.publish import task
@task() @task()
def add(a, b): def add(a, b):