From b350eef3f0256245ae888f316dcc639e0a3f24e1 Mon Sep 17 00:00:00 2001 From: Chris Meyers Date: Tue, 29 Sep 2015 14:10:26 -0400 Subject: [PATCH] session limit invalidation events via socket.io --- awx/api/views.py | 9 +- .../commands/run_socketio_service.py | 170 ++++++++++++++---- awx/main/utils.py | 4 +- 3 files changed, 144 insertions(+), 39 deletions(-) diff --git a/awx/api/views.py b/awx/api/views.py index 6e53ed9997..b1ca33ed27 100644 --- a/awx/api/views.py +++ b/awx/api/views.py @@ -64,6 +64,7 @@ from awx.api.permissions import * # noqa from awx.api.renderers import * # noqa from awx.api.serializers import * # noqa from awx.fact.models import * # noqa +from awx.main.utils import emit_websocket_notification def api_exception_handler(exc): ''' @@ -528,16 +529,20 @@ class AuthTokenView(APIView): reason='')[0] token.refresh() except IndexError: + token = AuthToken.objects.create(user=serializer.object['user'], + request_hash=request_hash) # Get user un-expired tokens that are not invalidated that are # over the configured limit. # Mark them as invalid and inform the user invalid_tokens = AuthToken.get_tokens_over_limit(serializer.object['user']) for t in invalid_tokens: # TODO: send socket notification + emit_websocket_notification('/socket.io/control', + 'limit_reached', + dict(reason=unicode(AuthToken.reason_long('limit_reached'))), + token_key=t.key) t.invalidate(reason='limit_reached') - token = AuthToken.objects.create(user=serializer.object['user'], - request_hash=request_hash) return Response({'token': token.key, 'expires': token.expires}) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/awx/main/management/commands/run_socketio_service.py b/awx/main/management/commands/run_socketio_service.py index f2f0155642..8a4706aea9 100644 --- a/awx/main/management/commands/run_socketio_service.py +++ b/awx/main/management/commands/run_socketio_service.py @@ -5,6 +5,7 @@ import os import logging import urllib +import weakref from optparse import make_option from threading import Thread @@ -24,51 +25,139 @@ from socketio.namespace import BaseNamespace logger = logging.getLogger('awx.main.commands.run_socketio_service') -valid_sockets = [] +class SocketSession(object): + def __init__(self, session_id, token_key, socket): + self.socket = weakref.ref(socket) + self.session_id = session_id + self.token_key = token_key + self._valid = True + def is_valid(self): + return bool(self._valid) + + def invalidate(self): + self._valid = False + + def is_db_token_valid(self): + auth_token = AuthToken.objects.filter(key=self.token_key, reason='') + if not auth_token.exists(): + return False + auth_token = auth_token[0] + return bool(not auth_token.is_expired()) + +class SocketSessionManager(object): + sessions = [] + session_token_key_map = {} + + @classmethod + def _prune(cls): + #cls.sessions = [s for s in cls.sessions if s.is_valid()] + if len(cls.sessions) > 1000: + del cls.session_token_key_map[cls.sessions[0].token_key] + cls.sessions = cls.sessions[1:] + + @classmethod + def lookup(cls, token_key=None): + if not token_key: + raise ValueError("token_key required") + for s in cls.sessions: + if s.token_key == token_key: + return s + return None + + @classmethod + def add_session(cls, session): + cls.sessions.append(session) + cls.session_token_key_map[session.token_key] = session + cls._prune() + +class SocketController(object): + server = None + + @classmethod + def broadcast_packet(cls, packet): + # Broadcast message to everyone at endpoint + # Loop over the 'raw' list of sockets (don't trust our list) + for session_id, socket in list(cls.server.sockets.iteritems()): + socket_session = socket.session.get('socket_session', None) + if socket_session and socket_session.is_valid(): + try: + socket.send_packet(packet) + except Exception, e: + logger.error("Error sending client packet to %s: %s" % (str(session_id), str(packet))) + logger.error("Error was: " + str(e)) + + @classmethod + def send_packet(cls, packet, token_key): + if not token_key: + raise ValueError("token_key is required") + socket_session = SocketSessionManager.lookup(token_key=token_key) + # We may not find the socket_session if the user disconnected + # (it's actually more compliciated than that because of our prune logic) + if socket_session and socket_session.is_valid(): + socket = socket_session.socket() + if socket: + try: + socket.send_packet(packet) + except Exception, e: + logger.error("Error sending client packet to %s: %s" % (str(socket_session.session_id), str(packet))) + logger.error("Error was: " + str(e)) + + @classmethod + def set_server(cls, server): + cls.server = server + return server + +# +# Socket session is attached to self.session['socket_session'] +# self.session and self.socket.session point to the same dict +# class TowerBaseNamespace(BaseNamespace): def get_allowed_methods(self): return ['recv_disconnect'] def get_initial_acl(self): - global valid_sockets - v_user = self.valid_user() - self.is_valid_connection = False - if v_user: - if self.socket.sessid not in valid_sockets: - valid_sockets.append(self.socket.sessid) - self.is_valid_connection = True - if len(valid_sockets) > 1000: - valid_sockets = valid_sockets[1:] + request_token = self._get_request_token() + if request_token: + # (1) This is the first time the socket has been seen (first + # namespace joined). + # (2) This socket has already been seen (already joined and maybe + # left a namespace) + # + # Note: Assume that the user token is valid if the session is found + socket_session = self.session.get('socket_session', None) + if not socket_session: + socket_session = SocketSession(self.socket.sessid, request_token, self.socket) + if socket_session.is_db_token_valid(): + self.session['socket_session'] = socket_session + SocketSessionManager.add_session(socket_session) + else: + socket_session.invalidate() + return set(['recv_connect'] + self.get_allowed_methods()) else: logger.warn("Authentication Failure validating user") self.emit("connect_failed", "Authentication failed") return set(['recv_connect']) - def valid_user(self): + def _get_request_token(self): if 'QUERY_STRING' not in self.environ: return False - else: - try: - k, v = self.environ['QUERY_STRING'].split("=") - if k == "Token": - token_actual = urllib.unquote_plus(v).decode().replace("\"","") - auth_token = AuthToken.objects.filter(key=token_actual, reason='') - if not auth_token.exists(): - return False - auth_token = auth_token[0] - if not auth_token.is_expired(): - return auth_token.user - else: - return False - except Exception, e: - logger.error("Exception validating user: " + str(e)) - return False + + try: + k, v = self.environ['QUERY_STRING'].split("=") + if k == "Token": + token_actual = urllib.unquote_plus(v).decode().replace("\"","") + return token_actual + except Exception, e: + logger.error("Exception validating user: " + str(e)) + return False + return False def recv_connect(self): - if not self.is_valid_connection: + socket_session = self.session.get('socket_session', None) + if socket_session and not socket_session.is_valid(): self.disconnect(silent=False) class TestNamespace(TowerBaseNamespace): @@ -106,6 +195,14 @@ class ScheduleNamespace(TowerBaseNamespace): logger.info("Received client connect for schedule namespace from %s" % str(self.environ['REMOTE_ADDR'])) super(ScheduleNamespace, self).recv_connect() +# Catch-all namespace. +# Deliver 'global' events over this namespace +class ControlNamespace(TowerBaseNamespace): + + def recv_connect(self): + logger.warn("Received client connect for control namespace from %s" % str(self.environ['REMOTE_ADDR'])) + super(ControlNamespace, self).recv_connect() + class TowerSocket(object): def __call__(self, environ, start_response): @@ -115,7 +212,8 @@ class TowerSocket(object): '/socket.io/jobs': JobNamespace, '/socket.io/job_events': JobEventNamespace, '/socket.io/ad_hoc_command_events': AdHocCommandEventNamespace, - '/socket.io/schedules': ScheduleNamespace}) + '/socket.io/schedules': ScheduleNamespace, + '/socket.io/control': ControlNamespace}) else: logger.warn("Invalid connect path received: " + path) start_response('404 Not Found', []) @@ -130,13 +228,12 @@ def notification_handler(server): 'name': message['event'], 'type': 'event', } - for session_id, socket in list(server.sockets.iteritems()): - if session_id in valid_sockets: - try: - socket.send_packet(packet) - except Exception, e: - logger.error("Error sending client packet to %s: %s" % (str(session_id), str(packet))) - logger.error("Error was: " + str(e)) + + if 'token_key' in message: + # Best practice not to send the token over the socket + SocketController.send_packet(packet, message.pop('token_key')) + else: + SocketController.broadcast_packet(packet) class Command(NoArgsCommand): ''' @@ -164,6 +261,7 @@ class Command(NoArgsCommand): logger.info('Listening on port http://0.0.0.0:' + str(socketio_listen_port)) server = SocketIOServer(('0.0.0.0', socketio_listen_port), TowerSocket(), resource='socket.io') + SocketController.set_server(server) handler_thread = Thread(target=notification_handler, args=(server,)) handler_thread.daemon = True handler_thread.start() diff --git a/awx/main/utils.py b/awx/main/utils.py index 202e9989fe..9e6a005dc1 100644 --- a/awx/main/utils.py +++ b/awx/main/utils.py @@ -389,11 +389,13 @@ def get_system_task_capacity(): return 50 + ((int(total_mem_value) / 1024) - 2) * 75 -def emit_websocket_notification(endpoint, event, payload): +def emit_websocket_notification(endpoint, event, payload, token_key=None): from awx.main.socket import Socket try: with Socket('websocket', 'w', nowait=True, logger=logger) as websocket: + if token_key: + payload['token_key'] = token_key payload['event'] = event payload['endpoint'] = endpoint websocket.publish(payload)