POC postgres broker

This commit is contained in:
chris meyers
2019-12-20 10:21:53 -05:00
committed by Ryan Petrello
parent 355fb125cb
commit 558e92806b
5 changed files with 180 additions and 37 deletions

View File

@@ -1,10 +1,14 @@
import inspect
import logging
import sys
import json
from uuid import uuid4
import psycopg2
from django.conf import settings
from kombu import Exchange, Producer, Connection, Queue, Consumer
from kombu import Exchange, Producer
from django.db import connection
from pgpubsub import PubSub
logger = logging.getLogger('awx.main.dispatch')
@@ -85,26 +89,15 @@ class task:
if callable(queue):
queue = queue()
if not settings.IS_TESTING(sys.argv):
with Connection(settings.BROKER_URL, transport_options=settings.BROKER_TRANSPORT_OPTIONS) as conn:
exchange = Exchange(queue, type=exchange_type or 'direct')
# HACK: With Redis as the broker declaring an exchange isn't enough to create the queue
# Creating a Consumer _will_ create a queue so that publish will succeed. Note that we
# don't call consume() on the consumer so we don't actually eat any messages
Consumer(conn, queues=[Queue(queue, exchange, routing_key=queue)], accept=['json'])
producer = Producer(conn)
logger.debug('publish {}({}, queue={})'.format(
cls.name,
task_id,
queue
))
producer.publish(obj,
serializer='json',
compression='bzip2',
exchange=exchange,
declare=[exchange],
delivery_mode="persistent",
routing_key=queue)
conf = settings.DATABASES['default']
conn = psycopg2.connect(dbname=conf['NAME'],
host=conf['HOST'],
user=conf['USER'],
password=conf['PASSWORD'])
conn.set_session(autocommit=True)
logger.warn(f"Send message to queue {queue}")
pubsub = PubSub(conn)
pubsub.notify(queue, json.dumps(obj))
return (obj, queue)
# If the object we're wrapping *is* a class (e.g., RunJob), return

View File

@@ -1,3 +1,4 @@
from .base import AWXConsumer, AWXRedisConsumer, BaseWorker # noqa
from .basepg import AWXConsumerPG, BaseWorkerPG # noqa
from .callback import CallbackBrokerWorker # noqa
from .task import TaskWorker # noqa

View File

