Upgrade kombu to 3.0.21

This commit is contained in:
Matthew Jones
2014-08-06 15:32:17 -04:00
parent cbe26c4619
commit 07b538e5c5
40 changed files with 412 additions and 177 deletions

View File

@@ -31,7 +31,7 @@ httplib2==0.9 (httplib2/*)
importlib==1.0.3 (importlib/*, needed for Python 2.6 support) importlib==1.0.3 (importlib/*, needed for Python 2.6 support)
iso8601==0.1.10 (iso8601/*) iso8601==0.1.10 (iso8601/*)
keyring==4.0 (keyring/*, excluded bin/keyring) keyring==4.0 (keyring/*, excluded bin/keyring)
kombu==3.0.14 (kombu/*) kombu==3.0.21 (kombu/*)
Markdown==2.4 (markdown/*, excluded bin/markdown_py) Markdown==2.4 (markdown/*, excluded bin/markdown_py)
mock==1.0.1 (mock.py) mock==1.0.1 (mock.py)
ordereddict==1.1 (ordereddict.py, needed for Python 2.6 support) ordereddict==1.1 (ordereddict.py, needed for Python 2.6 support)

View File

@@ -7,7 +7,7 @@ version_info_t = namedtuple(
'version_info_t', ('major', 'minor', 'micro', 'releaselevel', 'serial'), 'version_info_t', ('major', 'minor', 'micro', 'releaselevel', 'serial'),
) )
VERSION = version_info_t(3, 0, 14, '', '') VERSION = version_info_t(3, 0, 21, '', '')
__version__ = '{0.major}.{0.minor}.{0.micro}{0.releaselevel}'.format(VERSION) __version__ = '{0.major}.{0.minor}.{0.micro}{0.releaselevel}'.format(VERSION)
__author__ = 'Ask Solem' __author__ = 'Ask Solem'
__contact__ = 'ask@celeryproject.org' __contact__ = 'ask@celeryproject.org'
@@ -99,6 +99,7 @@ new_module.__dict__.update({
'__homepage__': __homepage__, '__homepage__': __homepage__,
'__docformat__': __docformat__, '__docformat__': __docformat__,
'__package__': package, '__package__': package,
'version_info_t': version_info_t,
'VERSION': VERSION}) 'VERSION': VERSION})
if os.environ.get('KOMBU_LOG_DEBUG'): # pragma: no cover if os.environ.get('KOMBU_LOG_DEBUG'): # pragma: no cover

View File

@@ -272,35 +272,39 @@ class Hub(object):
item() item()
poll_timeout = fire_timers(propagate=propagate) if scheduled else 1 poll_timeout = fire_timers(propagate=propagate) if scheduled else 1
#print('[[[HUB]]]: %s' % (self.repr_active(), ))
if readers or writers: if readers or writers:
to_consolidate = [] to_consolidate = []
try: try:
events = poll(poll_timeout) events = poll(poll_timeout)
#print('[EVENTS]: %s' % (self.nepr_events(events or []), ))
except ValueError: # Issue 882 except ValueError: # Issue 882
raise StopIteration() raise StopIteration()
for fileno, event in events or (): for fd, event in events or ():
if fileno in consolidate and \ if fd in consolidate and \
writers.get(fileno) is None: writers.get(fd) is None:
to_consolidate.append(fileno) to_consolidate.append(fd)
continue continue
cb = cbargs = None cb = cbargs = None
try:
if event & READ: if event & READ:
cb, cbargs = readers[fileno] try:
elif event & WRITE: cb, cbargs = readers[fd]
cb, cbargs = writers[fileno] except KeyError:
elif event & ERR: self.remove_reader(fd)
try: continue
cb, cbargs = (readers.get(fileno) or elif event & WRITE:
writers.get(fileno)) try:
except TypeError: cb, cbargs = writers[fd]
pass except KeyError:
except (KeyError, Empty): self.remove_writer(fd)
hub_remove(fileno) continue
continue elif event & ERR:
try:
cb, cbargs = (readers.get(fd) or
writers.get(fd))
except TypeError:
pass
if cb is None: if cb is None:
continue continue
if isinstance(cb, generator): if isinstance(cb, generator):
@@ -309,11 +313,11 @@ class Hub(object):
except OSError as exc: except OSError as exc:
if get_errno(exc) != errno.EBADF: if get_errno(exc) != errno.EBADF:
raise raise
hub_remove(fileno) hub_remove(fd)
except StopIteration: except StopIteration:
pass pass
except Exception: except Exception:
hub_remove(fileno) hub_remove(fd)
raise raise
else: else:
try: try:

View File

@@ -22,7 +22,6 @@ from amqp import RecoverableConnectionError
from .entity import Exchange, Queue from .entity import Exchange, Queue
from .five import range from .five import range
from .log import get_logger from .log import get_logger
from .messaging import Consumer as _Consumer
from .serialization import registry as serializers from .serialization import registry as serializers
from .utils import uuid from .utils import uuid
@@ -91,33 +90,43 @@ def declaration_cached(entity, channel):
def maybe_declare(entity, channel=None, retry=False, **retry_policy): def maybe_declare(entity, channel=None, retry=False, **retry_policy):
if not entity.is_bound: is_bound = entity.is_bound
if not is_bound:
assert channel assert channel
entity = entity.bind(channel) entity = entity.bind(channel)
if retry:
return _imaybe_declare(entity, **retry_policy)
return _maybe_declare(entity)
if channel is None:
assert is_bound
channel = entity.channel
def _maybe_declare(entity): declared = ident = None
channel = entity.channel if channel.connection and entity.can_cache_declaration:
if not channel.connection:
raise RecoverableConnectionError('channel disconnected')
if entity.can_cache_declaration:
declared = channel.connection.client.declared_entities declared = channel.connection.client.declared_entities
ident = hash(entity) ident = hash(entity)
if ident not in declared: if ident in declared:
entity.declare() return False
declared.add(ident)
return True if retry:
return False return _imaybe_declare(entity, declared, ident,
channel, **retry_policy)
return _maybe_declare(entity, declared, ident, channel)
def _maybe_declare(entity, declared, ident, channel):
channel = channel or entity.channel
if not channel.connection:
raise RecoverableConnectionError('channel disconnected')
entity.declare() entity.declare()
if declared is not None and ident:
declared.add(ident)
return True return True
def _imaybe_declare(entity, **retry_policy): def _imaybe_declare(entity, declared, ident, channel, **retry_policy):
return entity.channel.connection.client.ensure( return entity.channel.connection.client.ensure(
entity, _maybe_declare, **retry_policy)(entity) entity, _maybe_declare, **retry_policy)(
entity, declared, ident, channel)
def drain_consumer(consumer, limit=1, timeout=None, callbacks=None): def drain_consumer(consumer, limit=1, timeout=None, callbacks=None):
@@ -138,8 +147,8 @@ def drain_consumer(consumer, limit=1, timeout=None, callbacks=None):
def itermessages(conn, channel, queue, limit=1, timeout=None, def itermessages(conn, channel, queue, limit=1, timeout=None,
Consumer=_Consumer, callbacks=None, **kwargs): callbacks=None, **kwargs):
return drain_consumer(Consumer(channel, queues=[queue], **kwargs), return drain_consumer(conn.Consumer(channel, queues=[queue], **kwargs),
limit=limit, timeout=timeout, callbacks=callbacks) limit=limit, timeout=timeout, callbacks=callbacks)
@@ -181,8 +190,6 @@ def eventloop(conn, limit=None, timeout=None, ignore_timeouts=False):
except socket.timeout: except socket.timeout:
if timeout and not ignore_timeouts: # pragma: no cover if timeout and not ignore_timeouts: # pragma: no cover
raise raise
except socket.error: # pragma: no cover
pass
def send_reply(exchange, req, msg, def send_reply(exchange, req, msg,

View File

@@ -7,7 +7,7 @@ Compression utilities.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from kombu.utils.encoding import ensure_bytes, bytes_to_str from kombu.utils.encoding import ensure_bytes
import zlib import zlib
@@ -67,7 +67,7 @@ def decompress(body, content_type):
:param content_type: mime-type of compression method used. :param content_type: mime-type of compression method used.
""" """
return bytes_to_str(get_decoder(content_type)(body)) return get_decoder(content_type)(body)
register(zlib.compress, register(zlib.compress,

View File

@@ -11,13 +11,8 @@ import os
import socket import socket
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial
from itertools import count, cycle from itertools import count, cycle
from operator import itemgetter from operator import itemgetter
try:
from urllib.parse import quote
except ImportError: # Py2
from urllib import quote # noqa
# jython breaks on relative import for .exceptions for some reason # jython breaks on relative import for .exceptions for some reason
# (Issue #112) # (Issue #112)
@@ -25,10 +20,10 @@ from kombu import exceptions
from .five import Empty, range, string_t, text_t, LifoQueue as _LifoQueue from .five import Empty, range, string_t, text_t, LifoQueue as _LifoQueue
from .log import get_logger from .log import get_logger
from .transport import get_transport_cls, supports_librabbitmq from .transport import get_transport_cls, supports_librabbitmq
from .utils import cached_property, retry_over_time, shufflecycle from .utils import cached_property, retry_over_time, shufflecycle, HashedSeq
from .utils.compat import OrderedDict from .utils.compat import OrderedDict
from .utils.functional import lazy from .utils.functional import lazy
from .utils.url import parse_url, urlparse from .utils.url import as_url, parse_url, quote, urlparse
__all__ = ['Connection', 'ConnectionPool', 'ChannelPool'] __all__ = ['Connection', 'ConnectionPool', 'ChannelPool']
@@ -199,6 +194,7 @@ class Connection(object):
"""Switch connection parameters to use a new URL (does not """Switch connection parameters to use a new URL (does not
reconnect)""" reconnect)"""
self.close() self.close()
self.declared_entities.clear()
self._closed = False self._closed = False
self._init_params(**dict(self._initial_params, **parse_url(url))) self._init_params(**dict(self._initial_params, **parse_url(url)))
@@ -565,40 +561,27 @@ class Connection(object):
return OrderedDict(self._info()) return OrderedDict(self._info())
def __eqhash__(self): def __eqhash__(self):
return hash('%s|%s|%s|%s|%s|%s' % ( return HashedSeq(self.transport_cls, self.hostname, self.userid,
self.transport_cls, self.hostname, self.userid, self.password, self.virtual_host, self.port,
self.password, self.virtual_host, self.port)) repr(self.transport_options))
def as_uri(self, include_password=False, mask=''): def as_uri(self, include_password=False, mask='**',
getfields=itemgetter('port', 'userid', 'password',
'virtual_host', 'transport')):
"""Convert connection parameters to URL form.""" """Convert connection parameters to URL form."""
hostname = self.hostname or 'localhost' hostname = self.hostname or 'localhost'
if self.transport.can_parse_url: if self.transport.can_parse_url:
if self.uri_prefix: if self.uri_prefix:
return '%s+%s' % (self.uri_prefix, hostname) return '%s+%s' % (self.uri_prefix, hostname)
return self.hostname return self.hostname
quoteS = partial(quote, safe='') # strict quote
fields = self.info() fields = self.info()
port, userid, password, transport = itemgetter( port, userid, password, vhost, transport = getfields(fields)
'port', 'userid', 'password', 'transport' scheme = ('{0}+{1}'.format(self.uri_prefix, transport)
)(fields) if self.uri_prefix else transport)
url = '%s://' % transport return as_url(
if userid or password: scheme, hostname, port, userid, password, quote(vhost),
if userid: sanitize=not include_password, mask=mask,
url += quoteS(userid) )
if password:
if include_password:
url += ':' + quoteS(password)
else:
url += ':' + mask if mask else ''
url += '@'
url += quoteS(fields['hostname'])
if port:
url += ':%s' % (port, )
url += '/' + quote(fields['virtual_host'])
if self.uri_prefix:
return '%s+%s' % (self.uri_prefix, url)
return url
def Pool(self, limit=None, preload=None): def Pool(self, limit=None, preload=None):
"""Pool of connections. """Pool of connections.
@@ -731,6 +714,10 @@ class Connection(object):
def __exit__(self, *args): def __exit__(self, *args):
self.release() self.release()
@property
def qos_semantics_matches_spec(self):
return self.transport.qos_semantics_matches_spec(self.connection)
@property @property
def connected(self): def connected(self):
"""Return true if the connection has been established.""" """Return true if the connection has been established."""

View File

@@ -288,7 +288,7 @@ class Exchange(MaybeChannelBound):
@property @property
def can_cache_declaration(self): def can_cache_declaration(self):
return self.durable and not self.auto_delete return not self.auto_delete
class binding(object): class binding(object):
@@ -672,7 +672,7 @@ class Queue(MaybeChannelBound):
@property @property
def can_cache_declaration(self): def can_cache_declaration(self):
return self.durable and not self.auto_delete return not self.auto_delete
@classmethod @classmethod
def from_dict(self, queue, **options): def from_dict(self, queue, **options):

View File

@@ -10,7 +10,7 @@
""" """
from __future__ import absolute_import from __future__ import absolute_import
############## py3k ######################################################### # ############# py3k #########################################################
import sys import sys
PY3 = sys.version_info[0] == 3 PY3 = sys.version_info[0] == 3
@@ -34,7 +34,7 @@ try:
except NameError: # pragma: no cover except NameError: # pragma: no cover
bytes_t = str # noqa bytes_t = str # noqa
############## time.monotonic ################################################ # ############# time.monotonic ###############################################
if sys.version_info < (3, 3): if sys.version_info < (3, 3):
@@ -89,7 +89,7 @@ try:
except ImportError: except ImportError:
monotonic = _monotonic # noqa monotonic = _monotonic # noqa
############## Py3 <-> Py2 ################################################### # ############# Py3 <-> Py2 ##################################################
if PY3: # pragma: no cover if PY3: # pragma: no cover
import builtins import builtins

View File

@@ -11,6 +11,7 @@ import numbers
from itertools import count from itertools import count
from .common import maybe_declare
from .compression import compress from .compression import compress
from .connection import maybe_channel, is_connection from .connection import maybe_channel, is_connection
from .entity import Exchange, Queue, DELIVERY_MODES from .entity import Exchange, Queue, DELIVERY_MODES
@@ -107,7 +108,6 @@ class Producer(object):
"""Declare the exchange if it hasn't already been declared """Declare the exchange if it hasn't already been declared
during this session.""" during this session."""
if entity: if entity:
from .common import maybe_declare
return maybe_declare(entity, self.channel, retry, **retry_policy) return maybe_declare(entity, self.channel, retry, **retry_policy)
def publish(self, body, routing_key=None, delivery_mode=None, def publish(self, body, routing_key=None, delivery_mode=None,
@@ -521,7 +521,6 @@ class Consumer(object):
whole messages. whole messages.
:param apply_global: Apply new settings globally on all channels. :param apply_global: Apply new settings globally on all channels.
Currently not supported by RabbitMQ.
""" """
return self.channel.basic_qos(prefetch_size, return self.channel.basic_qos(prefetch_size,

View File

@@ -135,7 +135,8 @@ class Node(object):
def reply(self, data, exchange, routing_key, ticket, **kwargs): def reply(self, data, exchange, routing_key, ticket, **kwargs):
self.mailbox._publish_reply(data, exchange, routing_key, ticket, self.mailbox._publish_reply(data, exchange, routing_key, ticket,
channel=self.channel) channel=self.channel,
serializer=self.mailbox.serializer)
class Mailbox(object): class Mailbox(object):
@@ -161,8 +162,12 @@ class Mailbox(object):
#: Only accepts json messages by default. #: Only accepts json messages by default.
accept = ['json'] accept = ['json']
#: Message serializer
serializer = None
def __init__(self, namespace, def __init__(self, namespace,
type='direct', connection=None, clock=None, accept=None): type='direct', connection=None, clock=None,
accept=None, serializer=None):
self.namespace = namespace self.namespace = namespace
self.connection = connection self.connection = connection
self.type = type self.type = type
@@ -172,6 +177,7 @@ class Mailbox(object):
self._tls = local() self._tls = local()
self.unclaimed = defaultdict(deque) self.unclaimed = defaultdict(deque)
self.accept = self.accept if accept is None else accept self.accept = self.accept if accept is None else accept
self.serializer = self.serializer if serializer is None else serializer
def __call__(self, connection): def __call__(self, connection):
bound = copy(self) bound = copy(self)
@@ -204,14 +210,14 @@ class Mailbox(object):
def get_reply_queue(self): def get_reply_queue(self):
oid = self.oid oid = self.oid
return Queue('%s.%s' % (oid, self.reply_exchange.name), return Queue(
exchange=self.reply_exchange, '%s.%s' % (oid, self.reply_exchange.name),
routing_key=oid, exchange=self.reply_exchange,
durable=False, routing_key=oid,
auto_delete=True, durable=False,
queue_arguments={ auto_delete=True,
'x-expires': int(REPLY_QUEUE_EXPIRES * 1000), queue_arguments={'x-expires': int(REPLY_QUEUE_EXPIRES * 1000)},
}) )
@cached_property @cached_property
def reply_queue(self): def reply_queue(self):
@@ -242,7 +248,8 @@ class Mailbox(object):
pass # queue probably deleted and no one is expecting a reply. pass # queue probably deleted and no one is expecting a reply.
def _publish(self, type, arguments, destination=None, def _publish(self, type, arguments, destination=None,
reply_ticket=None, channel=None, timeout=None): reply_ticket=None, channel=None, timeout=None,
serializer=None):
message = {'method': type, message = {'method': type,
'arguments': arguments, 'arguments': arguments,
'destination': destination} 'destination': destination}
@@ -253,16 +260,18 @@ class Mailbox(object):
message.update(ticket=reply_ticket, message.update(ticket=reply_ticket,
reply_to={'exchange': self.reply_exchange.name, reply_to={'exchange': self.reply_exchange.name,
'routing_key': self.oid}) 'routing_key': self.oid})
serializer = serializer or self.serializer
producer = Producer(chan, auto_declare=False) producer = Producer(chan, auto_declare=False)
producer.publish( producer.publish(
message, exchange=exchange.name, declare=[exchange], message, exchange=exchange.name, declare=[exchange],
headers={'clock': self.clock.forward(), headers={'clock': self.clock.forward(),
'expires': time() + timeout if timeout else 0}, 'expires': time() + timeout if timeout else 0},
serializer=serializer,
) )
def _broadcast(self, command, arguments=None, destination=None, def _broadcast(self, command, arguments=None, destination=None,
reply=False, timeout=1, limit=None, reply=False, timeout=1, limit=None,
callback=None, channel=None): callback=None, channel=None, serializer=None):
if destination is not None and \ if destination is not None and \
not isinstance(destination, (list, tuple)): not isinstance(destination, (list, tuple)):
raise ValueError( raise ValueError(
@@ -277,10 +286,12 @@ class Mailbox(object):
if limit is None and destination: if limit is None and destination:
limit = destination and len(destination) or None limit = destination and len(destination) or None
serializer = serializer or self.serializer
self._publish(command, arguments, destination=destination, self._publish(command, arguments, destination=destination,
reply_ticket=reply_ticket, reply_ticket=reply_ticket,
channel=chan, channel=chan,
timeout=timeout) timeout=timeout,
serializer=serializer)
if reply_ticket: if reply_ticket:
return self._collect(reply_ticket, limit=limit, return self._collect(reply_ticket, limit=limit,

View File

@@ -21,7 +21,7 @@ def select_blocking_method(type):
def _detect_environment(): def _detect_environment():
## -eventlet- # ## -eventlet-
if 'eventlet' in sys.modules: if 'eventlet' in sys.modules:
try: try:
from eventlet.patcher import is_monkey_patched as is_eventlet from eventlet.patcher import is_monkey_patched as is_eventlet
@@ -32,7 +32,7 @@ def _detect_environment():
except ImportError: except ImportError:
pass pass
# -gevent- # ## -gevent-
if 'gevent' in sys.modules: if 'gevent' in sys.modules:
try: try:
from gevent import socket as _gsocket from gevent import socket as _gsocket

View File

@@ -105,6 +105,8 @@ class test_maybe_declare(Case):
def test_with_retry(self): def test_with_retry(self):
channel = Mock() channel = Mock()
client = channel.connection.client = Mock()
client.declared_entities = set()
entity = Mock() entity = Mock()
entity.can_cache_declaration = True entity.can_cache_declaration = True
entity.is_bound = True entity.is_bound = True
@@ -265,8 +267,8 @@ class test_itermessages(Case):
conn = self.MockConnection() conn = self.MockConnection()
channel = Mock() channel = Mock()
channel.connection.client = conn channel.connection.client = conn
it = common.itermessages(conn, channel, 'q', limit=1, conn.Consumer = MockConsumer
Consumer=MockConsumer) it = common.itermessages(conn, channel, 'q', limit=1)
ret = next(it) ret = next(it)
self.assertTupleEqual(ret, ('body', 'message')) self.assertTupleEqual(ret, ('body', 'message'))
@@ -279,8 +281,8 @@ class test_itermessages(Case):
conn.should_raise_timeout = True conn.should_raise_timeout = True
channel = Mock() channel = Mock()
channel.connection.client = conn channel.connection.client = conn
it = common.itermessages(conn, channel, 'q', limit=1, conn.Consumer = MockConsumer
Consumer=MockConsumer) it = common.itermessages(conn, channel, 'q', limit=1)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
next(it) next(it)
@@ -291,8 +293,8 @@ class test_itermessages(Case):
deque_instance.popleft.side_effect = IndexError() deque_instance.popleft.side_effect = IndexError()
conn = self.MockConnection() conn = self.MockConnection()
channel = Mock() channel = Mock()
it = common.itermessages(conn, channel, 'q', limit=1, conn.Consumer = MockConsumer
Consumer=MockConsumer) it = common.itermessages(conn, channel, 'q', limit=1)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
next(it) next(it)

View File

@@ -34,7 +34,7 @@ class test_compression(Case):
self.assertIn('application/x-bz2', encoders) self.assertIn('application/x-bz2', encoders)
def test_compress__decompress__zlib(self): def test_compress__decompress__zlib(self):
text = 'The Quick Brown Fox Jumps Over The Lazy Dog' text = b'The Quick Brown Fox Jumps Over The Lazy Dog'
c, ctype = compression.compress(text, 'zlib') c, ctype = compression.compress(text, 'zlib')
self.assertNotEqual(text, c) self.assertNotEqual(text, c)
d = compression.decompress(c, ctype) d = compression.decompress(c, ctype)
@@ -43,7 +43,7 @@ class test_compression(Case):
def test_compress__decompress__bzip2(self): def test_compress__decompress__bzip2(self):
if not self.has_bzip2: if not self.has_bzip2:
raise SkipTest('bzip2 not available') raise SkipTest('bzip2 not available')
text = 'The Brown Quick Fox Over The Lazy Dog Jumps' text = b'The Brown Quick Fox Over The Lazy Dog Jumps'
c, ctype = compression.compress(text, 'bzip2') c, ctype = compression.compress(text, 'bzip2')
self.assertNotEqual(text, c) self.assertNotEqual(text, c)
d = compression.decompress(c, ctype) d = compression.decompress(c, ctype)

View File

@@ -17,7 +17,7 @@ class test_connection_utils(Case):
def setUp(self): def setUp(self):
self.url = 'amqp://user:pass@localhost:5672/my/vhost' self.url = 'amqp://user:pass@localhost:5672/my/vhost'
self.nopass = 'amqp://user@localhost:5672/my/vhost' self.nopass = 'amqp://user:**@localhost:5672/my/vhost'
self.expected = { self.expected = {
'transport': 'amqp', 'transport': 'amqp',
'userid': 'user', 'userid': 'user',
@@ -31,10 +31,6 @@ class test_connection_utils(Case):
result = parse_url(self.url) result = parse_url(self.url)
self.assertDictEqual(result, self.expected) self.assertDictEqual(result, self.expected)
def test_parse_url_mongodb(self):
result = parse_url('mongodb://example.com/')
self.assertEqual(result['hostname'], 'example.com/')
def test_parse_generated_as_uri(self): def test_parse_generated_as_uri(self):
conn = Connection(self.url) conn = Connection(self.url)
info = conn.info() info = conn.info()

View File

@@ -76,7 +76,7 @@ class test_Exchange(Case):
def test_can_cache_declaration(self): def test_can_cache_declaration(self):
self.assertTrue(Exchange('a', durable=True).can_cache_declaration) self.assertTrue(Exchange('a', durable=True).can_cache_declaration)
self.assertFalse(Exchange('a', durable=False).can_cache_declaration) self.assertTrue(Exchange('a', durable=False).can_cache_declaration)
def test_pickle(self): def test_pickle(self):
e1 = Exchange('foo', 'direct') e1 = Exchange('foo', 'direct')
@@ -285,7 +285,7 @@ class test_Queue(Case):
def test_can_cache_declaration(self): def test_can_cache_declaration(self):
self.assertTrue(Queue('a', durable=True).can_cache_declaration) self.assertTrue(Queue('a', durable=True).can_cache_declaration)
self.assertFalse(Queue('a', durable=False).can_cache_declaration) self.assertTrue(Queue('a', durable=False).can_cache_declaration)
def test_eq(self): def test_eq(self):
q1 = Queue('xxx', Exchange('xxx', 'direct'), 'xxx') q1 = Queue('xxx', Exchange('xxx', 'direct'), 'xxx')

View File

@@ -36,7 +36,7 @@ class test_Producer(Case):
p = Producer(None) p = Producer(None)
self.assertFalse(p._channel) self.assertFalse(p._channel)
@patch('kombu.common.maybe_declare') @patch('kombu.messaging.maybe_declare')
def test_maybe_declare(self, maybe_declare): def test_maybe_declare(self, maybe_declare):
p = self.connection.Producer() p = self.connection.Producer()
q = Queue('foo') q = Queue('foo')

View File

@@ -90,7 +90,6 @@ class test_ConsumerMixin(Case):
def test_Consumer_context(self): def test_Consumer_context(self):
c, Acons, Bcons = self._context() c, Acons, Bcons = self._context()
_connref = _chanref = None
with c.Consumer() as (conn, channel, consumer): with c.Consumer() as (conn, channel, consumer):
self.assertIs(conn, c.connection) self.assertIs(conn, c.connection)
@@ -104,7 +103,6 @@ class test_ConsumerMixin(Case):
self.assertIs(subcons.channel, conn.default_channel) self.assertIs(subcons.channel, conn.default_channel)
Acons.__enter__.assert_called_with() Acons.__enter__.assert_called_with()
Bcons.__enter__.assert_called_with() Bcons.__enter__.assert_called_with()
_connref, _chanref = conn, channel
c.on_consume_end.assert_called_with(conn, channel) c.on_consume_end.assert_called_with(conn, channel)

View File

@@ -220,6 +220,9 @@ class test_fun_PoolGroup(Case):
assert eqhash(c1) != eqhash(c2) assert eqhash(c1) != eqhash(c2)
assert eqhash(c1) == eqhash(c3) assert eqhash(c1) == eqhash(c3)
c4 = Connection(c1u, transport_options={'confirm_publish': True})
self.assertNotEqual(eqhash(c3), eqhash(c4))
p1 = pools.connections[c1] p1 = pools.connections[c1]
p2 = pools.connections[c2] p2 = pools.connections[c2]
p3 = pools.connections[c3] p3 = pools.connections[c3]

View File

@@ -38,9 +38,12 @@ class test_syn(Case):
def test_detect_environment_gevent(self): def test_detect_environment_gevent(self):
with patch('gevent.socket', create=True) as m: with patch('gevent.socket', create=True) as m:
prev, socket.socket = socket.socket, m.socket prev, socket.socket = socket.socket, m.socket
self.assertTrue(sys.modules['gevent']) try:
env = syn._detect_environment() self.assertTrue(sys.modules['gevent'])
self.assertEqual(env, 'gevent') env = syn._detect_environment()
self.assertEqual(env, 'gevent')
finally:
socket.socket = prev
def test_detect_environment_no_eventlet_or_gevent(self): def test_detect_environment_no_eventlet_or_gevent(self):
try: try:

View File

@@ -2,7 +2,7 @@ from __future__ import absolute_import
from kombu import Connection from kombu import Connection
from kombu.tests.case import Case, SkipTest, skip_if_not_module from kombu.tests.case import Case, SkipTest, Mock, skip_if_not_module
class MockConnection(dict): class MockConnection(dict):
@@ -16,8 +16,14 @@ class test_mongodb(Case):
def _get_connection(self, url, **kwargs): def _get_connection(self, url, **kwargs):
from kombu.transport import mongodb from kombu.transport import mongodb
class _Channel(mongodb.Channel):
def _create_client(self):
self._client = Mock(name='client')
class Transport(mongodb.Transport): class Transport(mongodb.Transport):
Connection = MockConnection Connection = MockConnection
Channel = _Channel
return Connection(url, transport=Transport, **kwargs).connect() return Connection(url, transport=Transport, **kwargs).connect()
@@ -48,7 +54,7 @@ class test_mongodb(Case):
self.assertEquals(dbname, 'dbname') self.assertEquals(dbname, 'dbname')
@skip_if_not_module('pymongo') @skip_if_not_module('pymongo')
def test_custom_credentions(self): def test_custom_credentials(self):
url = 'mongodb://localhost/dbname' url = 'mongodb://localhost/dbname'
c = self._get_connection(url, userid='foo', password='bar') c = self._get_connection(url, userid='foo', password='bar')
hostname, dbname, options = c.channels[0]._parse_uri() hostname, dbname, options = c.channels[0]._parse_uri()

View File

@@ -220,6 +220,7 @@ class Transport(redis.Transport):
class test_Channel(Case): class test_Channel(Case):
@skip_if_not_module('redis')
def setUp(self): def setUp(self):
self.connection = self.create_connection() self.connection = self.create_connection()
self.channel = self.connection.default_channel self.channel = self.connection.default_channel
@@ -616,10 +617,12 @@ class test_Channel(Case):
self.channel.connection.client.virtual_host = 'dwqeq' self.channel.connection.client.virtual_host = 'dwqeq'
self.channel._connparams() self.channel._connparams()
@skip_if_not_module('redis')
def test_connparams_allows_slash_in_db(self): def test_connparams_allows_slash_in_db(self):
self.channel.connection.client.virtual_host = '/123' self.channel.connection.client.virtual_host = '/123'
self.assertEqual(self.channel._connparams()['db'], 123) self.assertEqual(self.channel._connparams()['db'], 123)
@skip_if_not_module('redis')
def test_connparams_db_can_be_int(self): def test_connparams_db_can_be_int(self):
self.channel.connection.client.virtual_host = 124 self.channel.connection.client.virtual_host = 124
self.assertEqual(self.channel._connparams()['db'], 124) self.assertEqual(self.channel._connparams()['db'], 124)
@@ -630,6 +633,7 @@ class test_Channel(Case):
redis.Channel._new_queue(self.channel, 'elaine', auto_delete=True) redis.Channel._new_queue(self.channel, 'elaine', auto_delete=True)
self.assertIn('elaine', self.channel.auto_delete_queues) self.assertIn('elaine', self.channel.auto_delete_queues)
@skip_if_not_module('redis')
def test_connparams_regular_hostname(self): def test_connparams_regular_hostname(self):
self.channel.connection.client.hostname = 'george.vandelay.com' self.channel.connection.client.hostname = 'george.vandelay.com'
self.assertEqual( self.assertEqual(
@@ -776,13 +780,16 @@ class test_Channel(Case):
with patch('kombu.transport.redis.Channel._create_client'): with patch('kombu.transport.redis.Channel._create_client'):
with Connection('redis+socket:///tmp/redis.sock') as conn: with Connection('redis+socket:///tmp/redis.sock') as conn:
connparams = conn.default_channel._connparams() connparams = conn.default_channel._connparams()
self.assertEqual(connparams['connection_class'], self.assertTrue(issubclass(
redis.redis.UnixDomainSocketConnection) connparams['connection_class'],
redis.redis.UnixDomainSocketConnection,
))
self.assertEqual(connparams['path'], '/tmp/redis.sock') self.assertEqual(connparams['path'], '/tmp/redis.sock')
class test_Redis(Case): class test_Redis(Case):
@skip_if_not_module('redis')
def setUp(self): def setUp(self):
self.connection = Connection(transport=Transport) self.connection = Connection(transport=Transport)
self.exchange = Exchange('test_Redis', type='direct') self.exchange = Exchange('test_Redis', type='direct')
@@ -939,6 +946,7 @@ def _redis_modules():
class test_MultiChannelPoller(Case): class test_MultiChannelPoller(Case):
@skip_if_not_module('redis')
def setUp(self): def setUp(self):
self.Poller = redis.MultiChannelPoller self.Poller = redis.MultiChannelPoller
@@ -1043,7 +1051,6 @@ class test_MultiChannelPoller(Case):
p._channels.clear.assert_called_with() p._channels.clear.assert_called_with()
p._fd_to_chan.clear.assert_called_with() p._fd_to_chan.clear.assert_called_with()
p._chan_to_sock.clear.assert_called_with() p._chan_to_sock.clear.assert_called_with()
self.assertIsNone(p.poller)
def test_register_when_registered_reregisters(self): def test_register_when_registered_reregisters(self):
p = self.Poller() p = self.Poller()

View File

@@ -267,8 +267,8 @@ class test_Channel(Case):
c.exchange_declare(n) c.exchange_declare(n)
c.queue_declare(n) c.queue_declare(n)
c.queue_bind(n, n, n) c.queue_bind(n, n, n)
c.queue_bind(n, n, n) # tests code path that returns # tests code path that returns if queue already bound.
# if queue already bound. c.queue_bind(n, n, n)
c.queue_delete(n, if_empty=True) c.queue_delete(n, if_empty=True)
self.assertIn(n, c.state.bindings) self.assertIn(n, c.state.bindings)

View File

@@ -11,7 +11,9 @@ if sys.version_info >= (3, 0):
else: else:
from StringIO import StringIO, StringIO as BytesIO # noqa from StringIO import StringIO, StringIO as BytesIO # noqa
from kombu import version_info_t
from kombu import utils from kombu import utils
from kombu.utils.text import version_string_as_tuple
from kombu.five import string_t from kombu.five import string_t
from kombu.tests.case import ( from kombu.tests.case import (
@@ -379,3 +381,32 @@ class test_shufflecycle(Case):
next(cycle) next(cycle)
finally: finally:
utils.repeat = prev_repeat utils.repeat = prev_repeat
class test_version_string_as_tuple(Case):
def test_versions(self):
self.assertTupleEqual(
version_string_as_tuple('3'),
version_info_t(3, 0, 0, '', ''),
)
self.assertTupleEqual(
version_string_as_tuple('3.3'),
version_info_t(3, 3, 0, '', ''),
)
self.assertTupleEqual(
version_string_as_tuple('3.3.1'),
version_info_t(3, 3, 1, '', ''),
)
self.assertTupleEqual(
version_string_as_tuple('3.3.1a3'),
version_info_t(3, 3, 1, 'a3', ''),
)
self.assertTupleEqual(
version_string_as_tuple('3.3.1a3-40c32'),
version_info_t(3, 3, 1, 'a3', '40c32'),
)
self.assertEqual(
version_string_as_tuple('3.3.1.a3.40c32'),
version_info_t(3, 3, 1, 'a3', '40c32'),
)

View File

@@ -17,11 +17,27 @@ except ImportError:
pass pass
from struct import unpack from struct import unpack
from amqplib import client_0_8 as amqp
from amqplib.client_0_8 import transport class NA(object):
from amqplib.client_0_8.channel import Channel as _Channel pass
from amqplib.client_0_8.exceptions import AMQPConnectionException
from amqplib.client_0_8.exceptions import AMQPChannelException try:
from amqplib import client_0_8 as amqp
from amqplib.client_0_8 import transport
from amqplib.client_0_8.channel import Channel as _Channel
from amqplib.client_0_8.exceptions import AMQPConnectionException
from amqplib.client_0_8.exceptions import AMQPChannelException
except ImportError: # pragma: no cover
class NAx(object):
pass
amqp = NA
amqp.Connection = NA
transport = _Channel = NA # noqa
# Sphinx crashes if this is NA, must be different class
transport.TCPTransport = transport.SSLTransport = NAx
AMQPConnectionException = AMQPChannelException = NA # noqa
from kombu.five import items from kombu.five import items
from kombu.utils.encoding import str_to_bytes from kombu.utils.encoding import str_to_bytes
@@ -321,6 +337,9 @@ class Transport(base.Transport):
self.client = client self.client = client
self.default_port = kwargs.get('default_port') or self.default_port self.default_port = kwargs.get('default_port') or self.default_port
if amqp is NA:
raise ImportError('Missing amqplib library (pip install amqplib)')
def create_channel(self, connection): def create_channel(self, connection):
return connection.channel() return connection.channel()

View File

@@ -152,6 +152,9 @@ class Transport(object):
return _read return _read
def qos_semantics_matches_spec(self, connection):
return True
def on_readable(self, connection, loop): def on_readable(self, connection, loop):
reader = self.__reader reader = self.__reader
if reader is None: if reader is None:

View File

@@ -10,7 +10,6 @@ Beanstalk transport.
""" """
from __future__ import absolute_import from __future__ import absolute_import
import beanstalkc
import socket import socket
from anyjson import loads, dumps from anyjson import loads, dumps
@@ -20,6 +19,11 @@ from kombu.utils.encoding import bytes_to_str
from . import virtual from . import virtual
try:
import beanstalkc
except ImportError: # pragma: no cover
beanstalkc = None # noqa
DEFAULT_PORT = 11300 DEFAULT_PORT = 11300
__author__ = 'David Ziegler <david.ziegler@gmail.com>' __author__ = 'David Ziegler <david.ziegler@gmail.com>'
@@ -127,16 +131,25 @@ class Transport(virtual.Transport):
default_port = DEFAULT_PORT default_port = DEFAULT_PORT
connection_errors = ( connection_errors = (
virtual.Transport.connection_errors + ( virtual.Transport.connection_errors + (
socket.error, beanstalkc.SocketError, IOError) socket.error, IOError,
getattr(beanstalkc, 'SocketError', None),
)
) )
channel_errors = ( channel_errors = (
virtual.Transport.channel_errors + ( virtual.Transport.channel_errors + (
socket.error, IOError, socket.error, IOError,
beanstalkc.SocketError, getattr(beanstalkc, 'SocketError', None),
beanstalkc.BeanstalkcException) getattr(beanstalkc, 'BeanstalkcException', None),
)
) )
driver_type = 'beanstalk' driver_type = 'beanstalk'
driver_name = 'beanstalkc' driver_name = 'beanstalkc'
def __init__(self, *args, **kwargs):
if beanstalkc is None:
raise ImportError(
'Missing beanstalkc library (pip install beanstalkc)')
super(Transport, self).__init__(*args, **kwargs)
def driver_version(self): def driver_version(self):
return beanstalkc.__version__ return beanstalkc.__version__

View File

@@ -11,7 +11,6 @@ CouchDB transport.
from __future__ import absolute_import from __future__ import absolute_import
import socket import socket
import couchdb
from anyjson import loads, dumps from anyjson import loads, dumps
@@ -21,6 +20,11 @@ from kombu.utils.encoding import bytes_to_str
from . import virtual from . import virtual
try:
import couchdb
except ImportError: # pragma: no cover
couchdb = None # noqa
DEFAULT_PORT = 5984 DEFAULT_PORT = 5984
DEFAULT_DATABASE = 'kombu_default' DEFAULT_DATABASE = 'kombu_default'
@@ -80,7 +84,9 @@ class Channel(virtual.Channel):
port)) port))
# Use username and password if avaliable # Use username and password if avaliable
try: try:
server.resource.credentials = (conninfo.userid, conninfo.password) if conninfo.userid:
server.resource.credentials = (conninfo.userid,
conninfo.password)
except AttributeError: except AttributeError:
pass pass
try: try:
@@ -110,20 +116,27 @@ class Transport(virtual.Transport):
connection_errors = ( connection_errors = (
virtual.Transport.connection_errors + ( virtual.Transport.connection_errors + (
socket.error, socket.error,
couchdb.HTTPError, getattr(couchdb, 'HTTPError', None),
couchdb.ServerError, getattr(couchdb, 'ServerError', None),
couchdb.Unauthorized) getattr(couchdb, 'Unauthorized', None),
)
) )
channel_errors = ( channel_errors = (
virtual.Transport.channel_errors + ( virtual.Transport.channel_errors + (
couchdb.HTTPError, getattr(couchdb, 'HTTPError', None),
couchdb.ServerError, getattr(couchdb, 'ServerError', None),
couchdb.PreconditionFailed, getattr(couchdb, 'PreconditionFailed', None),
couchdb.ResourceConflict, getattr(couchdb, 'ResourceConflict', None),
couchdb.ResourceNotFound) getattr(couchdb, 'ResourceNotFound', None),
)
) )
driver_type = 'couchdb' driver_type = 'couchdb'
driver_name = 'couchdb' driver_name = 'couchdb'
def __init__(self, *args, **kwargs):
if couchdb is None:
raise ImportError('Missing couchdb library (pip install couchdb)')
super(Transport, self).__init__(*args, **kwargs)
def driver_version(self): def driver_version(self):
return couchdb.__version__ return couchdb.__version__

View File

@@ -35,7 +35,6 @@ class Channel(virtual.Channel):
super(Channel, self).basic_consume(queue, *args, **kwargs) super(Channel, self).basic_consume(queue, *args, **kwargs)
def _get(self, queue): def _get(self, queue):
#self.refresh_connection()
m = Queue.objects.fetch(queue) m = Queue.objects.fetch(queue)
if m: if m:
return loads(bytes_to_str(m)) return loads(bytes_to_str(m))

View File

@@ -11,6 +11,7 @@ from __future__ import absolute_import
import os import os
import socket import socket
import warnings
try: try:
import librabbitmq as amqp import librabbitmq as amqp
@@ -24,9 +25,14 @@ except ImportError: # pragma: no cover
from kombu.five import items, values from kombu.five import items, values
from kombu.utils.amq_manager import get_manager from kombu.utils.amq_manager import get_manager
from kombu.utils.text import version_string_as_tuple
from . import base from . import base
W_VERSION = """
librabbitmq version too old to detect RabbitMQ version information
so make sure you are using librabbitmq 1.5 when using rabbitmq > 3.3
"""
DEFAULT_PORT = 5672 DEFAULT_PORT = 5672
NO_SSL_ERROR = """\ NO_SSL_ERROR = """\
@@ -150,6 +156,16 @@ class Transport(base.Transport):
def get_manager(self, *args, **kwargs): def get_manager(self, *args, **kwargs):
return get_manager(self.client, *args, **kwargs) return get_manager(self.client, *args, **kwargs)
def qos_semantics_matches_spec(self, connection):
try:
props = connection.server_properties
except AttributeError:
warnings.warn(UserWarning(W_VERSION))
else:
if props.get('product') == 'RabbitMQ':
return version_string_as_tuple(props['version']) < (3, 3)
return True
@property @property
def default_connection_params(self): def default_connection_params(self):
return {'userid': 'guest', 'password': 'guest', return {'userid': 'guest', 'password': 'guest',

View File

@@ -55,14 +55,14 @@ class BroadcastCursor(object):
def __iter__(self): def __iter__(self):
return self return self
def next(self): def __next__(self):
while True: while True:
try: try:
msg = next(self._cursor) msg = next(self._cursor)
except pymongo.errors.OperationFailure, e: except pymongo.errors.OperationFailure as exc:
# In some cases tailed cursor can become invalid # In some cases tailed cursor can become invalid
# and have to be reinitalized # and have to be reinitalized
if 'not valid at server' in e.message: if 'not valid at server' in exc.message:
self.purge() self.purge()
continue continue
@@ -74,6 +74,7 @@ class BroadcastCursor(object):
self._offset += 1 self._offset += 1
return msg return msg
next = __next__
class Channel(virtual.Channel): class Channel(virtual.Channel):
@@ -86,6 +87,9 @@ class Channel(virtual.Channel):
self._broadcast_cursors = {} self._broadcast_cursors = {}
# Evaluate connection
self._create_client()
def _new_queue(self, queue, **kwargs): def _new_queue(self, queue, **kwargs):
pass pass
@@ -206,7 +210,7 @@ class Channel(virtual.Channel):
self.get_broadcast().ensure_index([('queue', 1)]) self.get_broadcast().ensure_index([('queue', 1)])
self.get_routing().ensure_index([('queue', 1), ('exchange', 1)]) self.get_routing().ensure_index([('queue', 1), ('exchange', 1)])
#TODO Store a more complete exchange metatable in the routing collection # TODO Store a more complete exchange metatable in the routing collection
def get_table(self, exchange): def get_table(self, exchange):
"""Get table of bindings for ``exchange``.""" """Get table of bindings for ``exchange``."""
localRoutes = frozenset(self.state.exchanges[exchange]['table']) localRoutes = frozenset(self.state.exchanges[exchange]['table'])
@@ -249,12 +253,14 @@ class Channel(virtual.Channel):
self._fanout_queues.pop(queue) self._fanout_queues.pop(queue)
def _create_client(self):
self._open()
self._ensure_indexes()
@property @property
def client(self): def client(self):
if self._client is None: if self._client is None:
self._open() self._create_client()
self._ensure_indexes()
return self._client return self._client
def get_messages(self): def get_messages(self):

View File

@@ -11,6 +11,7 @@ import amqp
from kombu.five import items from kombu.five import items
from kombu.utils.amq_manager import get_manager from kombu.utils.amq_manager import get_manager
from kombu.utils.text import version_string_as_tuple
from . import base from . import base
@@ -129,6 +130,12 @@ class Transport(base.Transport):
def heartbeat_check(self, connection, rate=2): def heartbeat_check(self, connection, rate=2):
return connection.heartbeat_tick(rate=rate) return connection.heartbeat_tick(rate=rate)
def qos_semantics_matches_spec(self, connection):
props = connection.server_properties
if props.get('product') == 'RabbitMQ':
return version_string_as_tuple(props['version']) < (3, 3)
return True
@property @property
def default_connection_params(self): def default_connection_params(self):
return {'userid': 'guest', 'password': 'guest', return {'userid': 'guest', 'password': 'guest',

View File

@@ -246,7 +246,6 @@ class MultiChannelPoller(object):
self._channels.clear() self._channels.clear()
self._fd_to_chan.clear() self._fd_to_chan.clear()
self._chan_to_sock.clear() self._chan_to_sock.clear()
self.poller = None
def add(self, channel): def add(self, channel):
self._channels.add(channel) self._channels.add(channel)
@@ -254,6 +253,11 @@ class MultiChannelPoller(object):
def discard(self, channel): def discard(self, channel):
self._channels.discard(channel) self._channels.discard(channel)
def _on_connection_disconnect(self, connection):
sock = getattr(connection, '_sock', None)
if sock is not None:
self.poller.unregister(sock)
def _register(self, channel, client, type): def _register(self, channel, client, type):
if (channel, client, type) in self._chan_to_sock: if (channel, client, type) in self._chan_to_sock:
self._unregister(channel, client, type) self._unregister(channel, client, type)
@@ -450,6 +454,10 @@ class Channel(virtual.Channel):
if self._pool is not None: if self._pool is not None:
self._pool.disconnect() self._pool.disconnect()
def _on_connection_disconnect(self, connection):
if self.connection and self.connection.cycle:
self.connection.cycle._on_connection_disconnect(connection)
def _do_restore_message(self, payload, exchange, routing_key, def _do_restore_message(self, payload, exchange, routing_key,
client=None, leftmost=False): client=None, leftmost=False):
with self.conn_or_acquire(client) as client: with self.conn_or_acquire(client) as client:
@@ -466,6 +474,8 @@ class Channel(virtual.Channel):
crit('Could not restore message: %r', payload, exc_info=True) crit('Could not restore message: %r', payload, exc_info=True)
def _restore(self, message, leftmost=False): def _restore(self, message, leftmost=False):
if not self.ack_emulation:
return super(Channel, self)._restore(message)
tag = message.delivery_tag tag = message.delivery_tag
with self.conn_or_acquire() as client: with self.conn_or_acquire() as client:
P, _ = client.pipeline() \ P, _ = client.pipeline() \
@@ -778,6 +788,19 @@ class Channel(virtual.Channel):
connparams.pop('port', None) connparams.pop('port', None)
connparams['db'] = self._prepare_virtual_host( connparams['db'] = self._prepare_virtual_host(
connparams.pop('virtual_host', None)) connparams.pop('virtual_host', None))
channel = self
connection_cls = (
connparams.get('connection_class') or
redis.Connection
)
class Connection(connection_cls):
def disconnect(self):
channel._on_connection_disconnect(self)
super(Connection, self).disconnect()
connparams['connection_class'] = Connection
return connparams return connparams
def _create_client(self): def _create_client(self):
@@ -888,6 +911,8 @@ class Transport(virtual.Transport):
driver_name = 'redis' driver_name = 'redis'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if redis is None:
raise ImportError('Missing redis library (pip install redis)')
super(Transport, self).__init__(*args, **kwargs) super(Transport, self).__init__(*args, **kwargs)
# Get redis-py exceptions. # Get redis-py exceptions.
@@ -905,6 +930,11 @@ class Transport(virtual.Transport):
add_reader = loop.add_reader add_reader = loop.add_reader
on_readable = self.on_readable on_readable = self.on_readable
def _on_disconnect(connection):
if connection._sock:
loop.remove(connection._sock)
cycle._on_connection_disconnect = _on_disconnect
def on_poll_start(): def on_poll_start():
cycle_poll_start() cycle_poll_start()
[add_reader(fd, on_readable, fd) for fd in cycle.fds] [add_reader(fd, on_readable, fd) for fd in cycle.fds]

View File

@@ -153,6 +153,7 @@ class Transport(virtual.Transport):
default_port = 0 default_port = 0
driver_type = 'sql' driver_type = 'sql'
driver_name = 'sqlalchemy' driver_name = 'sqlalchemy'
connection_errors = (OperationalError, )
def driver_version(self): def driver_version(self):
import sqlalchemy import sqlalchemy

View File

@@ -520,7 +520,7 @@ class Channel(AbstractChannel, base.StdChannel):
return self.typeof(exchange).deliver( return self.typeof(exchange).deliver(
message, exchange, routing_key, **kwargs message, exchange, routing_key, **kwargs
) )
# anon exchange: routing_key is the destintaion queue # anon exchange: routing_key is the destination queue
return self._put(routing_key, message, **kwargs) return self._put(routing_key, message, **kwargs)
def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs): def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs):

View File

@@ -101,6 +101,19 @@ def symbol_by_name(name, aliases={}, imp=None, package=None,
return default return default
class HashedSeq(list):
"""type used for hash() to make sure the hash is not generated
multiple times."""
__slots__ = 'hashvalue'
def __init__(self, *seq):
self[:] = seq
self.hashvalue = hash(seq)
def __hash__(self):
return self.hashvalue
def eqhash(o): def eqhash(o):
try: try:
return o.__eqhash__() return o.__eqhash__()

View File

@@ -8,7 +8,7 @@ Helps compatibility with older Python versions.
from __future__ import absolute_import from __future__ import absolute_import
############## timedelta_seconds() -> delta.total_seconds #################### # ############# timedelta_seconds() -> delta.total_seconds ###################
from datetime import timedelta from datetime import timedelta
HAVE_TIMEDELTA_TOTAL_SECONDS = hasattr(timedelta, 'total_seconds') HAVE_TIMEDELTA_TOTAL_SECONDS = hasattr(timedelta, 'total_seconds')
@@ -36,7 +36,7 @@ else: # pragma: no cover
return 0 return 0
return delta.days * 86400 + delta.seconds + (delta.microseconds / 10e5) return delta.days * 86400 + delta.seconds + (delta.microseconds / 10e5)
############## socket.error.errno ############################################ # ############# socket.error.errno ###########################################
def get_errno(exc): def get_errno(exc):
@@ -53,7 +53,7 @@ def get_errno(exc):
pass pass
return 0 return 0
############## collections.OrderedDict ####################################### # ############# collections.OrderedDict ######################################
try: try:
from collections import OrderedDict from collections import OrderedDict
except ImportError: except ImportError:

View File

@@ -83,7 +83,7 @@ class _epoll(Poller):
def unregister(self, fd): def unregister(self, fd):
try: try:
self._epoll.unregister(fd) self._epoll.unregister(fd)
except (socket.error, ValueError, KeyError): except (socket.error, ValueError, KeyError, TypeError):
pass pass
except (IOError, OSError) as exc: except (IOError, OSError) as exc:
if get_errno(exc) != errno.ENOENT: if get_errno(exc) != errno.ENOENT:
@@ -202,7 +202,14 @@ class _select(Poller):
self.unregister(fd) self.unregister(fd)
def unregister(self, fd): def unregister(self, fd):
fd = fileno(fd) try:
fd = fileno(fd)
except socket.error as exc:
# we don't know the previous fd of this object
# but it will be removed by the next poll iteration.
if get_errno(exc) in SELECT_BAD_FD:
return
raise
self._rfd.discard(fd) self._rfd.discard(fd)
self._wfd.discard(fd) self._wfd.discard(fd)
self._efd.discard(fd) self._efd.discard(fd)

View File

@@ -25,10 +25,10 @@ class TokenBucket(object):
""" """
#: The rate in tokens/second that the bucket will be refilled #: The rate in tokens/second that the bucket will be refilled.
fill_rate = None fill_rate = None
#: Maximum number of tokensin the bucket. #: Maximum number of tokens in the bucket.
capacity = 1 capacity = 1
#: Timestamp of the last time a token was taken out of the bucket. #: Timestamp of the last time a token was taken out of the bucket.

View File

@@ -3,6 +3,9 @@ from __future__ import absolute_import
from difflib import SequenceMatcher from difflib import SequenceMatcher
from kombu import version_info_t
from kombu.five import string_t
def fmatch_iter(needle, haystack, min_ratio=0.6): def fmatch_iter(needle, haystack, min_ratio=0.6):
for key in haystack: for key in haystack:
@@ -18,3 +21,27 @@ def fmatch_best(needle, haystack, min_ratio=0.6):
)[0][1] )[0][1]
except IndexError: except IndexError:
pass pass
def version_string_as_tuple(s):
v = _unpack_version(*s.split('.'))
# X.Y.3a1 -> (X, Y, 3, 'a1')
if isinstance(v.micro, string_t):
v = version_info_t(v.major, v.minor, *_splitmicro(*v[2:]))
# X.Y.3a1-40 -> (X, Y, 3, 'a1', '40')
if not v.serial and v.releaselevel and '-' in v.releaselevel:
v = version_info_t(*list(v[0:3]) + v.releaselevel.split('-'))
return v
def _unpack_version(major, minor=0, micro=0, releaselevel='', serial=''):
return version_info_t(int(major), int(minor), micro, releaselevel, serial)
def _splitmicro(micro, releaselevel='', serial=''):
for index, char in enumerate(micro):
if not char.isdigit():
break
else:
return int(micro or 0), releaselevel, serial
return int(micro[:index]), micro[index:], serial

View File

@@ -1,12 +1,17 @@
from __future__ import absolute_import from __future__ import absolute_import
from functools import partial
try: try:
from urllib.parse import unquote, urlparse, parse_qsl from urllib.parse import parse_qsl, quote, unquote, urlparse
except ImportError: except ImportError:
from urllib import unquote # noqa from urllib import quote, unquote # noqa
from urlparse import urlparse, parse_qsl # noqa from urlparse import urlparse, parse_qsl # noqa
from . import kwdict from . import kwdict
from kombu.five import string_t
safequote = partial(quote, safe='')
def _parse_url(url): def _parse_url(url):
@@ -14,17 +19,9 @@ def _parse_url(url):
schemeless = url[len(scheme) + 3:] schemeless = url[len(scheme) + 3:]
# parse with HTTP URL semantics # parse with HTTP URL semantics
parts = urlparse('http://' + schemeless) parts = urlparse('http://' + schemeless)
# The first pymongo.Connection() argument (host) can be
# a mongodb connection URI. If this is the case, don't
# use port but let pymongo get the port(s) from the URI instead.
# This enables the use of replica sets and sharding.
# See pymongo.Connection() for more info.
port = scheme != 'mongodb' and parts.port or None
hostname = schemeless if scheme == 'mongodb' else parts.hostname
path = parts.path or '' path = parts.path or ''
path = path[1:] if path and path[0] == '/' else path path = path[1:] if path and path[0] == '/' else path
return (scheme, unquote(hostname or '') or None, port, return (scheme, unquote(parts.hostname or '') or None, parts.port,
unquote(parts.username or '') or None, unquote(parts.username or '') or None,
unquote(parts.password or '') or None, unquote(parts.password or '') or None,
unquote(path or '') or None, unquote(path or '') or None,
@@ -36,3 +33,32 @@ def parse_url(url):
return dict(transport=scheme, hostname=host, return dict(transport=scheme, hostname=host,
port=port, userid=user, port=port, userid=user,
password=password, virtual_host=path, **query) password=password, virtual_host=path, **query)
def as_url(scheme, host=None, port=None, user=None, password=None,
path=None, query=None, sanitize=False, mask='**'):
parts = ['{0}://'.format(scheme)]
if user or password:
if user:
parts.append(safequote(user))
if password:
if sanitize:
parts.extend([':', mask] if mask else [':'])
else:
parts.extend([':', safequote(password)])
parts.append('@')
parts.append(safequote(host) if host else '')
if port:
parts.extend([':', port])
parts.extend(['/', path])
return ''.join(str(part) for part in parts if part)
def sanitize_url(url, mask='**'):
return as_url(*_parse_url(url), sanitize=True, mask=mask)
def maybe_sanitize_url(url, mask='**'):
if isinstance(url, string_t) and '://' in url:
return sanitize_url(url, mask)
return url