Upgrade gevent-socketio to 0.3.6

This commit is contained in:
Matthew Jones 2014-08-06 14:00:16 -04:00
parent e3bdd966ed
commit b18b075bc9
13 changed files with 831 additions and 237 deletions

View File

@ -25,7 +25,7 @@ django-split-settings==0.1.1 (split_settings/*)
django-taggit==0.11.2 (taggit/*)
djangorestframework==2.3.13 (rest_framework/*)
django-qsstats-magic==0.7.2 (django-qsstats-magic/*, minor fix in qsstats/__init__.py)
gevent-socketio==0.3.5-rc1 (socketio/*)
gevent-socketio==0.3.6 (socketio/*)
gevent-websocket==0.9.3 (geventwebsocket/*)
httplib2==0.8 (httplib2/*)
importlib==1.0.3 (importlib/*, needed for Python 2.6 support)

View File

@ -6,7 +6,8 @@ import gevent
log = logging.getLogger(__name__)
def socketio_manage(environ, namespaces, request=None, error_handler=None):
def socketio_manage(environ, namespaces, request=None, error_handler=None,
json_loads=None, json_dumps=None):
"""Main SocketIO management function, call from within your Framework of
choice's view.
@ -20,7 +21,7 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None):
use Socket.GLOBAL_NS to be more explicit. So it would look like:
.. code-block:: python
namespaces={'': GlobalNamespace,
'/chat': ChatNamespace}
@ -35,6 +36,11 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None):
The callable you pass in should have the same signature as the default
error handler.
The ``json_loads`` and ``json_dumps`` are overrides for the default
``json.loads`` and ``json.dumps`` function calls. Override these at
the top-most level here. This will affect all sockets created by this
socketio manager, and all namespaces inside.
This function will block the current "view" or "controller" in your
framework to do the recv/send on the socket, and dispatch incoming messages
to your namespaces.
@ -45,6 +51,7 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None):
def my_view(request):
socketio_manage(request.environ, {'': GlobalNamespace}, request)
NOTE: You must understand that this function is going to be called
*only once* per socket opening, *even though* you are using a long
polling mechanism. The subsequent calls (for long polling) will
@ -67,10 +74,14 @@ def socketio_manage(environ, namespaces, request=None, error_handler=None):
if error_handler:
socket._set_error_handler(error_handler)
receiver_loop = socket._spawn_receiver_loop()
watcher = socket._spawn_watcher()
if json_loads:
socket._set_json_loads(json_loads)
if json_dumps:
socket._set_json_dumps(json_dumps)
gevent.joinall([receiver_loop, watcher])
receiver_loop = socket._spawn_receiver_loop()
gevent.joinall([receiver_loop])
# TODO: double check, what happens to the WSGI request here ? it vanishes ?
return

View File

@ -0,0 +1,21 @@
### default json loaders
try:
import simplejson as json
json_decimal_args = {"use_decimal": True} # pragma: no cover
except ImportError:
import json
import decimal
class DecimalEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, decimal.Decimal):
return float(o)
return super(DecimalEncoder, self).default(o)
json_decimal_args = {"cls": DecimalEncoder}
def default_json_dumps(data):
return json.dumps(data, separators=(',', ':'),
**json_decimal_args)
def default_json_loads(data):
return json.loads(data)

View File

@ -0,0 +1,6 @@
class SessionNotFound(Exception):
def __init__(self, sessid):
self.sessid = sessid
def __str__(self):
return "Session %s not found!" % self.sessid

View File

@ -5,17 +5,23 @@ import urlparse
from gevent.pywsgi import WSGIHandler
from socketio import transports
from geventwebsocket.handler import WebSocketHandler
class SocketIOHandler(WSGIHandler):
RE_REQUEST_URL = re.compile(r"""
^/(?P<resource>[^/]+)
/(?P<protocol_version>[^/]+)
^/(?P<resource>.+?)
/1
/(?P<transport_id>[^/]+)
/(?P<sessid>[^/]+)/?$
""", re.X)
RE_HANDSHAKE_URL = re.compile(r"^/(?P<resource>[^/]+)/1/$", re.X)
RE_HANDSHAKE_URL = re.compile(r"^/(?P<resource>.+?)/1/$", re.X)
# new socket.io versions (> 0.9.8) call an obscure url with two slashes
# instead of a transport when disconnecting
# https://github.com/LearnBoost/socket.io-client/blob/0.9.16/lib/socket.js#L361
RE_DISCONNECT_URL = re.compile(r"""
^/(?P<resource>.+?)
/(?P<protocol_version>[^/]+)
//(?P<sessid>[^/]+)/?$
""", re.X)
handler_types = {
'websocket': transports.WebsocketTransport,
@ -26,9 +32,16 @@ class SocketIOHandler(WSGIHandler):
'jsonp-polling': transports.JSONPolling,
}
def __init__(self, *args, **kwargs):
def __init__(self, config, *args, **kwargs):
"""Create a new SocketIOHandler.
:param config: dict Configuration for timeouts and intervals
that will go down to the other components, transports, etc..
"""
self.socketio_connection = False
self.allowed_paths = None
self.config = config
super(SocketIOHandler, self).__init__(*args, **kwargs)
@ -36,7 +49,7 @@ class SocketIOHandler(WSGIHandler):
if self.server.transports:
self.transports = self.server.transports
if not set(self.transports).issubset(set(self.handler_types)):
raise Exception("transports should be elements of: %s" %
raise ValueError("transports should be elements of: %s" %
(self.handler_types.keys()))
def _do_handshake(self, tokens):
@ -44,7 +57,10 @@ class SocketIOHandler(WSGIHandler):
self.log_error("socket.io URL mismatch")
else:
socket = self.server.get_socket()
data = "%s:15:10:%s" % (socket.sessid, ",".join(self.transports))
data = "%s:%s:%s:%s" % (socket.sessid,
self.config['heartbeat_timeout'] or '',
self.config['close_timeout'] or '',
",".join(self.transports))
self.write_smart(data)
def write_jsonp_result(self, data, wrapper="0"):
@ -74,10 +90,16 @@ class SocketIOHandler(WSGIHandler):
self.process_result()
def handle_one_response(self):
"""This function deals with *ONE INCOMING REQUEST* from the web.
It will wire and exchange message to the queues for long-polling
methods, otherwise, will stay alive for websockets.
"""
path = self.environ.get('PATH_INFO')
# Kick non-socket.io requests to our superclass
if not path.lstrip('/').startswith(self.server.resource):
if not path.lstrip('/').startswith(self.server.resource + '/'):
return super(SocketIOHandler, self).handle_one_response()
self.status = None
@ -85,64 +107,118 @@ class SocketIOHandler(WSGIHandler):
self.result = None
self.response_length = 0
self.response_use_chunked = False
# This is analyzed for each and every HTTP requests involved
# in the Socket.IO protocol, whether long-running or long-polling
# (read: websocket or xhr-polling methods)
request_method = self.environ.get("REQUEST_METHOD")
request_tokens = self.RE_REQUEST_URL.match(path)
handshake_tokens = self.RE_HANDSHAKE_URL.match(path)
disconnect_tokens = self.RE_DISCONNECT_URL.match(path)
# Parse request URL and QUERY_STRING and do handshake
if request_tokens:
request_tokens = request_tokens.groupdict()
if handshake_tokens:
# Deal with first handshake here, create the Socket and push
# the config up.
return self._do_handshake(handshake_tokens.groupdict())
elif disconnect_tokens:
# it's a disconnect request via XHR
tokens = disconnect_tokens.groupdict()
elif request_tokens:
tokens = request_tokens.groupdict()
# and continue...
else:
handshake_tokens = self.RE_HANDSHAKE_URL.match(path)
# This is no socket.io request. Let the WSGI app handle it.
return super(SocketIOHandler, self).handle_one_response()
if handshake_tokens:
return self._do_handshake(handshake_tokens.groupdict())
else:
# This is no socket.io request. Let the WSGI app handle it.
return super(SocketIOHandler, self).handle_one_response()
# Setup the transport and socket
transport = self.handler_types.get(request_tokens["transport_id"])
sessid = request_tokens["sessid"]
# Setup socket
sessid = tokens["sessid"]
socket = self.server.get_socket(sessid)
if not socket:
self.handle_bad_request()
return [] # Do not say the session is not found, just bad request
# so they don't start brute forcing to find open sessions
if self.environ['QUERY_STRING'].startswith('disconnect'):
# according to socket.io specs disconnect requests
# have a `disconnect` query string
# https://github.com/LearnBoost/socket.io-spec#forced-socket-disconnection
socket.disconnect()
self.handle_disconnect_request()
return []
# Setup transport
transport = self.handler_types.get(tokens["transport_id"])
# In case this is WebSocket request, switch to the WebSocketHandler
# FIXME: fix this ugly class change
old_class = None
if issubclass(transport, (transports.WebsocketTransport,
transports.FlashSocketTransport)):
self.__class__ = WebSocketHandler
old_class = self.__class__
self.__class__ = self.server.ws_handler_class
self.prevent_wsgi_call = True # thank you
# TODO: any errors, treat them ??
self.handle_one_response()
self.handle_one_response() # does the Websocket dance before we continue
# Make the socket object available for WSGI apps
self.environ['socketio'] = socket
# Create a transport and handle the request likewise
self.transport = transport(self)
self.transport = transport(self, self.config)
jobs = self.transport.connect(socket, request_method)
# Keep track of those jobs (reading, writing and heartbeat jobs) so
# that we can kill them later with Socket.kill()
socket.jobs.extend(jobs)
# transports register their own spawn'd jobs now
self.transport.do_exchange(socket, request_method)
try:
# We'll run the WSGI app if it wasn't already done.
if socket.wsgi_app_greenlet is None:
# TODO: why don't we spawn a call to handle_one_response here ?
# why call directly the WSGI machinery ?
start_response = lambda status, headers, exc=None: None
socket.wsgi_app_greenlet = gevent.spawn(self.application,
self.environ,
start_response)
except:
self.handle_error(*sys.exc_info())
if not socket.connection_established:
# This is executed only on the *first* packet of the establishment
# of the virtual Socket connection.
socket.connection_established = True
socket.state = socket.STATE_CONNECTED
socket._spawn_heartbeat()
socket._spawn_watcher()
# TODO DOUBLE-CHECK: do we need to joinall here ?
gevent.joinall(jobs)
try:
# We'll run the WSGI app if it wasn't already done.
if socket.wsgi_app_greenlet is None:
# TODO: why don't we spawn a call to handle_one_response here ?
# why call directly the WSGI machinery ?
start_response = lambda status, headers, exc=None: None
socket.wsgi_app_greenlet = gevent.spawn(self.application,
self.environ,
start_response)
except:
self.handle_error(*sys.exc_info())
# we need to keep the connection open if we are an open socket
if tokens['transport_id'] in ['flashsocket', 'websocket']:
# wait here for all jobs to finished, when they are done
gevent.joinall(socket.jobs)
# Switch back to the old class so references to this don't use the
# incorrect class. Useful for debugging.
if old_class:
self.__class__ = old_class
# Clean up circular references so they can be garbage collected.
if hasattr(self, 'websocket') and self.websocket:
if hasattr(self.websocket, 'environ'):
del self.websocket.environ
del self.websocket
if self.environ:
del self.environ
def handle_bad_request(self):
self.close_connection = True
self.start_reponse("400 Bad Request", [
self.start_response("400 Bad Request", [
('Content-Type', 'text/plain'),
('Connection', 'close'),
('Content-Length', 0)
])
def handle_disconnect_request(self):
self.close_connection = True
self.start_response("200 OK", [
('Content-Type', 'text/plain'),
('Connection', 'close'),
('Content-Length', 0)

View File

@ -27,6 +27,10 @@ class BaseNamespace(object):
def on_my_second_event(self, whatever):
print "This holds the first arg that was passed", whatever
Handlers are automatically dispatched based on the name of the incoming
event. For example, a 'user message' event will be handled by
``on_user_message()``. To change this, override :meth:`process_event`.
We can also access the full packet directly by making an event handler
that accepts a single argument named 'packet':
@ -43,9 +47,11 @@ class BaseNamespace(object):
self.session = self.socket.session # easily accessible session
self.request = request
self.ns_name = ns_name
self.allowed_methods = None # be careful, None means all methods
# are allowed while an empty list
# means totally closed.
#: Store for ACL allowed methods. Be careful as ``None`` means
#: that all methods are allowed, while an empty list means every
#: method is denied. Value: list of strings or ``None``. You
#: can and should use the various ``acl`` methods to tweak this.
self.allowed_methods = None
self.jobs = []
self.reset_acl()
@ -95,7 +101,7 @@ class BaseNamespace(object):
access to all of the ``on_*()`` and ``recv_*()`` functions,
etc.. methods.
Return something like: ``['on_connect', 'on_public_method']``
Return something like: ``set(['recv_connect', 'on_public_method'])``
You can later modify this list dynamically (inside
``on_connect()`` for example) using:
@ -113,6 +119,9 @@ class BaseNamespace(object):
**Beware**, returning ``None`` leaves the namespace completely
accessible.
The methods that are open are stored in the ``allowed_methods``
attribute of the ``Namespace`` instance.
"""
return None
@ -160,10 +169,7 @@ class BaseNamespace(object):
if not callback:
print "ERROR: No such callback for ackId %s" % packet['ackId']
return
try:
return callback(*(packet['args']))
except TypeError, e:
print "ERROR: Call to callback function failed", packet
return callback(*(packet['args']))
elif packet_type == 'disconnect':
# Force a disconnect on the namespace.
return self.call_method_with_acl('recv_disconnect', packet)
@ -178,17 +184,30 @@ class BaseNamespace(object):
``on_``-prefixed methods. You could then implement your own dispatch.
See the source code for inspiration.
To process events that have callbacks on the client side, you
must define your event with a single parameter: ``packet``.
In this case, it will be the full ``packet`` object and you
can inspect its ``ack`` and ``id`` keys to define if and how
you reply. A correct reply to an event with a callback would
look like this:
There are two ways to deal with callbacks from the client side
(meaning, the browser has a callback waiting for data that this
server will be sending back):
The first one is simply to return an object. If the incoming
packet requested has an 'ack' field set, meaning the browser is
waiting for callback data, it will automatically be packaged
and sent, associated with the 'ackId' from the browser. The
return value must be a *sequence* of elements, that will be
mapped to the positional parameters of the callback function
on the browser side.
If you want to *know* that you're dealing with a packet
that requires a return value, you can do those things manually
by inspecting the ``ack`` and ``id`` keys from the ``packet``
object. Your callback will behave specially if the name of
the argument to your method is ``packet``. It will fill it
with the unprocessed ``packet`` object for your inspection,
like this:
.. code-block:: python
def on_my_callback(self, packet):
if 'ack' in packet':
if 'ack' in packet:
self.emit('go_back', 'param1', id=packet['id'])
"""
args = packet['args']
@ -227,12 +246,17 @@ class BaseNamespace(object):
Those are the two behaviors:
* If there is only one parameter on the dispatched method and
it is equal to ``packet``, then pass in the packet as the
it is named ``packet``, then pass in the packet dict as the
sole parameter.
* Otherwise, pass in the arguments as specified by the
different ``recv_*()`` methods args specs, or the
:meth:`process_event` documentation.
This method will also consider the
``exception_handler_decorator``. See Namespace documentation
for details and examples.
"""
method = getattr(self, method_name, None)
if method is None:
@ -247,6 +271,11 @@ class BaseNamespace(object):
"The server-side method is invalid, as it doesn't "
"have 'self' as its first argument")
return
# Check if we need to decorate to handle exceptions
if hasattr(self, 'exception_handler_decorator'):
method = self.exception_handler_decorator(method)
if len(func_args) == 2 and func_args[1] == 'packet':
return method(packet)
else:
@ -264,10 +293,14 @@ class BaseNamespace(object):
:func:`~socketio.socketio_manage`) without clogging the
memory.
If you override this method, you probably want to only
initialize the variables you're going to use in the events of
this namespace, say, with some default values, but not perform
any operation that assumes authentication/authorization.
If you override this method, you probably want to initialize
the variables you're going to use in the events handled by this
namespace, setup ACLs, etc..
This method is called on all base classes following the _`method resolution order <http://docs.python.org/library/stdtypes.html?highlight=mro#class.__mro__>`
so you don't need to call super() to initialize the mixins or
other derived classes.
"""
pass
@ -432,8 +465,14 @@ class BaseNamespace(object):
It will be monitored by the "watcher" process in the Socket. If the
socket disconnects, all these greenlets are going to be killed, after
calling BaseNamespace.disconnect()
This method uses the ``exception_handler_decorator``. See
Namespace documentation for more information.
"""
# self.log.debug("Spawning sub-Namespace Greenlet: %s" % fn.__name__)
if hasattr(self, 'exception_handler_decorator'):
fn = self.exception_handler_decorator(fn)
new = gevent.spawn(fn, *args, **kwargs)
self.jobs.append(new)
return new
@ -454,8 +493,12 @@ class BaseNamespace(object):
packet = {"type": "disconnect",
"endpoint": self.ns_name}
self.socket.send_packet(packet)
self.socket.remove_namespace(self.ns_name)
self.kill_local_jobs()
# remove_namespace might throw GreenletExit so
# kill_local_jobs must be in finally
try:
self.socket.remove_namespace(self.ns_name)
finally:
self.kill_local_jobs()
def kill_local_jobs(self):
"""Kills all the jobs spawned with BaseNamespace.spawn() on a namespace

View File

@ -1,17 +1,4 @@
try:
import simplejson as json
json_decimal_args = {"use_decimal": True} # pragma: no cover
except ImportError:
import json
import decimal
class DecimalEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, decimal.Decimal):
return float(o)
return super(DecimalEncoder, self).default(o)
json_decimal_args = {"cls": DecimalEncoder}
from socketio.defaultjson import default_json_dumps, default_json_loads
MSG_TYPES = {
'disconnect': 0,
@ -45,7 +32,7 @@ socketio_packet_attributes = ['type', 'name', 'data', 'endpoint', 'args',
'ackId', 'reason', 'advice', 'qs', 'id']
def encode(data):
def encode(data, json_dumps=default_json_dumps):
"""
Encode an attribute dict into a byte string.
"""
@ -72,14 +59,13 @@ def encode(data):
if msg == '3':
payload = data['data']
if msg == '4':
payload = json.dumps(data['data'], separators=(',', ':'),
**json_decimal_args)
payload = json_dumps(data['data'])
if msg == '5':
d = {}
d['name'] = data['name']
if 'args' in data and data['args'] != []:
d['args'] = data['args']
payload = json.dumps(d, separators=(',', ':'), **json_decimal_args)
payload = json_dumps(d)
if 'id' in data:
msg += ':' + str(data['id'])
if data['ack'] == 'data':
@ -98,8 +84,7 @@ def encode(data):
# '6:::' [id] '+' [data]
msg += '::' + data.get('endpoint', '') + ':' + str(data['ackId'])
if 'args' in data and data['args'] != []:
msg += '+' + json.dumps(data['args'], separators=(',', ':'),
**json_decimal_args)
msg += '+' + json_dumps(data['args'])
elif msg == '7':
# '7::' [endpoint] ':' [reason] '+' [advice]
@ -117,7 +102,7 @@ def encode(data):
return msg
def decode(rawstr):
def decode(rawstr, json_loads=default_json_loads):
"""
Decode a rawstr packet arriving from the socket into a dict.
"""
@ -163,11 +148,11 @@ def decode(rawstr):
decoded_msg['data'] = data
elif msg_type == "4": # json msg
decoded_msg['data'] = json.loads(data)
decoded_msg['data'] = json_loads(data)
elif msg_type == "5": # event
try:
data = json.loads(data)
data = json_loads(data)
except ValueError, e:
print("Invalid JSON event message", data)
decoded_msg['args'] = []
@ -182,7 +167,7 @@ def decode(rawstr):
if '+' in data:
ackId, data = data.split('+')
decoded_msg['ackId'] = int(ackId)
decoded_msg['args'] = json.loads(data)
decoded_msg['args'] = json_loads(data)
else:
decoded_msg['ackId'] = int(data)
decoded_msg['args'] = []

View File

@ -0,0 +1,72 @@
import logging
from socketio import socketio_manage
from django.http import HttpResponse
from django.views.decorators.csrf import csrf_exempt
from django.utils.importlib import import_module
# for Django 1.3 support
try:
from django.conf.urls import patterns, url, include
except ImportError:
from django.conf.urls.defaults import patterns, url, include
SOCKETIO_NS = {}
LOADING_SOCKETIO = False
def autodiscover():
"""
Auto-discover INSTALLED_APPS sockets.py modules and fail silently when
not present. NOTE: socketio_autodiscover was inspired/copied from
django.contrib.admin autodiscover
"""
global LOADING_SOCKETIO
if LOADING_SOCKETIO:
return
LOADING_SOCKETIO = True
import imp
from django.conf import settings
for app in settings.INSTALLED_APPS:
try:
app_path = import_module(app).__path__
except AttributeError:
continue
try:
imp.find_module('sockets', app_path)
except ImportError:
continue
import_module("%s.sockets" % app)
LOADING_SOCKETIO = False
class namespace(object):
def __init__(self, name=''):
self.name = name
def __call__(self, handler):
SOCKETIO_NS[self.name] = handler
return handler
@csrf_exempt
def socketio(request):
try:
socketio_manage(request.environ, SOCKETIO_NS, request)
except:
logging.getLogger("socketio").error("Exception while handling socketio connection", exc_info=True)
return HttpResponse("")
urls = patterns("", (r'', socketio))

View File

@ -8,6 +8,7 @@ from gevent.pywsgi import WSGIServer
from socketio.handler import SocketIOHandler
from socketio.policyserver import FlashPolicyServer
from socketio.virtsocket import Socket
from geventwebsocket.handler import WebSocketHandler
__all__ = ['SocketIOServer']
@ -16,19 +17,41 @@ class SocketIOServer(WSGIServer):
"""A WSGI Server with a resource that acts like an SocketIO."""
def __init__(self, *args, **kwargs):
"""
This is just like the standard WSGIServer __init__, except with a
"""This is just like the standard WSGIServer __init__, except with a
few additional ``kwargs``:
:param resource: The URL which has to be identified as a socket.io request. Defaults to the /socket.io/ URL.
:param resource: The URL which has to be identified as a
socket.io request. Defaults to the /socket.io/ URL.
:param transports: Optional list of transports to allow. List of
strings, each string should be one of
handler.SocketIOHandler.handler_types.
:param policy_server: Boolean describing whether or not to use the
Flash policy server. Default True.
:param policy_listener : A tuple containing (host, port) for the
policy server. This is optional and used only if policy server
:param policy_listener: A tuple containing (host, port) for the
policy server. This is optional and used only if policy server
is set to true. The default value is 0.0.0.0:843
:param heartbeat_interval: int The timeout for the server, we
should receive a heartbeat from the client within this
interval. This should be less than the
``heartbeat_timeout``.
:param heartbeat_timeout: int The timeout for the client when
it should send a new heartbeat to the server. This value
is sent to the client after a successful handshake.
:param close_timeout: int The timeout for the client, when it
closes the connection it still X amounts of seconds to do
re open of the connection. This value is sent to the
client after a successful handshake.
:param log_file: str The file in which you want the PyWSGI
server to write its access log. If not specified, it
is sent to `stderr` (with gevent 0.13).
"""
self.sockets = {}
if 'namespace' in kwargs:
@ -40,18 +63,48 @@ class SocketIOServer(WSGIServer):
self.transports = kwargs.pop('transports', None)
if kwargs.pop('policy_server', True):
policylistener = kwargs.pop('policy_listener', (args[0][0], 10843))
try:
address = args[0][0]
except TypeError:
try:
address = args[0].address[0]
except AttributeError:
address = args[0].cfg_addr[0]
policylistener = kwargs.pop('policy_listener', (address, 10843))
self.policy_server = FlashPolicyServer(policylistener)
else:
self.policy_server = None
kwargs['handler_class'] = SocketIOHandler
# Extract other config options
self.config = {
'heartbeat_timeout': 60,
'close_timeout': 60,
'heartbeat_interval': 25,
}
for f in ('heartbeat_timeout', 'heartbeat_interval', 'close_timeout'):
if f in kwargs:
self.config[f] = int(kwargs.pop(f))
if not 'handler_class' in kwargs:
kwargs['handler_class'] = SocketIOHandler
if not 'ws_handler_class' in kwargs:
self.ws_handler_class = WebSocketHandler
else:
self.ws_handler_class = kwargs.pop('ws_handler_class')
log_file = kwargs.pop('log_file', None)
if log_file:
kwargs['log'] = open(log_file, 'a')
super(SocketIOServer, self).__init__(*args, **kwargs)
def start_accepting(self):
if self.policy_server is not None:
try:
self.policy_server.start()
if not self.policy_server.started:
self.policy_server.start()
except error, ex:
sys.stderr.write(
'FAILED to start flash policy server: %s\n' % (ex, ))
@ -60,13 +113,14 @@ class SocketIOServer(WSGIServer):
sys.stderr.write('FAILED to start flash policy server.\n\n')
super(SocketIOServer, self).start_accepting()
def kill(self):
def stop(self, timeout=None):
if self.policy_server is not None:
self.policy_server.kill()
super(SocketIOServer, self).kill()
self.policy_server.stop()
super(SocketIOServer, self).stop(timeout=timeout)
def handle(self, socket, address):
handler = self.handler_class(socket, address, self)
# Pass in the config about timeouts, heartbeats, also...
handler = self.handler_class(self.config, socket, address, self)
handler.handle()
def get_socket(self, sessid=''):
@ -74,10 +128,59 @@ class SocketIOServer(WSGIServer):
socket = self.sockets.get(sessid)
if sessid and not socket:
return None # you ask for a session that doesn't exist!
if socket is None:
socket = Socket(self)
socket = Socket(self, self.config)
self.sockets[socket.sessid] = socket
else:
socket.incr_hits()
return socket
def serve(app, **kw):
_quiet = kw.pop('_quiet', False)
_resource = kw.pop('resource', 'socket.io')
if not _quiet: # pragma: no cover
# idempotent if logging has already been set up
import logging
logging.basicConfig()
host = kw.pop('host', '127.0.0.1')
port = int(kw.pop('port', 6543))
transports = kw.pop('transports', None)
if transports:
transports = [x.strip() for x in transports.split(',')]
policy_server = kw.pop('policy_server', False)
if policy_server in (True, 'True', 'true', 'enable', 'yes', 'on', '1'):
policy_server = True
policy_listener_host = kw.pop('policy_listener_host', host)
policy_listener_port = int(kw.pop('policy_listener_port', 10843))
kw['policy_listener'] = (policy_listener_host, policy_listener_port)
else:
policy_server = False
server = SocketIOServer((host, port),
app,
resource=_resource,
transports=transports,
policy_server=policy_server,
**kw)
if not _quiet:
print('serving on http://%s:%s' % (host, port))
server.serve_forever()
def serve_paste(app, global_conf, **kw):
"""pserve / paster serve / waitress replacement / integration
You can pass as parameters:
transports = websockets, xhr-multipart, xhr-longpolling, etc...
policy_server = True
"""
serve(app, **kw)
return 0

View File

@ -1,44 +1,182 @@
import os
import gevent
import time
from gevent.pool import Pool
from gevent.server import StreamServer
from gunicorn.workers.ggevent import GeventPyWSGIWorker
from gunicorn.workers.ggevent import PyWSGIHandler
from gunicorn.workers.ggevent import GeventResponse
from gunicorn import version_info as gunicorn_version
from socketio.server import SocketIOServer
from socketio.handler import SocketIOHandler
from geventwebsocket.handler import WebSocketHandler
from datetime import datetime
from functools import partial
class GunicornWSGIHandler(PyWSGIHandler, SocketIOHandler):
pass
class GunicornWebSocketWSGIHandler(WebSocketHandler):
def log_request(self):
start = datetime.fromtimestamp(self.time_start)
finish = datetime.fromtimestamp(self.time_finish)
response_time = finish - start
resp = GeventResponse(self.status, [],
self.response_length)
req_headers = [h.split(":", 1) for h in self.headers.headers]
self.server.log.access(
resp, req_headers, self.environ, response_time)
class GeventSocketIOBaseWorker(GeventPyWSGIWorker):
""" The base gunicorn worker class """
transports = None
def __init__(self, age, ppid, socket, app, timeout, cfg, log):
if os.environ.get('POLICY_SERVER', None) is None:
if self.policy_server:
os.environ['POLICY_SERVER'] = 'true'
else:
self.policy_server = False
super(GeventSocketIOBaseWorker, self).__init__(
age, ppid, socket, app, timeout, cfg, log)
def run(self):
self.socket.setblocking(1)
pool = Pool(self.worker_connections)
self.server_class.base_env['wsgi.multiprocess'] = \
self.cfg.workers > 1
server = self.server_class(
self.socket, application=self.wsgi,
spawn=pool, handler_class=self.wsgi_handler,
namespace=self.namespace, policy_server=self.policy_server)
server.start()
try:
while self.alive:
self.notify()
if gunicorn_version >= (0, 17, 0):
servers = []
ssl_args = {}
if self.ppid != os.getppid():
self.log.info("Parent changed, shutting down: %s", self)
break
if self.cfg.is_ssl:
ssl_args = dict(
server_side=True,
do_handshake_on_connect=False,
**self.cfg.ssl_options
)
gevent.sleep(1.0)
for s in self.sockets:
s.setblocking(1)
pool = Pool(self.worker_connections)
if self.server_class is not None:
self.server_class.base_env['wsgi.multiprocess'] = \
self.cfg.workers > 1
except KeyboardInterrupt:
pass
server = self.server_class(
s,
application=self.wsgi,
spawn=pool,
resource=self.resource,
log=self.log,
policy_server=self.policy_server,
handler_class=self.wsgi_handler,
ws_handler_class=self.ws_wsgi_handler,
**ssl_args
)
else:
hfun = partial(self.handle, s)
server = StreamServer(
s, handle=hfun, spawn=pool, **ssl_args)
# try to stop the connections
try:
self.notify()
server.stop(timeout=self.timeout)
except:
pass
server.start()
servers.append(server)
pid = os.getpid()
try:
while self.alive:
self.notify()
if pid == os.getpid() and self.ppid != os.getppid():
self.log.info(
"Parent changed, shutting down: %s", self)
break
gevent.sleep(1.0)
except KeyboardInterrupt:
pass
try:
# Stop accepting requests
[server.stop_accepting() for server in servers]
# Handle current requests until graceful_timeout
ts = time.time()
while time.time() - ts <= self.cfg.graceful_timeout:
accepting = 0
for server in servers:
if server.pool.free_count() != server.pool.size:
accepting += 1
if not accepting:
return
self.notify()
gevent.sleep(1.0)
# Force kill all active the handlers
self.log.warning("Worker graceful timeout (pid:%s)" % self.pid)
[server.stop(timeout=1) for server in servers]
except:
pass
else:
self.socket.setblocking(1)
pool = Pool(self.worker_connections)
self.server_class.base_env['wsgi.multiprocess'] = \
self.cfg.workers > 1
server = self.server_class(
self.socket,
application=self.wsgi,
spawn=pool,
resource=self.resource,
log=self.log,
policy_server=self.policy_server,
handler_class=self.wsgi_handler,
ws_handler_class=self.ws_wsgi_handler,
)
server.start()
pid = os.getpid()
try:
while self.alive:
self.notify()
if pid == os.getpid() and self.ppid != os.getppid():
self.log.info(
"Parent changed, shutting down: %s", self)
break
gevent.sleep(1.0)
except KeyboardInterrupt:
pass
try:
# Stop accepting requests
server.kill()
# Handle current requests until graceful_timeout
ts = time.time()
while time.time() - ts <= self.cfg.graceful_timeout:
if server.pool.free_count() == server.pool.size:
return # all requests was handled
self.notify()
gevent.sleep(1.0)
# Force kill all active the handlers
self.log.warning("Worker graceful timeout (pid:%s)" % self.pid)
server.stop(timeout=1)
except:
pass
class GeventSocketIOWorker(GeventSocketIOBaseWorker):
@ -49,9 +187,20 @@ class GeventSocketIOWorker(GeventSocketIOBaseWorker):
being disabled.
"""
server_class = SocketIOServer
wsgi_handler = SocketIOHandler
wsgi_handler = GunicornWSGIHandler
ws_wsgi_handler = GunicornWebSocketWSGIHandler
# We need to define a namespace for the server, it would be nice if this
# was a configuration option, will probably end up how this implemented,
# for now this is just a proof of concept to make sure this will work
namespace = 'socket.io'
policy_server = False # Don't run the flash policy server
resource = 'socket.io'
policy_server = True
class NginxGeventSocketIOWorker(GeventSocketIOWorker):
"""
Worker which will not attempt to connect via websocket transport
Nginx is not compatible with websockets and therefore will not add the
wsgi.websocket key to the wsgi environment.
"""
transports = ['xhr-polling']

View File

@ -0,0 +1,30 @@
import gevent
import weakref
try:
import redis
except ImportError:
pass
class RedisStorage(object):
def __init__(self, server, **kwargs):
self.server = weakref.proxy(server)
self.jobs = []
self.host = kwargs.get('host', 'localhost')
self.port = kwargs.get('port', 6379)
r = redis.StrictRedis(host=self.host, port=self.port)
self.conn = r.pubsub()
self.spawn(self.listener)
def listener(self):
for m in self.conn.listen():
print("===============NEW MESSAGE!!!====== %s", m)
def spawn(self, fn, *args, **kwargs):
new = gevent.spawn(fn, *args, **kwargs)
self.jobs.append(new)
return new
def new_request(self, environ):
print("===========NEW REQUEST %s===========" % environ)

View File

@ -1,13 +1,21 @@
import gevent
import urllib
import urlparse
from geventwebsocket import WebSocketError
from gevent.queue import Empty
class BaseTransport(object):
"""Base class for all transports. Mostly wraps handler class functions."""
def __init__(self, handler):
def __init__(self, handler, config, **kwargs):
"""Base transport class.
:param config: dict Should contain the config keys, like
``heartbeat_interval``, ``heartbeat_timeout`` and
``close_timeout``.
"""
self.content_type = ("Content-Type", "text/plain; charset=UTF-8")
self.headers = [
("Access-Control-Allow-Origin", "*"),
@ -15,22 +23,28 @@ class BaseTransport(object):
("Access-Control-Allow-Methods", "POST, GET, OPTIONS"),
("Access-Control-Max-Age", 3600),
]
self.headers_list = []
self.handler = handler
self.config = config
def write(self, data=""):
if 'Content-Length' not in self.handler.response_headers_list:
self.handler.response_headers.append(('Content-Length', len(data)))
self.handler.response_headers_list.append('Content-Length')
# Gevent v 0.13
if hasattr(self.handler, 'response_headers_list'):
if 'Content-Length' not in self.handler.response_headers_list:
self.handler.response_headers.append(('Content-Length', len(data)))
self.handler.response_headers_list.append('Content-Length')
elif not hasattr(self.handler, 'provided_content_length') or self.handler.provided_content_length is None:
# Gevent 1.0bX
l = len(data)
self.handler.provided_content_length = l
self.handler.response_headers.append(('Content-Length', l))
self.handler.write(data)
self.handler.write_smart(data)
def start_response(self, status, headers, **kwargs):
if "Content-Type" not in [x[0] for x in headers]:
headers.append(self.content_type)
headers.extend(self.headers)
#print headers
self.handler.start_response(status, headers, **kwargs)
@ -46,15 +60,14 @@ class XHRPollingTransport(BaseTransport):
def get(self, socket):
socket.heartbeat()
payload = self.get_messages_payload(socket, timeout=5.0)
heartbeat_interval = self.config['heartbeat_interval']
payload = self.get_messages_payload(socket, timeout=heartbeat_interval)
if not payload:
payload = "8::" # NOOP
self.start_response("200 OK", [])
self.write(payload)
return []
def _request_body(self):
return self.handler.wsgi_input.readline()
@ -68,8 +81,6 @@ class XHRPollingTransport(BaseTransport):
])
self.write("1")
return []
def get_messages_payload(self, socket, timeout=None):
"""This will fetch the messages from the Socket's queue, and if
there are many messes, pack multiple messages in one payload and return
@ -87,14 +98,16 @@ class XHRPollingTransport(BaseTransport):
``messages`` - List of raw messages to encode, if necessary
"""
if not messages:
if not messages or messages[0] is None:
return ''
if len(messages) == 1:
return messages[0].encode('utf-8')
payload = u''.join(u'\ufffd%d\ufffd%s' % (len(p), p)
for p in messages)
payload = u''.join([(u'\ufffd%d\ufffd%s' % (len(p), p))
for p in messages if p is not None])
# FIXME: why is it so that we must filter None from here ? How
# is it even possible that a None gets in there ?
return payload.encode('utf-8')
@ -115,7 +128,6 @@ class XHRPollingTransport(BaseTransport):
"""
payload = payload.decode('utf-8')
if payload[0] == u"\ufffd":
#print "MULTIMSG FULL", payload
ret = []
while len(payload) != 0:
len_end = payload.find(u"\ufffd", 1)
@ -123,21 +135,19 @@ class XHRPollingTransport(BaseTransport):
msg_start = len_end + 1
msg_end = length + msg_start
message = payload[msg_start:msg_end]
#print "MULTIMSG", length, message
ret.append(message)
payload = payload[msg_end:]
return ret
return [payload]
def connect(self, socket, request_method):
if not socket.connection_confirmed:
socket.connection_confirmed = True
def do_exchange(self, socket, request_method):
if not socket.connection_established:
# Runs only the first time we get a Socket opening
self.start_response("200 OK", [
("Connection", "close"),
])
self.write("1::") # 'connect' packet
return []
return
elif request_method in ("GET", "POST", "OPTIONS"):
return getattr(self, request_method.lower())(socket)
else:
@ -145,21 +155,33 @@ class XHRPollingTransport(BaseTransport):
class JSONPolling(XHRPollingTransport):
def __init__(self, handler):
super(JSONPolling, self).__init__(handler)
def __init__(self, handler, config):
super(JSONPolling, self).__init__(handler, config)
self.content_type = ("Content-Type", "text/javascript; charset=UTF-8")
def _request_body(self):
data = super(JSONPolling, self)._request_body()
# resolve %20%3F's, take out wrapping d="...", etc..
return urllib.unquote_plus(data)[3:-1] \
data = urllib.unquote_plus(data)[3:-1] \
.replace(r'\"', '"') \
.replace(r"\\", "\\")
# For some reason, in case of multiple messages passed in one
# query, IE7 sends it escaped, not utf-8 encoded. This dirty
# hack handled it
if data[0] == "\\":
data = data.decode("unicode_escape").encode("utf-8")
return data
def write(self, data):
"""Just quote out stuff before sending it out"""
args = urlparse.parse_qs(self.handler.environ.get("QUERY_STRING"))
if "i" in args:
i = args["i"]
else:
i = "0"
# TODO: don't we need to quote this data in here ?
super(JSONPolling, self).write("io.j[0]('%s');" % data)
super(JSONPolling, self).write("io.j[%s]('%s');" % (i, data))
class XHRMultipartTransport(XHRPollingTransport):
@ -170,11 +192,9 @@ class XHRMultipartTransport(XHRPollingTransport):
"multipart/x-mixed-replace;boundary=\"socketio\""
)
def connect(self, socket, request_method):
def do_exchange(self, socket, request_method):
if request_method == "GET":
# TODO: double verify this, because we're not sure. look at git revs.
heartbeat = socket._spawn_heartbeat()
return [heartbeat] + self.get(socket)
return self.get(socket)
elif request_method == "POST":
return self.post(socket)
else:
@ -196,22 +216,28 @@ class XHRMultipartTransport(XHRPollingTransport):
if not payload:
# That would mean the call to Queue.get() returned Empty,
# so it was in fact killed, since we pass no timeout=..
socket.kill()
break
return
# See below
else:
try:
self.write_multipart(header)
self.write_multipart(payload)
self.write_multipart("--socketio\r\n")
except socket.error:
socket.kill()
break
# The client might try to reconnect, even with a socket
# error, so let's just let it go, and not kill the
# socket completely. Other processes will ensure
# we kill everything if the user expires the timeouts.
#
# WARN: this means that this payload is LOST, unless we
# decide to re-inject it into the queue.
return
return [gevent.spawn(chunk)]
socket.spawn(chunk)
class WebsocketTransport(BaseTransport):
def connect(self, socket, request_method):
def do_exchange(self, socket, request_method):
websocket = self.handler.environ['wsgi.websocket']
websocket.send("1::") # 'connect' packet
@ -220,27 +246,26 @@ class WebsocketTransport(BaseTransport):
message = socket.get_client_msg()
if message is None:
socket.kill()
break
websocket.send(message)
try:
websocket.send(message)
except (WebSocketError, TypeError):
# We can't send a message on the socket
# it is dead, let the other sockets know
socket.disconnect()
def read_from_ws():
while True:
message = websocket.receive()
if not message:
socket.kill()
if message is None:
break
else:
if message is not None:
socket.put_server_msg(message)
gr1 = gevent.spawn(send_into_ws)
gr2 = gevent.spawn(read_from_ws)
heartbeat1, heartbeat2 = socket._spawn_heartbeat()
return [gr1, gr2, heartbeat1, heartbeat2]
socket.spawn(send_into_ws)
socket.spawn(read_from_ws)
class FlashSocketTransport(WebsocketTransport):
@ -250,21 +275,43 @@ class FlashSocketTransport(WebsocketTransport):
class HTMLFileTransport(XHRPollingTransport):
"""Not tested at all!"""
def __init__(self, handler):
super(HTMLFileTransport, self).__init__(handler)
def __init__(self, handler, config):
super(HTMLFileTransport, self).__init__(handler, config)
self.content_type = ("Content-Type", "text/html")
def write_packed(self, data):
self.write("<script>parent.s._('%s', document);</script>" % data)
self.write("<script>_('%s');</script>" % data)
def handle_get_response(self, socket):
def write(self, data):
l = 1024 * 5
super(HTMLFileTransport, self).write("%d\r\n%s%s\r\n" % (l, data, " " * (l - len(data))))
def do_exchange(self, socket, request_method):
return super(HTMLFileTransport, self).do_exchange(socket, request_method)
def get(self, socket):
self.start_response("200 OK", [
("Connection", "keep-alive"),
("Content-Type", "text/html"),
("Transfer-Encoding", "chunked"),
])
self.write("<html><body>" + " " * 244)
self.write("<html><body><script>var _ = function (msg) { parent.s._(msg, document); };</script>")
self.write_packed("1::") # 'connect' packet
payload = self.get_messages_payload(socket, timeout=5.0)
self.write_packed(payload)
def chunk():
while True:
payload = self.get_messages_payload(socket)
if not payload:
# That would mean the call to Queue.get() returned Empty,
# so it was in fact killed, since we pass no timeout=..
return
else:
try:
self.write_packed(payload)
except socket.error:
# See comments for XHRMultipart
return
socket.spawn(chunk)

View File

@ -11,12 +11,17 @@ in a different way
"""
import random
import weakref
import logging
import gevent
from gevent.queue import Queue
from gevent.event import Event
from socketio import packet
from socketio.defaultjson import default_json_loads, default_json_dumps
log = logging.getLogger(__name__)
def default_error_handler(socket, error_name, error_message, endpoint,
@ -40,10 +45,11 @@ def default_error_handler(socket, error_name, error_message, endpoint,
# Send an error event through the Socket
if not quiet:
socket.send_packet(pkt)
# Log that error somewhere for debugging...
print "default_error_handler: %s, %s (endpoint=%s, msg_id=%s)" % (
error_name, error_message, endpoint, msg_id)
log.error(u"default_error_handler: {}, {} (endpoint={}, msg_id={})".format(
error_name, error_message, endpoint, msg_id
))
class Socket(object):
@ -65,7 +71,10 @@ class Socket(object):
"""Use this to be explicit when specifying a Global Namespace (an endpoint
with no name, not '/chat' or anything."""
def __init__(self, server, error_handler=None):
json_loads = staticmethod(default_json_loads)
json_dumps = staticmethod(default_json_dumps)
def __init__(self, server, config, error_handler=None):
self.server = weakref.proxy(server)
self.sessid = str(random.random())[2:]
self.session = {} # the session dict, for general developer usage
@ -76,7 +85,7 @@ class Socket(object):
self.timeout = Event()
self.wsgi_app_greenlet = None
self.state = "NEW"
self.connection_confirmed = False
self.connection_established = False
self.ack_callbacks = {}
self.ack_counter = 0
self.request = None
@ -85,6 +94,7 @@ class Socket(object):
self.active_ns = {} # Namespace sessions that were instantiated
self.jobs = []
self.error_handler = default_error_handler
self.config = config
if error_handler is not None:
self.error_handler = error_handler
@ -116,6 +126,22 @@ class Socket(object):
"""
self.error_handler = error_handler
def _set_json_loads(self, json_loads):
"""Change the default JSON decoder.
This should be a callable that accepts a single string, and returns
a well-formed object.
"""
self.json_loads = json_loads
def _set_json_dumps(self, json_dumps):
"""Change the default JSON decoder.
This should be a callable that accepts a single string, and returns
a well-formed object.
"""
self.json_dumps = json_dumps
def _get_next_msgid(self):
"""This retrieves the next value for the 'id' field when sending
an 'event' or 'message' or 'json' that asks the remote client
@ -175,9 +201,6 @@ class Socket(object):
def incr_hits(self):
self.hits += 1
if self.hits == 1:
self.state = self.STATE_CONNECTED
def heartbeat(self):
"""This makes the heart beat for another X seconds. Call this when
you get a heartbeat packet in.
@ -186,7 +209,7 @@ class Socket(object):
"""
self.timeout.set()
def kill(self):
def kill(self, detach=False):
"""This function must/will be called when a socket is to be completely
shut down, closed by connection timeout, connection error or explicit
disconnection from the client.
@ -202,14 +225,23 @@ class Socket(object):
self.state = self.STATE_DISCONNECTING
self.server_queue.put_nowait(None)
self.client_queue.put_nowait(None)
self.disconnect()
if len(self.active_ns) > 0:
log.debug("Calling disconnect() on %s" % self)
self.disconnect()
if self.sessid in self.server.sockets:
self.server.sockets.pop(self.sessid)
if detach:
self.detach()
# gevent.kill(self.wsgi_app_greenlet)
else:
pass # Fail silently
gevent.killall(self.jobs)
def detach(self):
"""Detach this socket from the server. This should be done in
conjunction with kill(), once all the jobs are dead, detach the
socket for garbage collection."""
log.debug("Removing %s from server sockets" % self)
if self.sessid in self.server.sockets:
self.server.sockets.pop(self.sessid)
def put_server_msg(self, msg):
"""Writes to the server's pipe, to end up in in the Namespaces"""
@ -218,7 +250,6 @@ class Socket(object):
def put_client_msg(self, msg):
"""Writes to the client's pipe, to end up in the browser"""
self.heartbeat()
self.client_queue.put_nowait(msg)
def get_client_msg(self, **kwargs):
@ -293,10 +324,13 @@ class Socket(object):
if namespace in self.active_ns:
del self.active_ns[namespace]
if len(self.active_ns) == 0 and self.connected:
self.kill(detach=True)
def send_packet(self, pkt):
"""Low-level interface to queue a packet on the wire (encoded as wire
protocol"""
self.put_client_msg(packet.encode(pkt))
self.put_client_msg(packet.encode(pkt, self.json_dumps))
def spawn(self, fn, *args, **kwargs):
"""Spawn a new Greenlet, attached to this Socket instance.
@ -304,7 +338,7 @@ class Socket(object):
It will be monitored by the "watcher" method
"""
self.debug("Spawning sub-Socket Greenlet: %s" % fn.__name__)
log.debug("Spawning sub-Socket Greenlet: %s" % fn.__name__)
job = gevent.spawn(fn, *args, **kwargs)
self.jobs.append(job)
return job
@ -312,6 +346,13 @@ class Socket(object):
def _receiver_loop(self):
"""This is the loop that takes messages from the queue for the server
to consume, decodes them and dispatches them.
It is the main loop for a socket. We join on this process before
returning control to the web framework.
This process is not tracked by the socket itself, it is not going
to be killed by the ``gevent.killall(socket.jobs)``, so it must
exit gracefully itself.
"""
while True:
@ -320,7 +361,7 @@ class Socket(object):
if not rawdata:
continue # or close the connection ?
try:
pkt = packet.decode(rawdata)
pkt = packet.decode(rawdata, self.json_loads)
except (ValueError, KeyError, Exception), e:
self.error('invalid_packet',
"There was a decoding error when dealing with packet "
@ -334,7 +375,7 @@ class Socket(object):
if pkt['type'] == 'disconnect' and pkt['endpoint'] == '':
# On global namespace, we kill everything.
self.kill()
self.kill(detach=True)
continue
endpoint = pkt['endpoint']
@ -350,8 +391,12 @@ class Socket(object):
new_ns_class = self.namespaces[endpoint]
pkt_ns = new_ns_class(self.environ, endpoint,
request=self.request)
pkt_ns.initialize() # use this instead of __init__,
# for less confusion
# This calls initialize() on all the classes and mixins, etc..
# in the order of the MRO
for cls in type(pkt_ns).__mro__:
if hasattr(cls, 'initialize'):
cls.initialize(pkt_ns) # use this instead of __init__,
# for less confusion
self.active_ns[endpoint] = pkt_ns
@ -359,15 +404,20 @@ class Socket(object):
# Has the client requested an 'ack' with the reply parameters ?
if pkt.get('ack') == "data" and pkt.get('id'):
if type(retval) is tuple:
args = list(retval)
else:
args = [retval]
returning_ack = dict(type='ack', ackId=pkt['id'],
args=retval,
args=args,
endpoint=pkt.get('endpoint', ''))
self.send_packet(returning_ack)
# Now, are we still connected ?
if not self.connected:
self.kill() # ?? what,s the best clean-up when its not a
# user-initiated disconnect
self.kill(detach=True) # ?? what,s the best clean-up
# when its not a
# user-initiated disconnect
return
def _spawn_receiver_loop(self):
@ -379,50 +429,51 @@ class Socket(object):
return job
def _watcher(self):
"""Watch if any of the greenlets for a request have died. If so, kill
the request and the socket.
"""
# TODO: add that if any of the request.jobs die, kill them all and exit
gevent.sleep(5.0)
"""Watch out if we've been disconnected, in that case, kill
all the jobs.
"""
while True:
gevent.sleep(1.0)
if not self.connected:
# Killing Socket-level jobs
gevent.killall(self.jobs)
for ns_name, ns in list(self.active_ns.iteritems()):
ns.recv_disconnect()
# Killing Socket-level jobs
gevent.killall(self.jobs)
break
def _spawn_watcher(self):
"""This one is not waited for with joinall(socket.jobs), as it
is an external watcher, to clean up when everything is done."""
job = gevent.spawn(self._watcher)
return job
def _heartbeat(self):
"""Start the heartbeat Greenlet to check connection health."""
self.state = self.STATE_CONNECTED
interval = self.config['heartbeat_interval']
while self.connected:
gevent.sleep(5.0) # FIXME: make this a setting
gevent.sleep(interval)
# TODO: this process could use a timeout object like the disconnect
# timeout thing, and ONLY send packets when none are sent!
# We would do that by calling timeout.set() for a "sending"
# timeout. If we're sending 100 messages a second, there is
# no need to push some heartbeats in there also.
self.put_client_msg("2::") # TODO: make it a heartbeat packet
self.put_client_msg("2::")
def _disconnect_timeout(self):
self.timeout.clear()
def _heartbeat_timeout(self):
timeout = float(self.config['heartbeat_timeout'])
while True:
self.timeout.clear()
gevent.sleep(0)
wait_res = self.timeout.wait(timeout=timeout)
if not wait_res:
if self.connected:
log.debug("heartbeat timed out, killing socket")
self.kill(detach=True)
return
if self.timeout.wait(10.0):
gevent.spawn(self._disconnect_timeout)
else:
self.kill()
def _spawn_heartbeat(self):
"""This functions returns a list of jobs"""
job_sender = gevent.spawn(self._heartbeat)
job_waiter = gevent.spawn(self._disconnect_timeout)
self.jobs.extend((job_sender, job_waiter))
return job_sender, job_waiter
self.spawn(self._heartbeat)
self.spawn(self._heartbeat_timeout)