add get_task_queuename

get_local_queuename will return the pod name of the instance

now that web and task are in different pods when web container queue a task it will be put into a queue without as task worker to execute the task
This commit is contained in:
Hao Liu
2023-03-22 15:31:08 -04:00
parent 049fb4eff5
commit cd3f7666be
10 changed files with 48 additions and 39 deletions

View File

@@ -4,11 +4,11 @@ import logging
# AWX # AWX
from awx.main.analytics.subsystem_metrics import Metrics from awx.main.analytics.subsystem_metrics import Metrics
from awx.main.dispatch.publish import task from awx.main.dispatch.publish import task
from awx.main.dispatch import get_local_queuename from awx.main.dispatch import get_task_queuename
logger = logging.getLogger('awx.main.scheduler') logger = logging.getLogger('awx.main.scheduler')
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def send_subsystem_metrics(): def send_subsystem_metrics():
Metrics().send_metrics() Metrics().send_metrics()

View File

@@ -5,6 +5,7 @@ from contextlib import contextmanager
from django.conf import settings from django.conf import settings
from django.db import connection as pg_connection from django.db import connection as pg_connection
import os
NOT_READY = ([], [], []) NOT_READY = ([], [], [])
@@ -14,6 +15,15 @@ def get_local_queuename():
return settings.CLUSTER_HOST_ID return settings.CLUSTER_HOST_ID
def get_task_queuename():
if os.getenv('AWX_COMPONENT') == 'web':
from awx.main.models.ha import Instance
return Instance.objects.filter(node_type__in=['control', 'hybrid']).order_by('?').first().hostname
else:
return settings.CLUSTER_HOST_ID
class PubSub(object): class PubSub(object):
def __init__(self, conn): def __init__(self, conn):
self.conn = conn self.conn = conn

View File

@@ -6,7 +6,7 @@ from django.conf import settings
from django.db import connection from django.db import connection
import redis import redis
from awx.main.dispatch import get_local_queuename from awx.main.dispatch import get_task_queuename
from . import pg_bus_conn from . import pg_bus_conn
@@ -21,7 +21,7 @@ class Control(object):
if service not in self.services: if service not in self.services:
raise RuntimeError('{} must be in {}'.format(service, self.services)) raise RuntimeError('{} must be in {}'.format(service, self.services))
self.service = service self.service = service
self.queuename = host or get_local_queuename() self.queuename = host or get_task_queuename()
def status(self, *args, **kwargs): def status(self, *args, **kwargs):
r = redis.Redis.from_url(settings.BROKER_URL) r = redis.Redis.from_url(settings.BROKER_URL)

View File

@@ -8,7 +8,7 @@ from django.core.cache import cache as django_cache
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.db import connection as django_connection from django.db import connection as django_connection
from awx.main.dispatch import get_local_queuename from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.control import Control from awx.main.dispatch.control import Control
from awx.main.dispatch.pool import AutoscalePool from awx.main.dispatch.pool import AutoscalePool
from awx.main.dispatch.worker import AWXConsumerPG, TaskWorker from awx.main.dispatch.worker import AWXConsumerPG, TaskWorker
@@ -76,7 +76,7 @@ class Command(BaseCommand):
consumer = None consumer = None
try: try:
queues = ['tower_broadcast_all', 'tower_settings_change', get_local_queuename()] queues = ['tower_broadcast_all', 'tower_settings_change', 'rsyslog_configurer', get_task_queuename()]
consumer = AWXConsumerPG('dispatcher', TaskWorker(), queues, AutoscalePool(min_workers=4)) consumer = AWXConsumerPG('dispatcher', TaskWorker(), queues, AutoscalePool(min_workers=4))
consumer.run() consumer.run()
except KeyboardInterrupt: except KeyboardInterrupt:

View File

