From 3f4913f5ab8ad7f10395574bb91453fdeff0e23e Mon Sep 17 00:00:00 2001 From: Chris Meyers Date: Fri, 6 Nov 2015 10:30:56 -0500 Subject: [PATCH] adds socket tests --- .../commands/run_socketio_service.py | 71 ++++++----- awx/main/tests/commands/__init__.py | 1 + .../tests/commands/run_socketio_service.py | 116 ++++++++++++++++++ 3 files changed, 157 insertions(+), 31 deletions(-) create mode 100644 awx/main/tests/commands/run_socketio_service.py diff --git a/awx/main/management/commands/run_socketio_service.py b/awx/main/management/commands/run_socketio_service.py index c4439218db..56d8cc5b1c 100644 --- a/awx/main/management/commands/run_socketio_service.py +++ b/awx/main/management/commands/run_socketio_service.py @@ -46,43 +46,52 @@ class SocketSession(object): return bool(not auth_token.is_expired()) class SocketSessionManager(object): - socket_sessions = [] - socket_session_token_key_map = {} - @classmethod - def _prune(cls): - if len(cls.socket_sessions) > 1000: - session = cls.socket_session[0] - del cls.socket_session_token_key_map[session.token_key] - cls.sessions = cls.socket_sessions[1:] + def __init__(self): + self.SESSIONS_MAX = 1000 + self.socket_sessions = [] + self.socket_session_token_key_map = {} + + def _prune(self): + if len(self.socket_sessions) > self.SESSIONS_MAX: + session = self.socket_sessions[0] + entries = self.socket_session_token_key_map[session.token_key] + del entries[session.session_id] + if len(entries) == 0: + del self.socket_session_token_key_map[session.token_key] + self.socket_sessions.pop(0) ''' Returns an dict of sessions ''' - @classmethod - def lookup(cls, token_key=None): + def lookup(self, token_key=None): if not token_key: raise ValueError("token_key required") - return cls.socket_session_token_key_map.get(token_key, None) + return self.socket_session_token_key_map.get(token_key, None) - @classmethod - def add_session(cls, session): - cls.socket_sessions.append(session) - entries = cls.socket_session_token_key_map.get(session.token_key, None) + def add_session(self, session): + self.socket_sessions.append(session) + entries = self.socket_session_token_key_map.get(session.token_key, None) if not entries: entries = {} - cls.socket_session_token_key_map[session.token_key] = entries + self.socket_session_token_key_map[session.token_key] = entries entries[session.session_id] = session - cls._prune() + self._prune() + return session class SocketController(object): - server = None - @classmethod - def broadcast_packet(cls, packet): + def __init__(self, SocketSessionManager): + self.server = None + self.SocketSessionManager = SocketSessionManager + + def add_session(self, session): + return self.SocketSessionManager.add_session(session) + + def broadcast_packet(self, packet): # Broadcast message to everyone at endpoint # Loop over the 'raw' list of sockets (don't trust our list) - for session_id, socket in list(cls.server.sockets.iteritems()): + for session_id, socket in list(self.server.sockets.iteritems()): socket_session = socket.session.get('socket_session', None) if socket_session and socket_session.is_valid(): try: @@ -91,11 +100,10 @@ class SocketController(object): logger.error("Error sending client packet to %s: %s" % (str(session_id), str(packet))) logger.error("Error was: " + str(e)) - @classmethod - def send_packet(cls, packet, token_key): + def send_packet(self, packet, token_key): if not token_key: raise ValueError("token_key is required") - socket_sessions = SocketSessionManager.lookup(token_key=token_key) + socket_sessions = self.SocketSessionManager.lookup(token_key=token_key) # We may not find the socket_session if the user disconnected # (it's actually more compliciated than that because of our prune logic) if not socket_sessions: @@ -112,11 +120,12 @@ class SocketController(object): logger.error("Error sending client packet to %s: %s" % (str(socket_session.session_id), str(packet))) logger.error("Error was: " + str(e)) - @classmethod - def set_server(cls, server): - cls.server = server + def set_server(self, server): + self.server = server return server +socketController = SocketController(SocketSessionManager()) + # # Socket session is attached to self.session['socket_session'] # self.session and self.socket.session point to the same dict @@ -140,7 +149,7 @@ class TowerBaseNamespace(BaseNamespace): socket_session = SocketSession(self.socket.sessid, request_token, self.socket) if socket_session.is_db_token_valid(): self.session['socket_session'] = socket_session - SocketSessionManager.add_session(socket_session) + socketController.add_session(socket_session) else: socket_session.invalidate() @@ -240,9 +249,9 @@ def notification_handler(server): if 'token_key' in message: # Best practice not to send the token over the socket - SocketController.send_packet(packet, message.pop('token_key')) + socketController.send_packet(packet, message.pop('token_key')) else: - SocketController.broadcast_packet(packet) + socketController.broadcast_packet(packet) class Command(NoArgsCommand): ''' @@ -270,7 +279,7 @@ class Command(NoArgsCommand): logger.info('Listening on port http://0.0.0.0:' + str(socketio_listen_port)) server = SocketIOServer(('0.0.0.0', socketio_listen_port), TowerSocket(), resource='socket.io') - SocketController.set_server(server) + socketController.set_server(server) handler_thread = Thread(target=notification_handler, args=(server,)) handler_thread.daemon = True handler_thread.start() diff --git a/awx/main/tests/commands/__init__.py b/awx/main/tests/commands/__init__.py index 683950169f..dc89a6f8b6 100644 --- a/awx/main/tests/commands/__init__.py +++ b/awx/main/tests/commands/__init__.py @@ -8,4 +8,5 @@ from .commands_monolithic import * # noqa from .cleanup_facts import * # noqa from .age_deleted import * # noqa from .remove_instance import * # noqa +from .run_socketio_service import * # noqa diff --git a/awx/main/tests/commands/run_socketio_service.py b/awx/main/tests/commands/run_socketio_service.py new file mode 100644 index 0000000000..be882d6b20 --- /dev/null +++ b/awx/main/tests/commands/run_socketio_service.py @@ -0,0 +1,116 @@ +# Copyright (c) 2015 Ansible, Inc. +# All Rights Reserved + +# Python +from mock import MagicMock, Mock + +# Django +from django.test import SimpleTestCase + +# AWX +from awx.fact.models.fact import * # noqa +from awx.main.management.commands.run_socketio_service import SocketSessionManager, SocketSession, SocketController + +__all__ = ['SocketSessionManagerUnitTest', 'SocketControllerUnitTest',] + +class WeakRefable(): + pass + +class SocketSessionManagerUnitTest(SimpleTestCase): + + def setUp(self): + self.session_manager = SocketSessionManager() + super(SocketSessionManagerUnitTest, self).setUp() + + def create_sessions(self, count, token_key=None): + self.sessions = [] + self.count = count + for i in range(0, count): + self.sessions.append(SocketSession(i, token_key or i, WeakRefable())) + self.session_manager.add_session(self.sessions[i]) + + def test_multiple_session_diff_token(self): + self.create_sessions(10) + + for s in self.sessions: + self.assertIn(s.token_key, self.session_manager.socket_session_token_key_map) + self.assertEqual(s, self.session_manager.socket_session_token_key_map[s.token_key][s.session_id]) + + + def test_multiple_session_same_token(self): + self.create_sessions(10, token_key='foo') + + sessions_dict = self.session_manager.lookup("foo") + self.assertEqual(len(sessions_dict), 10) + for s in self.sessions: + self.assertIn(s.session_id, sessions_dict) + self.assertEqual(s, sessions_dict[s.session_id]) + + def test_prune_sessions_max(self): + self.create_sessions(self.session_manager.SESSIONS_MAX + 10) + + self.assertEqual(len(self.session_manager.socket_sessions), self.session_manager.SESSIONS_MAX) + + +class SocketControllerUnitTest(SimpleTestCase): + + def setUp(self): + self.socket_controller = SocketController(SocketSessionManager()) + server = Mock() + self.socket_controller.set_server(server) + super(SocketControllerUnitTest, self).setUp() + + def create_clients(self, count, token_key=None): + self.sessions = [] + self.sockets =[] + self.count = count + self.sockets_dict = {} + for i in range(0, count): + if isinstance(token_key, list): + token_key_actual = token_key[i] + else: + token_key_actual = token_key or i + socket = MagicMock(session=dict()) + socket_session = SocketSession(i, token_key_actual, socket) + self.sockets.append(socket) + self.sessions.append(socket_session) + self.sockets_dict[i] = socket + self.socket_controller.add_session(socket_session) + + socket.session['socket_session'] = socket_session + socket.send_packet = Mock() + self.socket_controller.server.sockets = self.sockets_dict + + def test_broadcast_packet(self): + self.create_clients(10) + packet = { + "hello": "world" + } + self.socket_controller.broadcast_packet(packet) + for s in self.sockets: + s.send_packet.assert_called_with(packet) + + def test_send_packet(self): + self.create_clients(5, token_key=[0, 1, 2, 3, 4]) + packet = { + "hello": "world" + } + self.socket_controller.send_packet(packet, 2) + self.assertEqual(0, len(self.sockets[0].send_packet.mock_calls)) + self.assertEqual(0, len(self.sockets[1].send_packet.mock_calls)) + self.sockets[2].send_packet.assert_called_once_with(packet) + self.assertEqual(0, len(self.sockets[3].send_packet.mock_calls)) + self.assertEqual(0, len(self.sockets[4].send_packet.mock_calls)) + + def test_send_packet_multiple_sessions_one_token(self): + self.create_clients(5, token_key=[0, 1, 1, 1, 2]) + packet = { + "hello": "world" + } + self.socket_controller.send_packet(packet, 1) + self.assertEqual(0, len(self.sockets[0].send_packet.mock_calls)) + self.sockets[1].send_packet.assert_called_once_with(packet) + self.sockets[2].send_packet.assert_called_once_with(packet) + self.sockets[3].send_packet.assert_called_once_with(packet) + self.assertEqual(0, len(self.sockets[4].send_packet.mock_calls)) +