diff --git a/awx/main/management/commands/run_callback_receiver.py b/awx/main/management/commands/run_callback_receiver.py index 6f9f9ed837..bd3827dbf4 100644 --- a/awx/main/management/commands/run_callback_receiver.py +++ b/awx/main/management/commands/run_callback_receiver.py @@ -25,7 +25,7 @@ from django.db import connection # AWX from awx.main.models import * -from awx.main.queue import PubSub +from awx.main.socket import Socket MAX_REQUESTS = 10000 WORKERS = 4 @@ -102,8 +102,8 @@ class CallbackReceiver(object): total_messages = 0 last_parent_events = {} - with closing(PubSub('callbacks')) as callbacks: - for message in callbacks.subscribe(wait=0.1): + with Socket('callbacks', 'r') as callbacks: + for message in callbacks.listen(): total_messages += 1 if not use_workers: self.process_job_event(message) diff --git a/awx/main/management/commands/run_socketio_service.py b/awx/main/management/commands/run_socketio_service.py index 20afe8fc86..570959b121 100644 --- a/awx/main/management/commands/run_socketio_service.py +++ b/awx/main/management/commands/run_socketio_service.py @@ -24,7 +24,7 @@ from django.utils.tzinfo import FixedOffset # AWX import awx from awx.main.models import * -from awx.main.queue import PubSub +from awx.main.socket import Socket # gevent & socketio import gevent @@ -119,16 +119,16 @@ class TowerSocket(object): return ['Tower version %s' % awx.__version__] def notification_handler(server): - pubsub = PubSub('websocket') - for message in pubsub.subscribe(): - packet = { - 'args': message, - 'endpoint': message['endpoint'], - 'name': message['event'], - 'type': 'event', - } - for session_id, socket in list(server.sockets.iteritems()): - socket.send_packet(packet) + with Socket('websocket', 'r') as websocket: + for message in websocket.listen(): + packet = { + 'args': message, + 'endpoint': message['endpoint'], + 'name': message['event'], + 'type': 'event', + } + for session_id, socket in list(server.sockets.iteritems()): + socket.send_packet(packet) class Command(NoArgsCommand): ''' diff --git a/awx/main/queue.py b/awx/main/queue.py index 8ea30702f6..54102b6c2f 100644 --- a/awx/main/queue.py +++ b/awx/main/queue.py @@ -8,7 +8,7 @@ from redis import StrictRedis from django.conf import settings -__all__ = ['FifoQueue', 'PubSub'] +__all__ = ['FifoQueue'] # Determine, based on settings.BROKER_URL (for celery), what the correct Redis @@ -66,52 +66,3 @@ class FifoQueue(object): answer = redis.lpop(self._queue_name) if answer: return json.loads(answer) - - -class PubSub(object): - """An abstraction class implemented for pubsub. - - Intended to allow alteration of backend details in a single, consistent - way throughout the Tower application. - """ - def __init__(self, queue_name): - """Instantiate a pubsub object, which is able to interact with a - Redis key as a pubsub. - - Ideally this should be used with `contextmanager.closing` to ensure - well-behavedness: - - from contextlib import closing - - with closing(PubSub('foobar')) as foobar: - for message in foobar.subscribe(wait=0.1): - - """ - self._queue_name = queue_name - self._ps = redis.pubsub(ignore_subscribe_messages=True) - self._ps.subscribe(queue_name) - - def publish(self, message): - """Publish a message to the given queue.""" - redis.publish(self._queue_name, json.dumps(message)) - - def retrieve(self): - """Retrieve a single message from the subcription channel - and return it. - """ - return self._ps.get_message() - - def subscribe(self, wait=0.001): - """Listen to content from the subscription channel indefinitely, - and yield messages as they are retrieved. - """ - while True: - message = self.retrieve() - if message is None: - time.sleep(max(wait, 0.001)) - else: - yield json.loads(message['data']) - - def close(self): - """Close the pubsub connection.""" - self._ps.close() diff --git a/awx/main/socket.py b/awx/main/socket.py new file mode 100644 index 0000000000..679b5f5fd0 --- /dev/null +++ b/awx/main/socket.py @@ -0,0 +1,164 @@ +# Copyright (c) 2014, Ansible, Inc. +# All Rights Reserved. + +import os + +import zmq + +from django.conf import settings + + +class Socket(object): + """An abstraction class implemented for a dumb OS socket. + + Intended to allow alteration of backend details in a single, consistent + way throughout the Tower application. + """ + def __init__(self, bucket, rw, debug=0, logger=None): + """Instantiate a Socket object, which uses ZeroMQ to actually perform + passing a message back and forth. + + Designed to be used as a context manager: + + with Socket('callbacks', 'w') as socket: + socket.publish({'message': 'foo bar baz'}) + + If listening for messages through a socket, the `listen` method + is a simple generator: + + with Socket('callbacks', 'r') as socket: + for message in socket.listen(): + [...] + """ + self._bucket = bucket + self._rw = { + 'r': zmq.REP, + 'w': zmq.REQ, + }[rw.lower()] + + self._connection_pid = None + self._context = None + self._socket = None + + self._debug = debug + self._logger = logger + + def __enter__(self): + self.connect() + return self + + def __exit__(self, *args, **kwargs): + self.close() + + @property + def is_connected(self): + if self._socket: + return True + return False + + @property + def port(self): + return { + 'callbacks': os.environ.get('CALLBACK_CONSUMER_PORT', + settings.CALLBACK_CONSUMER_PORT), + 'task_commands': settings.TASK_COMMAND_PORT, + 'websocket': settings.SOCKETIO_NOTIFICATION_PORT, + }[self._bucket] + + def connect(self): + """Connect to ZeroMQ.""" + + # Make sure that we are clearing everything out if there is + # a problem; PID crossover can cause bad news. + active_pid = os.getpid() + if self._connection_pid is None: + self._connection_pid = active_pid + if self._connection_pid != active_pid: + self._context = None + self._socket = None + self._connection_pid = active_pid + + # If the port is an integer, convert it into tcp:// + port = self.port + if isinstance(port, int): + port = 'tcp://127.0.0.1:%d' % port + + # If the port is None, then this is an intentional dummy; + # honor this. (For testing.) + if not port: + return + + # Okay, create the connection. + if self._context is None: + self._context = zmq.Context() + self._socket = self._context.socket(self._rw) + if self._rw == zmq.REQ: + self._socket.connect(port) + else: + self._socket.bind(port) + + def close(self): + """Disconnect and tear down.""" + if self._socket: + self._socket.close() + self._socket = None + self._context = None + + def publish(self, message): + """Publish a message over the socket.""" + + # If the port is None, no-op. + if self.port is None: + return + + # If we are not connected, whine. + if not self.is_connected: + raise RuntimeError('Cannot publish a message when not connected ' + 'to the socket.') + + # If we are in the wrong mode, whine. + if self._rw != zmq.REQ: + raise RuntimeError('This socket is not opened for writing.') + + # If we are in debug mode; provide the PID. + if self._debug: + message.update({'pid': os.getpid(), + 'connection_pid': self._connection_pid}) + + # Send the message. + for retry in xrange(4): + try: + self._socket.send_json(message) + self._socket.recv() + break + except Exception as ex: + if self._logger: + self._logger.info('Publish Exception: %r; retry=%d', + ex, retry, exc_info=True) + if retry >= 3: + raise + + def listen(self): + """Retrieve a single message from the subcription channel + and return it. + """ + # If the port is None, no-op. + if self.port is None: + raise StopIteration + + # If we are not connected, whine. + if not self.is_connected: + raise RuntimeError('Cannot publish a message when not connected ' + 'to the socket.') + + # If we are in the wrong mode, whine. + if self._rw != zmq.REP: + raise RuntimeError('This socket is not opened for reading.') + + # Actually listen to the socket. + while True: + try: + message = self._socket.recv_json() + yield message + finally: + self._socket.send('1') diff --git a/awx/main/tests/jobs.py b/awx/main/tests/jobs.py index 47a4ec3787..0f7dec19e1 100644 --- a/awx/main/tests/jobs.py +++ b/awx/main/tests/jobs.py @@ -1193,7 +1193,6 @@ class JobTest(BaseJobTestMixin, django.test.TestCase): @override_settings(CELERY_ALWAYS_EAGER=True, CELERY_EAGER_PROPAGATES_EXCEPTIONS=True, - CALLBACK_CONSUMER_PORT='', ANSIBLE_TRANSPORT='local') class JobStartCancelTest(BaseJobTestMixin, django.test.LiveServerTestCase): '''Job API tests that need to use the celery task backend.''' diff --git a/awx/main/utils.py b/awx/main/utils.py index d1b8a83f06..cac1b30149 100644 --- a/awx/main/utils.py +++ b/awx/main/utils.py @@ -361,11 +361,12 @@ def get_system_task_capacity(): def emit_websocket_notification(endpoint, event, payload): - from awx.main.queue import PubSub - pubsub = PubSub('websocket') - payload['event'] = event - payload['endpoint'] = endpoint - pubsub.publish(payload) + from awx.main.socket import Socket + + with Socket('websocket', 'w') as websocket: + payload['event'] = event + payload['endpoint'] = endpoint + websocket.publish(payload) _inventory_updates = threading.local() diff --git a/awx/plugins/callback/job_event_callback.py b/awx/plugins/callback/job_event_callback.py index 7b81f23597..5d985f9bd0 100644 --- a/awx/plugins/callback/job_event_callback.py +++ b/awx/plugins/callback/job_event_callback.py @@ -44,7 +44,7 @@ from contextlib import closing import requests # Tower -from awx.main.queue import PubSub +from awx.main.socket import Socket class TokenAuth(requests.auth.AuthBase): @@ -115,26 +115,11 @@ class CallbackModule(object): 'counter': self.counter, 'created': datetime.datetime.utcnow().isoformat(), } - active_pid = os.getpid() - if self.job_callback_debug: - msg.update({ - 'pid': active_pid, - }) - for retry_count in xrange(4): - try: - if not hasattr(self, 'connection_pid'): - self.connection_pid = active_pid - # Publish the callback through Redis. - with closing(PubSub('callbacks')) as callbacks: - callbacks.publish(msg) - return - except Exception, e: - self.logger.info('Publish Exception: %r, retry=%d', e, - retry_count, exc_info=True) - # TODO: Maybe recycle connection here? - if retry_count >= 3: - raise + # Publish the callback. + with Socket('callbacks', 'w', debug=self.job_callback_debug, + logger=self.logger) as callbacks: + callbacks.publish(msg) def _post_rest_api_event(self, event, event_data): data = json.dumps({ diff --git a/awx/settings/defaults.py b/awx/settings/defaults.py index 90ed8390ef..6dd120d9fc 100644 --- a/awx/settings/defaults.py +++ b/awx/settings/defaults.py @@ -493,12 +493,12 @@ else: INTERNAL_API_URL = 'http://127.0.0.1:8000' # ZeroMQ callback settings. -CALLBACK_CONSUMER_PORT = "tcp://127.0.0.1:5556" +CALLBACK_CONSUMER_PORT = 5556 CALLBACK_QUEUE_PORT = "ipc:///tmp/callback_receiver.ipc" -TASK_COMMAND_PORT = "tcp://127.0.0.1:6559" +TASK_COMMAND_PORT = 6559 -SOCKETIO_NOTIFICATION_PORT = "tcp://127.0.0.1:6557" +SOCKETIO_NOTIFICATION_PORT = 6557 SOCKETIO_LISTEN_PORT = 8080 ORG_ADMINS_CAN_SEE_ALL_USERS = True