adds socket tests

This commit is contained in:
Chris Meyers
2015-11-06 10:30:56 -05:00
parent 014419f6ca
commit 3f4913f5ab
3 changed files with 157 additions and 31 deletions

View File

@@ -46,43 +46,52 @@ class SocketSession(object):
return bool(not auth_token.is_expired())
class SocketSessionManager(object):
socket_sessions = []
socket_session_token_key_map = {}
@classmethod
def _prune(cls):
if len(cls.socket_sessions) > 1000:
session = cls.socket_session[0]
del cls.socket_session_token_key_map[session.token_key]
cls.sessions = cls.socket_sessions[1:]
def __init__(self):
self.SESSIONS_MAX = 1000
self.socket_sessions = []
self.socket_session_token_key_map = {}
def _prune(self):
if len(self.socket_sessions) > self.SESSIONS_MAX:
session = self.socket_sessions[0]
entries = self.socket_session_token_key_map[session.token_key]
del entries[session.session_id]
if len(entries) == 0:
del self.socket_session_token_key_map[session.token_key]
self.socket_sessions.pop(0)
'''
Returns an dict of sessions <session_id, session>
'''
@classmethod
def lookup(cls, token_key=None):
def lookup(self, token_key=None):
if not token_key:
raise ValueError("token_key required")
return cls.socket_session_token_key_map.get(token_key, None)
return self.socket_session_token_key_map.get(token_key, None)
@classmethod
def add_session(cls, session):
cls.socket_sessions.append(session)
entries = cls.socket_session_token_key_map.get(session.token_key, None)
def add_session(self, session):
self.socket_sessions.append(session)
entries = self.socket_session_token_key_map.get(session.token_key, None)
if not entries:
entries = {}
cls.socket_session_token_key_map[session.token_key] = entries
self.socket_session_token_key_map[session.token_key] = entries
entries[session.session_id] = session
cls._prune()
self._prune()
return session
class SocketController(object):
server = None
@classmethod
def broadcast_packet(cls, packet):
def __init__(self, SocketSessionManager):
self.server = None
self.SocketSessionManager = SocketSessionManager
def add_session(self, session):
return self.SocketSessionManager.add_session(session)
def broadcast_packet(self, packet):
# Broadcast message to everyone at endpoint
# Loop over the 'raw' list of sockets (don't trust our list)
for session_id, socket in list(cls.server.sockets.iteritems()):
for session_id, socket in list(self.server.sockets.iteritems()):
socket_session = socket.session.get('socket_session', None)
if socket_session and socket_session.is_valid():
try:
@@ -91,11 +100,10 @@ class SocketController(object):
logger.error("Error sending client packet to %s: %s" % (str(session_id), str(packet)))
logger.error("Error was: " + str(e))
@classmethod
def send_packet(cls, packet, token_key):
def send_packet(self, packet, token_key):
if not token_key:
raise ValueError("token_key is required")
socket_sessions = SocketSessionManager.lookup(token_key=token_key)
socket_sessions = self.SocketSessionManager.lookup(token_key=token_key)
# We may not find the socket_session if the user disconnected
# (it's actually more compliciated than that because of our prune logic)
if not socket_sessions:
@@ -112,11 +120,12 @@ class SocketController(object):
logger.error("Error sending client packet to %s: %s" % (str(socket_session.session_id), str(packet)))
logger.error("Error was: " + str(e))
@classmethod
def set_server(cls, server):
cls.server = server
def set_server(self, server):
self.server = server
return server
socketController = SocketController(SocketSessionManager())
#
# Socket session is attached to self.session['socket_session']
# self.session and self.socket.session point to the same dict
@@ -140,7 +149,7 @@ class TowerBaseNamespace(BaseNamespace):
socket_session = SocketSession(self.socket.sessid, request_token, self.socket)
if socket_session.is_db_token_valid():
self.session['socket_session'] = socket_session
SocketSessionManager.add_session(socket_session)
socketController.add_session(socket_session)
else:
socket_session.invalidate()
@@ -240,9 +249,9 @@ def notification_handler(server):
if 'token_key' in message:
# Best practice not to send the token over the socket
SocketController.send_packet(packet, message.pop('token_key'))
socketController.send_packet(packet, message.pop('token_key'))
else:
SocketController.broadcast_packet(packet)
socketController.broadcast_packet(packet)
class Command(NoArgsCommand):
'''
@@ -270,7 +279,7 @@ class Command(NoArgsCommand):
logger.info('Listening on port http://0.0.0.0:' + str(socketio_listen_port))
server = SocketIOServer(('0.0.0.0', socketio_listen_port), TowerSocket(), resource='socket.io')
SocketController.set_server(server)
socketController.set_server(server)
handler_thread = Thread(target=notification_handler, args=(server,))
handler_thread.daemon = True
handler_thread.start()