refactor and test the callback receiver as a base for a task dispatcher

This commit is contained in:
Ryan Petrello 2018-08-08 10:26:15 -04:00
parent 8ad46436df
commit da74f1d01f
No known key found for this signature in database
GPG Key ID: F2AA5F2122351777
7 changed files with 370 additions and 231 deletions

View File

97
awx/main/dispatch/pool.py Normal file
View File

@ -0,0 +1,97 @@
import errno
import logging
import os
import signal
import traceback
from multiprocessing import Process
from multiprocessing import Queue as MPQueue
from Queue import Full as QueueFull
from django.conf import settings
from django.db import connection as django_connection
from django.core.cache import cache as django_cache
logger = logging.getLogger('awx.main.dispatch')
def signame(sig):
return dict(
(k, v) for v, k in signal.__dict__.items()
if v.startswith('SIG') and not v.startswith('SIG_')
)[sig]
class WorkerPool(object):
def __init__(self, min_workers=None, queue_size=None):
self.min_workers = min_workers or settings.JOB_EVENT_WORKERS
self.queue_size = queue_size or settings.JOB_EVENT_MAX_QUEUE_SIZE
# self.workers tracks the state of worker running worker processes:
# [
# (total_messages_consumed, multiprocessing.Queue, multiprocessing.Process),
# (total_messages_consumed, multiprocessing.Queue, multiprocessing.Process),
# (total_messages_consumed, multiprocessing.Queue, multiprocessing.Process),
# (total_messages_consumed, multiprocessing.Queue, multiprocessing.Process)
# ]
self.workers = []
def __len__(self):
return len(self.workers)
def init_workers(self, target, *target_args):
def shutdown_handler(active_workers):
def _handler(signum, frame):
logger.debug('received shutdown {}'.format(signame(signum)))
try:
for active_worker in active_workers:
logger.debug('terminating worker')
signal.signal(signum, signal.SIG_DFL)
os.kill(os.getpid(), signum) # Rethrow signal, this time without catching it
except Exception:
logger.exception('error in shutdown_handler')
return _handler
django_connection.close()
django_cache.close()
for idx in range(self.min_workers):
queue_actual = MPQueue(self.queue_size)
w = Process(target=target, args=(queue_actual, idx,) + target_args)
w.start()
logger.debug('started {}[{}]'.format(target.im_self.__class__.__name__, idx))
self.workers.append([0, queue_actual, w])
signal.signal(signal.SIGINT, shutdown_handler([p[2] for p in self.workers]))
signal.signal(signal.SIGTERM, shutdown_handler([p[2] for p in self.workers]))
def write(self, preferred_queue, body):
queue_order = sorted(range(self.min_workers), cmp=lambda x, y: -1 if x==preferred_queue else 0)
write_attempt_order = []
for queue_actual in queue_order:
try:
worker_actual = self.workers[queue_actual]
worker_actual[1].put(body, block=True, timeout=5)
logger.debug('delivered to Worker[{}] qsize {}'.format(
queue_actual, worker_actual[1].qsize()
))
worker_actual[0] += 1
return queue_actual
except QueueFull:
pass
except Exception:
tb = traceback.format_exc()
logger.warn("could not write to queue %s" % preferred_queue)
logger.warn("detail: {}".format(tb))
write_attempt_order.append(preferred_queue)
logger.warn("could not write payload to any queue, attempted order: {}".format(write_attempt_order))
return None
def stop(self):
for worker in self.workers:
messages, queue, process = worker
try:
os.kill(process.pid, signal.SIGTERM)
except OSError as e:
if e.errno != errno.ESRCH:
raise

173
awx/main/dispatch/worker.py Normal file
View File

