add get_task_queuename

get_local_queuename will return the pod name of the instance

now that web and task are in different pods when web container queue a task it will be put into a queue without as task worker to execute the task
This commit is contained in:
Hao Liu 2023-03-22 15:31:08 -04:00
parent 049fb4eff5
commit cd3f7666be
10 changed files with 48 additions and 39 deletions

View File

@ -4,11 +4,11 @@ import logging
# AWX
from awx.main.analytics.subsystem_metrics import Metrics
from awx.main.dispatch.publish import task
from awx.main.dispatch import get_local_queuename
from awx.main.dispatch import get_task_queuename
logger = logging.getLogger('awx.main.scheduler')
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def send_subsystem_metrics():
Metrics().send_metrics()

View File

@ -5,6 +5,7 @@ from contextlib import contextmanager
from django.conf import settings
from django.db import connection as pg_connection
import os
NOT_READY = ([], [], [])
@ -14,6 +15,15 @@ def get_local_queuename():
return settings.CLUSTER_HOST_ID
def get_task_queuename():
if os.getenv('AWX_COMPONENT') == 'web':
from awx.main.models.ha import Instance
return Instance.objects.filter(node_type__in=['control', 'hybrid']).order_by('?').first().hostname
else:
return settings.CLUSTER_HOST_ID
class PubSub(object):
def __init__(self, conn):
self.conn = conn

View File

@ -6,7 +6,7 @@ from django.conf import settings
from django.db import connection
import redis
from awx.main.dispatch import get_local_queuename
from awx.main.dispatch import get_task_queuename
from . import pg_bus_conn
@ -21,7 +21,7 @@ class Control(object):
if service not in self.services:
raise RuntimeError('{} must be in {}'.format(service, self.services))
self.service = service
self.queuename = host or get_local_queuename()
self.queuename = host or get_task_queuename()
def status(self, *args, **kwargs):
r = redis.Redis.from_url(settings.BROKER_URL)

View File

@ -8,7 +8,7 @@ from django.core.cache import cache as django_cache
from django.core.management.base import BaseCommand
from django.db import connection as django_connection
from awx.main.dispatch import get_local_queuename
from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.control import Control
from awx.main.dispatch.pool import AutoscalePool
from awx.main.dispatch.worker import AWXConsumerPG, TaskWorker
@ -76,7 +76,7 @@ class Command(BaseCommand):
consumer = None
try:
queues = ['tower_broadcast_all', 'tower_settings_change', get_local_queuename()]
queues = ['tower_broadcast_all', 'tower_settings_change', 'rsyslog_configurer', get_task_queuename()]
consumer = AWXConsumerPG('dispatcher', TaskWorker(), queues, AutoscalePool(min_workers=4))
consumer.run()
except KeyboardInterrupt:

View File

