Merge remote-tracking branch 'origin/zeromq'

* origin/zeromq:
  Update supervisor configuration during deployment to allow pulling in zeromq based dependencies and so that the new callback handler will get started alongside celery
  Update unit tests to manage zeromq based tasks
  Cleanup and refactor some parts of thew new zeromq based callback receiver
  Update playbook to install zeromq3 and the dependent python module
  Some cleanup and documentation for zeromq implementation
  Pull results off zeromq and distribute to workers
  Manage the zeromq connection per-pid
  Initial 0mq implementation
This commit is contained in:
Matthew Jones
2014-02-18 14:15:57 -05:00
10 changed files with 219 additions and 181 deletions

View File

@@ -0,0 +1,134 @@
# Copyright (c) 2014 AnsibleWorks, Inc.
# All Rights Reserved.
# Python
import datetime
import logging
import json
from optparse import make_option
from multiprocessing import Process
# Django
from django.conf import settings
from django.core.management.base import NoArgsCommand, CommandError
from django.db import transaction, DatabaseError
from django.contrib.auth.models import User
from django.utils.dateparse import parse_datetime
from django.utils.timezone import now, is_aware, make_aware
from django.utils.tzinfo import FixedOffset
# AWX
from awx.main.models import *
# ZeroMQ
import zmq
def run_subscriber(consumer_port, queue_port, use_workers=True):
consumer_context = zmq.Context()
consumer_subscriber = consumer_context.socket(zmq.PULL)
consumer_subscriber.bind(consumer_port)
queue_context = zmq.Context()
queue_publisher = queue_context.socket(zmq.PUSH)
queue_publisher.bind(queue_port)
if use_workers:
workers = []
for idx in range(4):
w = Worker(queue_port)
w.start()
workers.append(w)
while True: # Handle signal
message = consumer_subscriber.recv_json()
if use_workers:
queue_publisher.send_json(message)
else:
process_job_event(message)
@transaction.commit_on_success
def process_job_event(data):
event = data.get('event', '')
if not event or 'job_id' not in data:
return
try:
if not isinstance(data['created'], datetime.datetime):
data['created'] = parse_datetime(data['created'])
if not data['created'].tzinfo:
data['created'] = data['created'].replace(tzinfo=FixedOffset(0))
except (KeyError, ValueError):
data.pop('created', None)
if settings.DEBUG:
print data
for key in data.keys():
if key not in ('job_id', 'event', 'event_data', 'created'):
data.pop(key)
data['play'] = data.get('event_data', {}).get('play', '').strip()
data['task'] = data.get('event_data', {}).get('task', '').strip()
for retry_count in xrange(11):
try:
if event == 'playbook_on_stats':
transaction.commit()
job_event = JobEvent(**data)
job_event.save(post_process=True)
if not event.startswith('runner_'):
transaction.commit()
break
except DatabaseError as e:
transaction.rollback()
logger.debug('Database error saving job event, retrying in '
'1 second (retry #%d): %s', retry_count + 1, e)
time.sleep(1)
else:
logger.error('Failed to save job event after %d retries.',
retry_count)
class Worker(Process):
'''
Process to validate and store save job events received via zeromq
'''
def __init__(self, port):
self.port = port
def run(self):
print("Starting worker")
pool_context = zmq.Context()
pool_subscriber = pool_context.socket(zmq.PULL)
pool_subscriber.connect(self.port)
while True:
message = pool_subscriber.recv_json()
process_job_event(message)
class Command(NoArgsCommand):
'''
Save Job Callback receiver (see awx.plugins.callbacks.job_event_callback)
Runs as a management command and receives job save events. It then hands
them off to worker processors (see Worker) which writes them to the database
'''
help = 'Launch the job callback receiver'
option_list = NoArgsCommand.option_list + (
make_option('--port', dest='port', type='int', default=5556,
help='Port to listen for requests on'),)
def init_logging(self):
log_levels = dict(enumerate([logging.ERROR, logging.INFO,
logging.DEBUG, 0]))
self.logger = logging.getLogger('awx.main.commands.run_callback_receiver')
self.logger.setLevel(log_levels.get(self.verbosity, 0))
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(message)s'))
self.logger.addHandler(handler)
self.logger.propagate = False
def handle_noargs(self, **options):
self.verbosity = int(options.get('verbosity', 1))
self.init_logging()
consumer_port = settings.CALLBACK_CONSUMER_PORT
queue_port = settings.CALLBACK_QUEUE_PORT
run_subscriber(consumer_port, queue_port)