@ -0,0 +1,173 @@
# Copyright (c) 2018 Ansible by Red Hat
# All Rights Reserved.
import logging
import os
import signal
import time
import traceback
from uuid import UUID
from Queue import Empty as QueueEmpty
from kombu.mixins import ConsumerMixin
from django.conf import settings
from django.db import DatabaseError, OperationalError, connection as django_connection
from django.db.utils import InterfaceError, InternalError
from awx.main.models import (JobEvent, AdHocCommandEvent, ProjectUpdateEvent,
InventoryUpdateEvent, SystemJobEvent, UnifiedJob)
from awx.main.consumers import emit_channel_notification
from awx.main.dispatch.pool import WorkerPool
logger = logging.getLogger('awx.main.dispatch')
class WorkerSignalHandler:
def __init__(self):
self.kill_now = False
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, *args, **kwargs):
self.kill_now = True
class AWXConsumer(ConsumerMixin):
def __init__(self, connection, worker, queues=[]):
self.connection = connection
self.total_messages = 0
self.queues = queues
self.pool = WorkerPool()
self.pool.init_workers(worker.work_loop)
def get_consumers(self, Consumer, channel):
return [Consumer(queues=self.queues, accept=['json'],
callbacks=[self.process_task])]
def process_task(self, body, message):
if "uuid" in body and body['uuid']:
try:
queue = UUID(body['uuid']).int % len(self.pool)
except Exception:
queue = self.total_messages % len(self.pool)
else:
queue = self.total_messages % len(self.pool)
self.pool.write(queue, body)
self.total_messages += 1
message.ack()
class BaseWorker(object):
def work_loop(self, queue, idx, *args):
signal_handler = WorkerSignalHandler()
while not signal_handler.kill_now:
try:
body = queue.get(block=True, timeout=1)
except QueueEmpty:
continue
except Exception as e:
logger.error("Exception on worker, restarting: " + str(e))
continue
self.perform_work(body, *args)
def perform_work(self, body):
raise NotImplemented()
class CallbackBrokerWorker(BaseWorker):
MAX_RETRIES = 2
def perform_work(self, body):
try:
event_map = {
'job_id': JobEvent,
'ad_hoc_command_id': AdHocCommandEvent,
'project_update_id': ProjectUpdateEvent,
'inventory_update_id': InventoryUpdateEvent,
'system_job_id': SystemJobEvent,
}
if not any([key in body for key in event_map]):
raise Exception('Payload does not have a job identifier')
if settings.DEBUG:
from pygments import highlight
from pygments.lexers import PythonLexer
from pygments.formatters import Terminal256Formatter
from pprint import pformat
logger.info('Body: {}'.format(
highlight(pformat(body, width=160), PythonLexer(), Terminal256Formatter(style='friendly'))
)[:1024 * 4])
def _save_event_data():
for key, cls in event_map.items():
if key in body:
cls.create_from_data(**body)
job_identifier = 'unknown job'
for key in event_map.keys():
if key in body:
job_identifier = body[key]
break
if body.get('event') == 'EOF':
try:
final_counter = body.get('final_counter', 0)
logger.info('Event processing is finished for Job {}, sending notifications'.format(job_identifier))
# EOF events are sent when stdout for the running task is
# closed. don't actually persist them to the database; we
# just use them to report `summary` websocket events as an
# approximation for when a job is "done"
emit_channel_notification(
'jobs-summary',
dict(group_name='jobs', unified_job_id=job_identifier, final_counter=final_counter)
)
# Additionally, when we've processed all events, we should
# have all the data we need to send out success/failure
# notification templates
uj = UnifiedJob.objects.get(pk=job_identifier)
if hasattr(uj, 'send_notification_templates'):
retries = 0
while retries < 5:
if uj.finished:
uj.send_notification_templates('succeeded' if uj.status == 'successful' else 'failed')
break
else:
# wait a few seconds to avoid a race where the
# events are persisted _before_ the UJ.status
# changes from running -> successful
retries += 1
time.sleep(1)
uj = UnifiedJob.objects.get(pk=job_identifier)
except Exception:
logger.exception('Worker failed to emit notifications: Job {}'.format(job_identifier))
return
retries = 0
while retries <= self.MAX_RETRIES:
try:
_save_event_data()
break
except (OperationalError, InterfaceError, InternalError) as e:
if retries >= self.MAX_RETRIES:
logger.exception('Worker could not re-establish database connectivity, shutting down gracefully: Job {}'.format(job_identifier))
os.kill(os.getppid(), signal.SIGINT)
return
delay = 60 * retries
logger.exception('Database Error Saving Job Event, retry #{i} in {delay} seconds:'.format(
i=retries + 1,
delay=delay
))
django_connection.close()
time.sleep(delay)
retries += 1
except DatabaseError as e:
logger.exception('Database Error Saving Job Event for Job {}'.format(job_identifier))
break
except Exception as exc:
tb = traceback.format_exc()
logger.error('Callback Task Processor Raised Exception: %r', exc)
logger.error('Detail: {}'.format(tb))