@ -32,7 +32,7 @@ from polymorphic.models import PolymorphicModel
# AWX
from awx.main.models.base import CommonModelNameNotUnique, PasswordFieldsModel, NotificationFieldsModel, prevent_search
from awx.main.dispatch import get_local_queuename
from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.control import Control as ControlDispatcher
from awx.main.registrar import activity_stream_registrar
from awx.main.models.mixins import ResourceMixin, TaskManagerUnifiedJobMixin, ExecutionEnvironmentMixin
@ -1567,7 +1567,7 @@ class UnifiedJob(
return r
def get_queue_name(self):
return self.controller_node or self.execution_node or get_local_queuename()
return self.controller_node or self.execution_node or get_task_queuename()
@property
def is_container_group_task(self):

View File

@ -8,7 +8,7 @@ from django.conf import settings
from awx import MODE
from awx.main.scheduler import TaskManager, DependencyManager, WorkflowManager
from awx.main.dispatch.publish import task
from awx.main.dispatch import get_local_queuename
from awx.main.dispatch import get_task_queuename
logger = logging.getLogger('awx.main.scheduler')
@ -20,16 +20,16 @@ def run_manager(manager, prefix):
manager().schedule()
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def task_manager():
run_manager(TaskManager, "task")
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def dependency_manager():
run_manager(DependencyManager, "dependency")
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def workflow_manager():
run_manager(WorkflowManager, "workflow")

View File

@ -29,7 +29,7 @@ from gitdb.exc import BadName as BadGitName
# AWX
from awx.main.dispatch.publish import task
from awx.main.dispatch import get_local_queuename
from awx.main.dispatch import get_task_queuename
from awx.main.constants import (
PRIVILEGE_ESCALATION_METHODS,
STANDARD_INVENTORY_UPDATE_ENV,
@ -806,7 +806,7 @@ class SourceControlMixin(BaseTask):
self.release_lock(project)
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
class RunJob(SourceControlMixin, BaseTask):
"""
Run a job using ansible-playbook.
@ -1121,7 +1121,7 @@ class RunJob(SourceControlMixin, BaseTask):
update_inventory_computed_fields.delay(inventory.id)
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
class RunProjectUpdate(BaseTask):
model = ProjectUpdate
event_model = ProjectUpdateEvent
@ -1443,7 +1443,7 @@ class RunProjectUpdate(BaseTask):
return params
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
class RunInventoryUpdate(SourceControlMixin, BaseTask):
model = InventoryUpdate
event_model = InventoryUpdateEvent
@ -1706,7 +1706,7 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask):
raise PostRunError('Error occured while saving inventory data, see traceback or server logs', status='error', tb=traceback.format_exc())
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
class RunAdHocCommand(BaseTask):
"""
Run an ad hoc command using ansible.
@ -1859,7 +1859,7 @@ class RunAdHocCommand(BaseTask):
return d
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
class RunSystemJob(BaseTask):
model = SystemJob
event_model = SystemJobEvent

View File

@ -28,7 +28,7 @@ from awx.main.utils.common import (
from awx.main.constants import MAX_ISOLATED_PATH_COLON_DELIMITER
from awx.main.tasks.signals import signal_state, signal_callback, SignalExit
from awx.main.models import Instance, InstanceLink, UnifiedJob
from awx.main.dispatch import get_local_queuename
from awx.main.dispatch import get_task_queuename
from awx.main.dispatch.publish import task
# Receptorctl
@ -713,7 +713,7 @@ def write_receptor_config():
links.update(link_state=InstanceLink.States.ESTABLISHED)
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def remove_deprovisioned_node(hostname):
InstanceLink.objects.filter(source__hostname=hostname).update(link_state=InstanceLink.States.REMOVING)
InstanceLink.objects.filter(target__hostname=hostname).update(link_state=InstanceLink.States.REMOVING)

View File

@ -50,7 +50,7 @@ from awx.main.models import (
)
from awx.main.constants import ACTIVE_STATES
from awx.main.dispatch.publish import task
from awx.main.dispatch import get_local_queuename, reaper
from awx.main.dispatch import get_task_queuename, reaper
from awx.main.utils.common import (
get_type_for_model,
ignore_inventory_computed_fields,
@ -129,7 +129,7 @@ def inform_cluster_of_shutdown():
logger.exception('Encountered problem with normal shutdown signal.')
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def apply_cluster_membership_policies():
from awx.main.signals import disable_activity_stream
@ -282,7 +282,7 @@ def profile_sql(threshold=1, minutes=1):
logger.error('SQL QUERIES >={}s ENABLED FOR {} MINUTE(S)'.format(threshold, minutes))
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def send_notifications(notification_list, job_id=None):
if not isinstance(notification_list, list):
raise TypeError("notification_list should be of type list")
@ -313,7 +313,7 @@ def send_notifications(notification_list, job_id=None):
logger.exception('Error saving notification {} result.'.format(notification.id))
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def gather_analytics():
from awx.conf.models import Setting
from rest_framework.fields import DateTimeField
@ -326,7 +326,7 @@ def gather_analytics():
analytics.gather()
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def purge_old_stdout_files():
nowtime = time.time()
for f in os.listdir(settings.JOBOUTPUT_ROOT):
@ -374,12 +374,12 @@ def handle_removed_image(remove_images=None):
_cleanup_images_and_files(remove_images=remove_images, file_pattern='')
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def cleanup_images_and_files():
_cleanup_images_and_files()
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def cluster_node_health_check(node):
"""
Used for the health check endpoint, refreshes the status of the instance, but must be ran on target node
@ -398,7 +398,7 @@ def cluster_node_health_check(node):
this_inst.local_health_check()
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def execution_node_health_check(node):
if node == '':
logger.warning('Remote health check incorrectly called with blank string')
@ -492,7 +492,7 @@ def inspect_execution_nodes(instance_list):
execution_node_health_check.apply_async([hostname])
@task(queue=get_local_queuename, bind_kwargs=['dispatch_time', 'worker_tasks'])
@task(queue=get_task_queuename, bind_kwargs=['dispatch_time', 'worker_tasks'])
def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None):
logger.debug("Cluster node heartbeat task.")
nowtime = now()
@ -582,7 +582,7 @@ def cluster_node_heartbeat(dispatch_time=None, worker_tasks=None):
reaper.reap_waiting(instance=this_inst, excluded_uuids=active_task_ids, ref_time=datetime.fromisoformat(dispatch_time))
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def awx_receptor_workunit_reaper():
"""
When an AWX job is launched via receptor, files such as status, stdin, and stdout are created
@ -618,7 +618,7 @@ def awx_receptor_workunit_reaper():
administrative_workunit_reaper(receptor_work_list)
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def awx_k8s_reaper():
if not settings.RECEPTOR_RELEASE_WORK:
return
@ -638,7 +638,7 @@ def awx_k8s_reaper():
logger.exception("Failed to delete orphaned pod {} from {}".format(job.log_format, group))
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def awx_periodic_scheduler():
with advisory_lock('awx_periodic_scheduler_lock', wait=False) as acquired:
if acquired is False:
@ -704,7 +704,7 @@ def schedule_manager_success_or_error(instance):
ScheduleWorkflowManager().schedule()
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def handle_work_success(task_actual):
try:
instance = UnifiedJob.get_instance_by_type(task_actual['type'], task_actual['id'])
@ -716,7 +716,7 @@ def handle_work_success(task_actual):
schedule_manager_success_or_error(instance)
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def handle_work_error(task_actual):
try:
instance = UnifiedJob.get_instance_by_type(task_actual['type'], task_actual['id'])
@ -756,7 +756,7 @@ def handle_work_error(task_actual):
schedule_manager_success_or_error(instance)
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def update_inventory_computed_fields(inventory_id):
"""
Signal handler and wrapper around inventory.update_computed_fields to
@ -797,7 +797,7 @@ def update_smart_memberships_for_inventory(smart_inventory):
return False
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def update_host_smart_inventory_memberships():
smart_inventories = Inventory.objects.filter(kind='smart', host_filter__isnull=False, pending_deletion=False)
changed_inventories = set([])
@ -813,7 +813,7 @@ def update_host_smart_inventory_memberships():
smart_inventory.update_computed_fields()
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def delete_inventory(inventory_id, user_id, retries=5):
# Delete inventory as user
if user_id is None:
@ -878,7 +878,7 @@ def _reconstruct_relationships(copy_mapping):
new_obj.save()
@task(queue=get_local_queuename)
@task(queue=get_task_queuename)
def deep_copy_model_obj(model_module, model_name, obj_pk, new_obj_pk, user_pk, uuid, permission_check_func=None):
sub_obj_list = cache.get(uuid)
if sub_obj_list is None:

View File

@ -201,7 +201,6 @@ class WebsocketRelayConnection:
class WebSocketRelayManager(object):
def __init__(self):
self.local_hostname = get_local_host()
self.relay_connections = dict()
# hostname -> ip