View File

@@ -494,11 +494,11 @@ class RunJob(BaseTask):
elif job.status in ('pending', 'waiting'): elif job.status in ('pending', 'waiting'):
job = self.update_model(job.pk, status='pending') job = self.update_model(job.pk, status='pending')
# Start another task to process job events. # Start another task to process job events.
if settings.BROKER_URL.startswith('amqp://'): # if settings.BROKER_URL.startswith('amqp://'):
app = Celery('tasks', broker=settings.BROKER_URL) # app = Celery('tasks', broker=settings.BROKER_URL)
send_task('awx.main.tasks.save_job_events', kwargs={ # send_task('awx.main.tasks.save_job_events', kwargs={
'job_id': job.id, # 'job_id': job.id,
}, serializer='json') # }, serializer='json')
return True return True
else: else:
return False return False
@@ -511,20 +511,21 @@ class RunJob(BaseTask):
# Send a special message to this job's event queue after the job has run # Send a special message to this job's event queue after the job has run
# to tell the save job events task to end. # to tell the save job events task to end.
if settings.BROKER_URL.startswith('amqp://'): if settings.BROKER_URL.startswith('amqp://'):
job_events_exchange = Exchange('job_events', 'direct', durable=True) pass
job_events_queue = Queue('job_events[%d]' % job.id, # job_events_exchange = Exchange('job_events', 'direct', durable=True)
exchange=job_events_exchange, # job_events_queue = Queue('job_events[%d]' % job.id,
routing_key=('job_events[%d]' % job.id), # exchange=job_events_exchange,
auto_delete=True) # routing_key=('job_events[%d]' % job.id),
with Connection(settings.BROKER_URL, transport_options={'confirm_publish': True}) as conn: # auto_delete=True)
with conn.Producer(serializer='json') as producer: # with Connection(settings.BROKER_URL, transport_options={'confirm_publish': True}) as conn:
msg = { # with conn.Producer(serializer='json') as producer:
'job_id': job.id, # msg = {
'event': '__complete__' # 'job_id': job.id,
} # 'event': '__complete__'
producer.publish(msg, exchange=job_events_exchange, # }
routing_key=('job_events[%d]' % job.id), # producer.publish(msg, exchange=job_events_exchange,
declare=[job_events_queue]) # routing_key=('job_events[%d]' % job.id),
# declare=[job_events_queue])
# Update job event fields after job has completed (only when using REST # Update job event fields after job has completed (only when using REST
# API callback). # API callback).
@@ -532,91 +533,6 @@ class RunJob(BaseTask):
for job_event in job.job_events.order_by('pk'): for job_event in job.job_events.order_by('pk'):
job_event.save(post_process=True) job_event.save(post_process=True)
class SaveJobEvents(Task):
name = 'awx.main.tasks.save_job_events'
def process_job_event(self, data, message, events_received=None):
if events_received is None:
events_received = {}
begints = time.time()
event = data.get('event', '')
if not event or 'job_id' not in data:
return
try:
if not isinstance(data['created'], datetime.datetime):
data['created'] = parse_datetime(data['created'])
if not data['created'].tzinfo:
data['created'] = data['created'].replace(tzinfo=FixedOffset(0))
except (KeyError, ValueError):
data.pop('created', None)
if settings.DEBUG:
print data
for key in data.keys():
if key not in ('job_id', 'event', 'event_data', 'created'):
data.pop(key)
data['play'] = data.get('event_data', {}).get('play', '').strip()
data['task'] = data.get('event_data', {}).get('task', '').strip()
duplicate = False
if event != '__complete__':
for retry_count in xrange(11):
try:
# Commit any outstanding events before saving stats.
if event == 'playbook_on_stats':
transaction.commit()
if not JobEvent.objects.filter(**data).exists():
job_event = JobEvent(**data)
job_event.save(post_process=True)
if not event.startswith('runner_'):
transaction.commit()
else:
duplicate = True
if settings.DEBUG:
print 'skipping duplicate job event %r' % data
break
except DatabaseError as e:
transaction.rollback()
logger.debug('Database error saving job event, retrying in '
'1 second (retry #%d): %s', retry_count + 1, e)
time.sleep(1)
else:
logger.error('Failed to save job event after %d retries.',
retry_count)
if not duplicate:
if event not in events_received:
events_received[event] = 1
else:
events_received[event] += 1
if settings.DEBUG:
print 'saved job event in %0.3fs' % (time.time() - begints)
message.ack()
@transaction.commit_on_success
def run(self, *args, **kwargs):
job_id = kwargs.get('job_id', None)
if not job_id:
return {}
events_received = {}
process_job_event = functools.partial(self.process_job_event,
events_received=events_received)
job_events_exchange = Exchange('job_events', 'direct', durable=True)
job_events_queue = Queue('job_events[%d]' % job_id,
exchange=job_events_exchange,
routing_key=('job_events[%d]' % job_id),
auto_delete=True)
with Connection(settings.BROKER_URL, transport_options={'confirm_publish': True}) as conn:
with conn.Consumer(job_events_queue, callbacks=[process_job_event]) as consumer:
while '__complete__' not in events_received:
conn.drain_events()
return {
'job_id': job_id,
'total_events': sum(events_received.values())}
class RunProjectUpdate(BaseTask): class RunProjectUpdate(BaseTask):
name = 'awx.main.tasks.run_project_update' name = 'awx.main.tasks.run_project_update'