View File

@ -1,231 +1,11 @@
# Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved.
# Python
import logging
import os
import signal
import time
from uuid import UUID
from multiprocessing import Process
from multiprocessing import Queue as MPQueue
from Queue import Empty as QueueEmpty
from Queue import Full as QueueFull
from kombu import Connection, Exchange, Queue
from kombu.mixins import ConsumerMixin
# Django
from django.conf import settings
from django.core.management.base import BaseCommand
from django.db import connection as django_connection
from django.db import DatabaseError, OperationalError
from django.db.utils import InterfaceError, InternalError
from django.core.cache import cache as django_cache
from kombu import Connection, Exchange, Queue
# AWX
from awx.main.models import * # noqa
from awx.main.consumers import emit_channel_notification
logger = logging.getLogger('awx.main.commands.run_callback_receiver')
class WorkerSignalHandler:
def __init__(self):
self.kill_now = False
signal.signal(signal.SIGINT, self.exit_gracefully)
signal.signal(signal.SIGTERM, self.exit_gracefully)
def exit_gracefully(self, *args, **kwargs):
self.kill_now = True
class CallbackBrokerWorker(ConsumerMixin):
MAX_RETRIES = 2
def __init__(self, connection, use_workers=True):
self.connection = connection
self.worker_queues = []
self.total_messages = 0
self.init_workers(use_workers)
def init_workers(self, use_workers=True):
def shutdown_handler(active_workers):
def _handler(signum, frame):
try:
for active_worker in active_workers:
active_worker.terminate()
signal.signal(signum, signal.SIG_DFL)
os.kill(os.getpid(), signum) # Rethrow signal, this time without catching it
except Exception:
logger.exception('Error in shutdown_handler')
return _handler
if use_workers:
for idx in range(settings.JOB_EVENT_WORKERS):
queue_actual = MPQueue(settings.JOB_EVENT_MAX_QUEUE_SIZE)
w = Process(target=self.callback_worker, args=(queue_actual, idx,))
if settings.DEBUG:
logger.info('Starting worker %s' % str(idx))
self.worker_queues.append([0, queue_actual, w])
# It's important to close these _right before_ we fork; we
# don't want the forked processes to inherit the open sockets
# for the DB and memcached connections (that way lies race
# conditions)
django_connection.close()
django_cache.close()
for _, _, w in self.worker_queues:
w.start()
elif settings.DEBUG:
logger.warn('Started callback receiver (no workers)')
signal.signal(signal.SIGINT, shutdown_handler([p[2] for p in self.worker_queues]))
signal.signal(signal.SIGTERM, shutdown_handler([p[2] for p in self.worker_queues]))
def get_consumers(self, Consumer, channel):
return [Consumer(queues=[Queue(settings.CALLBACK_QUEUE,
Exchange(settings.CALLBACK_QUEUE, type='direct'),
routing_key=settings.CALLBACK_QUEUE)],
accept=['json'],
callbacks=[self.process_task])]
def process_task(self, body, message):
if "uuid" in body and body['uuid']:
try:
queue = UUID(body['uuid']).int % settings.JOB_EVENT_WORKERS
except Exception:
queue = self.total_messages % settings.JOB_EVENT_WORKERS
else:
queue = self.total_messages % settings.JOB_EVENT_WORKERS
self.write_queue_worker(queue, body)
self.total_messages += 1
message.ack()
def write_queue_worker(self, preferred_queue, body):
queue_order = sorted(range(settings.JOB_EVENT_WORKERS), cmp=lambda x, y: -1 if x==preferred_queue else 0)
write_attempt_order = []
for queue_actual in queue_order:
try:
worker_actual = self.worker_queues[queue_actual]
worker_actual[1].put(body, block=True, timeout=5)
worker_actual[0] += 1
return queue_actual
except QueueFull:
pass
except Exception:
import traceback
tb = traceback.format_exc()
logger.warn("Could not write to queue %s" % preferred_queue)
logger.warn("Detail: {}".format(tb))
write_attempt_order.append(preferred_queue)
logger.warn("Could not write payload to any queue, attempted order: {}".format(write_attempt_order))
return None
def callback_worker(self, queue_actual, idx):
signal_handler = WorkerSignalHandler()
while not signal_handler.kill_now:
try:
body = queue_actual.get(block=True, timeout=1)
except QueueEmpty:
continue
except Exception as e:
logger.error("Exception on worker thread, restarting: " + str(e))
continue
try:
event_map = {
'job_id': JobEvent,
'ad_hoc_command_id': AdHocCommandEvent,
'project_update_id': ProjectUpdateEvent,
'inventory_update_id': InventoryUpdateEvent,
'system_job_id': SystemJobEvent,
}
if not any([key in body for key in event_map]):
raise Exception('Payload does not have a job identifier')
if settings.DEBUG:
from pygments import highlight
from pygments.lexers import PythonLexer
from pygments.formatters import Terminal256Formatter
from pprint import pformat
logger.info('Body: {}'.format(
highlight(pformat(body, width=160), PythonLexer(), Terminal256Formatter(style='friendly'))
)[:1024 * 4])
def _save_event_data():
for key, cls in event_map.items():
if key in body:
cls.create_from_data(**body)
job_identifier = 'unknown job'
for key in event_map.keys():
if key in body:
job_identifier = body[key]
break
if body.get('event') == 'EOF':
try:
final_counter = body.get('final_counter', 0)
logger.info('Event processing is finished for Job {}, sending notifications'.format(job_identifier))
# EOF events are sent when stdout for the running task is
# closed. don't actually persist them to the database; we
# just use them to report `summary` websocket events as an
# approximation for when a job is "done"
emit_channel_notification(
'jobs-summary',
dict(group_name='jobs', unified_job_id=job_identifier, final_counter=final_counter)
)
# Additionally, when we've processed all events, we should
# have all the data we need to send out success/failure
# notification templates
uj = UnifiedJob.objects.get(pk=job_identifier)
if hasattr(uj, 'send_notification_templates'):
retries = 0
while retries < 5:
if uj.finished:
uj.send_notification_templates('succeeded' if uj.status == 'successful' else 'failed')
break
else:
# wait a few seconds to avoid a race where the
# events are persisted _before_ the UJ.status
# changes from running -> successful
retries += 1
time.sleep(1)
uj = UnifiedJob.objects.get(pk=job_identifier)
except Exception:
logger.exception('Worker failed to emit notifications: Job {}'.format(job_identifier))
continue
retries = 0
while retries <= self.MAX_RETRIES:
try:
_save_event_data()
break
except (OperationalError, InterfaceError, InternalError) as e:
if retries >= self.MAX_RETRIES:
logger.exception('Worker could not re-establish database connectivity, shutting down gracefully: Job {}'.format(job_identifier))
os.kill(os.getppid(), signal.SIGINT)
return
delay = 60 * retries
logger.exception('Database Error Saving Job Event, retry #{i} in {delay} seconds:'.format(
i=retries + 1,
delay=delay
))
django_connection.close()
time.sleep(delay)
retries += 1
except DatabaseError as e:
logger.exception('Database Error Saving Job Event for Job {}'.format(job_identifier))
break
except Exception as exc:
import traceback
tb = traceback.format_exc()
logger.error('Callback Task Processor Raised Exception: %r', exc)
logger.error('Detail: {}'.format(tb))
from awx.main.dispatch.worker import AWXConsumer, CallbackBrokerWorker
class Command(BaseCommand):
@ -238,8 +18,22 @@ class Command(BaseCommand):
def handle(self, *arg, **options):
with Connection(settings.BROKER_URL) as conn:
consumer = None
try:
worker = CallbackBrokerWorker(conn)
worker.run()
consumer = AWXConsumer(
'callback_receiver',
conn,
CallbackBrokerWorker(),
[
Queue(
settings.CALLBACK_QUEUE,
Exchange(settings.CALLBACK_QUEUE, type='direct'),
routing_key=settings.CALLBACK_QUEUE
)
]
)
consumer.run()
except KeyboardInterrupt:
print('Terminating Callback Receiver')
if consumer:
consumer.stop()