@@ -0,0 +1,161 @@
# Copyright (c) 2018 Ansible by Red Hat
# All Rights Reserved.
import os
import logging
import signal
import sys
import json
from uuid import UUID
from queue import Empty as QueueEmpty
from django import db
from django.db import connection as pg_connection
from pgpubsub import PubSub
from awx.main.dispatch.pool import WorkerPool
SHORT_CIRCUIT = False
if 'run_callback_receiver' in sys.argv:
logger = logging.getLogger('awx.main.commands.run_callback_receiver')
else:
logger = logging.getLogger('awx.main.dispatch')
def signame(sig):
return dict(
(k, v) for v, k in signal.__dict__.items()
if v.startswith('SIG') and not v.startswith('SIG_')
)[sig]
class WorkerSignalHandler:
def __init__(self):
self.kill_now = False
signal.signal(signal.SIGINT, self.exit_gracefully)
def exit_gracefully(self, *args, **kwargs):
self.kill_now = True
class AWXConsumerPG(object):
def __init__(self, name, connection, worker, queues=[], pool=None):
self.name = name
self.connection = pg_connection
self.total_messages = 0
self.queues = queues
self.worker = worker
self.pool = pool
# TODO, maybe get new connection and reconnect periodically
self.pubsub = PubSub(pg_connection.cursor().connection)
if pool is None:
self.pool = WorkerPool()
self.pool.init_workers(self.worker.work_loop)
@property
def listening_on(self):
return 'listening on {}'.format([f'{q}' for q in self.queues])
def control(self, body, message):
logger.warn(body)
control = body.get('control')
if control in ('status', 'running'):
if control == 'status':
msg = '\n'.join([self.listening_on, self.pool.debug()])
elif control == 'running':
msg = []
for worker in self.pool.workers:
worker.calculate_managed_tasks()
msg.extend(worker.managed_tasks.keys())
self.pubsub.notify(message.properties['reply_to'], msg)
elif control == 'reload':
for worker in self.pool.workers:
worker.quit()
else:
logger.error('unrecognized control message: {}'.format(control))
def process_task(self, body, message):
if SHORT_CIRCUIT or 'control' in body:
try:
return self.control(body, message)
except Exception:
logger.exception("Exception handling control message:")
return
if len(self.pool):
if "uuid" in body and body['uuid']:
try:
queue = UUID(body['uuid']).int % len(self.pool)
except Exception:
queue = self.total_messages % len(self.pool)
else:
queue = self.total_messages % len(self.pool)
else:
queue = 0
self.pool.write(queue, body)
self.total_messages += 1
def run(self, *args, **kwargs):
signal.signal(signal.SIGINT, self.stop)
signal.signal(signal.SIGTERM, self.stop)
self.worker.on_start()
logger.warn(f"Running worker {self.name} listening to queues {self.queues}")
self.pubsub = PubSub(pg_connection.cursor().connection)
for queue in self.queues:
self.pubsub.listen(queue)
for e in self.pubsub.events():
logger.warn(f"Processing task {e}")
self.process_task(json.loads(e.payload), e)
def stop(self, signum, frame):
logger.warn('received {}, stopping'.format(signame(signum)))
for queue in self.queues:
self.pubsub.unlisten(queue)
self.worker.on_stop()
raise SystemExit()
class BaseWorkerPG(object):
def work_loop(self, queue, finished, idx, *args):
ppid = os.getppid()
signal_handler = WorkerSignalHandler()
while not signal_handler.kill_now:
# if the parent PID changes, this process has been orphaned
# via e.g., segfault or sigkill, we should exit too
if os.getppid() != ppid:
break
try:
body = queue.get(block=True, timeout=1)
if body == 'QUIT':
break
except QueueEmpty:
continue
except Exception as e:
logger.error("Exception on worker {}, restarting: ".format(idx) + str(e))
continue
try:
for conn in db.connections.all():
# If the database connection has a hiccup during the prior message, close it
# so we can establish a new connection
conn.close_if_unusable_or_obsolete()
self.perform_work(body, *args)
finally:
if 'uuid' in body:
uuid = body['uuid']
logger.debug('task {} is finished'.format(uuid))
finished.put(uuid)
logger.warn('worker exiting gracefully pid:{}'.format(os.getpid()))
def perform_work(self, body):
raise NotImplementedError()
def on_start(self):
pass
def on_stop(self):
pass

View File

@@ -8,12 +8,12 @@ from kubernetes.config import kube_config
from awx.main.tasks import dispatch_startup, inform_cluster_of_shutdown
from .base import BaseWorker
from .basepg import BaseWorkerPG
logger = logging.getLogger('awx.main.dispatch')
class TaskWorker(BaseWorker):
class TaskWorker(BaseWorkerPG):
'''
A worker implementation that deserializes task messages and runs native
Python code.

View File

@@ -64,20 +64,8 @@ class Command(BaseCommand):
AWXProxyHandler.disable()
with Connection(settings.BROKER_URL, transport_options=settings.BROKER_TRANSPORT_OPTIONS) as conn:
try:
bcast = 'tower_broadcast_all'
queues = [
Queue(q, Exchange(q), routing_key=q)
for q in (settings.AWX_CELERY_QUEUES_STATIC + [get_local_queuename()])
]
queues.append(
Queue(
construct_bcast_queue_name(bcast),
exchange=Exchange(bcast, type='fanout'),
routing_key=bcast,
reply=True
)
)
consumer = AWXConsumer(
queues = ['tower_broadcast_all'] + settings.AWX_CELERY_QUEUES_STATIC + [get_local_queuename()]
consumer = AWXConsumerPG(
'dispatcher',
conn,
TaskWorker(),