View File

@@ -10,6 +10,7 @@ import os
import shutil import shutil
import tempfile import tempfile
import time import time
from multiprocessing import Process
# PyYAML # PyYAML
import yaml import yaml
@@ -23,6 +24,8 @@ from django.test.client import Client
# AWX # AWX
from awx.main.models import * from awx.main.models import *
from awx.main.backend import LDAPSettings from awx.main.backend import LDAPSettings
from awx.main.management.commands.run_callback_receiver import run_subscriber
class BaseTestMixin(object): class BaseTestMixin(object):
''' '''
@@ -363,6 +366,14 @@ class BaseTestMixin(object):
for obj in response['results']: for obj in response['results']:
self.assertTrue(set(obj.keys()) <= set(fields)) self.assertTrue(set(obj.keys()) <= set(fields))
def start_queue(self, consumer_port, queue_port):
self.queue_process = Process(target=run_subscriber,
args=(consumer_port, queue_port, False,))
self.queue_process.start()
def terminate_queue(self):
self.queue_process.terminate()
class BaseTest(BaseTestMixin, django.test.TestCase): class BaseTest(BaseTestMixin, django.test.TestCase):
''' '''
Base class for unit tests. Base class for unit tests.

View File

@@ -301,9 +301,11 @@ class CleanupJobsTest(BaseCommandMixin, BaseLiveServerTest):
self.project = None self.project = None
self.credential = None self.credential = None
settings.INTERNAL_API_URL = self.live_server_url settings.INTERNAL_API_URL = self.live_server_url
self.start_queue(settings.CALLBACK_CONSUMER_PORT, settings.CALLBACK_QUEUE_PORT)
def tearDown(self): def tearDown(self):
super(CleanupJobsTest, self).tearDown() super(CleanupJobsTest, self).tearDown()
self.terminate_queue()
if self.test_project_path: if self.test_project_path:
shutil.rmtree(self.test_project_path, True) shutil.rmtree(self.test_project_path, True)