View File

@ -0,0 +1,72 @@
import multiprocessing
import random
import sys
from uuid import uuid4
import pytest
from awx.main.dispatch.worker import BaseWorker
from awx.main.dispatch.pool import WorkerPool
class SimpleWorker(BaseWorker):
def perform_work(self, body, *args):
pass
class ResultWriter(BaseWorker):
def perform_work(self, body, result_queue):
result_queue.put(body + '!!!')
@pytest.mark.django_db
class TestWorkerPool:
def setup_method(self, test_method):
self.pool = WorkerPool(min_workers=3)
def teardown_method(self, test_method):
self.pool.stop()
def test_worker(self):
self.pool.init_workers(SimpleWorker().work_loop)
assert len(self.pool) == 3
for worker in self.pool.workers:
total, _, process = worker
assert total == 0
assert process.is_alive() is True
def test_single_task(self):
self.pool.init_workers(SimpleWorker().work_loop)
self.pool.write(0, 'xyz')
assert self.pool.workers[0][0] == 1 # worker at index 0 handled one task
assert self.pool.workers[1][0] == 0
assert self.pool.workers[2][0] == 0
def test_queue_preference(self):
self.pool.init_workers(SimpleWorker().work_loop)
self.pool.write(2, 'xyz')
assert self.pool.workers[0][0] == 0
assert self.pool.workers[1][0] == 0
assert self.pool.workers[2][0] == 1 # worker at index 2 handled one task
def test_worker_processing(self):
result_queue = multiprocessing.Queue()
self.pool.init_workers(ResultWriter().work_loop, result_queue)
uuids = []
for i in range(10):
self.pool.write(
random.choice(self.pool.workers)[0],
'Hello, Worker {}'.format(i)
)
all_messages = [result_queue.get(timeout=1) for i in range(10)]
all_messages.sort()
assert all_messages == [
'Hello, Worker {}!!!'.format(i)
for i in range(10)
]
total_handled = sum([worker[0] for worker in self.pool.workers])
assert total_handled == 10

