diff --git a/awx/lib/site-packages/README b/awx/lib/site-packages/README index 823c532620..2531bc8901 100644 --- a/awx/lib/site-packages/README +++ b/awx/lib/site-packages/README @@ -25,7 +25,7 @@ django-split-settings==0.1.1 (split_settings/*) django-taggit==0.11.2 (taggit/*) djangorestframework==2.3.13 (rest_framework/*) django-qsstats-magic==0.7.2 (django-qsstats-magic/*, minor fix in qsstats/__init__.py) -gevent-socketio==0.3.5-rc1 (socketio/*) +gevent-socketio==0.3.6 (socketio/*) gevent-websocket==0.9.3 (geventwebsocket/*) httplib2==0.8 (httplib2/*) importlib==1.0.3 (importlib/*, needed for Python 2.6 support) diff --git a/awx/lib/site-packages/socketio/__init__.py b/awx/lib/site-packages/socketio/__init__.py index ee4faba6a9..bffb366263 100644 --- a/awx/lib/site-packages/socketio/__init__.py +++ b/awx/lib/site-packages/socketio/__init__.py @@ -6,7 +6,8 @@ import gevent log = logging.getLogger(__name__) -def socketio_manage(environ, namespaces, request=None, error_handler=None): +def socketio_manage(environ, namespaces, request=None, error_handler=None, + json_loads=None, json_dumps=None): """Main SocketIO management function, call from within your Framework of choice's view. @@ -20,7 +21,7 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None): use Socket.GLOBAL_NS to be more explicit. So it would look like: .. code-block:: python - + namespaces={'': GlobalNamespace, '/chat': ChatNamespace} @@ -35,6 +36,11 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None): The callable you pass in should have the same signature as the default error handler. + The ``json_loads`` and ``json_dumps`` are overrides for the default + ``json.loads`` and ``json.dumps`` function calls. Override these at + the top-most level here. This will affect all sockets created by this + socketio manager, and all namespaces inside. + This function will block the current "view" or "controller" in your framework to do the recv/send on the socket, and dispatch incoming messages to your namespaces. @@ -45,6 +51,7 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None): def my_view(request): socketio_manage(request.environ, {'': GlobalNamespace}, request) + NOTE: You must understand that this function is going to be called *only once* per socket opening, *even though* you are using a long polling mechanism. The subsequent calls (for long polling) will @@ -67,10 +74,14 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None): if error_handler: socket._set_error_handler(error_handler) - receiver_loop = socket._spawn_receiver_loop() - watcher = socket._spawn_watcher() + if json_loads: + socket._set_json_loads(json_loads) + if json_dumps: + socket._set_json_dumps(json_dumps) - gevent.joinall([receiver_loop, watcher]) + receiver_loop = socket._spawn_receiver_loop() + + gevent.joinall([receiver_loop]) # TODO: double check, what happens to the WSGI request here ? it vanishes ? return diff --git a/awx/lib/site-packages/socketio/defaultjson.py b/awx/lib/site-packages/socketio/defaultjson.py new file mode 100644 index 0000000000..65dd9bf0b0 --- /dev/null +++ b/awx/lib/site-packages/socketio/defaultjson.py @@ -0,0 +1,21 @@ +### default json loaders +try: + import simplejson as json + json_decimal_args = {"use_decimal": True} # pragma: no cover +except ImportError: + import json + import decimal + + class DecimalEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, decimal.Decimal): + return float(o) + return super(DecimalEncoder, self).default(o) + json_decimal_args = {"cls": DecimalEncoder} + +def default_json_dumps(data): + return json.dumps(data, separators=(',', ':'), + **json_decimal_args) + +def default_json_loads(data): + return json.loads(data) diff --git a/awx/lib/site-packages/socketio/exceptions.py b/awx/lib/site-packages/socketio/exceptions.py new file mode 100644 index 0000000000..d9b738903f --- /dev/null +++ b/awx/lib/site-packages/socketio/exceptions.py @@ -0,0 +1,6 @@ +class SessionNotFound(Exception): + def __init__(self, sessid): + self.sessid = sessid + + def __str__(self): + return "Session %s not found!" % self.sessid diff --git a/awx/lib/site-packages/socketio/handler.py b/awx/lib/site-packages/socketio/handler.py index 2bdae0b263..e4b972f30a 100644 --- a/awx/lib/site-packages/socketio/handler.py +++ b/awx/lib/site-packages/socketio/handler.py @@ -5,17 +5,23 @@ import urlparse from gevent.pywsgi import WSGIHandler from socketio import transports -from geventwebsocket.handler import WebSocketHandler - class SocketIOHandler(WSGIHandler): RE_REQUEST_URL = re.compile(r""" - ^/(?P[^/]+) - /(?P[^/]+) + ^/(?P.+?) + /1 /(?P[^/]+) /(?P[^/]+)/?$ """, re.X) - RE_HANDSHAKE_URL = re.compile(r"^/(?P[^/]+)/1/$", re.X) + RE_HANDSHAKE_URL = re.compile(r"^/(?P.+?)/1/$", re.X) + # new socket.io versions (> 0.9.8) call an obscure url with two slashes + # instead of a transport when disconnecting + # https://github.com/LearnBoost/socket.io-client/blob/0.9.16/lib/socket.js#L361 + RE_DISCONNECT_URL = re.compile(r""" + ^/(?P.+?) + /(?P[^/]+) + //(?P[^/]+)/?$ + """, re.X) handler_types = { 'websocket': transports.WebsocketTransport, @@ -26,9 +32,16 @@ class SocketIOHandler(WSGIHandler): 'jsonp-polling': transports.JSONPolling, } - def __init__(self, *args, **kwargs): + def __init__(self, config, *args, **kwargs): + """Create a new SocketIOHandler. + + :param config: dict Configuration for timeouts and intervals + that will go down to the other components, transports, etc.. + + """ self.socketio_connection = False self.allowed_paths = None + self.config = config super(SocketIOHandler, self).__init__(*args, **kwargs) @@ -36,7 +49,7 @@ class SocketIOHandler(WSGIHandler): if self.server.transports: self.transports = self.server.transports if not set(self.transports).issubset(set(self.handler_types)): - raise Exception("transports should be elements of: %s" % + raise ValueError("transports should be elements of: %s" % (self.handler_types.keys())) def _do_handshake(self, tokens): @@ -44,7 +57,10 @@ class SocketIOHandler(WSGIHandler): self.log_error("socket.io URL mismatch") else: socket = self.server.get_socket() - data = "%s:15:10:%s" % (socket.sessid, ",".join(self.transports)) + data = "%s:%s:%s:%s" % (socket.sessid, + self.config['heartbeat_timeout'] or '', + self.config['close_timeout'] or '', + ",".join(self.transports)) self.write_smart(data) def write_jsonp_result(self, data, wrapper="0"): @@ -74,10 +90,16 @@ class SocketIOHandler(WSGIHandler): self.process_result() def handle_one_response(self): + """This function deals with *ONE INCOMING REQUEST* from the web. + + It will wire and exchange message to the queues for long-polling + methods, otherwise, will stay alive for websockets. + + """ path = self.environ.get('PATH_INFO') # Kick non-socket.io requests to our superclass - if not path.lstrip('/').startswith(self.server.resource): + if not path.lstrip('/').startswith(self.server.resource + '/'): return super(SocketIOHandler, self).handle_one_response() self.status = None @@ -85,64 +107,118 @@ class SocketIOHandler(WSGIHandler): self.result = None self.response_length = 0 self.response_use_chunked = False + + # This is analyzed for each and every HTTP requests involved + # in the Socket.IO protocol, whether long-running or long-polling + # (read: websocket or xhr-polling methods) request_method = self.environ.get("REQUEST_METHOD") request_tokens = self.RE_REQUEST_URL.match(path) + handshake_tokens = self.RE_HANDSHAKE_URL.match(path) + disconnect_tokens = self.RE_DISCONNECT_URL.match(path) - # Parse request URL and QUERY_STRING and do handshake - if request_tokens: - request_tokens = request_tokens.groupdict() + if handshake_tokens: + # Deal with first handshake here, create the Socket and push + # the config up. + return self._do_handshake(handshake_tokens.groupdict()) + elif disconnect_tokens: + # it's a disconnect request via XHR + tokens = disconnect_tokens.groupdict() + elif request_tokens: + tokens = request_tokens.groupdict() + # and continue... else: - handshake_tokens = self.RE_HANDSHAKE_URL.match(path) + # This is no socket.io request. Let the WSGI app handle it. + return super(SocketIOHandler, self).handle_one_response() - if handshake_tokens: - return self._do_handshake(handshake_tokens.groupdict()) - else: - # This is no socket.io request. Let the WSGI app handle it. - return super(SocketIOHandler, self).handle_one_response() - - # Setup the transport and socket - transport = self.handler_types.get(request_tokens["transport_id"]) - sessid = request_tokens["sessid"] + # Setup socket + sessid = tokens["sessid"] socket = self.server.get_socket(sessid) + if not socket: + self.handle_bad_request() + return [] # Do not say the session is not found, just bad request + # so they don't start brute forcing to find open sessions + + if self.environ['QUERY_STRING'].startswith('disconnect'): + # according to socket.io specs disconnect requests + # have a `disconnect` query string + # https://github.com/LearnBoost/socket.io-spec#forced-socket-disconnection + socket.disconnect() + self.handle_disconnect_request() + return [] + + # Setup transport + transport = self.handler_types.get(tokens["transport_id"]) # In case this is WebSocket request, switch to the WebSocketHandler # FIXME: fix this ugly class change + old_class = None if issubclass(transport, (transports.WebsocketTransport, transports.FlashSocketTransport)): - self.__class__ = WebSocketHandler + old_class = self.__class__ + self.__class__ = self.server.ws_handler_class self.prevent_wsgi_call = True # thank you # TODO: any errors, treat them ?? - self.handle_one_response() + self.handle_one_response() # does the Websocket dance before we continue # Make the socket object available for WSGI apps self.environ['socketio'] = socket # Create a transport and handle the request likewise - self.transport = transport(self) + self.transport = transport(self, self.config) - jobs = self.transport.connect(socket, request_method) - # Keep track of those jobs (reading, writing and heartbeat jobs) so - # that we can kill them later with Socket.kill() - socket.jobs.extend(jobs) + # transports register their own spawn'd jobs now + self.transport.do_exchange(socket, request_method) - try: - # We'll run the WSGI app if it wasn't already done. - if socket.wsgi_app_greenlet is None: - # TODO: why don't we spawn a call to handle_one_response here ? - # why call directly the WSGI machinery ? - start_response = lambda status, headers, exc=None: None - socket.wsgi_app_greenlet = gevent.spawn(self.application, - self.environ, - start_response) - except: - self.handle_error(*sys.exc_info()) + if not socket.connection_established: + # This is executed only on the *first* packet of the establishment + # of the virtual Socket connection. + socket.connection_established = True + socket.state = socket.STATE_CONNECTED + socket._spawn_heartbeat() + socket._spawn_watcher() - # TODO DOUBLE-CHECK: do we need to joinall here ? - gevent.joinall(jobs) + try: + # We'll run the WSGI app if it wasn't already done. + if socket.wsgi_app_greenlet is None: + # TODO: why don't we spawn a call to handle_one_response here ? + # why call directly the WSGI machinery ? + start_response = lambda status, headers, exc=None: None + socket.wsgi_app_greenlet = gevent.spawn(self.application, + self.environ, + start_response) + except: + self.handle_error(*sys.exc_info()) + + # we need to keep the connection open if we are an open socket + if tokens['transport_id'] in ['flashsocket', 'websocket']: + # wait here for all jobs to finished, when they are done + gevent.joinall(socket.jobs) + + # Switch back to the old class so references to this don't use the + # incorrect class. Useful for debugging. + if old_class: + self.__class__ = old_class + + # Clean up circular references so they can be garbage collected. + if hasattr(self, 'websocket') and self.websocket: + if hasattr(self.websocket, 'environ'): + del self.websocket.environ + del self.websocket + if self.environ: + del self.environ def handle_bad_request(self): self.close_connection = True - self.start_reponse("400 Bad Request", [ + self.start_response("400 Bad Request", [ + ('Content-Type', 'text/plain'), + ('Connection', 'close'), + ('Content-Length', 0) + ]) + + + def handle_disconnect_request(self): + self.close_connection = True + self.start_response("200 OK", [ ('Content-Type', 'text/plain'), ('Connection', 'close'), ('Content-Length', 0) diff --git a/awx/lib/site-packages/socketio/namespace.py b/awx/lib/site-packages/socketio/namespace.py index 6fc42af80c..cef5ca1d0a 100644 --- a/awx/lib/site-packages/socketio/namespace.py +++ b/awx/lib/site-packages/socketio/namespace.py @@ -27,6 +27,10 @@ class BaseNamespace(object): def on_my_second_event(self, whatever): print "This holds the first arg that was passed", whatever + Handlers are automatically dispatched based on the name of the incoming + event. For example, a 'user message' event will be handled by + ``on_user_message()``. To change this, override :meth:`process_event`. + We can also access the full packet directly by making an event handler that accepts a single argument named 'packet': @@ -43,9 +47,11 @@ class BaseNamespace(object): self.session = self.socket.session # easily accessible session self.request = request self.ns_name = ns_name - self.allowed_methods = None # be careful, None means all methods - # are allowed while an empty list - # means totally closed. + #: Store for ACL allowed methods. Be careful as ``None`` means + #: that all methods are allowed, while an empty list means every + #: method is denied. Value: list of strings or ``None``. You + #: can and should use the various ``acl`` methods to tweak this. + self.allowed_methods = None self.jobs = [] self.reset_acl() @@ -95,7 +101,7 @@ class BaseNamespace(object): access to all of the ``on_*()`` and ``recv_*()`` functions, etc.. methods. - Return something like: ``['on_connect', 'on_public_method']`` + Return something like: ``set(['recv_connect', 'on_public_method'])`` You can later modify this list dynamically (inside ``on_connect()`` for example) using: @@ -113,6 +119,9 @@ class BaseNamespace(object): **Beware**, returning ``None`` leaves the namespace completely accessible. + + The methods that are open are stored in the ``allowed_methods`` + attribute of the ``Namespace`` instance. """ return None @@ -160,10 +169,7 @@ class BaseNamespace(object): if not callback: print "ERROR: No such callback for ackId %s" % packet['ackId'] return - try: - return callback(*(packet['args'])) - except TypeError, e: - print "ERROR: Call to callback function failed", packet + return callback(*(packet['args'])) elif packet_type == 'disconnect': # Force a disconnect on the namespace. return self.call_method_with_acl('recv_disconnect', packet) @@ -178,17 +184,30 @@ class BaseNamespace(object): ``on_``-prefixed methods. You could then implement your own dispatch. See the source code for inspiration. - To process events that have callbacks on the client side, you - must define your event with a single parameter: ``packet``. - In this case, it will be the full ``packet`` object and you - can inspect its ``ack`` and ``id`` keys to define if and how - you reply. A correct reply to an event with a callback would - look like this: + There are two ways to deal with callbacks from the client side + (meaning, the browser has a callback waiting for data that this + server will be sending back): + + The first one is simply to return an object. If the incoming + packet requested has an 'ack' field set, meaning the browser is + waiting for callback data, it will automatically be packaged + and sent, associated with the 'ackId' from the browser. The + return value must be a *sequence* of elements, that will be + mapped to the positional parameters of the callback function + on the browser side. + + If you want to *know* that you're dealing with a packet + that requires a return value, you can do those things manually + by inspecting the ``ack`` and ``id`` keys from the ``packet`` + object. Your callback will behave specially if the name of + the argument to your method is ``packet``. It will fill it + with the unprocessed ``packet`` object for your inspection, + like this: .. code-block:: python def on_my_callback(self, packet): - if 'ack' in packet': + if 'ack' in packet: self.emit('go_back', 'param1', id=packet['id']) """ args = packet['args'] @@ -227,12 +246,17 @@ class BaseNamespace(object): Those are the two behaviors: * If there is only one parameter on the dispatched method and - it is equal to ``packet``, then pass in the packet as the + it is named ``packet``, then pass in the packet dict as the sole parameter. * Otherwise, pass in the arguments as specified by the different ``recv_*()`` methods args specs, or the :meth:`process_event` documentation. + + This method will also consider the + ``exception_handler_decorator``. See Namespace documentation + for details and examples. + """ method = getattr(self, method_name, None) if method is None: @@ -247,6 +271,11 @@ class BaseNamespace(object): "The server-side method is invalid, as it doesn't " "have 'self' as its first argument") return + + # Check if we need to decorate to handle exceptions + if hasattr(self, 'exception_handler_decorator'): + method = self.exception_handler_decorator(method) + if len(func_args) == 2 and func_args[1] == 'packet': return method(packet) else: @@ -264,10 +293,14 @@ class BaseNamespace(object): :func:`~socketio.socketio_manage`) without clogging the memory. - If you override this method, you probably want to only - initialize the variables you're going to use in the events of - this namespace, say, with some default values, but not perform - any operation that assumes authentication/authorization. + If you override this method, you probably want to initialize + the variables you're going to use in the events handled by this + namespace, setup ACLs, etc.. + + This method is called on all base classes following the _`method resolution order ` + so you don't need to call super() to initialize the mixins or + other derived classes. + """ pass @@ -432,8 +465,14 @@ class BaseNamespace(object): It will be monitored by the "watcher" process in the Socket. If the socket disconnects, all these greenlets are going to be killed, after calling BaseNamespace.disconnect() + + This method uses the ``exception_handler_decorator``. See + Namespace documentation for more information. + """ # self.log.debug("Spawning sub-Namespace Greenlet: %s" % fn.__name__) + if hasattr(self, 'exception_handler_decorator'): + fn = self.exception_handler_decorator(fn) new = gevent.spawn(fn, *args, **kwargs) self.jobs.append(new) return new @@ -454,8 +493,12 @@ class BaseNamespace(object): packet = {"type": "disconnect", "endpoint": self.ns_name} self.socket.send_packet(packet) - self.socket.remove_namespace(self.ns_name) - self.kill_local_jobs() + # remove_namespace might throw GreenletExit so + # kill_local_jobs must be in finally + try: + self.socket.remove_namespace(self.ns_name) + finally: + self.kill_local_jobs() def kill_local_jobs(self): """Kills all the jobs spawned with BaseNamespace.spawn() on a namespace diff --git a/awx/lib/site-packages/socketio/packet.py b/awx/lib/site-packages/socketio/packet.py index 35cd5c3eae..f626d40816 100644 --- a/awx/lib/site-packages/socketio/packet.py +++ b/awx/lib/site-packages/socketio/packet.py @@ -1,17 +1,4 @@ -try: - import simplejson as json - json_decimal_args = {"use_decimal": True} # pragma: no cover -except ImportError: - import json - import decimal - - class DecimalEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, decimal.Decimal): - return float(o) - return super(DecimalEncoder, self).default(o) - json_decimal_args = {"cls": DecimalEncoder} - +from socketio.defaultjson import default_json_dumps, default_json_loads MSG_TYPES = { 'disconnect': 0, @@ -45,7 +32,7 @@ socketio_packet_attributes = ['type', 'name', 'data', 'endpoint', 'args', 'ackId', 'reason', 'advice', 'qs', 'id'] -def encode(data): +def encode(data, json_dumps=default_json_dumps): """ Encode an attribute dict into a byte string. """ @@ -72,14 +59,13 @@ def encode(data): if msg == '3': payload = data['data'] if msg == '4': - payload = json.dumps(data['data'], separators=(',', ':'), - **json_decimal_args) + payload = json_dumps(data['data']) if msg == '5': d = {} d['name'] = data['name'] if 'args' in data and data['args'] != []: d['args'] = data['args'] - payload = json.dumps(d, separators=(',', ':'), **json_decimal_args) + payload = json_dumps(d) if 'id' in data: msg += ':' + str(data['id']) if data['ack'] == 'data': @@ -98,8 +84,7 @@ def encode(data): # '6:::' [id] '+' [data] msg += '::' + data.get('endpoint', '') + ':' + str(data['ackId']) if 'args' in data and data['args'] != []: - msg += '+' + json.dumps(data['args'], separators=(',', ':'), - **json_decimal_args) + msg += '+' + json_dumps(data['args']) elif msg == '7': # '7::' [endpoint] ':' [reason] '+' [advice] @@ -117,7 +102,7 @@ def encode(data): return msg -def decode(rawstr): +def decode(rawstr, json_loads=default_json_loads): """ Decode a rawstr packet arriving from the socket into a dict. """ @@ -163,11 +148,11 @@ def decode(rawstr): decoded_msg['data'] = data elif msg_type == "4": # json msg - decoded_msg['data'] = json.loads(data) + decoded_msg['data'] = json_loads(data) elif msg_type == "5": # event try: - data = json.loads(data) + data = json_loads(data) except ValueError, e: print("Invalid JSON event message", data) decoded_msg['args'] = [] @@ -182,7 +167,7 @@ def decode(rawstr): if '+' in data: ackId, data = data.split('+') decoded_msg['ackId'] = int(ackId) - decoded_msg['args'] = json.loads(data) + decoded_msg['args'] = json_loads(data) else: decoded_msg['ackId'] = int(data) decoded_msg['args'] = [] diff --git a/awx/lib/site-packages/socketio/sdjango.py b/awx/lib/site-packages/socketio/sdjango.py new file mode 100644 index 0000000000..d764cdc783 --- /dev/null +++ b/awx/lib/site-packages/socketio/sdjango.py @@ -0,0 +1,72 @@ +import logging + +from socketio import socketio_manage +from django.http import HttpResponse +from django.views.decorators.csrf import csrf_exempt +from django.utils.importlib import import_module + +# for Django 1.3 support +try: + from django.conf.urls import patterns, url, include +except ImportError: + from django.conf.urls.defaults import patterns, url, include + + +SOCKETIO_NS = {} + + +LOADING_SOCKETIO = False + + + +def autodiscover(): + """ + Auto-discover INSTALLED_APPS sockets.py modules and fail silently when + not present. NOTE: socketio_autodiscover was inspired/copied from + django.contrib.admin autodiscover + """ + global LOADING_SOCKETIO + if LOADING_SOCKETIO: + return + LOADING_SOCKETIO = True + + import imp + from django.conf import settings + + for app in settings.INSTALLED_APPS: + + try: + app_path = import_module(app).__path__ + except AttributeError: + continue + + try: + imp.find_module('sockets', app_path) + except ImportError: + continue + + import_module("%s.sockets" % app) + + LOADING_SOCKETIO = False + + +class namespace(object): + def __init__(self, name=''): + self.name = name + + def __call__(self, handler): + SOCKETIO_NS[self.name] = handler + return handler + + + +@csrf_exempt +def socketio(request): + try: + socketio_manage(request.environ, SOCKETIO_NS, request) + except: + logging.getLogger("socketio").error("Exception while handling socketio connection", exc_info=True) + return HttpResponse("") + + +urls = patterns("", (r'', socketio)) diff --git a/awx/lib/site-packages/socketio/server.py b/awx/lib/site-packages/socketio/server.py index 100ec577f0..74f242825e 100644 --- a/awx/lib/site-packages/socketio/server.py +++ b/awx/lib/site-packages/socketio/server.py @@ -8,6 +8,7 @@ from gevent.pywsgi import WSGIServer from socketio.handler import SocketIOHandler from socketio.policyserver import FlashPolicyServer from socketio.virtsocket import Socket +from geventwebsocket.handler import WebSocketHandler __all__ = ['SocketIOServer'] @@ -16,19 +17,41 @@ class SocketIOServer(WSGIServer): """A WSGI Server with a resource that acts like an SocketIO.""" def __init__(self, *args, **kwargs): - """ - This is just like the standard WSGIServer __init__, except with a + """This is just like the standard WSGIServer __init__, except with a few additional ``kwargs``: - :param resource: The URL which has to be identified as a socket.io request. Defaults to the /socket.io/ URL. + :param resource: The URL which has to be identified as a + socket.io request. Defaults to the /socket.io/ URL. + :param transports: Optional list of transports to allow. List of strings, each string should be one of handler.SocketIOHandler.handler_types. + :param policy_server: Boolean describing whether or not to use the Flash policy server. Default True. - :param policy_listener : A tuple containing (host, port) for the - policy server. This is optional and used only if policy server + + :param policy_listener: A tuple containing (host, port) for the + policy server. This is optional and used only if policy server is set to true. The default value is 0.0.0.0:843 + + :param heartbeat_interval: int The timeout for the server, we + should receive a heartbeat from the client within this + interval. This should be less than the + ``heartbeat_timeout``. + + :param heartbeat_timeout: int The timeout for the client when + it should send a new heartbeat to the server. This value + is sent to the client after a successful handshake. + + :param close_timeout: int The timeout for the client, when it + closes the connection it still X amounts of seconds to do + re open of the connection. This value is sent to the + client after a successful handshake. + + :param log_file: str The file in which you want the PyWSGI + server to write its access log. If not specified, it + is sent to `stderr` (with gevent 0.13). + """ self.sockets = {} if 'namespace' in kwargs: @@ -40,18 +63,48 @@ class SocketIOServer(WSGIServer): self.transports = kwargs.pop('transports', None) if kwargs.pop('policy_server', True): - policylistener = kwargs.pop('policy_listener', (args[0][0], 10843)) + try: + address = args[0][0] + except TypeError: + try: + address = args[0].address[0] + except AttributeError: + address = args[0].cfg_addr[0] + policylistener = kwargs.pop('policy_listener', (address, 10843)) self.policy_server = FlashPolicyServer(policylistener) else: self.policy_server = None - kwargs['handler_class'] = SocketIOHandler + # Extract other config options + self.config = { + 'heartbeat_timeout': 60, + 'close_timeout': 60, + 'heartbeat_interval': 25, + } + for f in ('heartbeat_timeout', 'heartbeat_interval', 'close_timeout'): + if f in kwargs: + self.config[f] = int(kwargs.pop(f)) + + if not 'handler_class' in kwargs: + kwargs['handler_class'] = SocketIOHandler + + + if not 'ws_handler_class' in kwargs: + self.ws_handler_class = WebSocketHandler + else: + self.ws_handler_class = kwargs.pop('ws_handler_class') + + log_file = kwargs.pop('log_file', None) + if log_file: + kwargs['log'] = open(log_file, 'a') + super(SocketIOServer, self).__init__(*args, **kwargs) def start_accepting(self): if self.policy_server is not None: try: - self.policy_server.start() + if not self.policy_server.started: + self.policy_server.start() except error, ex: sys.stderr.write( 'FAILED to start flash policy server: %s\n' % (ex, )) @@ -60,13 +113,14 @@ class SocketIOServer(WSGIServer): sys.stderr.write('FAILED to start flash policy server.\n\n') super(SocketIOServer, self).start_accepting() - def kill(self): + def stop(self, timeout=None): if self.policy_server is not None: - self.policy_server.kill() - super(SocketIOServer, self).kill() + self.policy_server.stop() + super(SocketIOServer, self).stop(timeout=timeout) def handle(self, socket, address): - handler = self.handler_class(socket, address, self) + # Pass in the config about timeouts, heartbeats, also... + handler = self.handler_class(self.config, socket, address, self) handler.handle() def get_socket(self, sessid=''): @@ -74,10 +128,59 @@ class SocketIOServer(WSGIServer): socket = self.sockets.get(sessid) + if sessid and not socket: + return None # you ask for a session that doesn't exist! if socket is None: - socket = Socket(self) + socket = Socket(self, self.config) self.sockets[socket.sessid] = socket else: socket.incr_hits() return socket + + +def serve(app, **kw): + _quiet = kw.pop('_quiet', False) + _resource = kw.pop('resource', 'socket.io') + if not _quiet: # pragma: no cover + # idempotent if logging has already been set up + import logging + logging.basicConfig() + + host = kw.pop('host', '127.0.0.1') + port = int(kw.pop('port', 6543)) + + transports = kw.pop('transports', None) + if transports: + transports = [x.strip() for x in transports.split(',')] + + policy_server = kw.pop('policy_server', False) + if policy_server in (True, 'True', 'true', 'enable', 'yes', 'on', '1'): + policy_server = True + policy_listener_host = kw.pop('policy_listener_host', host) + policy_listener_port = int(kw.pop('policy_listener_port', 10843)) + kw['policy_listener'] = (policy_listener_host, policy_listener_port) + else: + policy_server = False + + server = SocketIOServer((host, port), + app, + resource=_resource, + transports=transports, + policy_server=policy_server, + **kw) + if not _quiet: + print('serving on http://%s:%s' % (host, port)) + server.serve_forever() + + +def serve_paste(app, global_conf, **kw): + """pserve / paster serve / waitress replacement / integration + + You can pass as parameters: + + transports = websockets, xhr-multipart, xhr-longpolling, etc... + policy_server = True + """ + serve(app, **kw) + return 0 diff --git a/awx/lib/site-packages/socketio/sgunicorn.py b/awx/lib/site-packages/socketio/sgunicorn.py index aaa01131d9..a5a993a576 100644 --- a/awx/lib/site-packages/socketio/sgunicorn.py +++ b/awx/lib/site-packages/socketio/sgunicorn.py @@ -1,44 +1,182 @@ import os - import gevent +import time + from gevent.pool import Pool +from gevent.server import StreamServer from gunicorn.workers.ggevent import GeventPyWSGIWorker +from gunicorn.workers.ggevent import PyWSGIHandler +from gunicorn.workers.ggevent import GeventResponse +from gunicorn import version_info as gunicorn_version from socketio.server import SocketIOServer from socketio.handler import SocketIOHandler +from geventwebsocket.handler import WebSocketHandler + +from datetime import datetime +from functools import partial + + +class GunicornWSGIHandler(PyWSGIHandler, SocketIOHandler): + pass + + +class GunicornWebSocketWSGIHandler(WebSocketHandler): + def log_request(self): + start = datetime.fromtimestamp(self.time_start) + finish = datetime.fromtimestamp(self.time_finish) + response_time = finish - start + resp = GeventResponse(self.status, [], + self.response_length) + req_headers = [h.split(":", 1) for h in self.headers.headers] + self.server.log.access( + resp, req_headers, self.environ, response_time) + class GeventSocketIOBaseWorker(GeventPyWSGIWorker): """ The base gunicorn worker class """ + + transports = None + + def __init__(self, age, ppid, socket, app, timeout, cfg, log): + if os.environ.get('POLICY_SERVER', None) is None: + if self.policy_server: + os.environ['POLICY_SERVER'] = 'true' + else: + self.policy_server = False + + super(GeventSocketIOBaseWorker, self).__init__( + age, ppid, socket, app, timeout, cfg, log) + def run(self): - self.socket.setblocking(1) - pool = Pool(self.worker_connections) - self.server_class.base_env['wsgi.multiprocess'] = \ - self.cfg.workers > 1 - server = self.server_class( - self.socket, application=self.wsgi, - spawn=pool, handler_class=self.wsgi_handler, - namespace=self.namespace, policy_server=self.policy_server) - server.start() - try: - while self.alive: - self.notify() + if gunicorn_version >= (0, 17, 0): + servers = [] + ssl_args = {} - if self.ppid != os.getppid(): - self.log.info("Parent changed, shutting down: %s", self) - break + if self.cfg.is_ssl: + ssl_args = dict( + server_side=True, + do_handshake_on_connect=False, + **self.cfg.ssl_options + ) - gevent.sleep(1.0) + for s in self.sockets: + s.setblocking(1) + pool = Pool(self.worker_connections) + if self.server_class is not None: + self.server_class.base_env['wsgi.multiprocess'] = \ + self.cfg.workers > 1 - except KeyboardInterrupt: - pass + server = self.server_class( + s, + application=self.wsgi, + spawn=pool, + resource=self.resource, + log=self.log, + policy_server=self.policy_server, + handler_class=self.wsgi_handler, + ws_handler_class=self.ws_wsgi_handler, + **ssl_args + ) + else: + hfun = partial(self.handle, s) + server = StreamServer( + s, handle=hfun, spawn=pool, **ssl_args) - # try to stop the connections - try: - self.notify() - server.stop(timeout=self.timeout) - except: - pass + server.start() + servers.append(server) + + pid = os.getpid() + try: + while self.alive: + self.notify() + + if pid == os.getpid() and self.ppid != os.getppid(): + self.log.info( + "Parent changed, shutting down: %s", self) + break + + gevent.sleep(1.0) + + except KeyboardInterrupt: + pass + + try: + # Stop accepting requests + [server.stop_accepting() for server in servers] + + # Handle current requests until graceful_timeout + ts = time.time() + while time.time() - ts <= self.cfg.graceful_timeout: + accepting = 0 + for server in servers: + if server.pool.free_count() != server.pool.size: + accepting += 1 + + if not accepting: + return + + self.notify() + gevent.sleep(1.0) + + # Force kill all active the handlers + self.log.warning("Worker graceful timeout (pid:%s)" % self.pid) + [server.stop(timeout=1) for server in servers] + except: + pass + else: + self.socket.setblocking(1) + pool = Pool(self.worker_connections) + self.server_class.base_env['wsgi.multiprocess'] = \ + self.cfg.workers > 1 + + server = self.server_class( + self.socket, + application=self.wsgi, + spawn=pool, + resource=self.resource, + log=self.log, + policy_server=self.policy_server, + handler_class=self.wsgi_handler, + ws_handler_class=self.ws_wsgi_handler, + ) + + server.start() + pid = os.getpid() + + try: + while self.alive: + self.notify() + + if pid == os.getpid() and self.ppid != os.getppid(): + self.log.info( + "Parent changed, shutting down: %s", self) + break + + gevent.sleep(1.0) + + except KeyboardInterrupt: + pass + + try: + # Stop accepting requests + server.kill() + + # Handle current requests until graceful_timeout + ts = time.time() + while time.time() - ts <= self.cfg.graceful_timeout: + if server.pool.free_count() == server.pool.size: + return # all requests was handled + + self.notify() + gevent.sleep(1.0) + + # Force kill all active the handlers + self.log.warning("Worker graceful timeout (pid:%s)" % self.pid) + server.stop(timeout=1) + except: + pass class GeventSocketIOWorker(GeventSocketIOBaseWorker): @@ -49,9 +187,20 @@ class GeventSocketIOWorker(GeventSocketIOBaseWorker): being disabled. """ server_class = SocketIOServer - wsgi_handler = SocketIOHandler + wsgi_handler = GunicornWSGIHandler + ws_wsgi_handler = GunicornWebSocketWSGIHandler # We need to define a namespace for the server, it would be nice if this # was a configuration option, will probably end up how this implemented, # for now this is just a proof of concept to make sure this will work - namespace = 'socket.io' - policy_server = False # Don't run the flash policy server + resource = 'socket.io' + policy_server = True + + +class NginxGeventSocketIOWorker(GeventSocketIOWorker): + """ + Worker which will not attempt to connect via websocket transport + + Nginx is not compatible with websockets and therefore will not add the + wsgi.websocket key to the wsgi environment. + """ + transports = ['xhr-polling'] diff --git a/awx/lib/site-packages/socketio/storage.py b/awx/lib/site-packages/socketio/storage.py new file mode 100644 index 0000000000..e980b0a547 --- /dev/null +++ b/awx/lib/site-packages/socketio/storage.py @@ -0,0 +1,30 @@ +import gevent +import weakref + +try: + import redis +except ImportError: + pass + + +class RedisStorage(object): + def __init__(self, server, **kwargs): + self.server = weakref.proxy(server) + self.jobs = [] + self.host = kwargs.get('host', 'localhost') + self.port = kwargs.get('port', 6379) + r = redis.StrictRedis(host=self.host, port=self.port) + self.conn = r.pubsub() + self.spawn(self.listener) + + def listener(self): + for m in self.conn.listen(): + print("===============NEW MESSAGE!!!====== %s", m) + + def spawn(self, fn, *args, **kwargs): + new = gevent.spawn(fn, *args, **kwargs) + self.jobs.append(new) + return new + + def new_request(self, environ): + print("===========NEW REQUEST %s===========" % environ) diff --git a/awx/lib/site-packages/socketio/transports.py b/awx/lib/site-packages/socketio/transports.py index 6564b5abd8..d4b7ab22c0 100644 --- a/awx/lib/site-packages/socketio/transports.py +++ b/awx/lib/site-packages/socketio/transports.py @@ -1,13 +1,21 @@ import gevent import urllib - +import urlparse +from geventwebsocket import WebSocketError from gevent.queue import Empty class BaseTransport(object): """Base class for all transports. Mostly wraps handler class functions.""" - def __init__(self, handler): + def __init__(self, handler, config, **kwargs): + """Base transport class. + + :param config: dict Should contain the config keys, like + ``heartbeat_interval``, ``heartbeat_timeout`` and + ``close_timeout``. + + """ self.content_type = ("Content-Type", "text/plain; charset=UTF-8") self.headers = [ ("Access-Control-Allow-Origin", "*"), @@ -15,22 +23,28 @@ class BaseTransport(object): ("Access-Control-Allow-Methods", "POST, GET, OPTIONS"), ("Access-Control-Max-Age", 3600), ] - self.headers_list = [] self.handler = handler + self.config = config def write(self, data=""): - if 'Content-Length' not in self.handler.response_headers_list: - self.handler.response_headers.append(('Content-Length', len(data))) - self.handler.response_headers_list.append('Content-Length') + # Gevent v 0.13 + if hasattr(self.handler, 'response_headers_list'): + if 'Content-Length' not in self.handler.response_headers_list: + self.handler.response_headers.append(('Content-Length', len(data))) + self.handler.response_headers_list.append('Content-Length') + elif not hasattr(self.handler, 'provided_content_length') or self.handler.provided_content_length is None: + # Gevent 1.0bX + l = len(data) + self.handler.provided_content_length = l + self.handler.response_headers.append(('Content-Length', l)) - self.handler.write(data) + self.handler.write_smart(data) def start_response(self, status, headers, **kwargs): if "Content-Type" not in [x[0] for x in headers]: headers.append(self.content_type) headers.extend(self.headers) - #print headers self.handler.start_response(status, headers, **kwargs) @@ -46,15 +60,14 @@ class XHRPollingTransport(BaseTransport): def get(self, socket): socket.heartbeat() - payload = self.get_messages_payload(socket, timeout=5.0) + heartbeat_interval = self.config['heartbeat_interval'] + payload = self.get_messages_payload(socket, timeout=heartbeat_interval) if not payload: payload = "8::" # NOOP self.start_response("200 OK", []) self.write(payload) - return [] - def _request_body(self): return self.handler.wsgi_input.readline() @@ -68,8 +81,6 @@ class XHRPollingTransport(BaseTransport): ]) self.write("1") - return [] - def get_messages_payload(self, socket, timeout=None): """This will fetch the messages from the Socket's queue, and if there are many messes, pack multiple messages in one payload and return @@ -87,14 +98,16 @@ class XHRPollingTransport(BaseTransport): ``messages`` - List of raw messages to encode, if necessary """ - if not messages: + if not messages or messages[0] is None: return '' if len(messages) == 1: return messages[0].encode('utf-8') - payload = u''.join(u'\ufffd%d\ufffd%s' % (len(p), p) - for p in messages) + payload = u''.join([(u'\ufffd%d\ufffd%s' % (len(p), p)) + for p in messages if p is not None]) + # FIXME: why is it so that we must filter None from here ? How + # is it even possible that a None gets in there ? return payload.encode('utf-8') @@ -115,7 +128,6 @@ class XHRPollingTransport(BaseTransport): """ payload = payload.decode('utf-8') if payload[0] == u"\ufffd": - #print "MULTIMSG FULL", payload ret = [] while len(payload) != 0: len_end = payload.find(u"\ufffd", 1) @@ -123,21 +135,19 @@ class XHRPollingTransport(BaseTransport): msg_start = len_end + 1 msg_end = length + msg_start message = payload[msg_start:msg_end] - #print "MULTIMSG", length, message ret.append(message) payload = payload[msg_end:] return ret return [payload] - def connect(self, socket, request_method): - if not socket.connection_confirmed: - socket.connection_confirmed = True + def do_exchange(self, socket, request_method): + if not socket.connection_established: + # Runs only the first time we get a Socket opening self.start_response("200 OK", [ ("Connection", "close"), ]) self.write("1::") # 'connect' packet - - return [] + return elif request_method in ("GET", "POST", "OPTIONS"): return getattr(self, request_method.lower())(socket) else: @@ -145,21 +155,33 @@ class XHRPollingTransport(BaseTransport): class JSONPolling(XHRPollingTransport): - def __init__(self, handler): - super(JSONPolling, self).__init__(handler) + def __init__(self, handler, config): + super(JSONPolling, self).__init__(handler, config) self.content_type = ("Content-Type", "text/javascript; charset=UTF-8") def _request_body(self): data = super(JSONPolling, self)._request_body() # resolve %20%3F's, take out wrapping d="...", etc.. - return urllib.unquote_plus(data)[3:-1] \ + data = urllib.unquote_plus(data)[3:-1] \ .replace(r'\"', '"') \ .replace(r"\\", "\\") + # For some reason, in case of multiple messages passed in one + # query, IE7 sends it escaped, not utf-8 encoded. This dirty + # hack handled it + if data[0] == "\\": + data = data.decode("unicode_escape").encode("utf-8") + return data + def write(self, data): """Just quote out stuff before sending it out""" + args = urlparse.parse_qs(self.handler.environ.get("QUERY_STRING")) + if "i" in args: + i = args["i"] + else: + i = "0" # TODO: don't we need to quote this data in here ? - super(JSONPolling, self).write("io.j[0]('%s');" % data) + super(JSONPolling, self).write("io.j[%s]('%s');" % (i, data)) class XHRMultipartTransport(XHRPollingTransport): @@ -170,11 +192,9 @@ class XHRMultipartTransport(XHRPollingTransport): "multipart/x-mixed-replace;boundary=\"socketio\"" ) - def connect(self, socket, request_method): + def do_exchange(self, socket, request_method): if request_method == "GET": - # TODO: double verify this, because we're not sure. look at git revs. - heartbeat = socket._spawn_heartbeat() - return [heartbeat] + self.get(socket) + return self.get(socket) elif request_method == "POST": return self.post(socket) else: @@ -196,22 +216,28 @@ class XHRMultipartTransport(XHRPollingTransport): if not payload: # That would mean the call to Queue.get() returned Empty, # so it was in fact killed, since we pass no timeout=.. - socket.kill() - break + return + # See below else: try: self.write_multipart(header) self.write_multipart(payload) self.write_multipart("--socketio\r\n") except socket.error: - socket.kill() - break + # The client might try to reconnect, even with a socket + # error, so let's just let it go, and not kill the + # socket completely. Other processes will ensure + # we kill everything if the user expires the timeouts. + # + # WARN: this means that this payload is LOST, unless we + # decide to re-inject it into the queue. + return - return [gevent.spawn(chunk)] + socket.spawn(chunk) class WebsocketTransport(BaseTransport): - def connect(self, socket, request_method): + def do_exchange(self, socket, request_method): websocket = self.handler.environ['wsgi.websocket'] websocket.send("1::") # 'connect' packet @@ -220,27 +246,26 @@ class WebsocketTransport(BaseTransport): message = socket.get_client_msg() if message is None: - socket.kill() break - - websocket.send(message) + try: + websocket.send(message) + except (WebSocketError, TypeError): + # We can't send a message on the socket + # it is dead, let the other sockets know + socket.disconnect() def read_from_ws(): while True: message = websocket.receive() - if not message: - socket.kill() + if message is None: break else: if message is not None: socket.put_server_msg(message) - gr1 = gevent.spawn(send_into_ws) - gr2 = gevent.spawn(read_from_ws) - heartbeat1, heartbeat2 = socket._spawn_heartbeat() - - return [gr1, gr2, heartbeat1, heartbeat2] + socket.spawn(send_into_ws) + socket.spawn(read_from_ws) class FlashSocketTransport(WebsocketTransport): @@ -250,21 +275,43 @@ class FlashSocketTransport(WebsocketTransport): class HTMLFileTransport(XHRPollingTransport): """Not tested at all!""" - def __init__(self, handler): - super(HTMLFileTransport, self).__init__(handler) + def __init__(self, handler, config): + super(HTMLFileTransport, self).__init__(handler, config) self.content_type = ("Content-Type", "text/html") def write_packed(self, data): - self.write("" % data) + self.write("" % data) - def handle_get_response(self, socket): + def write(self, data): + l = 1024 * 5 + super(HTMLFileTransport, self).write("%d\r\n%s%s\r\n" % (l, data, " " * (l - len(data)))) + + def do_exchange(self, socket, request_method): + return super(HTMLFileTransport, self).do_exchange(socket, request_method) + + def get(self, socket): self.start_response("200 OK", [ ("Connection", "keep-alive"), ("Content-Type", "text/html"), ("Transfer-Encoding", "chunked"), ]) - self.write("" + " " * 244) + self.write("") + self.write_packed("1::") # 'connect' packet - payload = self.get_messages_payload(socket, timeout=5.0) - self.write_packed(payload) + def chunk(): + while True: + payload = self.get_messages_payload(socket) + + if not payload: + # That would mean the call to Queue.get() returned Empty, + # so it was in fact killed, since we pass no timeout=.. + return + else: + try: + self.write_packed(payload) + except socket.error: + # See comments for XHRMultipart + return + + socket.spawn(chunk) diff --git a/awx/lib/site-packages/socketio/virtsocket.py b/awx/lib/site-packages/socketio/virtsocket.py index 8218313f25..bd79c01a35 100644 --- a/awx/lib/site-packages/socketio/virtsocket.py +++ b/awx/lib/site-packages/socketio/virtsocket.py @@ -11,12 +11,17 @@ in a different way """ import random import weakref +import logging import gevent from gevent.queue import Queue from gevent.event import Event from socketio import packet +from socketio.defaultjson import default_json_loads, default_json_dumps + + +log = logging.getLogger(__name__) def default_error_handler(socket, error_name, error_message, endpoint, @@ -40,10 +45,11 @@ def default_error_handler(socket, error_name, error_message, endpoint, # Send an error event through the Socket if not quiet: socket.send_packet(pkt) - + # Log that error somewhere for debugging... - print "default_error_handler: %s, %s (endpoint=%s, msg_id=%s)" % ( - error_name, error_message, endpoint, msg_id) + log.error(u"default_error_handler: {}, {} (endpoint={}, msg_id={})".format( + error_name, error_message, endpoint, msg_id + )) class Socket(object): @@ -65,7 +71,10 @@ class Socket(object): """Use this to be explicit when specifying a Global Namespace (an endpoint with no name, not '/chat' or anything.""" - def __init__(self, server, error_handler=None): + json_loads = staticmethod(default_json_loads) + json_dumps = staticmethod(default_json_dumps) + + def __init__(self, server, config, error_handler=None): self.server = weakref.proxy(server) self.sessid = str(random.random())[2:] self.session = {} # the session dict, for general developer usage @@ -76,7 +85,7 @@ class Socket(object): self.timeout = Event() self.wsgi_app_greenlet = None self.state = "NEW" - self.connection_confirmed = False + self.connection_established = False self.ack_callbacks = {} self.ack_counter = 0 self.request = None @@ -85,6 +94,7 @@ class Socket(object): self.active_ns = {} # Namespace sessions that were instantiated self.jobs = [] self.error_handler = default_error_handler + self.config = config if error_handler is not None: self.error_handler = error_handler @@ -116,6 +126,22 @@ class Socket(object): """ self.error_handler = error_handler + def _set_json_loads(self, json_loads): + """Change the default JSON decoder. + + This should be a callable that accepts a single string, and returns + a well-formed object. + """ + self.json_loads = json_loads + + def _set_json_dumps(self, json_dumps): + """Change the default JSON decoder. + + This should be a callable that accepts a single string, and returns + a well-formed object. + """ + self.json_dumps = json_dumps + def _get_next_msgid(self): """This retrieves the next value for the 'id' field when sending an 'event' or 'message' or 'json' that asks the remote client @@ -175,9 +201,6 @@ class Socket(object): def incr_hits(self): self.hits += 1 - if self.hits == 1: - self.state = self.STATE_CONNECTED - def heartbeat(self): """This makes the heart beat for another X seconds. Call this when you get a heartbeat packet in. @@ -186,7 +209,7 @@ class Socket(object): """ self.timeout.set() - def kill(self): + def kill(self, detach=False): """This function must/will be called when a socket is to be completely shut down, closed by connection timeout, connection error or explicit disconnection from the client. @@ -202,14 +225,23 @@ class Socket(object): self.state = self.STATE_DISCONNECTING self.server_queue.put_nowait(None) self.client_queue.put_nowait(None) - self.disconnect() + if len(self.active_ns) > 0: + log.debug("Calling disconnect() on %s" % self) + self.disconnect() - if self.sessid in self.server.sockets: - self.server.sockets.pop(self.sessid) + if detach: + self.detach() - # gevent.kill(self.wsgi_app_greenlet) - else: - pass # Fail silently + gevent.killall(self.jobs) + + def detach(self): + """Detach this socket from the server. This should be done in + conjunction with kill(), once all the jobs are dead, detach the + socket for garbage collection.""" + + log.debug("Removing %s from server sockets" % self) + if self.sessid in self.server.sockets: + self.server.sockets.pop(self.sessid) def put_server_msg(self, msg): """Writes to the server's pipe, to end up in in the Namespaces""" @@ -218,7 +250,6 @@ class Socket(object): def put_client_msg(self, msg): """Writes to the client's pipe, to end up in the browser""" - self.heartbeat() self.client_queue.put_nowait(msg) def get_client_msg(self, **kwargs): @@ -293,10 +324,13 @@ class Socket(object): if namespace in self.active_ns: del self.active_ns[namespace] + if len(self.active_ns) == 0 and self.connected: + self.kill(detach=True) + def send_packet(self, pkt): """Low-level interface to queue a packet on the wire (encoded as wire protocol""" - self.put_client_msg(packet.encode(pkt)) + self.put_client_msg(packet.encode(pkt, self.json_dumps)) def spawn(self, fn, *args, **kwargs): """Spawn a new Greenlet, attached to this Socket instance. @@ -304,7 +338,7 @@ class Socket(object): It will be monitored by the "watcher" method """ - self.debug("Spawning sub-Socket Greenlet: %s" % fn.__name__) + log.debug("Spawning sub-Socket Greenlet: %s" % fn.__name__) job = gevent.spawn(fn, *args, **kwargs) self.jobs.append(job) return job @@ -312,6 +346,13 @@ class Socket(object): def _receiver_loop(self): """This is the loop that takes messages from the queue for the server to consume, decodes them and dispatches them. + + It is the main loop for a socket. We join on this process before + returning control to the web framework. + + This process is not tracked by the socket itself, it is not going + to be killed by the ``gevent.killall(socket.jobs)``, so it must + exit gracefully itself. """ while True: @@ -320,7 +361,7 @@ class Socket(object): if not rawdata: continue # or close the connection ? try: - pkt = packet.decode(rawdata) + pkt = packet.decode(rawdata, self.json_loads) except (ValueError, KeyError, Exception), e: self.error('invalid_packet', "There was a decoding error when dealing with packet " @@ -334,7 +375,7 @@ class Socket(object): if pkt['type'] == 'disconnect' and pkt['endpoint'] == '': # On global namespace, we kill everything. - self.kill() + self.kill(detach=True) continue endpoint = pkt['endpoint'] @@ -350,8 +391,12 @@ class Socket(object): new_ns_class = self.namespaces[endpoint] pkt_ns = new_ns_class(self.environ, endpoint, request=self.request) - pkt_ns.initialize() # use this instead of __init__, - # for less confusion + # This calls initialize() on all the classes and mixins, etc.. + # in the order of the MRO + for cls in type(pkt_ns).__mro__: + if hasattr(cls, 'initialize'): + cls.initialize(pkt_ns) # use this instead of __init__, + # for less confusion self.active_ns[endpoint] = pkt_ns @@ -359,15 +404,20 @@ class Socket(object): # Has the client requested an 'ack' with the reply parameters ? if pkt.get('ack') == "data" and pkt.get('id'): + if type(retval) is tuple: + args = list(retval) + else: + args = [retval] returning_ack = dict(type='ack', ackId=pkt['id'], - args=retval, + args=args, endpoint=pkt.get('endpoint', '')) self.send_packet(returning_ack) # Now, are we still connected ? if not self.connected: - self.kill() # ?? what,s the best clean-up when its not a - # user-initiated disconnect + self.kill(detach=True) # ?? what,s the best clean-up + # when its not a + # user-initiated disconnect return def _spawn_receiver_loop(self): @@ -379,50 +429,51 @@ class Socket(object): return job def _watcher(self): - """Watch if any of the greenlets for a request have died. If so, kill - the request and the socket. - """ - # TODO: add that if any of the request.jobs die, kill them all and exit - gevent.sleep(5.0) + """Watch out if we've been disconnected, in that case, kill + all the jobs. + """ while True: gevent.sleep(1.0) - if not self.connected: - # Killing Socket-level jobs - gevent.killall(self.jobs) for ns_name, ns in list(self.active_ns.iteritems()): ns.recv_disconnect() + # Killing Socket-level jobs + gevent.killall(self.jobs) break def _spawn_watcher(self): + """This one is not waited for with joinall(socket.jobs), as it + is an external watcher, to clean up when everything is done.""" job = gevent.spawn(self._watcher) return job def _heartbeat(self): """Start the heartbeat Greenlet to check connection health.""" - self.state = self.STATE_CONNECTED - + interval = self.config['heartbeat_interval'] while self.connected: - gevent.sleep(5.0) # FIXME: make this a setting + gevent.sleep(interval) # TODO: this process could use a timeout object like the disconnect # timeout thing, and ONLY send packets when none are sent! # We would do that by calling timeout.set() for a "sending" # timeout. If we're sending 100 messages a second, there is # no need to push some heartbeats in there also. - self.put_client_msg("2::") # TODO: make it a heartbeat packet + self.put_client_msg("2::") - def _disconnect_timeout(self): - self.timeout.clear() + def _heartbeat_timeout(self): + timeout = float(self.config['heartbeat_timeout']) + while True: + self.timeout.clear() + gevent.sleep(0) + wait_res = self.timeout.wait(timeout=timeout) + if not wait_res: + if self.connected: + log.debug("heartbeat timed out, killing socket") + self.kill(detach=True) + return - if self.timeout.wait(10.0): - gevent.spawn(self._disconnect_timeout) - else: - self.kill() def _spawn_heartbeat(self): """This functions returns a list of jobs""" - job_sender = gevent.spawn(self._heartbeat) - job_waiter = gevent.spawn(self._disconnect_timeout) - self.jobs.extend((job_sender, job_waiter)) - return job_sender, job_waiter + self.spawn(self._heartbeat) + self.spawn(self._heartbeat_timeout)