@@ -32,7 +32,7 @@ from polymorphic.models import PolymorphicModel
# AWX # AWX
from awx.main.models.base import CommonModelNameNotUnique, PasswordFieldsModel, NotificationFieldsModel, prevent_search from awx.main.models.base import CommonModelNameNotUnique, PasswordFieldsModel, NotificationFieldsModel, prevent_search
from awx.main.dispatch import get_local_queuename from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.control import Control as ControlDispatcher from awx.main.dispatch.control import Control as ControlDispatcher
from awx.main.registrar import activity_stream_registrar from awx.main.registrar import activity_stream_registrar
from awx.main.models.mixins import ResourceMixin, TaskManagerUnifiedJobMixin, ExecutionEnvironmentMixin from awx.main.models.mixins import ResourceMixin, TaskManagerUnifiedJobMixin, ExecutionEnvironmentMixin
@@ -1567,7 +1567,7 @@ class UnifiedJob(
return r return r
def get_queue_name(self): def get_queue_name(self):
return self.controller_node or self.execution_node or get_local_queuename() return self.controller_node or self.execution_node or get_task_queuename()
@property @property
def is_container_group_task(self): def is_container_group_task(self):

View File

@@ -8,7 +8,7 @@ from django.conf import settings
from awx import MODE from awx import MODE
from awx.main.scheduler import TaskManager, DependencyManager, WorkflowManager from awx.main.scheduler import TaskManager, DependencyManager, WorkflowManager
from awx.main.dispatch.publish import task from awx.main.dispatch.publish import task
from awx.main.dispatch import get_local_queuename from awx.main.dispatch import get_task_queuename
logger = logging.getLogger('awx.main.scheduler') logger = logging.getLogger('awx.main.scheduler')
@@ -20,16 +20,16 @@ def run_manager(manager, prefix):
manager().schedule() manager().schedule()
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def task_manager(): def task_manager():
run_manager(TaskManager, "task") run_manager(TaskManager, "task")
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def dependency_manager(): def dependency_manager():
run_manager(DependencyManager, "dependency") run_manager(DependencyManager, "dependency")
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def workflow_manager(): def workflow_manager():
run_manager(WorkflowManager, "workflow") run_manager(WorkflowManager, "workflow")

View File

@@ -29,7 +29,7 @@ from gitdb.exc import BadName as BadGitName
# AWX # AWX
from awx.main.dispatch.publish import task from awx.main.dispatch.publish import task
from awx.main.dispatch import get_local_queuename from awx.main.dispatch import get_task_queuename
from awx.main.constants import ( from awx.main.constants import (
PRIVILEGE_ESCALATION_METHODS, PRIVILEGE_ESCALATION_METHODS,
STANDARD_INVENTORY_UPDATE_ENV, STANDARD_INVENTORY_UPDATE_ENV,
@@ -806,7 +806,7 @@ class SourceControlMixin(BaseTask):
self.release_lock(project) self.release_lock(project)
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
class RunJob(SourceControlMixin, BaseTask): class RunJob(SourceControlMixin, BaseTask):
""" """
Run a job using ansible-playbook. Run a job using ansible-playbook.
@@ -1121,7 +1121,7 @@ class RunJob(SourceControlMixin, BaseTask):
update_inventory_computed_fields.delay(inventory.id) update_inventory_computed_fields.delay(inventory.id)
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
class RunProjectUpdate(BaseTask): class RunProjectUpdate(BaseTask):
model = ProjectUpdate model = ProjectUpdate
event_model = ProjectUpdateEvent event_model = ProjectUpdateEvent
@@ -1443,7 +1443,7 @@ class RunProjectUpdate(BaseTask):
return params return params
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
class RunInventoryUpdate(SourceControlMixin, BaseTask): class RunInventoryUpdate(SourceControlMixin, BaseTask):
model = InventoryUpdate model = InventoryUpdate
event_model = InventoryUpdateEvent event_model = InventoryUpdateEvent
@@ -1706,7 +1706,7 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask):
raise PostRunError('Error occured while saving inventory data, see traceback or server logs', status='error', tb=traceback.format_exc()) raise PostRunError('Error occured while saving inventory data, see traceback or server logs', status='error', tb=traceback.format_exc())
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
class RunAdHocCommand(BaseTask): class RunAdHocCommand(BaseTask):
""" """
Run an ad hoc command using ansible. Run an ad hoc command using ansible.
@@ -1859,7 +1859,7 @@ class RunAdHocCommand(BaseTask):
return d return d
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
class RunSystemJob(BaseTask): class RunSystemJob(BaseTask):
model = SystemJob model = SystemJob
event_model = SystemJobEvent event_model = SystemJobEvent

View File

@@ -28,7 +28,7 @@ from awx.main.utils.common import (
from awx.main.constants import MAX_ISOLATED_PATH_COLON_DELIMITER from awx.main.constants import MAX_ISOLATED_PATH_COLON_DELIMITER
from awx.main.tasks.signals import signal_state, signal_callback, SignalExit from awx.main.tasks.signals import signal_state, signal_callback, SignalExit
from awx.main.models import Instance, InstanceLink, UnifiedJob from awx.main.models import Instance, InstanceLink, UnifiedJob
from awx.main.dispatch import get_local_queuename from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.publish import task from awx.main.dispatch.publish import task
# Receptorctl # Receptorctl
@@ -713,7 +713,7 @@ def write_receptor_config():
links.update(link_state=InstanceLink.States.ESTABLISHED) links.update(link_state=InstanceLink.States.ESTABLISHED)
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def remove_deprovisioned_node(hostname): def remove_deprovisioned_node(hostname):
InstanceLink.objects.filter(source__hostname=hostname).update(link_state=InstanceLink.States.REMOVING) InstanceLink.objects.filter(source__hostname=hostname).update(link_state=InstanceLink.States.REMOVING)
InstanceLink.objects.filter(target__hostname=hostname).update(link_state=InstanceLink.States.REMOVING) InstanceLink.objects.filter(target__hostname=hostname).update(link_state=InstanceLink.States.REMOVING)

View File

@@ -50,7 +50,7 @@ from awx.main.models import (
) )
from awx.main.constants import ACTIVE_STATES from awx.main.constants import ACTIVE_STATES
from awx.main.dispatch.publish import task from awx.main.dispatch.publish import task
from awx.main.dispatch import get_local_queuename, reaper from awx.main.dispatch import get_task_queuename, reaper
from awx.main.utils.common import ( from awx.main.utils.common import (
get_type_for_model, get_type_for_model,
ignore_inventory_computed_fields, ignore_inventory_computed_fields,
@@ -129,7 +129,7 @@ def inform_cluster_of_shutdown():
logger.exception('Encountered problem with normal shutdown signal.') logger.exception('Encountered problem with normal shutdown signal.')
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def apply_cluster_membership_policies(): def apply_cluster_membership_policies():
from awx.main.signals import disable_activity_stream from awx.main.signals import disable_activity_stream
@@ -282,7 +282,7 @@ def profile_sql(threshold=1, minutes=1):
logger.error('SQL QUERIES >={}s ENABLED FOR {} MINUTE(S)'.format(threshold, minutes)) logger.error('SQL QUERIES >={}s ENABLED FOR {} MINUTE(S)'.format(threshold, minutes))
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def send_notifications(notification_list, job_id=None): def send_notifications(notification_list, job_id=None):
if not isinstance(notification_list, list): if not isinstance(notification_list, list):
raise TypeError("notification_list should be of type list") raise TypeError("notification_list should be of type list")
@@ -313,7 +313,7 @@ def send_notifications(notification_list, job_id=None):
logger.exception('Error saving notification {} result.'.format(notification.id)) logger.exception('Error saving notification {} result.'.format(notification.id))
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def gather_analytics(): def gather_analytics():
from awx.conf.models import Setting from awx.conf.models import Setting
from rest_framework.fields import DateTimeField from rest_framework.fields import DateTimeField
@@ -326,7 +326,7 @@ def gather_analytics():
analytics.gather() analytics.gather()
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def purge_old_stdout_files(): def purge_old_stdout_files():
nowtime = time.time() nowtime = time.time()
for f in os.listdir(settings.JOBOUTPUT_ROOT): for f in os.listdir(settings.JOBOUTPUT_ROOT):
@@ -374,12 +374,12 @@ def handle_removed_image(remove_images=None):
_cleanup_images_and_files(remove_images=remove_images, file_pattern='') _cleanup_images_and_files(remove_images=remove_images, file_pattern='')
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def cleanup_images_and_files(): def cleanup_images_and_files():
_cleanup_images_and_files() _cleanup_images_and_files()
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def cluster_node_health_check(node): def cluster_node_health_check(node):
""" """
Used for the health check endpoint, refreshes the status of the instance, but must be ran on target node Used for the health check endpoint, refreshes the status of the instance, but must be ran on target node
@@ -398,7 +398,7 @@ def cluster_node_health_check(node):
this_inst.local_health_check() this_inst.local_health_check()
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def execution_node_health_check(node): def execution_node_health_check(node):
if node == '': if node == '':
logger.warning('Remote health check incorrectly called with blank string') logger.warning('Remote health check incorrectly called with blank string')
@@ -492,7 +492,7 @@ def inspect_execution_nodes(instance_list):
execution_node_health_check.apply_async([hostname]) execution_node_health_check.apply_async([hostname])
@task(queue=get_local_queuename, bind_kwargs=['dispatch_time', 'worker_tasks']) @task(queue=get_task_queuename, bind_kwargs=['dispatch_time', 'worker_tasks'])
def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None): def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None):
logger.debug("Cluster node heartbeat task.") logger.debug("Cluster node heartbeat task.")
nowtime = now() nowtime = now()
@@ -582,7 +582,7 @@ def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None):
reaper.reap_waiting(instance=this_inst, excluded_uuids=active_task_ids, ref_time=datetime.fromisoformat(dispatch_time)) reaper.reap_waiting(instance=this_inst, excluded_uuids=active_task_ids, ref_time=datetime.fromisoformat(dispatch_time))
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def awx_receptor_workunit_reaper(): def awx_receptor_workunit_reaper():
""" """
When an AWX job is launched via receptor, files such as status, stdin, and stdout are created When an AWX job is launched via receptor, files such as status, stdin, and stdout are created
@@ -618,7 +618,7 @@ def awx_receptor_workunit_reaper():
administrative_workunit_reaper(receptor_work_list) administrative_workunit_reaper(receptor_work_list)
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def awx_k8s_reaper(): def awx_k8s_reaper():
if not settings.RECEPTOR_RELEASE_WORK: if not settings.RECEPTOR_RELEASE_WORK:
return return
@@ -638,7 +638,7 @@ def awx_k8s_reaper():
logger.exception("Failed to delete orphaned pod {} from {}".format(job.log_format, group)) logger.exception("Failed to delete orphaned pod {} from {}".format(job.log_format, group))
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def awx_periodic_scheduler(): def awx_periodic_scheduler():
with advisory_lock('awx_periodic_scheduler_lock', wait=False) as acquired: with advisory_lock('awx_periodic_scheduler_lock', wait=False) as acquired:
if acquired is False: if acquired is False:
@@ -704,7 +704,7 @@ def schedule_manager_success_or_error(instance):
ScheduleWorkflowManager().schedule() ScheduleWorkflowManager().schedule()
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def handle_work_success(task_actual): def handle_work_success(task_actual):
try: try:
instance = UnifiedJob.get_instance_by_type(task_actual['type'], task_actual['id']) instance = UnifiedJob.get_instance_by_type(task_actual['type'], task_actual['id'])
@@ -716,7 +716,7 @@ def handle_work_success(task_actual):
schedule_manager_success_or_error(instance) schedule_manager_success_or_error(instance)
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def handle_work_error(task_actual): def handle_work_error(task_actual):
try: try:
instance = UnifiedJob.get_instance_by_type(task_actual['type'], task_actual['id']) instance = UnifiedJob.get_instance_by_type(task_actual['type'], task_actual['id'])
@@ -756,7 +756,7 @@ def handle_work_error(task_actual):
schedule_manager_success_or_error(instance) schedule_manager_success_or_error(instance)
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def update_inventory_computed_fields(inventory_id): def update_inventory_computed_fields(inventory_id):
""" """
Signal handler and wrapper around inventory.update_computed_fields to Signal handler and wrapper around inventory.update_computed_fields to
@@ -797,7 +797,7 @@ def update_smart_memberships_for_inventory(smart_inventory):
return False return False
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def update_host_smart_inventory_memberships(): def update_host_smart_inventory_memberships():
smart_inventories = Inventory.objects.filter(kind='smart', host_filter__isnull=False, pending_deletion=False) smart_inventories = Inventory.objects.filter(kind='smart', host_filter__isnull=False, pending_deletion=False)
changed_inventories = set([]) changed_inventories = set([])
@@ -813,7 +813,7 @@ def update_host_smart_inventory_memberships():
smart_inventory.update_computed_fields() smart_inventory.update_computed_fields()
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def delete_inventory(inventory_id, user_id, retries=5): def delete_inventory(inventory_id, user_id, retries=5):
# Delete inventory as user # Delete inventory as user
if user_id is None: if user_id is None:
@@ -878,7 +878,7 @@ def _reconstruct_relationships(copy_mapping):
new_obj.save() new_obj.save()
@task(queue=get_local_queuename) @task(queue=get_task_queuename)
def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, uuid, permission_check_func=None): def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, uuid, permission_check_func=None):
sub_obj_list = cache.get(uuid) sub_obj_list = cache.get(uuid)
if sub_obj_list is None: if sub_obj_list is None:

View File

@@ -201,7 +201,6 @@ class WebsocketRelayConnection:
class WebSocketRelayManager(object): class WebSocketRelayManager(object):
def __init__(self): def __init__(self):
self.local_hostname = get_local_host() self.local_hostname = get_local_host()
self.relay_connections = dict() self.relay_connections = dict()
# hostname -> ip # hostname -> ip