mirror of
https://github.com/ansible/awx.git
synced 2026-05-19 14:57:39 -02:30
session limit invalidation events via socket.io
This commit is contained in:
@@ -64,6 +64,7 @@ from awx.api.permissions import * # noqa
|
|||||||
from awx.api.renderers import * # noqa
|
from awx.api.renderers import * # noqa
|
||||||
from awx.api.serializers import * # noqa
|
from awx.api.serializers import * # noqa
|
||||||
from awx.fact.models import * # noqa
|
from awx.fact.models import * # noqa
|
||||||
|
from awx.main.utils import emit_websocket_notification
|
||||||
|
|
||||||
def api_exception_handler(exc):
|
def api_exception_handler(exc):
|
||||||
'''
|
'''
|
||||||
@@ -528,16 +529,20 @@ class AuthTokenView(APIView):
|
|||||||
reason='')[0]
|
reason='')[0]
|
||||||
token.refresh()
|
token.refresh()
|
||||||
except IndexError:
|
except IndexError:
|
||||||
|
token = AuthToken.objects.create(user=serializer.object['user'],
|
||||||
|
request_hash=request_hash)
|
||||||
# Get user un-expired tokens that are not invalidated that are
|
# Get user un-expired tokens that are not invalidated that are
|
||||||
# over the configured limit.
|
# over the configured limit.
|
||||||
# Mark them as invalid and inform the user
|
# Mark them as invalid and inform the user
|
||||||
invalid_tokens = AuthToken.get_tokens_over_limit(serializer.object['user'])
|
invalid_tokens = AuthToken.get_tokens_over_limit(serializer.object['user'])
|
||||||
for t in invalid_tokens:
|
for t in invalid_tokens:
|
||||||
# TODO: send socket notification
|
# TODO: send socket notification
|
||||||
|
emit_websocket_notification('/socket.io/control',
|
||||||
|
'limit_reached',
|
||||||
|
dict(reason=unicode(AuthToken.reason_long('limit_reached'))),
|
||||||
|
token_key=t.key)
|
||||||
t.invalidate(reason='limit_reached')
|
t.invalidate(reason='limit_reached')
|
||||||
|
|
||||||
token = AuthToken.objects.create(user=serializer.object['user'],
|
|
||||||
request_hash=request_hash)
|
|
||||||
return Response({'token': token.key, 'expires': token.expires})
|
return Response({'token': token.key, 'expires': token.expires})
|
||||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
|
import weakref
|
||||||
from optparse import make_option
|
from optparse import make_option
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
@@ -24,51 +25,139 @@ from socketio.namespace import BaseNamespace
|
|||||||
|
|
||||||
logger = logging.getLogger('awx.main.commands.run_socketio_service')
|
logger = logging.getLogger('awx.main.commands.run_socketio_service')
|
||||||
|
|
||||||
valid_sockets = []
|
class SocketSession(object):
|
||||||
|
def __init__(self, session_id, token_key, socket):
|
||||||
|
self.socket = weakref.ref(socket)
|
||||||
|
self.session_id = session_id
|
||||||
|
self.token_key = token_key
|
||||||
|
self._valid = True
|
||||||
|
|
||||||
|
def is_valid(self):
|
||||||
|
return bool(self._valid)
|
||||||
|
|
||||||
|
def invalidate(self):
|
||||||
|
self._valid = False
|
||||||
|
|
||||||
|
def is_db_token_valid(self):
|
||||||
|
auth_token = AuthToken.objects.filter(key=self.token_key, reason='')
|
||||||
|
if not auth_token.exists():
|
||||||
|
return False
|
||||||
|
auth_token = auth_token[0]
|
||||||
|
return bool(not auth_token.is_expired())
|
||||||
|
|
||||||
|
class SocketSessionManager(object):
|
||||||
|
sessions = []
|
||||||
|
session_token_key_map = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _prune(cls):
|
||||||
|
#cls.sessions = [s for s in cls.sessions if s.is_valid()]
|
||||||
|
if len(cls.sessions) > 1000:
|
||||||
|
del cls.session_token_key_map[cls.sessions[0].token_key]
|
||||||
|
cls.sessions = cls.sessions[1:]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def lookup(cls, token_key=None):
|
||||||
|
if not token_key:
|
||||||
|
raise ValueError("token_key required")
|
||||||
|
for s in cls.sessions:
|
||||||
|
if s.token_key == token_key:
|
||||||
|
return s
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_session(cls, session):
|
||||||
|
cls.sessions.append(session)
|
||||||
|
cls.session_token_key_map[session.token_key] = session
|
||||||
|
cls._prune()
|
||||||
|
|
||||||
|
class SocketController(object):
|
||||||
|
server = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def broadcast_packet(cls, packet):
|
||||||
|
# Broadcast message to everyone at endpoint
|
||||||
|
# Loop over the 'raw' list of sockets (don't trust our list)
|
||||||
|
for session_id, socket in list(cls.server.sockets.iteritems()):
|
||||||
|
socket_session = socket.session.get('socket_session', None)
|
||||||
|
if socket_session and socket_session.is_valid():
|
||||||
|
try:
|
||||||
|
socket.send_packet(packet)
|
||||||
|
except Exception, e:
|
||||||
|
logger.error("Error sending client packet to %s: %s" % (str(session_id), str(packet)))
|
||||||
|
logger.error("Error was: " + str(e))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def send_packet(cls, packet, token_key):
|
||||||
|
if not token_key:
|
||||||
|
raise ValueError("token_key is required")
|
||||||
|
socket_session = SocketSessionManager.lookup(token_key=token_key)
|
||||||
|
# We may not find the socket_session if the user disconnected
|
||||||
|
# (it's actually more compliciated than that because of our prune logic)
|
||||||
|
if socket_session and socket_session.is_valid():
|
||||||
|
socket = socket_session.socket()
|
||||||
|
if socket:
|
||||||
|
try:
|
||||||
|
socket.send_packet(packet)
|
||||||
|
except Exception, e:
|
||||||
|
logger.error("Error sending client packet to %s: %s" % (str(socket_session.session_id), str(packet)))
|
||||||
|
logger.error("Error was: " + str(e))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_server(cls, server):
|
||||||
|
cls.server = server
|
||||||
|
return server
|
||||||
|
|
||||||
|
#
|
||||||
|
# Socket session is attached to self.session['socket_session']
|
||||||
|
# self.session and self.socket.session point to the same dict
|
||||||
|
#
|
||||||
class TowerBaseNamespace(BaseNamespace):
|
class TowerBaseNamespace(BaseNamespace):
|
||||||
|
|
||||||
def get_allowed_methods(self):
|
def get_allowed_methods(self):
|
||||||
return ['recv_disconnect']
|
return ['recv_disconnect']
|
||||||
|
|
||||||
def get_initial_acl(self):
|
def get_initial_acl(self):
|
||||||
global valid_sockets
|
request_token = self._get_request_token()
|
||||||
v_user = self.valid_user()
|
if request_token:
|
||||||
self.is_valid_connection = False
|
# (1) This is the first time the socket has been seen (first
|
||||||
if v_user:
|
# namespace joined).
|
||||||
if self.socket.sessid not in valid_sockets:
|
# (2) This socket has already been seen (already joined and maybe
|
||||||
valid_sockets.append(self.socket.sessid)
|
# left a namespace)
|
||||||
self.is_valid_connection = True
|
#
|
||||||
if len(valid_sockets) > 1000:
|
# Note: Assume that the user token is valid if the session is found
|
||||||
valid_sockets = valid_sockets[1:]
|
socket_session = self.session.get('socket_session', None)
|
||||||
|
if not socket_session:
|
||||||
|
socket_session = SocketSession(self.socket.sessid, request_token, self.socket)
|
||||||
|
if socket_session.is_db_token_valid():
|
||||||
|
self.session['socket_session'] = socket_session
|
||||||
|
SocketSessionManager.add_session(socket_session)
|
||||||
|
else:
|
||||||
|
socket_session.invalidate()
|
||||||
|
|
||||||
return set(['recv_connect'] + self.get_allowed_methods())
|
return set(['recv_connect'] + self.get_allowed_methods())
|
||||||
else:
|
else:
|
||||||
logger.warn("Authentication Failure validating user")
|
logger.warn("Authentication Failure validating user")
|
||||||
self.emit("connect_failed", "Authentication failed")
|
self.emit("connect_failed", "Authentication failed")
|
||||||
return set(['recv_connect'])
|
return set(['recv_connect'])
|
||||||
|
|
||||||
def valid_user(self):
|
def _get_request_token(self):
|
||||||
if 'QUERY_STRING' not in self.environ:
|
if 'QUERY_STRING' not in self.environ:
|
||||||
return False
|
return False
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
k, v = self.environ['QUERY_STRING'].split("=")
|
k, v = self.environ['QUERY_STRING'].split("=")
|
||||||
if k == "Token":
|
if k == "Token":
|
||||||
token_actual = urllib.unquote_plus(v).decode().replace("\"","")
|
token_actual = urllib.unquote_plus(v).decode().replace("\"","")
|
||||||
auth_token = AuthToken.objects.filter(key=token_actual, reason='')
|
return token_actual
|
||||||
if not auth_token.exists():
|
except Exception, e:
|
||||||
return False
|
logger.error("Exception validating user: " + str(e))
|
||||||
auth_token = auth_token[0]
|
return False
|
||||||
if not auth_token.is_expired():
|
return False
|
||||||
return auth_token.user
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
except Exception, e:
|
|
||||||
logger.error("Exception validating user: " + str(e))
|
|
||||||
return False
|
|
||||||
|
|
||||||
def recv_connect(self):
|
def recv_connect(self):
|
||||||
if not self.is_valid_connection:
|
socket_session = self.session.get('socket_session', None)
|
||||||
|
if socket_session and not socket_session.is_valid():
|
||||||
self.disconnect(silent=False)
|
self.disconnect(silent=False)
|
||||||
|
|
||||||
class TestNamespace(TowerBaseNamespace):
|
class TestNamespace(TowerBaseNamespace):
|
||||||
@@ -106,6 +195,14 @@ class ScheduleNamespace(TowerBaseNamespace):
|
|||||||
logger.info("Received client connect for schedule namespace from %s" % str(self.environ['REMOTE_ADDR']))
|
logger.info("Received client connect for schedule namespace from %s" % str(self.environ['REMOTE_ADDR']))
|
||||||
super(ScheduleNamespace, self).recv_connect()
|
super(ScheduleNamespace, self).recv_connect()
|
||||||
|
|
||||||
|
# Catch-all namespace.
|
||||||
|
# Deliver 'global' events over this namespace
|
||||||
|
class ControlNamespace(TowerBaseNamespace):
|
||||||
|
|
||||||
|
def recv_connect(self):
|
||||||
|
logger.warn("Received client connect for control namespace from %s" % str(self.environ['REMOTE_ADDR']))
|
||||||
|
super(ControlNamespace, self).recv_connect()
|
||||||
|
|
||||||
class TowerSocket(object):
|
class TowerSocket(object):
|
||||||
|
|
||||||
def __call__(self, environ, start_response):
|
def __call__(self, environ, start_response):
|
||||||
@@ -115,7 +212,8 @@ class TowerSocket(object):
|
|||||||
'/socket.io/jobs': JobNamespace,
|
'/socket.io/jobs': JobNamespace,
|
||||||
'/socket.io/job_events': JobEventNamespace,
|
'/socket.io/job_events': JobEventNamespace,
|
||||||
'/socket.io/ad_hoc_command_events': AdHocCommandEventNamespace,
|
'/socket.io/ad_hoc_command_events': AdHocCommandEventNamespace,
|
||||||
'/socket.io/schedules': ScheduleNamespace})
|
'/socket.io/schedules': ScheduleNamespace,
|
||||||
|
'/socket.io/control': ControlNamespace})
|
||||||
else:
|
else:
|
||||||
logger.warn("Invalid connect path received: " + path)
|
logger.warn("Invalid connect path received: " + path)
|
||||||
start_response('404 Not Found', [])
|
start_response('404 Not Found', [])
|
||||||
@@ -130,13 +228,12 @@ def notification_handler(server):
|
|||||||
'name': message['event'],
|
'name': message['event'],
|
||||||
'type': 'event',
|
'type': 'event',
|
||||||
}
|
}
|
||||||
for session_id, socket in list(server.sockets.iteritems()):
|
|
||||||
if session_id in valid_sockets:
|
if 'token_key' in message:
|
||||||
try:
|
# Best practice not to send the token over the socket
|
||||||
socket.send_packet(packet)
|
SocketController.send_packet(packet, message.pop('token_key'))
|
||||||
except Exception, e:
|
else:
|
||||||
logger.error("Error sending client packet to %s: %s" % (str(session_id), str(packet)))
|
SocketController.broadcast_packet(packet)
|
||||||
logger.error("Error was: " + str(e))
|
|
||||||
|
|
||||||
class Command(NoArgsCommand):
|
class Command(NoArgsCommand):
|
||||||
'''
|
'''
|
||||||
@@ -164,6 +261,7 @@ class Command(NoArgsCommand):
|
|||||||
logger.info('Listening on port http://0.0.0.0:' + str(socketio_listen_port))
|
logger.info('Listening on port http://0.0.0.0:' + str(socketio_listen_port))
|
||||||
server = SocketIOServer(('0.0.0.0', socketio_listen_port), TowerSocket(), resource='socket.io')
|
server = SocketIOServer(('0.0.0.0', socketio_listen_port), TowerSocket(), resource='socket.io')
|
||||||
|
|
||||||
|
SocketController.set_server(server)
|
||||||
handler_thread = Thread(target=notification_handler, args=(server,))
|
handler_thread = Thread(target=notification_handler, args=(server,))
|
||||||
handler_thread.daemon = True
|
handler_thread.daemon = True
|
||||||
handler_thread.start()
|
handler_thread.start()
|
||||||
|
|||||||
@@ -389,11 +389,13 @@ def get_system_task_capacity():
|
|||||||
return 50 + ((int(total_mem_value) / 1024) - 2) * 75
|
return 50 + ((int(total_mem_value) / 1024) - 2) * 75
|
||||||
|
|
||||||
|
|
||||||
def emit_websocket_notification(endpoint, event, payload):
|
def emit_websocket_notification(endpoint, event, payload, token_key=None):
|
||||||
from awx.main.socket import Socket
|
from awx.main.socket import Socket
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Socket('websocket', 'w', nowait=True, logger=logger) as websocket:
|
with Socket('websocket', 'w', nowait=True, logger=logger) as websocket:
|
||||||
|
if token_key:
|
||||||
|
payload['token_key'] = token_key
|
||||||
payload['event'] = event
|
payload['event'] = event
|
||||||
payload['endpoint'] = endpoint
|
payload['endpoint'] = endpoint
|
||||||
websocket.publish(payload)
|
websocket.publish(payload)
|
||||||
|
|||||||
Reference in New Issue
Block a user