mirror of
https://github.com/ansible/awx.git
synced 2026-05-19 14:57:39 -02:30
adds socket tests
This commit is contained in:
@@ -46,43 +46,52 @@ class SocketSession(object):
|
||||
return bool(not auth_token.is_expired())
|
||||
|
||||
class SocketSessionManager(object):
|
||||
socket_sessions = []
|
||||
socket_session_token_key_map = {}
|
||||
|
||||
@classmethod
|
||||
def _prune(cls):
|
||||
if len(cls.socket_sessions) > 1000:
|
||||
session = cls.socket_session[0]
|
||||
del cls.socket_session_token_key_map[session.token_key]
|
||||
cls.sessions = cls.socket_sessions[1:]
|
||||
def __init__(self):
|
||||
self.SESSIONS_MAX = 1000
|
||||
self.socket_sessions = []
|
||||
self.socket_session_token_key_map = {}
|
||||
|
||||
def _prune(self):
|
||||
if len(self.socket_sessions) > self.SESSIONS_MAX:
|
||||
session = self.socket_sessions[0]
|
||||
entries = self.socket_session_token_key_map[session.token_key]
|
||||
del entries[session.session_id]
|
||||
if len(entries) == 0:
|
||||
del self.socket_session_token_key_map[session.token_key]
|
||||
self.socket_sessions.pop(0)
|
||||
|
||||
'''
|
||||
Returns an dict of sessions <session_id, session>
|
||||
'''
|
||||
@classmethod
|
||||
def lookup(cls, token_key=None):
|
||||
def lookup(self, token_key=None):
|
||||
if not token_key:
|
||||
raise ValueError("token_key required")
|
||||
return cls.socket_session_token_key_map.get(token_key, None)
|
||||
return self.socket_session_token_key_map.get(token_key, None)
|
||||
|
||||
@classmethod
|
||||
def add_session(cls, session):
|
||||
cls.socket_sessions.append(session)
|
||||
entries = cls.socket_session_token_key_map.get(session.token_key, None)
|
||||
def add_session(self, session):
|
||||
self.socket_sessions.append(session)
|
||||
entries = self.socket_session_token_key_map.get(session.token_key, None)
|
||||
if not entries:
|
||||
entries = {}
|
||||
cls.socket_session_token_key_map[session.token_key] = entries
|
||||
self.socket_session_token_key_map[session.token_key] = entries
|
||||
entries[session.session_id] = session
|
||||
cls._prune()
|
||||
self._prune()
|
||||
return session
|
||||
|
||||
class SocketController(object):
|
||||
server = None
|
||||
|
||||
@classmethod
|
||||
def broadcast_packet(cls, packet):
|
||||
def __init__(self, SocketSessionManager):
|
||||
self.server = None
|
||||
self.SocketSessionManager = SocketSessionManager
|
||||
|
||||
def add_session(self, session):
|
||||
return self.SocketSessionManager.add_session(session)
|
||||
|
||||
def broadcast_packet(self, packet):
|
||||
# Broadcast message to everyone at endpoint
|
||||
# Loop over the 'raw' list of sockets (don't trust our list)
|
||||
for session_id, socket in list(cls.server.sockets.iteritems()):
|
||||
for session_id, socket in list(self.server.sockets.iteritems()):
|
||||
socket_session = socket.session.get('socket_session', None)
|
||||
if socket_session and socket_session.is_valid():
|
||||
try:
|
||||
@@ -91,11 +100,10 @@ class SocketController(object):
|
||||
logger.error("Error sending client packet to %s: %s" % (str(session_id), str(packet)))
|
||||
logger.error("Error was: " + str(e))
|
||||
|
||||
@classmethod
|
||||
def send_packet(cls, packet, token_key):
|
||||
def send_packet(self, packet, token_key):
|
||||
if not token_key:
|
||||
raise ValueError("token_key is required")
|
||||
socket_sessions = SocketSessionManager.lookup(token_key=token_key)
|
||||
socket_sessions = self.SocketSessionManager.lookup(token_key=token_key)
|
||||
# We may not find the socket_session if the user disconnected
|
||||
# (it's actually more compliciated than that because of our prune logic)
|
||||
if not socket_sessions:
|
||||
@@ -112,11 +120,12 @@ class SocketController(object):
|
||||
logger.error("Error sending client packet to %s: %s" % (str(socket_session.session_id), str(packet)))
|
||||
logger.error("Error was: " + str(e))
|
||||
|
||||
@classmethod
|
||||
def set_server(cls, server):
|
||||
cls.server = server
|
||||
def set_server(self, server):
|
||||
self.server = server
|
||||
return server
|
||||
|
||||
socketController = SocketController(SocketSessionManager())
|
||||
|
||||
#
|
||||
# Socket session is attached to self.session['socket_session']
|
||||
# self.session and self.socket.session point to the same dict
|
||||
@@ -140,7 +149,7 @@ class TowerBaseNamespace(BaseNamespace):
|
||||
socket_session = SocketSession(self.socket.sessid, request_token, self.socket)
|
||||
if socket_session.is_db_token_valid():
|
||||
self.session['socket_session'] = socket_session
|
||||
SocketSessionManager.add_session(socket_session)
|
||||
socketController.add_session(socket_session)
|
||||
else:
|
||||
socket_session.invalidate()
|
||||
|
||||
@@ -240,9 +249,9 @@ def notification_handler(server):
|
||||
|
||||
if 'token_key' in message:
|
||||
# Best practice not to send the token over the socket
|
||||
SocketController.send_packet(packet, message.pop('token_key'))
|
||||
socketController.send_packet(packet, message.pop('token_key'))
|
||||
else:
|
||||
SocketController.broadcast_packet(packet)
|
||||
socketController.broadcast_packet(packet)
|
||||
|
||||
class Command(NoArgsCommand):
|
||||
'''
|
||||
@@ -270,7 +279,7 @@ class Command(NoArgsCommand):
|
||||
logger.info('Listening on port http://0.0.0.0:' + str(socketio_listen_port))
|
||||
server = SocketIOServer(('0.0.0.0', socketio_listen_port), TowerSocket(), resource='socket.io')
|
||||
|
||||
SocketController.set_server(server)
|
||||
socketController.set_server(server)
|
||||
handler_thread = Thread(target=notification_handler, args=(server,))
|
||||
handler_thread.daemon = True
|
||||
handler_thread.start()
|
||||
|
||||
Reference in New Issue
Block a user