View File

@ -1027,7 +1027,10 @@ LOGGING = {
'timed_import': {
'()': 'awx.main.utils.formatters.TimeFormatter',
'format': '%(relativeSeconds)9.3f %(levelname)-8s %(message)s'
}
},
'dispatcher': {
'format': '%(asctime)s %(levelname)-8s %(name)s PID:%(process)d %(message)s',
},
},
'handlers': {
'console': {
@ -1068,14 +1071,14 @@ LOGGING = {
'backupCount': 5,
'formatter':'simple',
},
'callback_receiver': {
'dispatcher': {
'level': 'WARNING',
'class':'logging.handlers.RotatingFileHandler',
'filters': ['require_debug_false'],
'filename': os.path.join(LOG_ROOT, 'callback_receiver.log'),
'filename': os.path.join(LOG_ROOT, 'dispatcher.log'),
'maxBytes': 1024 * 1024 * 5, # 5 MB
'backupCount': 5,
'formatter':'simple',
'formatter':'dispatcher',
},
'inventory_import': {
'level': 'DEBUG',
@ -1158,8 +1161,9 @@ LOGGING = {
},
'awx.main': {
'handlers': ['null']
}, 'awx.main.commands.run_callback_receiver': {
'handlers': ['callback_receiver'],
},
'awx.main.dispatch': {
'handlers': ['dispatcher'],
},
'awx.isolated.manager.playbooks': {
'handlers': ['management_playbooks'],

View File

@ -197,7 +197,6 @@ LOGGING['handlers']['syslog'] = {
LOGGING['loggers']['django.request']['handlers'] = ['console']
LOGGING['loggers']['rest_framework.request']['handlers'] = ['console']
LOGGING['loggers']['awx']['handlers'] = ['console', 'external_logger']
LOGGING['loggers']['awx.main.commands.run_callback_receiver']['handlers'] = ['console']
LOGGING['loggers']['awx.main.tasks']['handlers'] = ['console', 'external_logger']
LOGGING['loggers']['awx.main.scheduler']['handlers'] = ['console', 'external_logger']
LOGGING['loggers']['django_auth_ldap']['handlers'] = ['console']