Upgrade gevent-socketio to 0.3.6

This commit is contained in:
Matthew Jones
2014-08-06 14:00:16 -04:00
parent e3bdd966ed
commit b18b075bc9
13 changed files with 831 additions and 237 deletions

View File

@@ -25,7 +25,7 @@ django-split-settings==0.1.1 (split_settings/*)
django-taggit==0.11.2 (taggit/*) django-taggit==0.11.2 (taggit/*)
djangorestframework==2.3.13 (rest_framework/*) djangorestframework==2.3.13 (rest_framework/*)
django-qsstats-magic==0.7.2 (django-qsstats-magic/*, minor fix in qsstats/__init__.py) django-qsstats-magic==0.7.2 (django-qsstats-magic/*, minor fix in qsstats/__init__.py)
gevent-socketio==0.3.5-rc1 (socketio/*) gevent-socketio==0.3.6 (socketio/*)
gevent-websocket==0.9.3 (geventwebsocket/*) gevent-websocket==0.9.3 (geventwebsocket/*)
httplib2==0.8 (httplib2/*) httplib2==0.8 (httplib2/*)
importlib==1.0.3 (importlib/*, needed for Python 2.6 support) importlib==1.0.3 (importlib/*, needed for Python 2.6 support)

View File

@@ -6,7 +6,8 @@ import gevent
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def socketio_manage(environ, namespaces, request=None, error_handler=None): def socketio_manage(environ, namespaces, request=None, error_handler=None,
json_loads=None, json_dumps=None):
"""Main SocketIO management function, call from within your Framework of """Main SocketIO management function, call from within your Framework of
choice's view. choice's view.
@@ -20,7 +21,7 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None):
use Socket.GLOBAL_NS to be more explicit. So it would look like: use Socket.GLOBAL_NS to be more explicit. So it would look like:
.. code-block:: python .. code-block:: python
namespaces={'': GlobalNamespace, namespaces={'': GlobalNamespace,
'/chat': ChatNamespace} '/chat': ChatNamespace}
@@ -35,6 +36,11 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None):
The callable you pass in should have the same signature as the default The callable you pass in should have the same signature as the default
error handler. error handler.
The ``json_loads`` and ``json_dumps`` are overrides for the default
``json.loads`` and ``json.dumps`` function calls. Override these at
the top-most level here. This will affect all sockets created by this
socketio manager, and all namespaces inside.
This function will block the current "view" or "controller" in your This function will block the current "view" or "controller" in your
framework to do the recv/send on the socket, and dispatch incoming messages framework to do the recv/send on the socket, and dispatch incoming messages
to your namespaces. to your namespaces.
@@ -45,6 +51,7 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None):
def my_view(request): def my_view(request):
socketio_manage(request.environ, {'': GlobalNamespace}, request) socketio_manage(request.environ, {'': GlobalNamespace}, request)
NOTE: You must understand that this function is going to be called NOTE: You must understand that this function is going to be called
*only once* per socket opening, *even though* you are using a long *only once* per socket opening, *even though* you are using a long
polling mechanism. The subsequent calls (for long polling) will polling mechanism. The subsequent calls (for long polling) will
@@ -67,10 +74,14 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None):
if error_handler: if error_handler:
socket._set_error_handler(error_handler) socket._set_error_handler(error_handler)
receiver_loop = socket._spawn_receiver_loop() if json_loads:
watcher = socket._spawn_watcher() socket._set_json_loads(json_loads)
if json_dumps:
socket._set_json_dumps(json_dumps)
gevent.joinall([receiver_loop, watcher]) receiver_loop = socket._spawn_receiver_loop()
gevent.joinall([receiver_loop])
# TODO: double check, what happens to the WSGI request here ? it vanishes ? # TODO: double check, what happens to the WSGI request here ? it vanishes ?
return return

View File

@@ -0,0 +1,21 @@
### default json loaders
try:
import simplejson as json
json_decimal_args = {"use_decimal": True} # pragma: no cover
except ImportError:
import json
import decimal
class DecimalEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, decimal.Decimal):
return float(o)
return super(DecimalEncoder, self).default(o)
json_decimal_args = {"cls": DecimalEncoder}
def default_json_dumps(data):
return json.dumps(data, separators=(',', ':'),
**json_decimal_args)
def default_json_loads(data):
return json.loads(data)

View File

@@ -0,0 +1,6 @@
class SessionNotFound(Exception):
def __init__(self, sessid):
self.sessid = sessid
def __str__(self):
return "Session %s not found!" % self.sessid

View File

@@ -5,17 +5,23 @@ import urlparse
from gevent.pywsgi import WSGIHandler from gevent.pywsgi import WSGIHandler
from socketio import transports from socketio import transports
from geventwebsocket.handler import WebSocketHandler
class SocketIOHandler(WSGIHandler): class SocketIOHandler(WSGIHandler):
RE_REQUEST_URL = re.compile(r""" RE_REQUEST_URL = re.compile(r"""
^/(?P<resource>[^/]+) ^/(?P<resource>.+?)
/(?P<protocol_version>[^/]+) /1
/(?P<transport_id>[^/]+) /(?P<transport_id>[^/]+)
/(?P<sessid>[^/]+)/?$ /(?P<sessid>[^/]+)/?$
""", re.X) """, re.X)
RE_HANDSHAKE_URL = re.compile(r"^/(?P<resource>[^/]+)/1/$", re.X) RE_HANDSHAKE_URL = re.compile(r"^/(?P<resource>.+?)/1/$", re.X)
# new socket.io versions (> 0.9.8) call an obscure url with two slashes
# instead of a transport when disconnecting
# https://github.com/LearnBoost/socket.io-client/blob/0.9.16/lib/socket.js#L361
RE_DISCONNECT_URL = re.compile(r"""
^/(?P<resource>.+?)
/(?P<protocol_version>[^/]+)
//(?P<sessid>[^/]+)/?$
""", re.X)
handler_types = { handler_types = {
'websocket': transports.WebsocketTransport, 'websocket': transports.WebsocketTransport,
@@ -26,9 +32,16 @@ class SocketIOHandler(WSGIHandler):
'jsonp-polling': transports.JSONPolling, 'jsonp-polling': transports.JSONPolling,
} }
def __init__(self, *args, **kwargs): def __init__(self, config, *args, **kwargs):
"""Create a new SocketIOHandler.
:param config: dict Configuration for timeouts and intervals
that will go down to the other components, transports, etc..
"""
self.socketio_connection = False self.socketio_connection = False
self.allowed_paths = None self.allowed_paths = None
self.config = config
super(SocketIOHandler, self).__init__(*args, **kwargs) super(SocketIOHandler, self).__init__(*args, **kwargs)
@@ -36,7 +49,7 @@ class SocketIOHandler(WSGIHandler):
if self.server.transports: if self.server.transports:
self.transports = self.server.transports self.transports = self.server.transports
if not set(self.transports).issubset(set(self.handler_types)): if not set(self.transports).issubset(set(self.handler_types)):
raise Exception("transports should be elements of: %s" % raise ValueError("transports should be elements of: %s" %
(self.handler_types.keys())) (self.handler_types.keys()))
def _do_handshake(self, tokens): def _do_handshake(self, tokens):
@@ -44,7 +57,10 @@ class SocketIOHandler(WSGIHandler):
self.log_error("socket.io URL mismatch") self.log_error("socket.io URL mismatch")
else: else:
socket = self.server.get_socket() socket = self.server.get_socket()
data = "%s:15:10:%s" % (socket.sessid, ",".join(self.transports)) data = "%s:%s:%s:%s" % (socket.sessid,
self.config['heartbeat_timeout'] or '',
self.config['close_timeout'] or '',
",".join(self.transports))
self.write_smart(data) self.write_smart(data)
def write_jsonp_result(self, data, wrapper="0"): def write_jsonp_result(self, data, wrapper="0"):
@@ -74,10 +90,16 @@ class SocketIOHandler(WSGIHandler):
self.process_result() self.process_result()
def handle_one_response(self): def handle_one_response(self):
"""This function deals with *ONE INCOMING REQUEST* from the web.
It will wire and exchange message to the queues for long-polling
methods, otherwise, will stay alive for websockets.
"""
path = self.environ.get('PATH_INFO') path = self.environ.get('PATH_INFO')
# Kick non-socket.io requests to our superclass # Kick non-socket.io requests to our superclass
if not path.lstrip('/').startswith(self.server.resource): if not path.lstrip('/').startswith(self.server.resource + '/'):
return super(SocketIOHandler, self).handle_one_response() return super(SocketIOHandler, self).handle_one_response()
self.status = None self.status = None
@@ -85,64 +107,118 @@ class SocketIOHandler(WSGIHandler):
self.result = None self.result = None
self.response_length = 0 self.response_length = 0
self.response_use_chunked = False self.response_use_chunked = False
# This is analyzed for each and every HTTP requests involved
# in the Socket.IO protocol, whether long-running or long-polling
# (read: websocket or xhr-polling methods)
request_method = self.environ.get("REQUEST_METHOD") request_method = self.environ.get("REQUEST_METHOD")
request_tokens = self.RE_REQUEST_URL.match(path) request_tokens = self.RE_REQUEST_URL.match(path)
handshake_tokens = self.RE_HANDSHAKE_URL.match(path)
disconnect_tokens = self.RE_DISCONNECT_URL.match(path)
# Parse request URL and QUERY_STRING and do handshake if handshake_tokens:
if request_tokens: # Deal with first handshake here, create the Socket and push
request_tokens = request_tokens.groupdict() # the config up.
return self._do_handshake(handshake_tokens.groupdict())
elif disconnect_tokens:
# it's a disconnect request via XHR
tokens = disconnect_tokens.groupdict()
elif request_tokens:
tokens = request_tokens.groupdict()
# and continue...
else: else:
handshake_tokens = self.RE_HANDSHAKE_URL.match(path) # This is no socket.io request. Let the WSGI app handle it.
return super(SocketIOHandler, self).handle_one_response()
if handshake_tokens: # Setup socket
return self._do_handshake(handshake_tokens.groupdict()) sessid = tokens["sessid"]
else:
# This is no socket.io request. Let the WSGI app handle it.
return super(SocketIOHandler, self).handle_one_response()
# Setup the transport and socket
transport = self.handler_types.get(request_tokens["transport_id"])
sessid = request_tokens["sessid"]
socket = self.server.get_socket(sessid) socket = self.server.get_socket(sessid)
if not socket:
self.handle_bad_request()
return [] # Do not say the session is not found, just bad request
# so they don't start brute forcing to find open sessions
if self.environ['QUERY_STRING'].startswith('disconnect'):
# according to socket.io specs disconnect requests
# have a `disconnect` query string
# https://github.com/LearnBoost/socket.io-spec#forced-socket-disconnection
socket.disconnect()
self.handle_disconnect_request()
return []
# Setup transport
transport = self.handler_types.get(tokens["transport_id"])
# In case this is WebSocket request, switch to the WebSocketHandler # In case this is WebSocket request, switch to the WebSocketHandler
# FIXME: fix this ugly class change # FIXME: fix this ugly class change
old_class = None
if issubclass(transport, (transports.WebsocketTransport, if issubclass(transport, (transports.WebsocketTransport,
transports.FlashSocketTransport)): transports.FlashSocketTransport)):
self.__class__ = WebSocketHandler old_class = self.__class__
self.__class__ = self.server.ws_handler_class
self.prevent_wsgi_call = True # thank you self.prevent_wsgi_call = True # thank you
# TODO: any errors, treat them ?? # TODO: any errors, treat them ??
self.handle_one_response() self.handle_one_response() # does the Websocket dance before we continue
# Make the socket object available for WSGI apps # Make the socket object available for WSGI apps
self.environ['socketio'] = socket self.environ['socketio'] = socket
# Create a transport and handle the request likewise # Create a transport and handle the request likewise
self.transport = transport(self) self.transport = transport(self, self.config)
jobs = self.transport.connect(socket, request_method) # transports register their own spawn'd jobs now
# Keep track of those jobs (reading, writing and heartbeat jobs) so self.transport.do_exchange(socket, request_method)
# that we can kill them later with Socket.kill()
socket.jobs.extend(jobs)
try: if not socket.connection_established:
# We'll run the WSGI app if it wasn't already done. # This is executed only on the *first* packet of the establishment
if socket.wsgi_app_greenlet is None: # of the virtual Socket connection.
# TODO: why don't we spawn a call to handle_one_response here ? socket.connection_established = True
# why call directly the WSGI machinery ? socket.state = socket.STATE_CONNECTED
start_response = lambda status, headers, exc=None: None socket._spawn_heartbeat()
socket.wsgi_app_greenlet = gevent.spawn(self.application, socket._spawn_watcher()
self.environ,
start_response)
except:
self.handle_error(*sys.exc_info())
# TODO DOUBLE-CHECK: do we need to joinall here ? try:
gevent.joinall(jobs) # We'll run the WSGI app if it wasn't already done.
if socket.wsgi_app_greenlet is None:
# TODO: why don't we spawn a call to handle_one_response here ?
# why call directly the WSGI machinery ?
start_response = lambda status, headers, exc=None: None
socket.wsgi_app_greenlet = gevent.spawn(self.application,
self.environ,
start_response)
except:
self.handle_error(*sys.exc_info())
# we need to keep the connection open if we are an open socket
if tokens['transport_id'] in ['flashsocket', 'websocket']:
# wait here for all jobs to finished, when they are done
gevent.joinall(socket.jobs)
# Switch back to the old class so references to this don't use the
# incorrect class. Useful for debugging.
if old_class:
self.__class__ = old_class
# Clean up circular references so they can be garbage collected.
if hasattr(self, 'websocket') and self.websocket:
if hasattr(self.websocket, 'environ'):
del self.websocket.environ
del self.websocket
if self.environ:
del self.environ
def handle_bad_request(self): def handle_bad_request(self):
self.close_connection = True self.close_connection = True
self.start_reponse("400 Bad Request", [ self.start_response("400 Bad Request", [
('Content-Type', 'text/plain'),
('Connection', 'close'),
('Content-Length', 0)
])
def handle_disconnect_request(self):
self.close_connection = True
self.start_response("200 OK", [
('Content-Type', 'text/plain'), ('Content-Type', 'text/plain'),
('Connection', 'close'), ('Connection', 'close'),
('Content-Length', 0) ('Content-Length', 0)

View File

@@ -27,6 +27,10 @@ class BaseNamespace(object):
def on_my_second_event(self, whatever): def on_my_second_event(self, whatever):
print "This holds the first arg that was passed", whatever print "This holds the first arg that was passed", whatever
Handlers are automatically dispatched based on the name of the incoming
event. For example, a 'user message' event will be handled by
``on_user_message()``. To change this, override :meth:`process_event`.
We can also access the full packet directly by making an event handler We can also access the full packet directly by making an event handler
that accepts a single argument named 'packet': that accepts a single argument named 'packet':
@@ -43,9 +47,11 @@ class BaseNamespace(object):
self.session = self.socket.session # easily accessible session self.session = self.socket.session # easily accessible session
self.request = request self.request = request
self.ns_name = ns_name self.ns_name = ns_name
self.allowed_methods = None # be careful, None means all methods #: Store for ACL allowed methods. Be careful as ``None`` means
# are allowed while an empty list #: that all methods are allowed, while an empty list means every
# means totally closed. #: method is denied. Value: list of strings or ``None``. You
#: can and should use the various ``acl`` methods to tweak this.
self.allowed_methods = None
self.jobs = [] self.jobs = []
self.reset_acl() self.reset_acl()
@@ -95,7 +101,7 @@ class BaseNamespace(object):
access to all of the ``on_*()`` and ``recv_*()`` functions, access to all of the ``on_*()`` and ``recv_*()`` functions,
etc.. methods. etc.. methods.
Return something like: ``['on_connect', 'on_public_method']`` Return something like: ``set(['recv_connect', 'on_public_method'])``
You can later modify this list dynamically (inside You can later modify this list dynamically (inside
``on_connect()`` for example) using: ``on_connect()`` for example) using:
@@ -113,6 +119,9 @@ class BaseNamespace(object):
**Beware**, returning ``None`` leaves the namespace completely **Beware**, returning ``None`` leaves the namespace completely
accessible. accessible.
The methods that are open are stored in the ``allowed_methods``
attribute of the ``Namespace`` instance.
""" """
return None return None
@@ -160,10 +169,7 @@ class BaseNamespace(object):
if not callback: if not callback:
print "ERROR: No such callback for ackId %s" % packet['ackId'] print "ERROR: No such callback for ackId %s" % packet['ackId']
return return
try: return callback(*(packet['args']))
return callback(*(packet['args']))
except TypeError, e:
print "ERROR: Call to callback function failed", packet
elif packet_type == 'disconnect': elif packet_type == 'disconnect':
# Force a disconnect on the namespace. # Force a disconnect on the namespace.
return self.call_method_with_acl('recv_disconnect', packet) return self.call_method_with_acl('recv_disconnect', packet)
@@ -178,17 +184,30 @@ class BaseNamespace(object):
``on_``-prefixed methods. You could then implement your own dispatch. ``on_``-prefixed methods. You could then implement your own dispatch.
See the source code for inspiration. See the source code for inspiration.
To process events that have callbacks on the client side, you There are two ways to deal with callbacks from the client side
must define your event with a single parameter: ``packet``. (meaning, the browser has a callback waiting for data that this
In this case, it will be the full ``packet`` object and you server will be sending back):
can inspect its ``ack`` and ``id`` keys to define if and how
you reply. A correct reply to an event with a callback would The first one is simply to return an object. If the incoming
look like this: packet requested has an 'ack' field set, meaning the browser is
waiting for callback data, it will automatically be packaged
and sent, associated with the 'ackId' from the browser. The
return value must be a *sequence* of elements, that will be
mapped to the positional parameters of the callback function
on the browser side.
If you want to *know* that you're dealing with a packet
that requires a return value, you can do those things manually
by inspecting the ``ack`` and ``id`` keys from the ``packet``
object. Your callback will behave specially if the name of
the argument to your method is ``packet``. It will fill it
with the unprocessed ``packet`` object for your inspection,
like this:
.. code-block:: python .. code-block:: python
def on_my_callback(self, packet): def on_my_callback(self, packet):
if 'ack' in packet': if 'ack' in packet:
self.emit('go_back', 'param1', id=packet['id']) self.emit('go_back', 'param1', id=packet['id'])
""" """
args = packet['args'] args = packet['args']
@@ -227,12 +246,17 @@ class BaseNamespace(object):
Those are the two behaviors: Those are the two behaviors:
* If there is only one parameter on the dispatched method and * If there is only one parameter on the dispatched method and
it is equal to ``packet``, then pass in the packet as the it is named ``packet``, then pass in the packet dict as the
sole parameter. sole parameter.
* Otherwise, pass in the arguments as specified by the * Otherwise, pass in the arguments as specified by the
different ``recv_*()`` methods args specs, or the different ``recv_*()`` methods args specs, or the
:meth:`process_event` documentation. :meth:`process_event` documentation.
This method will also consider the
``exception_handler_decorator``. See Namespace documentation
for details and examples.
""" """
method = getattr(self, method_name, None) method = getattr(self, method_name, None)
if method is None: if method is None:
@@ -247,6 +271,11 @@ class BaseNamespace(object):
"The server-side method is invalid, as it doesn't " "The server-side method is invalid, as it doesn't "
"have 'self' as its first argument") "have 'self' as its first argument")
return return
# Check if we need to decorate to handle exceptions
if hasattr(self, 'exception_handler_decorator'):
method = self.exception_handler_decorator(method)
if len(func_args) == 2 and func_args[1] == 'packet': if len(func_args) == 2 and func_args[1] == 'packet':
return method(packet) return method(packet)
else: else:
@@ -264,10 +293,14 @@ class BaseNamespace(object):
:func:`~socketio.socketio_manage`) without clogging the :func:`~socketio.socketio_manage`) without clogging the
memory. memory.
If you override this method, you probably want to only If you override this method, you probably want to initialize
initialize the variables you're going to use in the events of the variables you're going to use in the events handled by this
this namespace, say, with some default values, but not perform namespace, setup ACLs, etc..
any operation that assumes authentication/authorization.
This method is called on all base classes following the _`method resolution order <http://docs.python.org/library/stdtypes.html?highlight=mro#class.__mro__>`
so you don't need to call super() to initialize the mixins or
other derived classes.
""" """
pass pass
@@ -432,8 +465,14 @@ class BaseNamespace(object):
It will be monitored by the "watcher" process in the Socket. If the It will be monitored by the "watcher" process in the Socket. If the
socket disconnects, all these greenlets are going to be killed, after socket disconnects, all these greenlets are going to be killed, after
calling BaseNamespace.disconnect() calling BaseNamespace.disconnect()
This method uses the ``exception_handler_decorator``. See
Namespace documentation for more information.
""" """
# self.log.debug("Spawning sub-Namespace Greenlet: %s" % fn.__name__) # self.log.debug("Spawning sub-Namespace Greenlet: %s" % fn.__name__)
if hasattr(self, 'exception_handler_decorator'):
fn = self.exception_handler_decorator(fn)
new = gevent.spawn(fn, *args, **kwargs) new = gevent.spawn(fn, *args, **kwargs)
self.jobs.append(new) self.jobs.append(new)
return new return new
@@ -454,8 +493,12 @@ class BaseNamespace(object):
packet = {"type": "disconnect", packet = {"type": "disconnect",
"endpoint": self.ns_name} "endpoint": self.ns_name}
self.socket.send_packet(packet) self.socket.send_packet(packet)
self.socket.remove_namespace(self.ns_name) # remove_namespace might throw GreenletExit so
self.kill_local_jobs() # kill_local_jobs must be in finally
try:
self.socket.remove_namespace(self.ns_name)
finally:
self.kill_local_jobs()
def kill_local_jobs(self): def kill_local_jobs(self):
"""Kills all the jobs spawned with BaseNamespace.spawn() on a namespace """Kills all the jobs spawned with BaseNamespace.spawn() on a namespace

View File

@@ -1,17 +1,4 @@
try: from socketio.defaultjson import default_json_dumps, default_json_loads
import simplejson as json
json_decimal_args = {"use_decimal": True} # pragma: no cover
except ImportError:
import json
import decimal
class DecimalEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, decimal.Decimal):
return float(o)
return super(DecimalEncoder, self).default(o)
json_decimal_args = {"cls": DecimalEncoder}
MSG_TYPES = { MSG_TYPES = {
'disconnect': 0, 'disconnect': 0,
@@ -45,7 +32,7 @@ socketio_packet_attributes = ['type', 'name', 'data', 'endpoint', 'args',
'ackId', 'reason', 'advice', 'qs', 'id'] 'ackId', 'reason', 'advice', 'qs', 'id']
def encode(data): def encode(data, json_dumps=default_json_dumps):
""" """
Encode an attribute dict into a byte string. Encode an attribute dict into a byte string.
""" """
@@ -72,14 +59,13 @@ def encode(data):
if msg == '3': if msg == '3':
payload = data['data'] payload = data['data']
if msg == '4': if msg == '4':
payload = json.dumps(data['data'], separators=(',', ':'), payload = json_dumps(data['data'])
**json_decimal_args)
if msg == '5': if msg == '5':
d = {} d = {}
d['name'] = data['name'] d['name'] = data['name']
if 'args' in data and data['args'] != []: if 'args' in data and data['args'] != []:
d['args'] = data['args'] d['args'] = data['args']
payload = json.dumps(d, separators=(',', ':'), **json_decimal_args) payload = json_dumps(d)
if 'id' in data: if 'id' in data:
msg += ':' + str(data['id']) msg += ':' + str(data['id'])
if data['ack'] == 'data': if data['ack'] == 'data':
@@ -98,8 +84,7 @@ def encode(data):
# '6:::' [id] '+' [data] # '6:::' [id] '+' [data]
msg += '::' + data.get('endpoint', '') + ':' + str(data['ackId']) msg += '::' + data.get('endpoint', '') + ':' + str(data['ackId'])
if 'args' in data and data['args'] != []: if 'args' in data and data['args'] != []:
msg += '+' + json.dumps(data['args'], separators=(',', ':'), msg += '+' + json_dumps(data['args'])
**json_decimal_args)
elif msg == '7': elif msg == '7':
# '7::' [endpoint] ':' [reason] '+' [advice] # '7::' [endpoint] ':' [reason] '+' [advice]
@@ -117,7 +102,7 @@ def encode(data):
return msg return msg
def decode(rawstr): def decode(rawstr, json_loads=default_json_loads):
""" """
Decode a rawstr packet arriving from the socket into a dict. Decode a rawstr packet arriving from the socket into a dict.
""" """
@@ -163,11 +148,11 @@ def decode(rawstr):
decoded_msg['data'] = data decoded_msg['data'] = data
elif msg_type == "4": # json msg elif msg_type == "4": # json msg
decoded_msg['data'] = json.loads(data) decoded_msg['data'] = json_loads(data)
elif msg_type == "5": # event elif msg_type == "5": # event
try: try:
data = json.loads(data) data = json_loads(data)
except ValueError, e: except ValueError, e:
print("Invalid JSON event message", data) print("Invalid JSON event message", data)
decoded_msg['args'] = [] decoded_msg['args'] = []
@@ -182,7 +167,7 @@ def decode(rawstr):
if '+' in data: if '+' in data:
ackId, data = data.split('+') ackId, data = data.split('+')
decoded_msg['ackId'] = int(ackId) decoded_msg['ackId'] = int(ackId)
decoded_msg['args'] = json.loads(data) decoded_msg['args'] = json_loads(data)
else: else:
decoded_msg['ackId'] = int(data) decoded_msg['ackId'] = int(data)
decoded_msg['args'] = [] decoded_msg['args'] = []

View File

@@ -0,0 +1,72 @@
import logging
from socketio import socketio_manage
from django.http import HttpResponse
from django.views.decorators.csrf import csrf_exempt
from django.utils.importlib import import_module
# for Django 1.3 support
try:
from django.conf.urls import patterns, url, include
except ImportError:
from django.conf.urls.defaults import patterns, url, include
SOCKETIO_NS = {}
LOADING_SOCKETIO = False
def autodiscover():
"""
Auto-discover INSTALLED_APPS sockets.py modules and fail silently when
not present. NOTE: socketio_autodiscover was inspired/copied from
django.contrib.admin autodiscover
"""
global LOADING_SOCKETIO
if LOADING_SOCKETIO:
return
LOADING_SOCKETIO = True
import imp
from django.conf import settings
for app in settings.INSTALLED_APPS:
try:
app_path = import_module(app).__path__
except AttributeError:
continue
try:
imp.find_module('sockets', app_path)
except ImportError:
continue
import_module("%s.sockets" % app)
LOADING_SOCKETIO = False
class namespace(object):
def __init__(self, name=''):
self.name = name
def __call__(self, handler):
SOCKETIO_NS[self.name] = handler
return handler
@csrf_exempt
def socketio(request):
try:
socketio_manage(request.environ, SOCKETIO_NS, request)
except:
logging.getLogger("socketio").error("Exception while handling socketio connection", exc_info=True)
return HttpResponse("")
urls = patterns("", (r'', socketio))

View File

@@ -8,6 +8,7 @@ from gevent.pywsgi import WSGIServer
from socketio.handler import SocketIOHandler from socketio.handler import SocketIOHandler
from socketio.policyserver import FlashPolicyServer from socketio.policyserver import FlashPolicyServer
from socketio.virtsocket import Socket from socketio.virtsocket import Socket
from geventwebsocket.handler import WebSocketHandler
__all__ = ['SocketIOServer'] __all__ = ['SocketIOServer']
@@ -16,19 +17,41 @@ class SocketIOServer(WSGIServer):
"""A WSGI Server with a resource that acts like an SocketIO.""" """A WSGI Server with a resource that acts like an SocketIO."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
""" """This is just like the standard WSGIServer __init__, except with a
This is just like the standard WSGIServer __init__, except with a
few additional ``kwargs``: few additional ``kwargs``:
:param resource: The URL which has to be identified as a socket.io request. Defaults to the /socket.io/ URL. :param resource: The URL which has to be identified as a
socket.io request. Defaults to the /socket.io/ URL.
:param transports: Optional list of transports to allow. List of :param transports: Optional list of transports to allow. List of
strings, each string should be one of strings, each string should be one of
handler.SocketIOHandler.handler_types. handler.SocketIOHandler.handler_types.
:param policy_server: Boolean describing whether or not to use the :param policy_server: Boolean describing whether or not to use the
Flash policy server. Default True. Flash policy server. Default True.
:param policy_listener : A tuple containing (host, port) for the
policy server. This is optional and used only if policy server :param policy_listener: A tuple containing (host, port) for the
policy server. This is optional and used only if policy server
is set to true. The default value is 0.0.0.0:843 is set to true. The default value is 0.0.0.0:843
:param heartbeat_interval: int The timeout for the server, we
should receive a heartbeat from the client within this
interval. This should be less than the
``heartbeat_timeout``.
:param heartbeat_timeout: int The timeout for the client when
it should send a new heartbeat to the server. This value
is sent to the client after a successful handshake.
:param close_timeout: int The timeout for the client, when it
closes the connection it still X amounts of seconds to do
re open of the connection. This value is sent to the
client after a successful handshake.
:param log_file: str The file in which you want the PyWSGI
server to write its access log. If not specified, it
is sent to `stderr` (with gevent 0.13).
""" """
self.sockets = {} self.sockets = {}
if 'namespace' in kwargs: if 'namespace' in kwargs:
@@ -40,18 +63,48 @@ class SocketIOServer(WSGIServer):
self.transports = kwargs.pop('transports', None) self.transports = kwargs.pop('transports', None)
if kwargs.pop('policy_server', True): if kwargs.pop('policy_server', True):
policylistener = kwargs.pop('policy_listener', (args[0][0], 10843)) try:
address = args[0][0]
except TypeError:
try:
address = args[0].address[0]
except AttributeError:
address = args[0].cfg_addr[0]
policylistener = kwargs.pop('policy_listener', (address, 10843))
self.policy_server = FlashPolicyServer(policylistener) self.policy_server = FlashPolicyServer(policylistener)
else: else:
self.policy_server = None self.policy_server = None
kwargs['handler_class'] = SocketIOHandler # Extract other config options
self.config = {
'heartbeat_timeout': 60,
'close_timeout': 60,
'heartbeat_interval': 25,
}
for f in ('heartbeat_timeout', 'heartbeat_interval', 'close_timeout'):
if f in kwargs:
self.config[f] = int(kwargs.pop(f))
if not 'handler_class' in kwargs:
kwargs['handler_class'] = SocketIOHandler
if not 'ws_handler_class' in kwargs:
self.ws_handler_class = WebSocketHandler
else:
self.ws_handler_class = kwargs.pop('ws_handler_class')
log_file = kwargs.pop('log_file', None)
if log_file:
kwargs['log'] = open(log_file, 'a')
super(SocketIOServer, self).__init__(*args, **kwargs) super(SocketIOServer, self).__init__(*args, **kwargs)
def start_accepting(self): def start_accepting(self):
if self.policy_server is not None: if self.policy_server is not None:
try: try:
self.policy_server.start() if not self.policy_server.started:
self.policy_server.start()
except error, ex: except error, ex:
sys.stderr.write( sys.stderr.write(
'FAILED to start flash policy server: %s\n' % (ex, )) 'FAILED to start flash policy server: %s\n' % (ex, ))
@@ -60,13 +113,14 @@ class SocketIOServer(WSGIServer):
sys.stderr.write('FAILED to start flash policy server.\n\n') sys.stderr.write('FAILED to start flash policy server.\n\n')
super(SocketIOServer, self).start_accepting() super(SocketIOServer, self).start_accepting()
def kill(self): def stop(self, timeout=None):
if self.policy_server is not None: if self.policy_server is not None:
self.policy_server.kill() self.policy_server.stop()
super(SocketIOServer, self).kill() super(SocketIOServer, self).stop(timeout=timeout)
def handle(self, socket, address): def handle(self, socket, address):
handler = self.handler_class(socket, address, self) # Pass in the config about timeouts, heartbeats, also...
handler = self.handler_class(self.config, socket, address, self)
handler.handle() handler.handle()
def get_socket(self, sessid=''): def get_socket(self, sessid=''):
@@ -74,10 +128,59 @@ class SocketIOServer(WSGIServer):
socket = self.sockets.get(sessid) socket = self.sockets.get(sessid)
if sessid and not socket:
return None # you ask for a session that doesn't exist!
if socket is None: if socket is None:
socket = Socket(self) socket = Socket(self, self.config)
self.sockets[socket.sessid] = socket self.sockets[socket.sessid] = socket
else: else:
socket.incr_hits() socket.incr_hits()
return socket return socket
def serve(app, **kw):
_quiet = kw.pop('_quiet', False)
_resource = kw.pop('resource', 'socket.io')
if not _quiet: # pragma: no cover
# idempotent if logging has already been set up
import logging
logging.basicConfig()
host = kw.pop('host', '127.0.0.1')
port = int(kw.pop('port', 6543))
transports = kw.pop('transports', None)
if transports:
transports = [x.strip() for x in transports.split(',')]
policy_server = kw.pop('policy_server', False)
if policy_server in (True, 'True', 'true', 'enable', 'yes', 'on', '1'):
policy_server = True
policy_listener_host = kw.pop('policy_listener_host', host)
policy_listener_port = int(kw.pop('policy_listener_port', 10843))
kw['policy_listener'] = (policy_listener_host, policy_listener_port)
else:
policy_server = False
server = SocketIOServer((host, port),
app,
resource=_resource,
transports=transports,
policy_server=policy_server,
**kw)
if not _quiet:
print('serving on http://%s:%s' % (host, port))
server.serve_forever()
def serve_paste(app, global_conf, **kw):
"""pserve / paster serve / waitress replacement / integration
You can pass as parameters:
transports = websockets, xhr-multipart, xhr-longpolling, etc...
policy_server = True
"""
serve(app, **kw)
return 0

View File

@@ -1,44 +1,182 @@
import os import os
import gevent import gevent
import time
from gevent.pool import Pool from gevent.pool import Pool
from gevent.server import StreamServer
from gunicorn.workers.ggevent import GeventPyWSGIWorker from gunicorn.workers.ggevent import GeventPyWSGIWorker
from gunicorn.workers.ggevent import PyWSGIHandler
from gunicorn.workers.ggevent import GeventResponse
from gunicorn import version_info as gunicorn_version
from socketio.server import SocketIOServer from socketio.server import SocketIOServer
from socketio.handler import SocketIOHandler from socketio.handler import SocketIOHandler
from geventwebsocket.handler import WebSocketHandler
from datetime import datetime
from functools import partial
class GunicornWSGIHandler(PyWSGIHandler, SocketIOHandler):
pass
class GunicornWebSocketWSGIHandler(WebSocketHandler):
def log_request(self):
start = datetime.fromtimestamp(self.time_start)
finish = datetime.fromtimestamp(self.time_finish)
response_time = finish - start
resp = GeventResponse(self.status, [],
self.response_length)
req_headers = [h.split(":", 1) for h in self.headers.headers]
self.server.log.access(
resp, req_headers, self.environ, response_time)
class GeventSocketIOBaseWorker(GeventPyWSGIWorker): class GeventSocketIOBaseWorker(GeventPyWSGIWorker):
""" The base gunicorn worker class """ """ The base gunicorn worker class """
transports = None
def __init__(self, age, ppid, socket, app, timeout, cfg, log):
if os.environ.get('POLICY_SERVER', None) is None:
if self.policy_server:
os.environ['POLICY_SERVER'] = 'true'
else:
self.policy_server = False
super(GeventSocketIOBaseWorker, self).__init__(
age, ppid, socket, app, timeout, cfg, log)
def run(self): def run(self):
self.socket.setblocking(1) if gunicorn_version >= (0, 17, 0):
pool = Pool(self.worker_connections) servers = []
self.server_class.base_env['wsgi.multiprocess'] = \ ssl_args = {}
self.cfg.workers > 1
server = self.server_class(
self.socket, application=self.wsgi,
spawn=pool, handler_class=self.wsgi_handler,
namespace=self.namespace, policy_server=self.policy_server)
server.start()
try:
while self.alive:
self.notify()
if self.ppid != os.getppid(): if self.cfg.is_ssl:
self.log.info("Parent changed, shutting down: %s", self) ssl_args = dict(
break server_side=True,
do_handshake_on_connect=False,
**self.cfg.ssl_options
)
gevent.sleep(1.0) for s in self.sockets:
s.setblocking(1)
pool = Pool(self.worker_connections)
if self.server_class is not None:
self.server_class.base_env['wsgi.multiprocess'] = \
self.cfg.workers > 1
except KeyboardInterrupt: server = self.server_class(
pass s,
application=self.wsgi,
spawn=pool,
resource=self.resource,
log=self.log,
policy_server=self.policy_server,
handler_class=self.wsgi_handler,
ws_handler_class=self.ws_wsgi_handler,
**ssl_args
)
else:
hfun = partial(self.handle, s)
server = StreamServer(
s, handle=hfun, spawn=pool, **ssl_args)
# try to stop the connections server.start()
try: servers.append(server)
self.notify()
server.stop(timeout=self.timeout) pid = os.getpid()
except: try:
pass while self.alive:
self.notify()
if pid == os.getpid() and self.ppid != os.getppid():
self.log.info(
"Parent changed, shutting down: %s", self)
break
gevent.sleep(1.0)
except KeyboardInterrupt:
pass
try:
# Stop accepting requests
[server.stop_accepting() for server in servers]
# Handle current requests until graceful_timeout
ts = time.time()
while time.time() - ts <= self.cfg.graceful_timeout:
accepting = 0
for server in servers:
if server.pool.free_count() != server.pool.size:
accepting += 1
if not accepting:
return
self.notify()
gevent.sleep(1.0)
# Force kill all active the handlers
self.log.warning("Worker graceful timeout (pid:%s)" % self.pid)
[server.stop(timeout=1) for server in servers]
except:
pass
else:
self.socket.setblocking(1)
pool = Pool(self.worker_connections)
self.server_class.base_env['wsgi.multiprocess'] = \
self.cfg.workers > 1
server = self.server_class(
self.socket,
application=self.wsgi,
spawn=pool,
resource=self.resource,
log=self.log,
policy_server=self.policy_server,
handler_class=self.wsgi_handler,
ws_handler_class=self.ws_wsgi_handler,
)
server.start()
pid = os.getpid()
try:
while self.alive:
self.notify()
if pid == os.getpid() and self.ppid != os.getppid():
self.log.info(
"Parent changed, shutting down: %s", self)
break
gevent.sleep(1.0)
except KeyboardInterrupt:
pass
try:
# Stop accepting requests
server.kill()
# Handle current requests until graceful_timeout
ts = time.time()
while time.time() - ts <= self.cfg.graceful_timeout:
if server.pool.free_count() == server.pool.size:
return # all requests was handled
self.notify()
gevent.sleep(1.0)
# Force kill all active the handlers
self.log.warning("Worker graceful timeout (pid:%s)" % self.pid)
server.stop(timeout=1)
except:
pass
class GeventSocketIOWorker(GeventSocketIOBaseWorker): class GeventSocketIOWorker(GeventSocketIOBaseWorker):
@@ -49,9 +187,20 @@ class GeventSocketIOWorker(GeventSocketIOBaseWorker):
being disabled. being disabled.
""" """
server_class = SocketIOServer server_class = SocketIOServer
wsgi_handler = SocketIOHandler wsgi_handler = GunicornWSGIHandler
ws_wsgi_handler = GunicornWebSocketWSGIHandler
# We need to define a namespace for the server, it would be nice if this # We need to define a namespace for the server, it would be nice if this
# was a configuration option, will probably end up how this implemented, # was a configuration option, will probably end up how this implemented,
# for now this is just a proof of concept to make sure this will work # for now this is just a proof of concept to make sure this will work
namespace = 'socket.io' resource = 'socket.io'
policy_server = False # Don't run the flash policy server policy_server = True
class NginxGeventSocketIOWorker(GeventSocketIOWorker):
"""
Worker which will not attempt to connect via websocket transport
Nginx is not compatible with websockets and therefore will not add the
wsgi.websocket key to the wsgi environment.
"""
transports = ['xhr-polling']

View File

@@ -0,0 +1,30 @@
import gevent
import weakref
try:
import redis
except ImportError:
pass
class RedisStorage(object):
def __init__(self, server, **kwargs):
self.server = weakref.proxy(server)
self.jobs = []
self.host = kwargs.get('host', 'localhost')
self.port = kwargs.get('port', 6379)
r = redis.StrictRedis(host=self.host, port=self.port)
self.conn = r.pubsub()
self.spawn(self.listener)
def listener(self):
for m in self.conn.listen():
print("===============NEW MESSAGE!!!====== %s", m)
def spawn(self, fn, *args, **kwargs):
new = gevent.spawn(fn, *args, **kwargs)
self.jobs.append(new)
return new
def new_request(self, environ):
print("===========NEW REQUEST %s===========" % environ)

View File

@@ -1,13 +1,21 @@
import gevent import gevent
import urllib import urllib
import urlparse
from geventwebsocket import WebSocketError
from gevent.queue import Empty from gevent.queue import Empty
class BaseTransport(object): class BaseTransport(object):
"""Base class for all transports. Mostly wraps handler class functions.""" """Base class for all transports. Mostly wraps handler class functions."""
def __init__(self, handler): def __init__(self, handler, config, **kwargs):
"""Base transport class.
:param config: dict Should contain the config keys, like
``heartbeat_interval``, ``heartbeat_timeout`` and
``close_timeout``.
"""
self.content_type = ("Content-Type", "text/plain; charset=UTF-8") self.content_type = ("Content-Type", "text/plain; charset=UTF-8")
self.headers = [ self.headers = [
("Access-Control-Allow-Origin", "*"), ("Access-Control-Allow-Origin", "*"),
@@ -15,22 +23,28 @@ class BaseTransport(object):
("Access-Control-Allow-Methods", "POST, GET, OPTIONS"), ("Access-Control-Allow-Methods", "POST, GET, OPTIONS"),
("Access-Control-Max-Age", 3600), ("Access-Control-Max-Age", 3600),
] ]
self.headers_list = []
self.handler = handler self.handler = handler
self.config = config
def write(self, data=""): def write(self, data=""):
if 'Content-Length' not in self.handler.response_headers_list: # Gevent v 0.13
self.handler.response_headers.append(('Content-Length', len(data))) if hasattr(self.handler, 'response_headers_list'):
self.handler.response_headers_list.append('Content-Length') if 'Content-Length' not in self.handler.response_headers_list:
self.handler.response_headers.append(('Content-Length', len(data)))
self.handler.response_headers_list.append('Content-Length')
elif not hasattr(self.handler, 'provided_content_length') or self.handler.provided_content_length is None:
# Gevent 1.0bX
l = len(data)
self.handler.provided_content_length = l
self.handler.response_headers.append(('Content-Length', l))
self.handler.write(data) self.handler.write_smart(data)
def start_response(self, status, headers, **kwargs): def start_response(self, status, headers, **kwargs):
if "Content-Type" not in [x[0] for x in headers]: if "Content-Type" not in [x[0] for x in headers]:
headers.append(self.content_type) headers.append(self.content_type)
headers.extend(self.headers) headers.extend(self.headers)
#print headers
self.handler.start_response(status, headers, **kwargs) self.handler.start_response(status, headers, **kwargs)
@@ -46,15 +60,14 @@ class XHRPollingTransport(BaseTransport):
def get(self, socket): def get(self, socket):
socket.heartbeat() socket.heartbeat()
payload = self.get_messages_payload(socket, timeout=5.0) heartbeat_interval = self.config['heartbeat_interval']
payload = self.get_messages_payload(socket, timeout=heartbeat_interval)
if not payload: if not payload:
payload = "8::" # NOOP payload = "8::" # NOOP
self.start_response("200 OK", []) self.start_response("200 OK", [])
self.write(payload) self.write(payload)
return []
def _request_body(self): def _request_body(self):
return self.handler.wsgi_input.readline() return self.handler.wsgi_input.readline()
@@ -68,8 +81,6 @@ class XHRPollingTransport(BaseTransport):
]) ])
self.write("1") self.write("1")
return []
def get_messages_payload(self, socket, timeout=None): def get_messages_payload(self, socket, timeout=None):
"""This will fetch the messages from the Socket's queue, and if """This will fetch the messages from the Socket's queue, and if
there are many messes, pack multiple messages in one payload and return there are many messes, pack multiple messages in one payload and return
@@ -87,14 +98,16 @@ class XHRPollingTransport(BaseTransport):
``messages`` - List of raw messages to encode, if necessary ``messages`` - List of raw messages to encode, if necessary
""" """
if not messages: if not messages or messages[0] is None:
return '' return ''
if len(messages) == 1: if len(messages) == 1:
return messages[0].encode('utf-8') return messages[0].encode('utf-8')
payload = u''.join(u'\ufffd%d\ufffd%s' % (len(p), p) payload = u''.join([(u'\ufffd%d\ufffd%s' % (len(p), p))
for p in messages) for p in messages if p is not None])
# FIXME: why is it so that we must filter None from here ? How
# is it even possible that a None gets in there ?
return payload.encode('utf-8') return payload.encode('utf-8')
@@ -115,7 +128,6 @@ class XHRPollingTransport(BaseTransport):
""" """
payload = payload.decode('utf-8') payload = payload.decode('utf-8')
if payload[0] == u"\ufffd": if payload[0] == u"\ufffd":
#print "MULTIMSG FULL", payload
ret = [] ret = []
while len(payload) != 0: while len(payload) != 0:
len_end = payload.find(u"\ufffd", 1) len_end = payload.find(u"\ufffd", 1)
@@ -123,21 +135,19 @@ class XHRPollingTransport(BaseTransport):
msg_start = len_end + 1 msg_start = len_end + 1
msg_end = length + msg_start msg_end = length + msg_start
message = payload[msg_start:msg_end] message = payload[msg_start:msg_end]
#print "MULTIMSG", length, message
ret.append(message) ret.append(message)
payload = payload[msg_end:] payload = payload[msg_end:]
return ret return ret
return [payload] return [payload]
def connect(self, socket, request_method): def do_exchange(self, socket, request_method):
if not socket.connection_confirmed: if not socket.connection_established:
socket.connection_confirmed = True # Runs only the first time we get a Socket opening
self.start_response("200 OK", [ self.start_response("200 OK", [
("Connection", "close"), ("Connection", "close"),
]) ])
self.write("1::") # 'connect' packet self.write("1::") # 'connect' packet
return
return []
elif request_method in ("GET", "POST", "OPTIONS"): elif request_method in ("GET", "POST", "OPTIONS"):
return getattr(self, request_method.lower())(socket) return getattr(self, request_method.lower())(socket)
else: else:
@@ -145,21 +155,33 @@ class XHRPollingTransport(BaseTransport):
class JSONPolling(XHRPollingTransport): class JSONPolling(XHRPollingTransport):
def __init__(self, handler): def __init__(self, handler, config):
super(JSONPolling, self).__init__(handler) super(JSONPolling, self).__init__(handler, config)
self.content_type = ("Content-Type", "text/javascript; charset=UTF-8") self.content_type = ("Content-Type", "text/javascript; charset=UTF-8")
def _request_body(self): def _request_body(self):
data = super(JSONPolling, self)._request_body() data = super(JSONPolling, self)._request_body()
# resolve %20%3F's, take out wrapping d="...", etc.. # resolve %20%3F's, take out wrapping d="...", etc..
return urllib.unquote_plus(data)[3:-1] \ data = urllib.unquote_plus(data)[3:-1] \
.replace(r'\"', '"') \ .replace(r'\"', '"') \
.replace(r"\\", "\\") .replace(r"\\", "\\")
# For some reason, in case of multiple messages passed in one
# query, IE7 sends it escaped, not utf-8 encoded. This dirty
# hack handled it
if data[0] == "\\":
data = data.decode("unicode_escape").encode("utf-8")
return data
def write(self, data): def write(self, data):
"""Just quote out stuff before sending it out""" """Just quote out stuff before sending it out"""
args = urlparse.parse_qs(self.handler.environ.get("QUERY_STRING"))
if "i" in args:
i = args["i"]
else:
i = "0"
# TODO: don't we need to quote this data in here ? # TODO: don't we need to quote this data in here ?
super(JSONPolling, self).write("io.j[0]('%s');" % data) super(JSONPolling, self).write("io.j[%s]('%s');" % (i, data))
class XHRMultipartTransport(XHRPollingTransport): class XHRMultipartTransport(XHRPollingTransport):
@@ -170,11 +192,9 @@ class XHRMultipartTransport(XHRPollingTransport):
"multipart/x-mixed-replace;boundary=\"socketio\"" "multipart/x-mixed-replace;boundary=\"socketio\""
) )
def connect(self, socket, request_method): def do_exchange(self, socket, request_method):
if request_method == "GET": if request_method == "GET":
# TODO: double verify this, because we're not sure. look at git revs. return self.get(socket)
heartbeat = socket._spawn_heartbeat()
return [heartbeat] + self.get(socket)
elif request_method == "POST": elif request_method == "POST":
return self.post(socket) return self.post(socket)
else: else:
@@ -196,22 +216,28 @@ class XHRMultipartTransport(XHRPollingTransport):
if not payload: if not payload:
# That would mean the call to Queue.get() returned Empty, # That would mean the call to Queue.get() returned Empty,
# so it was in fact killed, since we pass no timeout=.. # so it was in fact killed, since we pass no timeout=..
socket.kill() return
break # See below
else: else:
try: try:
self.write_multipart(header) self.write_multipart(header)
self.write_multipart(payload) self.write_multipart(payload)
self.write_multipart("--socketio\r\n") self.write_multipart("--socketio\r\n")
except socket.error: except socket.error:
socket.kill() # The client might try to reconnect, even with a socket
break # error, so let's just let it go, and not kill the
# socket completely. Other processes will ensure
# we kill everything if the user expires the timeouts.
#
# WARN: this means that this payload is LOST, unless we
# decide to re-inject it into the queue.
return
return [gevent.spawn(chunk)] socket.spawn(chunk)
class WebsocketTransport(BaseTransport): class WebsocketTransport(BaseTransport):
def connect(self, socket, request_method): def do_exchange(self, socket, request_method):
websocket = self.handler.environ['wsgi.websocket'] websocket = self.handler.environ['wsgi.websocket']
websocket.send("1::") # 'connect' packet websocket.send("1::") # 'connect' packet
@@ -220,27 +246,26 @@ class WebsocketTransport(BaseTransport):
message = socket.get_client_msg() message = socket.get_client_msg()
if message is None: if message is None:
socket.kill()
break break
try:
websocket.send(message) websocket.send(message)
except (WebSocketError, TypeError):
# We can't send a message on the socket
# it is dead, let the other sockets know
socket.disconnect()
def read_from_ws(): def read_from_ws():
while True: while True:
message = websocket.receive() message = websocket.receive()
if not message: if message is None:
socket.kill()
break break
else: else:
if message is not None: if message is not None:
socket.put_server_msg(message) socket.put_server_msg(message)
gr1 = gevent.spawn(send_into_ws) socket.spawn(send_into_ws)
gr2 = gevent.spawn(read_from_ws) socket.spawn(read_from_ws)
heartbeat1, heartbeat2 = socket._spawn_heartbeat()
return [gr1, gr2, heartbeat1, heartbeat2]
class FlashSocketTransport(WebsocketTransport): class FlashSocketTransport(WebsocketTransport):
@@ -250,21 +275,43 @@ class FlashSocketTransport(WebsocketTransport):
class HTMLFileTransport(XHRPollingTransport): class HTMLFileTransport(XHRPollingTransport):
"""Not tested at all!""" """Not tested at all!"""
def __init__(self, handler): def __init__(self, handler, config):
super(HTMLFileTransport, self).__init__(handler) super(HTMLFileTransport, self).__init__(handler, config)
self.content_type = ("Content-Type", "text/html") self.content_type = ("Content-Type", "text/html")
def write_packed(self, data): def write_packed(self, data):
self.write("<script>parent.s._('%s', document);</script>" % data) self.write("<script>_('%s');</script>" % data)
def handle_get_response(self, socket): def write(self, data):
l = 1024 * 5
super(HTMLFileTransport, self).write("%d\r\n%s%s\r\n" % (l, data, " " * (l - len(data))))
def do_exchange(self, socket, request_method):
return super(HTMLFileTransport, self).do_exchange(socket, request_method)
def get(self, socket):
self.start_response("200 OK", [ self.start_response("200 OK", [
("Connection", "keep-alive"), ("Connection", "keep-alive"),
("Content-Type", "text/html"), ("Content-Type", "text/html"),
("Transfer-Encoding", "chunked"), ("Transfer-Encoding", "chunked"),
]) ])
self.write("<html><body>" + " " * 244) self.write("<html><body><script>var _ = function (msg) { parent.s._(msg, document); };</script>")
self.write_packed("1::") # 'connect' packet
payload = self.get_messages_payload(socket, timeout=5.0)
self.write_packed(payload) def chunk():
while True:
payload = self.get_messages_payload(socket)
if not payload:
# That would mean the call to Queue.get() returned Empty,
# so it was in fact killed, since we pass no timeout=..
return
else:
try:
self.write_packed(payload)
except socket.error:
# See comments for XHRMultipart
return
socket.spawn(chunk)

View File

@@ -11,12 +11,17 @@ in a different way
""" """
import random import random
import weakref import weakref
import logging
import gevent import gevent
from gevent.queue import Queue from gevent.queue import Queue
from gevent.event import Event from gevent.event import Event
from socketio import packet from socketio import packet
from socketio.defaultjson import default_json_loads, default_json_dumps
log = logging.getLogger(__name__)
def default_error_handler(socket, error_name, error_message, endpoint, def default_error_handler(socket, error_name, error_message, endpoint,
@@ -40,10 +45,11 @@ def default_error_handler(socket, error_name, error_message, endpoint,
# Send an error event through the Socket # Send an error event through the Socket
if not quiet: if not quiet:
socket.send_packet(pkt) socket.send_packet(pkt)
# Log that error somewhere for debugging... # Log that error somewhere for debugging...
print "default_error_handler: %s, %s (endpoint=%s, msg_id=%s)" % ( log.error(u"default_error_handler: {}, {} (endpoint={}, msg_id={})".format(
error_name, error_message, endpoint, msg_id) error_name, error_message, endpoint, msg_id
))
class Socket(object): class Socket(object):
@@ -65,7 +71,10 @@ class Socket(object):
"""Use this to be explicit when specifying a Global Namespace (an endpoint """Use this to be explicit when specifying a Global Namespace (an endpoint
with no name, not '/chat' or anything.""" with no name, not '/chat' or anything."""
def __init__(self, server, error_handler=None): json_loads = staticmethod(default_json_loads)
json_dumps = staticmethod(default_json_dumps)
def __init__(self, server, config, error_handler=None):
self.server = weakref.proxy(server) self.server = weakref.proxy(server)
self.sessid = str(random.random())[2:] self.sessid = str(random.random())[2:]
self.session = {} # the session dict, for general developer usage self.session = {} # the session dict, for general developer usage
@@ -76,7 +85,7 @@ class Socket(object):
self.timeout = Event() self.timeout = Event()
self.wsgi_app_greenlet = None self.wsgi_app_greenlet = None
self.state = "NEW" self.state = "NEW"
self.connection_confirmed = False self.connection_established = False
self.ack_callbacks = {} self.ack_callbacks = {}
self.ack_counter = 0 self.ack_counter = 0
self.request = None self.request = None
@@ -85,6 +94,7 @@ class Socket(object):
self.active_ns = {} # Namespace sessions that were instantiated self.active_ns = {} # Namespace sessions that were instantiated
self.jobs = [] self.jobs = []
self.error_handler = default_error_handler self.error_handler = default_error_handler
self.config = config
if error_handler is not None: if error_handler is not None:
self.error_handler = error_handler self.error_handler = error_handler
@@ -116,6 +126,22 @@ class Socket(object):
""" """
self.error_handler = error_handler self.error_handler = error_handler
def _set_json_loads(self, json_loads):
"""Change the default JSON decoder.
This should be a callable that accepts a single string, and returns
a well-formed object.
"""
self.json_loads = json_loads
def _set_json_dumps(self, json_dumps):
"""Change the default JSON decoder.
This should be a callable that accepts a single string, and returns
a well-formed object.
"""
self.json_dumps = json_dumps
def _get_next_msgid(self): def _get_next_msgid(self):
"""This retrieves the next value for the 'id' field when sending """This retrieves the next value for the 'id' field when sending
an 'event' or 'message' or 'json' that asks the remote client an 'event' or 'message' or 'json' that asks the remote client
@@ -175,9 +201,6 @@ class Socket(object):
def incr_hits(self): def incr_hits(self):
self.hits += 1 self.hits += 1
if self.hits == 1:
self.state = self.STATE_CONNECTED
def heartbeat(self): def heartbeat(self):
"""This makes the heart beat for another X seconds. Call this when """This makes the heart beat for another X seconds. Call this when
you get a heartbeat packet in. you get a heartbeat packet in.
@@ -186,7 +209,7 @@ class Socket(object):
""" """
self.timeout.set() self.timeout.set()
def kill(self): def kill(self, detach=False):
"""This function must/will be called when a socket is to be completely """This function must/will be called when a socket is to be completely
shut down, closed by connection timeout, connection error or explicit shut down, closed by connection timeout, connection error or explicit
disconnection from the client. disconnection from the client.
@@ -202,14 +225,23 @@ class Socket(object):
self.state = self.STATE_DISCONNECTING self.state = self.STATE_DISCONNECTING
self.server_queue.put_nowait(None) self.server_queue.put_nowait(None)
self.client_queue.put_nowait(None) self.client_queue.put_nowait(None)
self.disconnect() if len(self.active_ns) > 0:
log.debug("Calling disconnect() on %s" % self)
self.disconnect()
if self.sessid in self.server.sockets: if detach:
self.server.sockets.pop(self.sessid) self.detach()
# gevent.kill(self.wsgi_app_greenlet) gevent.killall(self.jobs)
else:
pass # Fail silently def detach(self):
"""Detach this socket from the server. This should be done in
conjunction with kill(), once all the jobs are dead, detach the
socket for garbage collection."""
log.debug("Removing %s from server sockets" % self)
if self.sessid in self.server.sockets:
self.server.sockets.pop(self.sessid)
def put_server_msg(self, msg): def put_server_msg(self, msg):
"""Writes to the server's pipe, to end up in in the Namespaces""" """Writes to the server's pipe, to end up in in the Namespaces"""
@@ -218,7 +250,6 @@ class Socket(object):
def put_client_msg(self, msg): def put_client_msg(self, msg):
"""Writes to the client's pipe, to end up in the browser""" """Writes to the client's pipe, to end up in the browser"""
self.heartbeat()
self.client_queue.put_nowait(msg) self.client_queue.put_nowait(msg)
def get_client_msg(self, **kwargs): def get_client_msg(self, **kwargs):
@@ -293,10 +324,13 @@ class Socket(object):
if namespace in self.active_ns: if namespace in self.active_ns:
del self.active_ns[namespace] del self.active_ns[namespace]
if len(self.active_ns) == 0 and self.connected:
self.kill(detach=True)
def send_packet(self, pkt): def send_packet(self, pkt):
"""Low-level interface to queue a packet on the wire (encoded as wire """Low-level interface to queue a packet on the wire (encoded as wire
protocol""" protocol"""
self.put_client_msg(packet.encode(pkt)) self.put_client_msg(packet.encode(pkt, self.json_dumps))
def spawn(self, fn, *args, **kwargs): def spawn(self, fn, *args, **kwargs):
"""Spawn a new Greenlet, attached to this Socket instance. """Spawn a new Greenlet, attached to this Socket instance.
@@ -304,7 +338,7 @@ class Socket(object):
It will be monitored by the "watcher" method It will be monitored by the "watcher" method
""" """
self.debug("Spawning sub-Socket Greenlet: %s" % fn.__name__) log.debug("Spawning sub-Socket Greenlet: %s" % fn.__name__)
job = gevent.spawn(fn, *args, **kwargs) job = gevent.spawn(fn, *args, **kwargs)
self.jobs.append(job) self.jobs.append(job)
return job return job
@@ -312,6 +346,13 @@ class Socket(object):
def _receiver_loop(self): def _receiver_loop(self):
"""This is the loop that takes messages from the queue for the server """This is the loop that takes messages from the queue for the server
to consume, decodes them and dispatches them. to consume, decodes them and dispatches them.
It is the main loop for a socket. We join on this process before
returning control to the web framework.
This process is not tracked by the socket itself, it is not going
to be killed by the ``gevent.killall(socket.jobs)``, so it must
exit gracefully itself.
""" """
while True: while True:
@@ -320,7 +361,7 @@ class Socket(object):
if not rawdata: if not rawdata:
continue # or close the connection ? continue # or close the connection ?
try: try:
pkt = packet.decode(rawdata) pkt = packet.decode(rawdata, self.json_loads)
except (ValueError, KeyError, Exception), e: except (ValueError, KeyError, Exception), e:
self.error('invalid_packet', self.error('invalid_packet',
"There was a decoding error when dealing with packet " "There was a decoding error when dealing with packet "
@@ -334,7 +375,7 @@ class Socket(object):
if pkt['type'] == 'disconnect' and pkt['endpoint'] == '': if pkt['type'] == 'disconnect' and pkt['endpoint'] == '':
# On global namespace, we kill everything. # On global namespace, we kill everything.
self.kill() self.kill(detach=True)
continue continue
endpoint = pkt['endpoint'] endpoint = pkt['endpoint']
@@ -350,8 +391,12 @@ class Socket(object):
new_ns_class = self.namespaces[endpoint] new_ns_class = self.namespaces[endpoint]
pkt_ns = new_ns_class(self.environ, endpoint, pkt_ns = new_ns_class(self.environ, endpoint,
request=self.request) request=self.request)
pkt_ns.initialize() # use this instead of __init__, # This calls initialize() on all the classes and mixins, etc..
# for less confusion # in the order of the MRO
for cls in type(pkt_ns).__mro__:
if hasattr(cls, 'initialize'):
cls.initialize(pkt_ns) # use this instead of __init__,
# for less confusion
self.active_ns[endpoint] = pkt_ns self.active_ns[endpoint] = pkt_ns
@@ -359,15 +404,20 @@ class Socket(object):
# Has the client requested an 'ack' with the reply parameters ? # Has the client requested an 'ack' with the reply parameters ?
if pkt.get('ack') == "data" and pkt.get('id'): if pkt.get('ack') == "data" and pkt.get('id'):
if type(retval) is tuple:
args = list(retval)
else:
args = [retval]
returning_ack = dict(type='ack', ackId=pkt['id'], returning_ack = dict(type='ack', ackId=pkt['id'],
args=retval, args=args,
endpoint=pkt.get('endpoint', '')) endpoint=pkt.get('endpoint', ''))
self.send_packet(returning_ack) self.send_packet(returning_ack)
# Now, are we still connected ? # Now, are we still connected ?
if not self.connected: if not self.connected:
self.kill() # ?? what,s the best clean-up when its not a self.kill(detach=True) # ?? what,s the best clean-up
# user-initiated disconnect # when its not a
# user-initiated disconnect
return return
def _spawn_receiver_loop(self): def _spawn_receiver_loop(self):
@@ -379,50 +429,51 @@ class Socket(object):
return job return job
def _watcher(self): def _watcher(self):
"""Watch if any of the greenlets for a request have died. If so, kill """Watch out if we've been disconnected, in that case, kill
the request and the socket. all the jobs.
"""
# TODO: add that if any of the request.jobs die, kill them all and exit
gevent.sleep(5.0)
"""
while True: while True:
gevent.sleep(1.0) gevent.sleep(1.0)
if not self.connected: if not self.connected:
# Killing Socket-level jobs
gevent.killall(self.jobs)
for ns_name, ns in list(self.active_ns.iteritems()): for ns_name, ns in list(self.active_ns.iteritems()):
ns.recv_disconnect() ns.recv_disconnect()
# Killing Socket-level jobs
gevent.killall(self.jobs)
break break
def _spawn_watcher(self): def _spawn_watcher(self):
"""This one is not waited for with joinall(socket.jobs), as it
is an external watcher, to clean up when everything is done."""
job = gevent.spawn(self._watcher) job = gevent.spawn(self._watcher)
return job return job
def _heartbeat(self): def _heartbeat(self):
"""Start the heartbeat Greenlet to check connection health.""" """Start the heartbeat Greenlet to check connection health."""
self.state = self.STATE_CONNECTED interval = self.config['heartbeat_interval']
while self.connected: while self.connected:
gevent.sleep(5.0) # FIXME: make this a setting gevent.sleep(interval)
# TODO: this process could use a timeout object like the disconnect # TODO: this process could use a timeout object like the disconnect
# timeout thing, and ONLY send packets when none are sent! # timeout thing, and ONLY send packets when none are sent!
# We would do that by calling timeout.set() for a "sending" # We would do that by calling timeout.set() for a "sending"
# timeout. If we're sending 100 messages a second, there is # timeout. If we're sending 100 messages a second, there is
# no need to push some heartbeats in there also. # no need to push some heartbeats in there also.
self.put_client_msg("2::") # TODO: make it a heartbeat packet self.put_client_msg("2::")
def _disconnect_timeout(self): def _heartbeat_timeout(self):
self.timeout.clear() timeout = float(self.config['heartbeat_timeout'])
while True:
self.timeout.clear()
gevent.sleep(0)
wait_res = self.timeout.wait(timeout=timeout)
if not wait_res:
if self.connected:
log.debug("heartbeat timed out, killing socket")
self.kill(detach=True)
return
if self.timeout.wait(10.0):
gevent.spawn(self._disconnect_timeout)
else:
self.kill()
def _spawn_heartbeat(self): def _spawn_heartbeat(self):
"""This functions returns a list of jobs""" """This functions returns a list of jobs"""
job_sender = gevent.spawn(self._heartbeat) self.spawn(self._heartbeat)
job_waiter = gevent.spawn(self._disconnect_timeout) self.spawn(self._heartbeat_timeout)
self.jobs.extend((job_sender, job_waiter))
return job_sender, job_waiter