mirror of
https://github.com/ansible/awx.git
synced 2026-03-13 15:09:32 -02:30
POC postgres broker
This commit is contained in:
committed by
Ryan Petrello
parent
355fb125cb
commit
558e92806b
@@ -1,10 +1,14 @@
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
from uuid import uuid4
|
||||
import psycopg2
|
||||
|
||||
from django.conf import settings
|
||||
from kombu import Exchange, Producer, Connection, Queue, Consumer
|
||||
from kombu import Exchange, Producer
|
||||
from django.db import connection
|
||||
from pgpubsub import PubSub
|
||||
|
||||
|
||||
logger = logging.getLogger('awx.main.dispatch')
|
||||
@@ -85,26 +89,15 @@ class task:
|
||||
if callable(queue):
|
||||
queue = queue()
|
||||
if not settings.IS_TESTING(sys.argv):
|
||||
with Connection(settings.BROKER_URL, transport_options=settings.BROKER_TRANSPORT_OPTIONS) as conn:
|
||||
exchange = Exchange(queue, type=exchange_type or 'direct')
|
||||
|
||||
# HACK: With Redis as the broker declaring an exchange isn't enough to create the queue
|
||||
# Creating a Consumer _will_ create a queue so that publish will succeed. Note that we
|
||||
# don't call consume() on the consumer so we don't actually eat any messages
|
||||
Consumer(conn, queues=[Queue(queue, exchange, routing_key=queue)], accept=['json'])
|
||||
producer = Producer(conn)
|
||||
logger.debug('publish {}({}, queue={})'.format(
|
||||
cls.name,
|
||||
task_id,
|
||||
queue
|
||||
))
|
||||
producer.publish(obj,
|
||||
serializer='json',
|
||||
compression='bzip2',
|
||||
exchange=exchange,
|
||||
declare=[exchange],
|
||||
delivery_mode="persistent",
|
||||
routing_key=queue)
|
||||
conf = settings.DATABASES['default']
|
||||
conn = psycopg2.connect(dbname=conf['NAME'],
|
||||
host=conf['HOST'],
|
||||
user=conf['USER'],
|
||||
password=conf['PASSWORD'])
|
||||
conn.set_session(autocommit=True)
|
||||
logger.warn(f"Send message to queue {queue}")
|
||||
pubsub = PubSub(conn)
|
||||
pubsub.notify(queue, json.dumps(obj))
|
||||
return (obj, queue)
|
||||
|
||||
# If the object we're wrapping *is* a class (e.g., RunJob), return
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .base import AWXConsumer, AWXRedisConsumer, BaseWorker # noqa
|
||||
from .basepg import AWXConsumerPG, BaseWorkerPG # noqa
|
||||
from .callback import CallbackBrokerWorker # noqa
|
||||
from .task import TaskWorker # noqa
|
||||
|
||||
161
awx/main/dispatch/worker/basepg.py
Normal file
161
awx/main/dispatch/worker/basepg.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) 2018 Ansible by Red Hat
|
||||
# All Rights Reserved.
|
||||
|
||||
import os
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import json
|
||||
from uuid import UUID
|
||||
from queue import Empty as QueueEmpty
|
||||
|
||||
from django import db
|
||||
from django.db import connection as pg_connection
|
||||
|
||||
from pgpubsub import PubSub
|
||||
|
||||
from awx.main.dispatch.pool import WorkerPool
|
||||
|
||||
SHORT_CIRCUIT = False
|
||||
|
||||
if 'run_callback_receiver' in sys.argv:
|
||||
logger = logging.getLogger('awx.main.commands.run_callback_receiver')
|
||||
else:
|
||||
logger = logging.getLogger('awx.main.dispatch')
|
||||
|
||||
|
||||
def signame(sig):
|
||||
return dict(
|
||||
(k, v) for v, k in signal.__dict__.items()
|
||||
if v.startswith('SIG') and not v.startswith('SIG_')
|
||||
)[sig]
|
||||
|
||||
|
||||
class WorkerSignalHandler:
|
||||
|
||||
def __init__(self):
|
||||
self.kill_now = False
|
||||
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||
|
||||
def exit_gracefully(self, *args, **kwargs):
|
||||
self.kill_now = True
|
||||
|
||||
|
||||
class AWXConsumerPG(object):
|
||||
|
||||
def __init__(self, name, connection, worker, queues=[], pool=None):
|
||||
self.name = name
|
||||
self.connection = pg_connection
|
||||
self.total_messages = 0
|
||||
self.queues = queues
|
||||
self.worker = worker
|
||||
self.pool = pool
|
||||
# TODO, maybe get new connection and reconnect periodically
|
||||
self.pubsub = PubSub(pg_connection.cursor().connection)
|
||||
if pool is None:
|
||||
self.pool = WorkerPool()
|
||||
self.pool.init_workers(self.worker.work_loop)
|
||||
|
||||
@property
|
||||
def listening_on(self):
|
||||
return 'listening on {}'.format([f'{q}' for q in self.queues])
|
||||
|
||||
def control(self, body, message):
|
||||
logger.warn(body)
|
||||
control = body.get('control')
|
||||
if control in ('status', 'running'):
|
||||
if control == 'status':
|
||||
msg = '\n'.join([self.listening_on, self.pool.debug()])
|
||||
elif control == 'running':
|
||||
msg = []
|
||||
for worker in self.pool.workers:
|
||||
worker.calculate_managed_tasks()
|
||||
msg.extend(worker.managed_tasks.keys())
|
||||
self.pubsub.notify(message.properties['reply_to'], msg)
|
||||
elif control == 'reload':
|
||||
for worker in self.pool.workers:
|
||||
worker.quit()
|
||||
else:
|
||||
logger.error('unrecognized control message: {}'.format(control))
|
||||
|
||||
def process_task(self, body, message):
|
||||
if SHORT_CIRCUIT or 'control' in body:
|
||||
try:
|
||||
return self.control(body, message)
|
||||
except Exception:
|
||||
logger.exception("Exception handling control message:")
|
||||
return
|
||||
if len(self.pool):
|
||||
if "uuid" in body and body['uuid']:
|
||||
try:
|
||||
queue = UUID(body['uuid']).int % len(self.pool)
|
||||
except Exception:
|
||||
queue = self.total_messages % len(self.pool)
|
||||
else:
|
||||
queue = self.total_messages % len(self.pool)
|
||||
else:
|
||||
queue = 0
|
||||
self.pool.write(queue, body)
|
||||
self.total_messages += 1
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
signal.signal(signal.SIGINT, self.stop)
|
||||
signal.signal(signal.SIGTERM, self.stop)
|
||||
self.worker.on_start()
|
||||
|
||||
logger.warn(f"Running worker {self.name} listening to queues {self.queues}")
|
||||
self.pubsub = PubSub(pg_connection.cursor().connection)
|
||||
for queue in self.queues:
|
||||
self.pubsub.listen(queue)
|
||||
for e in self.pubsub.events():
|
||||
logger.warn(f"Processing task {e}")
|
||||
self.process_task(json.loads(e.payload), e)
|
||||
|
||||
def stop(self, signum, frame):
|
||||
logger.warn('received {}, stopping'.format(signame(signum)))
|
||||
for queue in self.queues:
|
||||
self.pubsub.unlisten(queue)
|
||||
self.worker.on_stop()
|
||||
raise SystemExit()
|
||||
|
||||
|
||||
class BaseWorkerPG(object):
|
||||
|
||||
def work_loop(self, queue, finished, idx, *args):
|
||||
ppid = os.getppid()
|
||||
signal_handler = WorkerSignalHandler()
|
||||
while not signal_handler.kill_now:
|
||||
# if the parent PID changes, this process has been orphaned
|
||||
# via e.g., segfault or sigkill, we should exit too
|
||||
if os.getppid() != ppid:
|
||||
break
|
||||
try:
|
||||
body = queue.get(block=True, timeout=1)
|
||||
if body == 'QUIT':
|
||||
break
|
||||
except QueueEmpty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error("Exception on worker {}, restarting: ".format(idx) + str(e))
|
||||
continue
|
||||
try:
|
||||
for conn in db.connections.all():
|
||||
# If the database connection has a hiccup during the prior message, close it
|
||||
# so we can establish a new connection
|
||||
conn.close_if_unusable_or_obsolete()
|
||||
self.perform_work(body, *args)
|
||||
finally:
|
||||
if 'uuid' in body:
|
||||
uuid = body['uuid']
|
||||
logger.debug('task {} is finished'.format(uuid))
|
||||
finished.put(uuid)
|
||||
logger.warn('worker exiting gracefully pid:{}'.format(os.getpid()))
|
||||
|
||||
def perform_work(self, body):
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_start(self):
|
||||
pass
|
||||
|
||||
def on_stop(self):
|
||||
pass
|
||||
@@ -8,12 +8,12 @@ from kubernetes.config import kube_config
|
||||
|
||||
from awx.main.tasks import dispatch_startup, inform_cluster_of_shutdown
|
||||
|
||||
from .base import BaseWorker
|
||||
from .basepg import BaseWorkerPG
|
||||
|
||||
logger = logging.getLogger('awx.main.dispatch')
|
||||
|
||||
|
||||
class TaskWorker(BaseWorker):
|
||||
class TaskWorker(BaseWorkerPG):
|
||||
'''
|
||||
A worker implementation that deserializes task messages and runs native
|
||||
Python code.
|
||||
|
||||
@@ -64,20 +64,8 @@ class Command(BaseCommand):
|
||||
AWXProxyHandler.disable()
|
||||
with Connection(settings.BROKER_URL, transport_options=settings.BROKER_TRANSPORT_OPTIONS) as conn:
|
||||
try:
|
||||
bcast = 'tower_broadcast_all'
|
||||
queues = [
|
||||
Queue(q, Exchange(q), routing_key=q)
|
||||
for q in (settings.AWX_CELERY_QUEUES_STATIC + [get_local_queuename()])
|
||||
]
|
||||
queues.append(
|
||||
Queue(
|
||||
construct_bcast_queue_name(bcast),
|
||||
exchange=Exchange(bcast, type='fanout'),
|
||||
routing_key=bcast,
|
||||
reply=True
|
||||
)
|
||||
)
|
||||
consumer = AWXConsumer(
|
||||
queues = ['tower_broadcast_all'] + settings.AWX_CELERY_QUEUES_STATIC + [get_local_queuename()]
|
||||
consumer = AWXConsumerPG(
|
||||
'dispatcher',
|
||||
conn,
|
||||
TaskWorker(),
|
||||
|
||||
Reference in New Issue
Block a user