mirror of
https://github.com/ansible/awx.git
synced 2026-05-17 14:27:42 -02:30
adds socket tests
This commit is contained in:
@@ -46,43 +46,52 @@ class SocketSession(object):
|
|||||||
return bool(not auth_token.is_expired())
|
return bool(not auth_token.is_expired())
|
||||||
|
|
||||||
class SocketSessionManager(object):
|
class SocketSessionManager(object):
|
||||||
socket_sessions = []
|
|
||||||
socket_session_token_key_map = {}
|
|
||||||
|
|
||||||
@classmethod
|
def __init__(self):
|
||||||
def _prune(cls):
|
self.SESSIONS_MAX = 1000
|
||||||
if len(cls.socket_sessions) > 1000:
|
self.socket_sessions = []
|
||||||
session = cls.socket_session[0]
|
self.socket_session_token_key_map = {}
|
||||||
del cls.socket_session_token_key_map[session.token_key]
|
|
||||||
cls.sessions = cls.socket_sessions[1:]
|
def _prune(self):
|
||||||
|
if len(self.socket_sessions) > self.SESSIONS_MAX:
|
||||||
|
session = self.socket_sessions[0]
|
||||||
|
entries = self.socket_session_token_key_map[session.token_key]
|
||||||
|
del entries[session.session_id]
|
||||||
|
if len(entries) == 0:
|
||||||
|
del self.socket_session_token_key_map[session.token_key]
|
||||||
|
self.socket_sessions.pop(0)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Returns an dict of sessions <session_id, session>
|
Returns an dict of sessions <session_id, session>
|
||||||
'''
|
'''
|
||||||
@classmethod
|
def lookup(self, token_key=None):
|
||||||
def lookup(cls, token_key=None):
|
|
||||||
if not token_key:
|
if not token_key:
|
||||||
raise ValueError("token_key required")
|
raise ValueError("token_key required")
|
||||||
return cls.socket_session_token_key_map.get(token_key, None)
|
return self.socket_session_token_key_map.get(token_key, None)
|
||||||
|
|
||||||
@classmethod
|
def add_session(self, session):
|
||||||
def add_session(cls, session):
|
self.socket_sessions.append(session)
|
||||||
cls.socket_sessions.append(session)
|
entries = self.socket_session_token_key_map.get(session.token_key, None)
|
||||||
entries = cls.socket_session_token_key_map.get(session.token_key, None)
|
|
||||||
if not entries:
|
if not entries:
|
||||||
entries = {}
|
entries = {}
|
||||||
cls.socket_session_token_key_map[session.token_key] = entries
|
self.socket_session_token_key_map[session.token_key] = entries
|
||||||
entries[session.session_id] = session
|
entries[session.session_id] = session
|
||||||
cls._prune()
|
self._prune()
|
||||||
|
return session
|
||||||
|
|
||||||
class SocketController(object):
|
class SocketController(object):
|
||||||
server = None
|
|
||||||
|
|
||||||
@classmethod
|
def __init__(self, SocketSessionManager):
|
||||||
def broadcast_packet(cls, packet):
|
self.server = None
|
||||||
|
self.SocketSessionManager = SocketSessionManager
|
||||||
|
|
||||||
|
def add_session(self, session):
|
||||||
|
return self.SocketSessionManager.add_session(session)
|
||||||
|
|
||||||
|
def broadcast_packet(self, packet):
|
||||||
# Broadcast message to everyone at endpoint
|
# Broadcast message to everyone at endpoint
|
||||||
# Loop over the 'raw' list of sockets (don't trust our list)
|
# Loop over the 'raw' list of sockets (don't trust our list)
|
||||||
for session_id, socket in list(cls.server.sockets.iteritems()):
|
for session_id, socket in list(self.server.sockets.iteritems()):
|
||||||
socket_session = socket.session.get('socket_session', None)
|
socket_session = socket.session.get('socket_session', None)
|
||||||
if socket_session and socket_session.is_valid():
|
if socket_session and socket_session.is_valid():
|
||||||
try:
|
try:
|
||||||
@@ -91,11 +100,10 @@ class SocketController(object):
|
|||||||
logger.error("Error sending client packet to %s: %s" % (str(session_id), str(packet)))
|
logger.error("Error sending client packet to %s: %s" % (str(session_id), str(packet)))
|
||||||
logger.error("Error was: " + str(e))
|
logger.error("Error was: " + str(e))
|
||||||
|
|
||||||
@classmethod
|
def send_packet(self, packet, token_key):
|
||||||
def send_packet(cls, packet, token_key):
|
|
||||||
if not token_key:
|
if not token_key:
|
||||||
raise ValueError("token_key is required")
|
raise ValueError("token_key is required")
|
||||||
socket_sessions = SocketSessionManager.lookup(token_key=token_key)
|
socket_sessions = self.SocketSessionManager.lookup(token_key=token_key)
|
||||||
# We may not find the socket_session if the user disconnected
|
# We may not find the socket_session if the user disconnected
|
||||||
# (it's actually more compliciated than that because of our prune logic)
|
# (it's actually more compliciated than that because of our prune logic)
|
||||||
if not socket_sessions:
|
if not socket_sessions:
|
||||||
@@ -112,11 +120,12 @@ class SocketController(object):
|
|||||||
logger.error("Error sending client packet to %s: %s" % (str(socket_session.session_id), str(packet)))
|
logger.error("Error sending client packet to %s: %s" % (str(socket_session.session_id), str(packet)))
|
||||||
logger.error("Error was: " + str(e))
|
logger.error("Error was: " + str(e))
|
||||||
|
|
||||||
@classmethod
|
def set_server(self, server):
|
||||||
def set_server(cls, server):
|
self.server = server
|
||||||
cls.server = server
|
|
||||||
return server
|
return server
|
||||||
|
|
||||||
|
socketController = SocketController(SocketSessionManager())
|
||||||
|
|
||||||
#
|
#
|
||||||
# Socket session is attached to self.session['socket_session']
|
# Socket session is attached to self.session['socket_session']
|
||||||
# self.session and self.socket.session point to the same dict
|
# self.session and self.socket.session point to the same dict
|
||||||
@@ -140,7 +149,7 @@ class TowerBaseNamespace(BaseNamespace):
|
|||||||
socket_session = SocketSession(self.socket.sessid, request_token, self.socket)
|
socket_session = SocketSession(self.socket.sessid, request_token, self.socket)
|
||||||
if socket_session.is_db_token_valid():
|
if socket_session.is_db_token_valid():
|
||||||
self.session['socket_session'] = socket_session
|
self.session['socket_session'] = socket_session
|
||||||
SocketSessionManager.add_session(socket_session)
|
socketController.add_session(socket_session)
|
||||||
else:
|
else:
|
||||||
socket_session.invalidate()
|
socket_session.invalidate()
|
||||||
|
|
||||||
@@ -240,9 +249,9 @@ def notification_handler(server):
|
|||||||
|
|
||||||
if 'token_key' in message:
|
if 'token_key' in message:
|
||||||
# Best practice not to send the token over the socket
|
# Best practice not to send the token over the socket
|
||||||
SocketController.send_packet(packet, message.pop('token_key'))
|
socketController.send_packet(packet, message.pop('token_key'))
|
||||||
else:
|
else:
|
||||||
SocketController.broadcast_packet(packet)
|
socketController.broadcast_packet(packet)
|
||||||
|
|
||||||
class Command(NoArgsCommand):
|
class Command(NoArgsCommand):
|
||||||
'''
|
'''
|
||||||
@@ -270,7 +279,7 @@ class Command(NoArgsCommand):
|
|||||||
logger.info('Listening on port http://0.0.0.0:' + str(socketio_listen_port))
|
logger.info('Listening on port http://0.0.0.0:' + str(socketio_listen_port))
|
||||||
server = SocketIOServer(('0.0.0.0', socketio_listen_port), TowerSocket(), resource='socket.io')
|
server = SocketIOServer(('0.0.0.0', socketio_listen_port), TowerSocket(), resource='socket.io')
|
||||||
|
|
||||||
SocketController.set_server(server)
|
socketController.set_server(server)
|
||||||
handler_thread = Thread(target=notification_handler, args=(server,))
|
handler_thread = Thread(target=notification_handler, args=(server,))
|
||||||
handler_thread.daemon = True
|
handler_thread.daemon = True
|
||||||
handler_thread.start()
|
handler_thread.start()
|
||||||
|
|||||||
@@ -8,4 +8,5 @@ from .commands_monolithic import * # noqa
|
|||||||
from .cleanup_facts import * # noqa
|
from .cleanup_facts import * # noqa
|
||||||
from .age_deleted import * # noqa
|
from .age_deleted import * # noqa
|
||||||
from .remove_instance import * # noqa
|
from .remove_instance import * # noqa
|
||||||
|
from .run_socketio_service import * # noqa
|
||||||
|
|
||||||
|
|||||||
116
awx/main/tests/commands/run_socketio_service.py
Normal file
116
awx/main/tests/commands/run_socketio_service.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# Copyright (c) 2015 Ansible, Inc.
|
||||||
|
# All Rights Reserved
|
||||||
|
|
||||||
|
# Python
|
||||||
|
from mock import MagicMock, Mock
|
||||||
|
|
||||||
|
# Django
|
||||||
|
from django.test import SimpleTestCase
|
||||||
|
|
||||||
|
# AWX
|
||||||
|
from awx.fact.models.fact import * # noqa
|
||||||
|
from awx.main.management.commands.run_socketio_service import SocketSessionManager, SocketSession, SocketController
|
||||||
|
|
||||||
|
__all__ = ['SocketSessionManagerUnitTest', 'SocketControllerUnitTest',]
|
||||||
|
|
||||||
|
class WeakRefable():
|
||||||
|
pass
|
||||||
|
|
||||||
|
class SocketSessionManagerUnitTest(SimpleTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.session_manager = SocketSessionManager()
|
||||||
|
super(SocketSessionManagerUnitTest, self).setUp()
|
||||||
|
|
||||||
|
def create_sessions(self, count, token_key=None):
|
||||||
|
self.sessions = []
|
||||||
|
self.count = count
|
||||||
|
for i in range(0, count):
|
||||||
|
self.sessions.append(SocketSession(i, token_key or i, WeakRefable()))
|
||||||
|
self.session_manager.add_session(self.sessions[i])
|
||||||
|
|
||||||
|
def test_multiple_session_diff_token(self):
|
||||||
|
self.create_sessions(10)
|
||||||
|
|
||||||
|
for s in self.sessions:
|
||||||
|
self.assertIn(s.token_key, self.session_manager.socket_session_token_key_map)
|
||||||
|
self.assertEqual(s, self.session_manager.socket_session_token_key_map[s.token_key][s.session_id])
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_session_same_token(self):
|
||||||
|
self.create_sessions(10, token_key='foo')
|
||||||
|
|
||||||
|
sessions_dict = self.session_manager.lookup("foo")
|
||||||
|
self.assertEqual(len(sessions_dict), 10)
|
||||||
|
for s in self.sessions:
|
||||||
|
self.assertIn(s.session_id, sessions_dict)
|
||||||
|
self.assertEqual(s, sessions_dict[s.session_id])
|
||||||
|
|
||||||
|
def test_prune_sessions_max(self):
|
||||||
|
self.create_sessions(self.session_manager.SESSIONS_MAX + 10)
|
||||||
|
|
||||||
|
self.assertEqual(len(self.session_manager.socket_sessions), self.session_manager.SESSIONS_MAX)
|
||||||
|
|
||||||
|
|
||||||
|
class SocketControllerUnitTest(SimpleTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.socket_controller = SocketController(SocketSessionManager())
|
||||||
|
server = Mock()
|
||||||
|
self.socket_controller.set_server(server)
|
||||||
|
super(SocketControllerUnitTest, self).setUp()
|
||||||
|
|
||||||
|
def create_clients(self, count, token_key=None):
|
||||||
|
self.sessions = []
|
||||||
|
self.sockets =[]
|
||||||
|
self.count = count
|
||||||
|
self.sockets_dict = {}
|
||||||
|
for i in range(0, count):
|
||||||
|
if isinstance(token_key, list):
|
||||||
|
token_key_actual = token_key[i]
|
||||||
|
else:
|
||||||
|
token_key_actual = token_key or i
|
||||||
|
socket = MagicMock(session=dict())
|
||||||
|
socket_session = SocketSession(i, token_key_actual, socket)
|
||||||
|
self.sockets.append(socket)
|
||||||
|
self.sessions.append(socket_session)
|
||||||
|
self.sockets_dict[i] = socket
|
||||||
|
self.socket_controller.add_session(socket_session)
|
||||||
|
|
||||||
|
socket.session['socket_session'] = socket_session
|
||||||
|
socket.send_packet = Mock()
|
||||||
|
self.socket_controller.server.sockets = self.sockets_dict
|
||||||
|
|
||||||
|
def test_broadcast_packet(self):
|
||||||
|
self.create_clients(10)
|
||||||
|
packet = {
|
||||||
|
"hello": "world"
|
||||||
|
}
|
||||||
|
self.socket_controller.broadcast_packet(packet)
|
||||||
|
for s in self.sockets:
|
||||||
|
s.send_packet.assert_called_with(packet)
|
||||||
|
|
||||||
|
def test_send_packet(self):
|
||||||
|
self.create_clients(5, token_key=[0, 1, 2, 3, 4])
|
||||||
|
packet = {
|
||||||
|
"hello": "world"
|
||||||
|
}
|
||||||
|
self.socket_controller.send_packet(packet, 2)
|
||||||
|
self.assertEqual(0, len(self.sockets[0].send_packet.mock_calls))
|
||||||
|
self.assertEqual(0, len(self.sockets[1].send_packet.mock_calls))
|
||||||
|
self.sockets[2].send_packet.assert_called_once_with(packet)
|
||||||
|
self.assertEqual(0, len(self.sockets[3].send_packet.mock_calls))
|
||||||
|
self.assertEqual(0, len(self.sockets[4].send_packet.mock_calls))
|
||||||
|
|
||||||
|
def test_send_packet_multiple_sessions_one_token(self):
|
||||||
|
self.create_clients(5, token_key=[0, 1, 1, 1, 2])
|
||||||
|
packet = {
|
||||||
|
"hello": "world"
|
||||||
|
}
|
||||||
|
self.socket_controller.send_packet(packet, 1)
|
||||||
|
self.assertEqual(0, len(self.sockets[0].send_packet.mock_calls))
|
||||||
|
self.sockets[1].send_packet.assert_called_once_with(packet)
|
||||||
|
self.sockets[2].send_packet.assert_called_once_with(packet)
|
||||||
|
self.sockets[3].send_packet.assert_called_once_with(packet)
|
||||||
|
self.assertEqual(0, len(self.sockets[4].send_packet.mock_calls))
|
||||||
|
|
||||||
Reference in New Issue
Block a user