diff --git a/awx/main/dispatch/publish.py b/awx/main/dispatch/publish.py index be64594ee3..cf94572e75 100644 --- a/awx/main/dispatch/publish.py +++ b/awx/main/dispatch/publish.py @@ -1,10 +1,14 @@ import inspect import logging import sys +import json from uuid import uuid4 +import psycopg2 from django.conf import settings -from kombu import Exchange, Producer, Connection, Queue, Consumer +from kombu import Exchange, Producer +from django.db import connection +from pgpubsub import PubSub logger = logging.getLogger('awx.main.dispatch') @@ -85,26 +89,15 @@ class task: if callable(queue): queue = queue() if not settings.IS_TESTING(sys.argv): - with Connection(settings.BROKER_URL, transport_options=settings.BROKER_TRANSPORT_OPTIONS) as conn: - exchange = Exchange(queue, type=exchange_type or 'direct') - - # HACK: With Redis as the broker declaring an exchange isn't enough to create the queue - # Creating a Consumer _will_ create a queue so that publish will succeed. Note that we - # don't call consume() on the consumer so we don't actually eat any messages - Consumer(conn, queues=[Queue(queue, exchange, routing_key=queue)], accept=['json']) - producer = Producer(conn) - logger.debug('publish {}({}, queue={})'.format( - cls.name, - task_id, - queue - )) - producer.publish(obj, - serializer='json', - compression='bzip2', - exchange=exchange, - declare=[exchange], - delivery_mode="persistent", - routing_key=queue) + conf = settings.DATABASES['default'] + conn = psycopg2.connect(dbname=conf['NAME'], + host=conf['HOST'], + user=conf['USER'], + password=conf['PASSWORD']) + conn.set_session(autocommit=True) + logger.warn(f"Send message to queue {queue}") + pubsub = PubSub(conn) + pubsub.notify(queue, json.dumps(obj)) return (obj, queue) # If the object we're wrapping *is* a class (e.g., RunJob), return diff --git a/awx/main/dispatch/worker/__init__.py b/awx/main/dispatch/worker/__init__.py index 06d64c437c..5472f83579 100644 --- a/awx/main/dispatch/worker/__init__.py +++ b/awx/main/dispatch/worker/__init__.py @@ -1,3 +1,4 @@ from .base import AWXConsumer, AWXRedisConsumer, BaseWorker # noqa +from .basepg import AWXConsumerPG, BaseWorkerPG # noqa from .callback import CallbackBrokerWorker # noqa from .task import TaskWorker # noqa diff --git a/awx/main/dispatch/worker/basepg.py b/awx/main/dispatch/worker/basepg.py new file mode 100644 index 0000000000..7a35fc59d6 --- /dev/null +++ b/awx/main/dispatch/worker/basepg.py @@ -0,0 +1,161 @@ +# Copyright (c) 2018 Ansible by Red Hat +# All Rights Reserved. + +import os +import logging +import signal +import sys +import json +from uuid import UUID +from queue import Empty as QueueEmpty + +from django import db +from django.db import connection as pg_connection + +from pgpubsub import PubSub + +from awx.main.dispatch.pool import WorkerPool + +SHORT_CIRCUIT = False + +if 'run_callback_receiver' in sys.argv: + logger = logging.getLogger('awx.main.commands.run_callback_receiver') +else: + logger = logging.getLogger('awx.main.dispatch') + + +def signame(sig): + return dict( + (k, v) for v, k in signal.__dict__.items() + if v.startswith('SIG') and not v.startswith('SIG_') + )[sig] + + +class WorkerSignalHandler: + + def __init__(self): + self.kill_now = False + signal.signal(signal.SIGINT, self.exit_gracefully) + + def exit_gracefully(self, *args, **kwargs): + self.kill_now = True + + +class AWXConsumerPG(object): + + def __init__(self, name, connection, worker, queues=[], pool=None): + self.name = name + self.connection = pg_connection + self.total_messages = 0 + self.queues = queues + self.worker = worker + self.pool = pool + # TODO, maybe get new connection and reconnect periodically + self.pubsub = PubSub(pg_connection.cursor().connection) + if pool is None: + self.pool = WorkerPool() + self.pool.init_workers(self.worker.work_loop) + + @property + def listening_on(self): + return 'listening on {}'.format([f'{q}' for q in self.queues]) + + def control(self, body, message): + logger.warn(body) + control = body.get('control') + if control in ('status', 'running'): + if control == 'status': + msg = '\n'.join([self.listening_on, self.pool.debug()]) + elif control == 'running': + msg = [] + for worker in self.pool.workers: + worker.calculate_managed_tasks() + msg.extend(worker.managed_tasks.keys()) + self.pubsub.notify(message.properties['reply_to'], msg) + elif control == 'reload': + for worker in self.pool.workers: + worker.quit() + else: + logger.error('unrecognized control message: {}'.format(control)) + + def process_task(self, body, message): + if SHORT_CIRCUIT or 'control' in body: + try: + return self.control(body, message) + except Exception: + logger.exception("Exception handling control message:") + return + if len(self.pool): + if "uuid" in body and body['uuid']: + try: + queue = UUID(body['uuid']).int % len(self.pool) + except Exception: + queue = self.total_messages % len(self.pool) + else: + queue = self.total_messages % len(self.pool) + else: + queue = 0 + self.pool.write(queue, body) + self.total_messages += 1 + + def run(self, *args, **kwargs): + signal.signal(signal.SIGINT, self.stop) + signal.signal(signal.SIGTERM, self.stop) + self.worker.on_start() + + logger.warn(f"Running worker {self.name} listening to queues {self.queues}") + self.pubsub = PubSub(pg_connection.cursor().connection) + for queue in self.queues: + self.pubsub.listen(queue) + for e in self.pubsub.events(): + logger.warn(f"Processing task {e}") + self.process_task(json.loads(e.payload), e) + + def stop(self, signum, frame): + logger.warn('received {}, stopping'.format(signame(signum))) + for queue in self.queues: + self.pubsub.unlisten(queue) + self.worker.on_stop() + raise SystemExit() + + +class BaseWorkerPG(object): + + def work_loop(self, queue, finished, idx, *args): + ppid = os.getppid() + signal_handler = WorkerSignalHandler() + while not signal_handler.kill_now: + # if the parent PID changes, this process has been orphaned + # via e.g., segfault or sigkill, we should exit too + if os.getppid() != ppid: + break + try: + body = queue.get(block=True, timeout=1) + if body == 'QUIT': + break + except QueueEmpty: + continue + except Exception as e: + logger.error("Exception on worker {}, restarting: ".format(idx) + str(e)) + continue + try: + for conn in db.connections.all(): + # If the database connection has a hiccup during the prior message, close it + # so we can establish a new connection + conn.close_if_unusable_or_obsolete() + self.perform_work(body, *args) + finally: + if 'uuid' in body: + uuid = body['uuid'] + logger.debug('task {} is finished'.format(uuid)) + finished.put(uuid) + logger.warn('worker exiting gracefully pid:{}'.format(os.getpid())) + + def perform_work(self, body): + raise NotImplementedError() + + def on_start(self): + pass + + def on_stop(self): + pass diff --git a/awx/main/dispatch/worker/task.py b/awx/main/dispatch/worker/task.py index 7e7437d445..80c1907fc1 100644 --- a/awx/main/dispatch/worker/task.py +++ b/awx/main/dispatch/worker/task.py @@ -8,12 +8,12 @@ from kubernetes.config import kube_config from awx.main.tasks import dispatch_startup, inform_cluster_of_shutdown -from .base import BaseWorker +from .basepg import BaseWorkerPG logger = logging.getLogger('awx.main.dispatch') -class TaskWorker(BaseWorker): +class TaskWorker(BaseWorkerPG): ''' A worker implementation that deserializes task messages and runs native Python code. diff --git a/awx/main/management/commands/run_dispatcher.py b/awx/main/management/commands/run_dispatcher.py index 7e69897687..ea6098db4c 100644 --- a/awx/main/management/commands/run_dispatcher.py +++ b/awx/main/management/commands/run_dispatcher.py @@ -64,20 +64,8 @@ class Command(BaseCommand): AWXProxyHandler.disable() with Connection(settings.BROKER_URL, transport_options=settings.BROKER_TRANSPORT_OPTIONS) as conn: try: - bcast = 'tower_broadcast_all' - queues = [ - Queue(q, Exchange(q), routing_key=q) - for q in (settings.AWX_CELERY_QUEUES_STATIC + [get_local_queuename()]) - ] - queues.append( - Queue( - construct_bcast_queue_name(bcast), - exchange=Exchange(bcast, type='fanout'), - routing_key=bcast, - reply=True - ) - ) - consumer = AWXConsumer( + queues = ['tower_broadcast_all'] + settings.AWX_CELERY_QUEUES_STATIC + [get_local_queuename()] + consumer = AWXConsumerPG( 'dispatcher', conn, TaskWorker(),