View File

@@ -991,7 +991,12 @@ class InventoryUpdatesTest(BaseTransactionTest):
self.group = self.inventory.groups.create(name='Cloud Group') self.group = self.inventory.groups.create(name='Cloud Group')
self.inventory2 = self.organization.inventories.create(name='Cloud Inventory 2') self.inventory2 = self.organization.inventories.create(name='Cloud Inventory 2')
self.group2 = self.inventory2.groups.create(name='Cloud Group 2') self.group2 = self.inventory2.groups.create(name='Cloud Group 2')
self.start_queue(settings.CALLBACK_CONSUMER_PORT, settings.CALLBACK_QUEUE_PORT)
def tearDown(self):
super(InventoryUpdatesTest, self).tearDown()
self.terminate_queue()
def update_inventory_source(self, group, **kwargs): def update_inventory_source(self, group, **kwargs):
inventory_source = group.inventory_source inventory_source = group.inventory_source
update_fields = [] update_fields = []

View File

@@ -442,6 +442,12 @@ class BaseJobTestMixin(BaseTestMixin):
def setUp(self): def setUp(self):
super(BaseJobTestMixin, self).setUp() super(BaseJobTestMixin, self).setUp()
self.populate() self.populate()
#self.start_queue("ipc:///tmp/test_consumer.ipc", "ipc:///tmp/test_queue.ipc")
self.start_queue(settings.CALLBACK_CONSUMER_PORT, settings.CALLBACK_QUEUE_PORT)
def tearDown(self):
super(BaseJobTestMixin, self).tearDown()
self.terminate_queue()
class JobTemplateTest(BaseJobTestMixin, django.test.TestCase): class JobTemplateTest(BaseJobTestMixin, django.test.TestCase):
@@ -773,6 +779,7 @@ MIDDLEWARE_CLASSES = filter(lambda x: not x.endswith('TransactionMiddleware'),
@override_settings(CELERY_ALWAYS_EAGER=True, @override_settings(CELERY_ALWAYS_EAGER=True,
CELERY_EAGER_PROPAGATES_EXCEPTIONS=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True,
CALLBACK_BYPASS_QUEUE=True,
ANSIBLE_TRANSPORT='local', ANSIBLE_TRANSPORT='local',
MIDDLEWARE_CLASSES=MIDDLEWARE_CLASSES) MIDDLEWARE_CLASSES=MIDDLEWARE_CLASSES)
class JobStartCancelTest(BaseJobTestMixin, django.test.LiveServerTestCase): class JobStartCancelTest(BaseJobTestMixin, django.test.LiveServerTestCase):
@@ -904,6 +911,9 @@ class JobStartCancelTest(BaseJobTestMixin, django.test.LiveServerTestCase):
job = self.job_ops_east_run job = self.job_ops_east_run
job.start() job.start()
# Wait for events to filter in since we are using a single consumer
time.sleep(30)
# Check that the job detail has been updated. # Check that the job detail has been updated.
url = reverse('api:job_detail', args=(job.pk,)) url = reverse('api:job_detail', args=(job.pk,))
with self.current_user(self.user_sue): with self.current_user(self.user_sue):

View File

@@ -680,6 +680,11 @@ class ProjectUpdatesTest(BaseTransactionTest):
def setUp(self): def setUp(self):
super(ProjectUpdatesTest, self).setUp() super(ProjectUpdatesTest, self).setUp()
self.setup_users() self.setup_users()
self.start_queue(settings.CALLBACK_CONSUMER_PORT, settings.CALLBACK_QUEUE_PORT)
def tearDown(self):
super(ProjectUpdatesTest, self).tearDown()
self.terminate_queue()
def create_project(self, **kwargs): def create_project(self, **kwargs):
cred_fields = ['scm_username', 'scm_password', 'scm_key_data', cred_fields = ['scm_username', 'scm_password', 'scm_key_data',

View File

@@ -188,12 +188,14 @@ class RunJobTest(BaseCeleryTest):
return args return args
RunJob.build_args = new_build_args RunJob.build_args = new_build_args
settings.INTERNAL_API_URL = self.live_server_url settings.INTERNAL_API_URL = self.live_server_url
self.start_queue(settings.CALLBACK_CONSUMER_PORT, settings.CALLBACK_QUEUE_PORT)
def tearDown(self): def tearDown(self):
super(RunJobTest, self).tearDown() super(RunJobTest, self).tearDown()
if self.test_project_path: if self.test_project_path:
shutil.rmtree(self.test_project_path, True) shutil.rmtree(self.test_project_path, True)
RunJob.build_args = self.original_build_args RunJob.build_args = self.original_build_args
self.terminate_queue()
def create_test_credential(self, **kwargs): def create_test_credential(self, **kwargs):
opts = { opts = {

View File

@@ -38,26 +38,13 @@ import sys
import urllib import urllib
import urlparse import urlparse
# Requests / Kombu import requests
try:
import requests
from kombu import Connection, Exchange, Queue
except ImportError:
# If running from an AWX installation, use the local version of requests if
# if cannot be found globally.
local_site_packages = os.path.join(os.path.dirname(__file__), '..', '..',
'lib', 'site-packages')
sys.path.insert(0, local_site_packages)
import requests
from kombu import Connection, Exchange, Queue
# Check to see if librabbitmq is installed. # Django
try: from django.conf import settings
import librabbitmq
LIBRABBITMQ_INSTALLED = True
except ImportError:
LIBRABBITMQ_INSTALLED = False
# ZeroMQ
import zmq
class TokenAuth(requests.auth.AuthBase): class TokenAuth(requests.auth.AuthBase):
@@ -93,14 +80,10 @@ class CallbackModule(object):
self.job_id = int(os.getenv('JOB_ID')) self.job_id = int(os.getenv('JOB_ID'))
self.base_url = os.getenv('REST_API_URL', '') self.base_url = os.getenv('REST_API_URL', '')
self.auth_token = os.getenv('REST_API_TOKEN', '') self.auth_token = os.getenv('REST_API_TOKEN', '')
self.broker_url = os.getenv('BROKER_URL', '') self.context = None
self.socket = None
self._init_logging() self._init_logging()
# Since we don't yet have a way to confirm publish when using self._init_connection()
# librabbitmq, ensure we use pyamqp even if librabbitmq happens to be
# installed.
if LIBRABBITMQ_INSTALLED:
self.logger.info('Forcing use of pyamqp instead of librabbitmq')
self.broker_url = self.broker_url.replace('amqp://', 'pyamqp://')
def _init_logging(self): def _init_logging(self):
try: try:
@@ -120,76 +103,42 @@ class CallbackModule(object):
self.logger.addHandler(handler) self.logger.addHandler(handler)
self.logger.propagate = False self.logger.propagate = False
def __del__(self): def _init_connection(self):
self._cleanup_connection() self.context = None
self.socket = None
def _publish_errback(self, exc, interval): def _start_connection(self):
self.logger.info('Publish Error: %r', exc) self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH)
def _cleanup_connection(self): self.socket.connect("tcp://127.0.0.1:5556")
if hasattr(self, 'producer'):
try:
#self.logger.debug('Cleanup Producer: %r', self.producer)
self.producer.cancel()
except:
pass
del self.producer
if hasattr(self, 'connection'):
try:
#self.logger.debug('Cleanup Connection: %r', self.connection)
self.connection.release()
except:
pass
del self.connection
def _post_job_event_queue_msg(self, event, event_data): def _post_job_event_queue_msg(self, event, event_data):
if not hasattr(self, 'job_events_exchange'):
self.job_events_exchange = Exchange('job_events', 'direct',
durable=True)
if not hasattr(self, 'job_events_queue'):
self.job_events_queue = Queue('job_events[%d]' % self.job_id,
exchange=self.job_events_exchange,
routing_key=('job_events[%d]' % self.job_id),
auto_delete=True)
msg = { msg = {
'job_id': self.job_id, 'job_id': self.job_id,
'event': event, 'event': event,
'event_data': event_data, 'event_data': event_data,
'created': datetime.datetime.utcnow().isoformat(), 'created': datetime.datetime.utcnow().isoformat(),
} }
active_pid = os.getpid()
if self.job_callback_debug: if self.job_callback_debug:
msg.update({ msg.update({
'pid': os.getpid(), 'pid': active_pid,
}) })
for retry_count in xrange(4): for retry_count in xrange(4):
try: try:
if not hasattr(self, 'connection_pid'): if not hasattr(self, 'connection_pid'):
self.connection_pid = os.getpid() self.connection_pid = active_pid
if self.connection_pid != os.getpid(): if self.connection_pid != active_pid:
self._cleanup_connection() self._init_connection()
if not hasattr(self, 'connection'): if self.context is None:
self.connection = Connection(self.broker_url, transport_options={'confirm_publish': True}) self._start_connection()
self.logger.debug('New Connection: %r, retry=%d',
self.connection, retry_count) self.socket.send_json(msg)
if not hasattr(self, 'producer'):
channel = self.connection.channel()
self.producer = self.connection.Producer(channel, exchange=self.job_events_exchange, serializer='json')
self.publish = self.connection.ensure(self.producer, self.producer.publish,
errback=self._publish_errback,
max_retries=3, interval_start=1, interval_step=1, interval_max=10)
self.logger.debug('New Producer: %r, retry=%d',
self.producer, retry_count)
self.logger.debug('Publish: %r, retry=%d', msg, retry_count)
self.publish(msg, exchange=self.job_events_exchange,
routing_key=('job_events[%d]' % self.job_id),
declare=[self.job_events_queue])
if event == 'playbook_on_stats':
self._cleanup_connection()
return return
except Exception, e: except Exception, e:
self.logger.info('Publish Exception: %r, retry=%d', e, self.logger.info('Publish Exception: %r, retry=%d', e,
retry_count, exc_info=True) retry_count, exc_info=True)
self._cleanup_connection() # TODO: Maybe recycle connection here?
if retry_count >= 3: if retry_count >= 3:
raise raise
@@ -222,7 +171,7 @@ class CallbackModule(object):
task = getattr(getattr(self, 'task', None), 'name', '') task = getattr(getattr(self, 'task', None), 'name', '')
if task and event not in self.EVENTS_WITHOUT_TASK: if task and event not in self.EVENTS_WITHOUT_TASK:
event_data['task'] = task event_data['task'] = task
if self.broker_url: if not settings.CALLBACK_BYPASS_QUEUE:
self._post_job_event_queue_msg(event, event_data) self._post_job_event_queue_msg(event, event_data)
else: else:
self._post_rest_api_event(event, event_data) self._post_rest_api_event(event, event_data)

View File

@@ -345,6 +345,10 @@ if 'devserver' in INSTALLED_APPS:
else: else:
INTERNAL_API_URL = 'http://127.0.0.1:8000' INTERNAL_API_URL = 'http://127.0.0.1:8000'
CALLBACK_CONSUMER_PORT = "tcp://127.0.0.1:5556"
CALLBACK_QUEUE_PORT = "ipc:///tmp/callback_receiver.ipc"
CALLBACK_BYPASS_QUEUE = False
# Logging configuration. # Logging configuration.
LOGGING = { LOGGING = {
'version': 1, 'version': 1,