Upgrade kombu to 3.0.21

This commit is contained in:
Matthew Jones 2014-08-06 15:32:17 -04:00
parent cbe26c4619
commit 07b538e5c5
40 changed files with 412 additions and 177 deletions

View File

@ -31,7 +31,7 @@ httplib2==0.9 (httplib2/*)
importlib==1.0.3 (importlib/*, needed for Python 2.6 support)
iso8601==0.1.10 (iso8601/*)
keyring==4.0 (keyring/*, excluded bin/keyring)
kombu==3.0.14 (kombu/*)
kombu==3.0.21 (kombu/*)
Markdown==2.4 (markdown/*, excluded bin/markdown_py)
mock==1.0.1 (mock.py)
ordereddict==1.1 (ordereddict.py, needed for Python 2.6 support)

View File

@ -7,7 +7,7 @@ version_info_t = namedtuple(
'version_info_t', ('major', 'minor', 'micro', 'releaselevel', 'serial'),
)
VERSION = version_info_t(3, 0, 14, '', '')
VERSION = version_info_t(3, 0, 21, '', '')
__version__ = '{0.major}.{0.minor}.{0.micro}{0.releaselevel}'.format(VERSION)
__author__ = 'Ask Solem'
__contact__ = 'ask@celeryproject.org'
@ -99,6 +99,7 @@ new_module.__dict__.update({
'__homepage__': __homepage__,
'__docformat__': __docformat__,
'__package__': package,
'version_info_t': version_info_t,
'VERSION': VERSION})
if os.environ.get('KOMBU_LOG_DEBUG'): # pragma: no cover

View File

@ -272,35 +272,39 @@ class Hub(object):
item()
poll_timeout = fire_timers(propagate=propagate) if scheduled else 1
#print('[[[HUB]]]: %s' % (self.repr_active(), ))
if readers or writers:
to_consolidate = []
try:
events = poll(poll_timeout)
#print('[EVENTS]: %s' % (self.nepr_events(events or []), ))
except ValueError: # Issue 882
raise StopIteration()
for fileno, event in events or ():
if fileno in consolidate and \
writers.get(fileno) is None:
to_consolidate.append(fileno)
for fd, event in events or ():
if fd in consolidate and \
writers.get(fd) is None:
to_consolidate.append(fd)
continue
cb = cbargs = None
try:
if event & READ:
cb, cbargs = readers[fileno]
elif event & WRITE:
cb, cbargs = writers[fileno]
elif event & ERR:
try:
cb, cbargs = (readers.get(fileno) or
writers.get(fileno))
except TypeError:
pass
except (KeyError, Empty):
hub_remove(fileno)
continue
if event & READ:
try:
cb, cbargs = readers[fd]
except KeyError:
self.remove_reader(fd)
continue
elif event & WRITE:
try:
cb, cbargs = writers[fd]
except KeyError:
self.remove_writer(fd)
continue
elif event & ERR:
try:
cb, cbargs = (readers.get(fd) or
writers.get(fd))
except TypeError:
pass
if cb is None:
continue
if isinstance(cb, generator):
@ -309,11 +313,11 @@ class Hub(object):
except OSError as exc:
if get_errno(exc) != errno.EBADF:
raise
hub_remove(fileno)
hub_remove(fd)
except StopIteration:
pass
except Exception:
hub_remove(fileno)
hub_remove(fd)
raise
else:
try:

View File

@ -22,7 +22,6 @@ from amqp import RecoverableConnectionError
from .entity import Exchange, Queue
from .five import range
from .log import get_logger
from .messaging import Consumer as _Consumer
from .serialization import registry as serializers
from .utils import uuid
@ -91,33 +90,43 @@ def declaration_cached(entity, channel):
def maybe_declare(entity, channel=None, retry=False, **retry_policy):
if not entity.is_bound:
is_bound = entity.is_bound
if not is_bound:
assert channel
entity = entity.bind(channel)
if retry:
return _imaybe_declare(entity, **retry_policy)
return _maybe_declare(entity)
if channel is None:
assert is_bound
channel = entity.channel
def _maybe_declare(entity):
channel = entity.channel
if not channel.connection:
raise RecoverableConnectionError('channel disconnected')
if entity.can_cache_declaration:
declared = ident = None
if channel.connection and entity.can_cache_declaration:
declared = channel.connection.client.declared_entities
ident = hash(entity)
if ident not in declared:
entity.declare()
declared.add(ident)
return True
return False
if ident in declared:
return False
if retry:
return _imaybe_declare(entity, declared, ident,
channel, **retry_policy)
return _maybe_declare(entity, declared, ident, channel)
def _maybe_declare(entity, declared, ident, channel):
channel = channel or entity.channel
if not channel.connection:
raise RecoverableConnectionError('channel disconnected')
entity.declare()
if declared is not None and ident:
declared.add(ident)
return True
def _imaybe_declare(entity, **retry_policy):
def _imaybe_declare(entity, declared, ident, channel, **retry_policy):
return entity.channel.connection.client.ensure(
entity, _maybe_declare, **retry_policy)(entity)
entity, _maybe_declare, **retry_policy)(
entity, declared, ident, channel)
def drain_consumer(consumer, limit=1, timeout=None, callbacks=None):
@ -138,8 +147,8 @@ def drain_consumer(consumer, limit=1, timeout=None, callbacks=None):
def itermessages(conn, channel, queue, limit=1, timeout=None,
Consumer=_Consumer, callbacks=None, **kwargs):
return drain_consumer(Consumer(channel, queues=[queue], **kwargs),
callbacks=None, **kwargs):
return drain_consumer(conn.Consumer(channel, queues=[queue], **kwargs),
limit=limit, timeout=timeout, callbacks=callbacks)
@ -181,8 +190,6 @@ def eventloop(conn, limit=None, timeout=None, ignore_timeouts=False):
except socket.timeout:
if timeout and not ignore_timeouts: # pragma: no cover
raise
except socket.error: # pragma: no cover
pass
def send_reply(exchange, req, msg,

View File

@ -7,7 +7,7 @@ Compression utilities.
"""
from __future__ import absolute_import
from kombu.utils.encoding import ensure_bytes, bytes_to_str
from kombu.utils.encoding import ensure_bytes
import zlib
@ -67,7 +67,7 @@ def decompress(body, content_type):
:param content_type: mime-type of compression method used.
"""
return bytes_to_str(get_decoder(content_type)(body))
return get_decoder(content_type)(body)
register(zlib.compress,

View File

@ -11,13 +11,8 @@ import os
import socket
from contextlib import contextmanager
from functools import partial
from itertools import count, cycle
from operator import itemgetter
try:
from urllib.parse import quote
except ImportError: # Py2
from urllib import quote # noqa
# jython breaks on relative import for .exceptions for some reason
# (Issue #112)
@ -25,10 +20,10 @@ from kombu import exceptions
from .five import Empty, range, string_t, text_t, LifoQueue as _LifoQueue
from .log import get_logger
from .transport import get_transport_cls, supports_librabbitmq
from .utils import cached_property, retry_over_time, shufflecycle
from .utils import cached_property, retry_over_time, shufflecycle, HashedSeq
from .utils.compat import OrderedDict
from .utils.functional import lazy
from .utils.url import parse_url, urlparse
from .utils.url import as_url, parse_url, quote, urlparse
__all__ = ['Connection', 'ConnectionPool', 'ChannelPool']
@ -199,6 +194,7 @@ class Connection(object):
"""Switch connection parameters to use a new URL (does not
reconnect)"""
self.close()
self.declared_entities.clear()
self._closed = False
self._init_params(**dict(self._initial_params, **parse_url(url)))
@ -565,40 +561,27 @@ class Connection(object):
return OrderedDict(self._info())
def __eqhash__(self):
return hash('%s|%s|%s|%s|%s|%s' % (
self.transport_cls, self.hostname, self.userid,
self.password, self.virtual_host, self.port))
return HashedSeq(self.transport_cls, self.hostname, self.userid,
self.password, self.virtual_host, self.port,
repr(self.transport_options))
def as_uri(self, include_password=False, mask=''):
def as_uri(self, include_password=False, mask='**',
getfields=itemgetter('port', 'userid', 'password',
'virtual_host', 'transport')):
"""Convert connection parameters to URL form."""
hostname = self.hostname or 'localhost'
if self.transport.can_parse_url:
if self.uri_prefix:
return '%s+%s' % (self.uri_prefix, hostname)
return self.hostname
quoteS = partial(quote, safe='') # strict quote
fields = self.info()
port, userid, password, transport = itemgetter(
'port', 'userid', 'password', 'transport'
)(fields)
url = '%s://' % transport
if userid or password:
if userid:
url += quoteS(userid)
if password:
if include_password:
url += ':' + quoteS(password)
else:
url += ':' + mask if mask else ''
url += '@'
url += quoteS(fields['hostname'])
if port:
url += ':%s' % (port, )
url += '/' + quote(fields['virtual_host'])
if self.uri_prefix:
return '%s+%s' % (self.uri_prefix, url)
return url
port, userid, password, vhost, transport = getfields(fields)
scheme = ('{0}+{1}'.format(self.uri_prefix, transport)
if self.uri_prefix else transport)
return as_url(
scheme, hostname, port, userid, password, quote(vhost),
sanitize=not include_password, mask=mask,
)
def Pool(self, limit=None, preload=None):
"""Pool of connections.
@ -731,6 +714,10 @@ class Connection(object):
def __exit__(self, *args):
self.release()
@property
def qos_semantics_matches_spec(self):
return self.transport.qos_semantics_matches_spec(self.connection)
@property
def connected(self):
"""Return true if the connection has been established."""

View File

@ -288,7 +288,7 @@ class Exchange(MaybeChannelBound):
@property
def can_cache_declaration(self):
return self.durable and not self.auto_delete
return not self.auto_delete
class binding(object):
@ -672,7 +672,7 @@ class Queue(MaybeChannelBound):
@property
def can_cache_declaration(self):
return self.durable and not self.auto_delete
return not self.auto_delete
@classmethod
def from_dict(self, queue, **options):

View File

@ -10,7 +10,7 @@
"""
from __future__ import absolute_import
############## py3k #########################################################
# ############# py3k #########################################################
import sys
PY3 = sys.version_info[0] == 3
@ -34,7 +34,7 @@ try:
except NameError: # pragma: no cover
bytes_t = str # noqa
############## time.monotonic ################################################
# ############# time.monotonic ###############################################
if sys.version_info < (3, 3):
@ -89,7 +89,7 @@ try:
except ImportError:
monotonic = _monotonic # noqa
############## Py3 <-> Py2 ###################################################
# ############# Py3 <-> Py2 ##################################################
if PY3: # pragma: no cover
import builtins

View File

@ -11,6 +11,7 @@ import numbers
from itertools import count
from .common import maybe_declare
from .compression import compress
from .connection import maybe_channel, is_connection
from .entity import Exchange, Queue, DELIVERY_MODES
@ -107,7 +108,6 @@ class Producer(object):
"""Declare the exchange if it hasn't already been declared
during this session."""
if entity:
from .common import maybe_declare
return maybe_declare(entity, self.channel, retry, **retry_policy)
def publish(self, body, routing_key=None, delivery_mode=None,
@ -521,7 +521,6 @@ class Consumer(object):
whole messages.
:param apply_global: Apply new settings globally on all channels.
Currently not supported by RabbitMQ.
"""
return self.channel.basic_qos(prefetch_size,

View File

@ -135,7 +135,8 @@ class Node(object):
def reply(self, data, exchange, routing_key, ticket, **kwargs):
self.mailbox._publish_reply(data, exchange, routing_key, ticket,
channel=self.channel)
channel=self.channel,
serializer=self.mailbox.serializer)
class Mailbox(object):
@ -161,8 +162,12 @@ class Mailbox(object):
#: Only accepts json messages by default.
accept = ['json']
#: Message serializer
serializer = None
def __init__(self, namespace,
type='direct', connection=None, clock=None, accept=None):
type='direct', connection=None, clock=None,
accept=None, serializer=None):
self.namespace = namespace
self.connection = connection
self.type = type
@ -172,6 +177,7 @@ class Mailbox(object):
self._tls = local()
self.unclaimed = defaultdict(deque)
self.accept = self.accept if accept is None else accept
self.serializer = self.serializer if serializer is None else serializer
def __call__(self, connection):
bound = copy(self)
@ -204,14 +210,14 @@ class Mailbox(object):
def get_reply_queue(self):
oid = self.oid
return Queue('%s.%s' % (oid, self.reply_exchange.name),
exchange=self.reply_exchange,
routing_key=oid,
durable=False,
auto_delete=True,
queue_arguments={
'x-expires': int(REPLY_QUEUE_EXPIRES * 1000),
})
return Queue(
'%s.%s' % (oid, self.reply_exchange.name),
exchange=self.reply_exchange,
routing_key=oid,
durable=False,
auto_delete=True,
queue_arguments={'x-expires': int(REPLY_QUEUE_EXPIRES * 1000)},
)
@cached_property
def reply_queue(self):
@ -242,7 +248,8 @@ class Mailbox(object):
pass # queue probably deleted and no one is expecting a reply.
def _publish(self, type, arguments, destination=None,
reply_ticket=None, channel=None, timeout=None):
reply_ticket=None, channel=None, timeout=None,
serializer=None):
message = {'method': type,
'arguments': arguments,
'destination': destination}
@ -253,16 +260,18 @@ class Mailbox(object):
message.update(ticket=reply_ticket,
reply_to={'exchange': self.reply_exchange.name,
'routing_key': self.oid})
serializer = serializer or self.serializer
producer = Producer(chan, auto_declare=False)
producer.publish(
message, exchange=exchange.name, declare=[exchange],
headers={'clock': self.clock.forward(),
'expires': time() + timeout if timeout else 0},
serializer=serializer,
)
def _broadcast(self, command, arguments=None, destination=None,
reply=False, timeout=1, limit=None,
callback=None, channel=None):
callback=None, channel=None, serializer=None):
if destination is not None and \
not isinstance(destination, (list, tuple)):
raise ValueError(
@ -277,10 +286,12 @@ class Mailbox(object):
if limit is None and destination:
limit = destination and len(destination) or None
serializer = serializer or self.serializer
self._publish(command, arguments, destination=destination,
reply_ticket=reply_ticket,
channel=chan,
timeout=timeout)
timeout=timeout,
serializer=serializer)
if reply_ticket:
return self._collect(reply_ticket, limit=limit,

View File

@ -21,7 +21,7 @@ def select_blocking_method(type):
def _detect_environment():
## -eventlet-
# ## -eventlet-
if 'eventlet' in sys.modules:
try:
from eventlet.patcher import is_monkey_patched as is_eventlet
@ -32,7 +32,7 @@ def _detect_environment():
except ImportError:
pass
# -gevent-
# ## -gevent-
if 'gevent' in sys.modules:
try:
from gevent import socket as _gsocket

View File

@ -105,6 +105,8 @@ class test_maybe_declare(Case):
def test_with_retry(self):
channel = Mock()
client = channel.connection.client = Mock()
client.declared_entities = set()
entity = Mock()
entity.can_cache_declaration = True
entity.is_bound = True
@ -265,8 +267,8 @@ class test_itermessages(Case):
conn = self.MockConnection()
channel = Mock()
channel.connection.client = conn
it = common.itermessages(conn, channel, 'q', limit=1,
Consumer=MockConsumer)
conn.Consumer = MockConsumer
it = common.itermessages(conn, channel, 'q', limit=1)
ret = next(it)
self.assertTupleEqual(ret, ('body', 'message'))
@ -279,8 +281,8 @@ class test_itermessages(Case):
conn.should_raise_timeout = True
channel = Mock()
channel.connection.client = conn
it = common.itermessages(conn, channel, 'q', limit=1,
Consumer=MockConsumer)
conn.Consumer = MockConsumer
it = common.itermessages(conn, channel, 'q', limit=1)
with self.assertRaises(StopIteration):
next(it)
@ -291,8 +293,8 @@ class test_itermessages(Case):
deque_instance.popleft.side_effect = IndexError()
conn = self.MockConnection()
channel = Mock()
it = common.itermessages(conn, channel, 'q', limit=1,
Consumer=MockConsumer)
conn.Consumer = MockConsumer
it = common.itermessages(conn, channel, 'q', limit=1)
with self.assertRaises(StopIteration):
next(it)

View File

@ -34,7 +34,7 @@ class test_compression(Case):
self.assertIn('application/x-bz2', encoders)
def test_compress__decompress__zlib(self):
text = 'The Quick Brown Fox Jumps Over The Lazy Dog'
text = b'The Quick Brown Fox Jumps Over The Lazy Dog'
c, ctype = compression.compress(text, 'zlib')
self.assertNotEqual(text, c)
d = compression.decompress(c, ctype)
@ -43,7 +43,7 @@ class test_compression(Case):
def test_compress__decompress__bzip2(self):
if not self.has_bzip2:
raise SkipTest('bzip2 not available')
text = 'The Brown Quick Fox Over The Lazy Dog Jumps'
text = b'The Brown Quick Fox Over The Lazy Dog Jumps'
c, ctype = compression.compress(text, 'bzip2')
self.assertNotEqual(text, c)
d = compression.decompress(c, ctype)

View File

@ -17,7 +17,7 @@ class test_connection_utils(Case):
def setUp(self):
self.url = 'amqp://user:pass@localhost:5672/my/vhost'
self.nopass = 'amqp://user@localhost:5672/my/vhost'
self.nopass = 'amqp://user:**@localhost:5672/my/vhost'
self.expected = {
'transport': 'amqp',
'userid': 'user',
@ -31,10 +31,6 @@ class test_connection_utils(Case):
result = parse_url(self.url)
self.assertDictEqual(result, self.expected)
def test_parse_url_mongodb(self):
result = parse_url('mongodb://example.com/')
self.assertEqual(result['hostname'], 'example.com/')
def test_parse_generated_as_uri(self):
conn = Connection(self.url)
info = conn.info()

View File

@ -76,7 +76,7 @@ class test_Exchange(Case):
def test_can_cache_declaration(self):
self.assertTrue(Exchange('a', durable=True).can_cache_declaration)
self.assertFalse(Exchange('a', durable=False).can_cache_declaration)
self.assertTrue(Exchange('a', durable=False).can_cache_declaration)
def test_pickle(self):
e1 = Exchange('foo', 'direct')
@ -285,7 +285,7 @@ class test_Queue(Case):
def test_can_cache_declaration(self):
self.assertTrue(Queue('a', durable=True).can_cache_declaration)
self.assertFalse(Queue('a', durable=False).can_cache_declaration)
self.assertTrue(Queue('a', durable=False).can_cache_declaration)
def test_eq(self):
q1 = Queue('xxx', Exchange('xxx', 'direct'), 'xxx')

View File

@ -36,7 +36,7 @@ class test_Producer(Case):
p = Producer(None)
self.assertFalse(p._channel)
@patch('kombu.common.maybe_declare')
@patch('kombu.messaging.maybe_declare')
def test_maybe_declare(self, maybe_declare):
p = self.connection.Producer()
q = Queue('foo')

View File

@ -90,7 +90,6 @@ class test_ConsumerMixin(Case):
def test_Consumer_context(self):
c, Acons, Bcons = self._context()
_connref = _chanref = None
with c.Consumer() as (conn, channel, consumer):
self.assertIs(conn, c.connection)
@ -104,7 +103,6 @@ class test_ConsumerMixin(Case):
self.assertIs(subcons.channel, conn.default_channel)
Acons.__enter__.assert_called_with()
Bcons.__enter__.assert_called_with()
_connref, _chanref = conn, channel
c.on_consume_end.assert_called_with(conn, channel)

View File

@ -220,6 +220,9 @@ class test_fun_PoolGroup(Case):
assert eqhash(c1) != eqhash(c2)
assert eqhash(c1) == eqhash(c3)
c4 = Connection(c1u, transport_options={'confirm_publish': True})
self.assertNotEqual(eqhash(c3), eqhash(c4))
p1 = pools.connections[c1]
p2 = pools.connections[c2]
p3 = pools.connections[c3]

View File

@ -38,9 +38,12 @@ class test_syn(Case):
def test_detect_environment_gevent(self):
with patch('gevent.socket', create=True) as m:
prev, socket.socket = socket.socket, m.socket
self.assertTrue(sys.modules['gevent'])
env = syn._detect_environment()
self.assertEqual(env, 'gevent')
try:
self.assertTrue(sys.modules['gevent'])
env = syn._detect_environment()
self.assertEqual(env, 'gevent')
finally:
socket.socket = prev
def test_detect_environment_no_eventlet_or_gevent(self):
try:

View File

@ -2,7 +2,7 @@ from __future__ import absolute_import
from kombu import Connection
from kombu.tests.case import Case, SkipTest, skip_if_not_module
from kombu.tests.case import Case, SkipTest, Mock, skip_if_not_module
class MockConnection(dict):
@ -16,8 +16,14 @@ class test_mongodb(Case):
def _get_connection(self, url, **kwargs):
from kombu.transport import mongodb
class _Channel(mongodb.Channel):
def _create_client(self):
self._client = Mock(name='client')
class Transport(mongodb.Transport):
Connection = MockConnection
Channel = _Channel
return Connection(url, transport=Transport, **kwargs).connect()
@ -48,7 +54,7 @@ class test_mongodb(Case):
self.assertEquals(dbname, 'dbname')
@skip_if_not_module('pymongo')
def test_custom_credentions(self):
def test_custom_credentials(self):
url = 'mongodb://localhost/dbname'
c = self._get_connection(url, userid='foo', password='bar')
hostname, dbname, options = c.channels[0]._parse_uri()

View File

@ -220,6 +220,7 @@ class Transport(redis.Transport):
class test_Channel(Case):
@skip_if_not_module('redis')
def setUp(self):
self.connection = self.create_connection()
self.channel = self.connection.default_channel
@ -616,10 +617,12 @@ class test_Channel(Case):
self.channel.connection.client.virtual_host = 'dwqeq'
self.channel._connparams()
@skip_if_not_module('redis')
def test_connparams_allows_slash_in_db(self):
self.channel.connection.client.virtual_host = '/123'
self.assertEqual(self.channel._connparams()['db'], 123)
@skip_if_not_module('redis')
def test_connparams_db_can_be_int(self):
self.channel.connection.client.virtual_host = 124
self.assertEqual(self.channel._connparams()['db'], 124)
@ -630,6 +633,7 @@ class test_Channel(Case):
redis.Channel._new_queue(self.channel, 'elaine', auto_delete=True)
self.assertIn('elaine', self.channel.auto_delete_queues)
@skip_if_not_module('redis')
def test_connparams_regular_hostname(self):
self.channel.connection.client.hostname = 'george.vandelay.com'
self.assertEqual(
@ -776,13 +780,16 @@ class test_Channel(Case):
with patch('kombu.transport.redis.Channel._create_client'):
with Connection('redis+socket:///tmp/redis.sock') as conn:
connparams = conn.default_channel._connparams()
self.assertEqual(connparams['connection_class'],
redis.redis.UnixDomainSocketConnection)
self.assertTrue(issubclass(
connparams['connection_class'],
redis.redis.UnixDomainSocketConnection,
))
self.assertEqual(connparams['path'], '/tmp/redis.sock')
class test_Redis(Case):
@skip_if_not_module('redis')
def setUp(self):
self.connection = Connection(transport=Transport)
self.exchange = Exchange('test_Redis', type='direct')
@ -939,6 +946,7 @@ def _redis_modules():
class test_MultiChannelPoller(Case):
@skip_if_not_module('redis')
def setUp(self):
self.Poller = redis.MultiChannelPoller
@ -1043,7 +1051,6 @@ class test_MultiChannelPoller(Case):
p._channels.clear.assert_called_with()
p._fd_to_chan.clear.assert_called_with()
p._chan_to_sock.clear.assert_called_with()
self.assertIsNone(p.poller)
def test_register_when_registered_reregisters(self):
p = self.Poller()

View File

@ -267,8 +267,8 @@ class test_Channel(Case):
c.exchange_declare(n)
c.queue_declare(n)
c.queue_bind(n, n, n)
c.queue_bind(n, n, n) # tests code path that returns
# if queue already bound.
# tests code path that returns if queue already bound.
c.queue_bind(n, n, n)
c.queue_delete(n, if_empty=True)
self.assertIn(n, c.state.bindings)

View File

@ -11,7 +11,9 @@ if sys.version_info >= (3, 0):
else:
from StringIO import StringIO, StringIO as BytesIO # noqa
from kombu import version_info_t
from kombu import utils
from kombu.utils.text import version_string_as_tuple
from kombu.five import string_t
from kombu.tests.case import (
@ -379,3 +381,32 @@ class test_shufflecycle(Case):
next(cycle)
finally:
utils.repeat = prev_repeat
class test_version_string_as_tuple(Case):
def test_versions(self):
self.assertTupleEqual(
version_string_as_tuple('3'),
version_info_t(3, 0, 0, '', ''),
)
self.assertTupleEqual(
version_string_as_tuple('3.3'),
version_info_t(3, 3, 0, '', ''),
)
self.assertTupleEqual(
version_string_as_tuple('3.3.1'),
version_info_t(3, 3, 1, '', ''),
)
self.assertTupleEqual(
version_string_as_tuple('3.3.1a3'),
version_info_t(3, 3, 1, 'a3', ''),
)
self.assertTupleEqual(
version_string_as_tuple('3.3.1a3-40c32'),
version_info_t(3, 3, 1, 'a3', '40c32'),
)
self.assertEqual(
version_string_as_tuple('3.3.1.a3.40c32'),
version_info_t(3, 3, 1, 'a3', '40c32'),
)

View File

@ -17,11 +17,27 @@ except ImportError:
pass
from struct import unpack
from amqplib import client_0_8 as amqp
from amqplib.client_0_8 import transport
from amqplib.client_0_8.channel import Channel as _Channel
from amqplib.client_0_8.exceptions import AMQPConnectionException
from amqplib.client_0_8.exceptions import AMQPChannelException
class NA(object):
pass
try:
from amqplib import client_0_8 as amqp
from amqplib.client_0_8 import transport
from amqplib.client_0_8.channel import Channel as _Channel
from amqplib.client_0_8.exceptions import AMQPConnectionException
from amqplib.client_0_8.exceptions import AMQPChannelException
except ImportError: # pragma: no cover
class NAx(object):
pass
amqp = NA
amqp.Connection = NA
transport = _Channel = NA # noqa
# Sphinx crashes if this is NA, must be different class
transport.TCPTransport = transport.SSLTransport = NAx
AMQPConnectionException = AMQPChannelException = NA # noqa
from kombu.five import items
from kombu.utils.encoding import str_to_bytes
@ -321,6 +337,9 @@ class Transport(base.Transport):
self.client = client
self.default_port = kwargs.get('default_port') or self.default_port
if amqp is NA:
raise ImportError('Missing amqplib library (pip install amqplib)')
def create_channel(self, connection):
return connection.channel()

View File

@ -152,6 +152,9 @@ class Transport(object):
return _read
def qos_semantics_matches_spec(self, connection):
return True
def on_readable(self, connection, loop):
reader = self.__reader
if reader is None:

View File

@ -10,7 +10,6 @@ Beanstalk transport.
"""
from __future__ import absolute_import
import beanstalkc
import socket
from anyjson import loads, dumps
@ -20,6 +19,11 @@ from kombu.utils.encoding import bytes_to_str
from . import virtual
try:
import beanstalkc
except ImportError: # pragma: no cover
beanstalkc = None # noqa
DEFAULT_PORT = 11300
__author__ = 'David Ziegler <david.ziegler@gmail.com>'
@ -127,16 +131,25 @@ class Transport(virtual.Transport):
default_port = DEFAULT_PORT
connection_errors = (
virtual.Transport.connection_errors + (
socket.error, beanstalkc.SocketError, IOError)
socket.error, IOError,
getattr(beanstalkc, 'SocketError', None),
)
)
channel_errors = (
virtual.Transport.channel_errors + (
socket.error, IOError,
beanstalkc.SocketError,
beanstalkc.BeanstalkcException)
getattr(beanstalkc, 'SocketError', None),
getattr(beanstalkc, 'BeanstalkcException', None),
)
)
driver_type = 'beanstalk'
driver_name = 'beanstalkc'
def __init__(self, *args, **kwargs):
if beanstalkc is None:
raise ImportError(
'Missing beanstalkc library (pip install beanstalkc)')
super(Transport, self).__init__(*args, **kwargs)
def driver_version(self):
return beanstalkc.__version__

View File

@ -11,7 +11,6 @@ CouchDB transport.
from __future__ import absolute_import
import socket
import couchdb
from anyjson import loads, dumps
@ -21,6 +20,11 @@ from kombu.utils.encoding import bytes_to_str
from . import virtual
try:
import couchdb
except ImportError: # pragma: no cover
couchdb = None # noqa
DEFAULT_PORT = 5984
DEFAULT_DATABASE = 'kombu_default'
@ -80,7 +84,9 @@ class Channel(virtual.Channel):
port))
# Use username and password if avaliable
try:
server.resource.credentials = (conninfo.userid, conninfo.password)
if conninfo.userid:
server.resource.credentials = (conninfo.userid,
conninfo.password)
except AttributeError:
pass
try:
@ -110,20 +116,27 @@ class Transport(virtual.Transport):
connection_errors = (
virtual.Transport.connection_errors + (
socket.error,
couchdb.HTTPError,
couchdb.ServerError,
couchdb.Unauthorized)
getattr(couchdb, 'HTTPError', None),
getattr(couchdb, 'ServerError', None),
getattr(couchdb, 'Unauthorized', None),
)
)
channel_errors = (
virtual.Transport.channel_errors + (
couchdb.HTTPError,
couchdb.ServerError,
couchdb.PreconditionFailed,
couchdb.ResourceConflict,
couchdb.ResourceNotFound)
getattr(couchdb, 'HTTPError', None),
getattr(couchdb, 'ServerError', None),
getattr(couchdb, 'PreconditionFailed', None),
getattr(couchdb, 'ResourceConflict', None),
getattr(couchdb, 'ResourceNotFound', None),
)
)
driver_type = 'couchdb'
driver_name = 'couchdb'
def __init__(self, *args, **kwargs):
if couchdb is None:
raise ImportError('Missing couchdb library (pip install couchdb)')
super(Transport, self).__init__(*args, **kwargs)
def driver_version(self):
return couchdb.__version__

View File

@ -35,7 +35,6 @@ class Channel(virtual.Channel):
super(Channel, self).basic_consume(queue, *args, **kwargs)
def _get(self, queue):
#self.refresh_connection()
m = Queue.objects.fetch(queue)
if m:
return loads(bytes_to_str(m))

View File

@ -11,6 +11,7 @@ from __future__ import absolute_import
import os
import socket
import warnings
try:
import librabbitmq as amqp
@ -24,9 +25,14 @@ except ImportError: # pragma: no cover
from kombu.five import items, values
from kombu.utils.amq_manager import get_manager
from kombu.utils.text import version_string_as_tuple
from . import base
W_VERSION = """
librabbitmq version too old to detect RabbitMQ version information
so make sure you are using librabbitmq 1.5 when using rabbitmq > 3.3
"""
DEFAULT_PORT = 5672
NO_SSL_ERROR = """\
@ -150,6 +156,16 @@ class Transport(base.Transport):
def get_manager(self, *args, **kwargs):
return get_manager(self.client, *args, **kwargs)
def qos_semantics_matches_spec(self, connection):
try:
props = connection.server_properties
except AttributeError:
warnings.warn(UserWarning(W_VERSION))
else:
if props.get('product') == 'RabbitMQ':
return version_string_as_tuple(props['version']) < (3, 3)
return True
@property
def default_connection_params(self):
return {'userid': 'guest', 'password': 'guest',

View File

@ -55,14 +55,14 @@ class BroadcastCursor(object):
def __iter__(self):
return self
def next(self):
def __next__(self):
while True:
try:
msg = next(self._cursor)
except pymongo.errors.OperationFailure, e:
except pymongo.errors.OperationFailure as exc:
# In some cases tailed cursor can become invalid
# and have to be reinitalized
if 'not valid at server' in e.message:
if 'not valid at server' in exc.message:
self.purge()
continue
@ -74,6 +74,7 @@ class BroadcastCursor(object):
self._offset += 1
return msg
next = __next__
class Channel(virtual.Channel):
@ -86,6 +87,9 @@ class Channel(virtual.Channel):
self._broadcast_cursors = {}
# Evaluate connection
self._create_client()
def _new_queue(self, queue, **kwargs):
pass
@ -206,7 +210,7 @@ class Channel(virtual.Channel):
self.get_broadcast().ensure_index([('queue', 1)])
self.get_routing().ensure_index([('queue', 1), ('exchange', 1)])
#TODO Store a more complete exchange metatable in the routing collection
# TODO Store a more complete exchange metatable in the routing collection
def get_table(self, exchange):
"""Get table of bindings for ``exchange``."""
localRoutes = frozenset(self.state.exchanges[exchange]['table'])
@ -249,12 +253,14 @@ class Channel(virtual.Channel):
self._fanout_queues.pop(queue)
def _create_client(self):
self._open()
self._ensure_indexes()
@property
def client(self):
if self._client is None:
self._open()
self._ensure_indexes()
self._create_client()
return self._client
def get_messages(self):

View File

@ -11,6 +11,7 @@ import amqp
from kombu.five import items
from kombu.utils.amq_manager import get_manager
from kombu.utils.text import version_string_as_tuple
from . import base
@ -129,6 +130,12 @@ class Transport(base.Transport):
def heartbeat_check(self, connection, rate=2):
return connection.heartbeat_tick(rate=rate)
def qos_semantics_matches_spec(self, connection):
props = connection.server_properties
if props.get('product') == 'RabbitMQ':
return version_string_as_tuple(props['version']) < (3, 3)
return True
@property
def default_connection_params(self):
return {'userid': 'guest', 'password': 'guest',

View File

@ -246,7 +246,6 @@ class MultiChannelPoller(object):
self._channels.clear()
self._fd_to_chan.clear()
self._chan_to_sock.clear()
self.poller = None
def add(self, channel):
self._channels.add(channel)
@ -254,6 +253,11 @@ class MultiChannelPoller(object):
def discard(self, channel):
self._channels.discard(channel)
def _on_connection_disconnect(self, connection):
sock = getattr(connection, '_sock', None)
if sock is not None:
self.poller.unregister(sock)
def _register(self, channel, client, type):
if (channel, client, type) in self._chan_to_sock:
self._unregister(channel, client, type)
@ -450,6 +454,10 @@ class Channel(virtual.Channel):
if self._pool is not None:
self._pool.disconnect()
def _on_connection_disconnect(self, connection):
if self.connection and self.connection.cycle:
self.connection.cycle._on_connection_disconnect(connection)
def _do_restore_message(self, payload, exchange, routing_key,
client=None, leftmost=False):
with self.conn_or_acquire(client) as client:
@ -466,6 +474,8 @@ class Channel(virtual.Channel):
crit('Could not restore message: %r', payload, exc_info=True)
def _restore(self, message, leftmost=False):
if not self.ack_emulation:
return super(Channel, self)._restore(message)
tag = message.delivery_tag
with self.conn_or_acquire() as client:
P, _ = client.pipeline() \
@ -778,6 +788,19 @@ class Channel(virtual.Channel):
connparams.pop('port', None)
connparams['db'] = self._prepare_virtual_host(
connparams.pop('virtual_host', None))
channel = self
connection_cls = (
connparams.get('connection_class') or
redis.Connection
)
class Connection(connection_cls):
def disconnect(self):
channel._on_connection_disconnect(self)
super(Connection, self).disconnect()
connparams['connection_class'] = Connection
return connparams
def _create_client(self):
@ -888,6 +911,8 @@ class Transport(virtual.Transport):
driver_name = 'redis'
def __init__(self, *args, **kwargs):
if redis is None:
raise ImportError('Missing redis library (pip install redis)')
super(Transport, self).__init__(*args, **kwargs)
# Get redis-py exceptions.
@ -905,6 +930,11 @@ class Transport(virtual.Transport):
add_reader = loop.add_reader
on_readable = self.on_readable
def _on_disconnect(connection):
if connection._sock:
loop.remove(connection._sock)
cycle._on_connection_disconnect = _on_disconnect
def on_poll_start():
cycle_poll_start()
[add_reader(fd, on_readable, fd) for fd in cycle.fds]

View File

@ -153,6 +153,7 @@ class Transport(virtual.Transport):
default_port = 0
driver_type = 'sql'
driver_name = 'sqlalchemy'
connection_errors = (OperationalError, )
def driver_version(self):
import sqlalchemy

View File

@ -520,7 +520,7 @@ class Channel(AbstractChannel, base.StdChannel):
return self.typeof(exchange).deliver(
message, exchange, routing_key, **kwargs
)
# anon exchange: routing_key is the destintaion queue
# anon exchange: routing_key is the destination queue
return self._put(routing_key, message, **kwargs)
def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs):

View File

@ -101,6 +101,19 @@ def symbol_by_name(name, aliases={}, imp=None, package=None,
return default
class HashedSeq(list):
"""type used for hash() to make sure the hash is not generated
multiple times."""
__slots__ = 'hashvalue'
def __init__(self, *seq):
self[:] = seq
self.hashvalue = hash(seq)
def __hash__(self):
return self.hashvalue
def eqhash(o):
try:
return o.__eqhash__()

View File

@ -8,7 +8,7 @@ Helps compatibility with older Python versions.
from __future__ import absolute_import
############## timedelta_seconds() -> delta.total_seconds ####################
# ############# timedelta_seconds() -> delta.total_seconds ###################
from datetime import timedelta
HAVE_TIMEDELTA_TOTAL_SECONDS = hasattr(timedelta, 'total_seconds')
@ -36,7 +36,7 @@ else: # pragma: no cover
return 0
return delta.days * 86400 + delta.seconds + (delta.microseconds / 10e5)
############## socket.error.errno ############################################
# ############# socket.error.errno ###########################################
def get_errno(exc):
@ -53,7 +53,7 @@ def get_errno(exc):
pass
return 0
############## collections.OrderedDict #######################################
# ############# collections.OrderedDict ######################################
try:
from collections import OrderedDict
except ImportError:

View File

@ -83,7 +83,7 @@ class _epoll(Poller):
def unregister(self, fd):
try:
self._epoll.unregister(fd)
except (socket.error, ValueError, KeyError):
except (socket.error, ValueError, KeyError, TypeError):
pass
except (IOError, OSError) as exc:
if get_errno(exc) != errno.ENOENT:
@ -202,7 +202,14 @@ class _select(Poller):
self.unregister(fd)
def unregister(self, fd):
fd = fileno(fd)
try:
fd = fileno(fd)
except socket.error as exc:
# we don't know the previous fd of this object
# but it will be removed by the next poll iteration.
if get_errno(exc) in SELECT_BAD_FD:
return
raise
self._rfd.discard(fd)
self._wfd.discard(fd)
self._efd.discard(fd)

View File

@ -25,10 +25,10 @@ class TokenBucket(object):
"""
#: The rate in tokens/second that the bucket will be refilled
#: The rate in tokens/second that the bucket will be refilled.
fill_rate = None
#: Maximum number of tokensin the bucket.
#: Maximum number of tokens in the bucket.
capacity = 1
#: Timestamp of the last time a token was taken out of the bucket.

View File

@ -3,6 +3,9 @@ from __future__ import absolute_import
from difflib import SequenceMatcher
from kombu import version_info_t
from kombu.five import string_t
def fmatch_iter(needle, haystack, min_ratio=0.6):
for key in haystack:
@ -18,3 +21,27 @@ def fmatch_best(needle, haystack, min_ratio=0.6):
)[0][1]
except IndexError:
pass
def version_string_as_tuple(s):
v = _unpack_version(*s.split('.'))
# X.Y.3a1 -> (X, Y, 3, 'a1')
if isinstance(v.micro, string_t):
v = version_info_t(v.major, v.minor, *_splitmicro(*v[2:]))
# X.Y.3a1-40 -> (X, Y, 3, 'a1', '40')
if not v.serial and v.releaselevel and '-' in v.releaselevel:
v = version_info_t(*list(v[0:3]) + v.releaselevel.split('-'))
return v
def _unpack_version(major, minor=0, micro=0, releaselevel='', serial=''):
return version_info_t(int(major), int(minor), micro, releaselevel, serial)
def _splitmicro(micro, releaselevel='', serial=''):
for index, char in enumerate(micro):
if not char.isdigit():
break
else:
return int(micro or 0), releaselevel, serial
return int(micro[:index]), micro[index:], serial

View File

@ -1,12 +1,17 @@
from __future__ import absolute_import
from functools import partial
try:
from urllib.parse import unquote, urlparse, parse_qsl
from urllib.parse import parse_qsl, quote, unquote, urlparse
except ImportError:
from urllib import unquote # noqa
from urllib import quote, unquote # noqa
from urlparse import urlparse, parse_qsl # noqa
from . import kwdict
from kombu.five import string_t
safequote = partial(quote, safe='')
def _parse_url(url):
@ -14,17 +19,9 @@ def _parse_url(url):
schemeless = url[len(scheme) + 3:]
# parse with HTTP URL semantics
parts = urlparse('http://' + schemeless)
# The first pymongo.Connection() argument (host) can be
# a mongodb connection URI. If this is the case, don't
# use port but let pymongo get the port(s) from the URI instead.
# This enables the use of replica sets and sharding.
# See pymongo.Connection() for more info.
port = scheme != 'mongodb' and parts.port or None
hostname = schemeless if scheme == 'mongodb' else parts.hostname
path = parts.path or ''
path = path[1:] if path and path[0] == '/' else path
return (scheme, unquote(hostname or '') or None, port,
return (scheme, unquote(parts.hostname or '') or None, parts.port,
unquote(parts.username or '') or None,
unquote(parts.password or '') or None,
unquote(path or '') or None,
@ -36,3 +33,32 @@ def parse_url(url):
return dict(transport=scheme, hostname=host,
port=port, userid=user,
password=password, virtual_host=path, **query)
def as_url(scheme, host=None, port=None, user=None, password=None,
path=None, query=None, sanitize=False, mask='**'):
parts = ['{0}://'.format(scheme)]
if user or password:
if user:
parts.append(safequote(user))
if password:
if sanitize:
parts.extend([':', mask] if mask else [':'])
else:
parts.extend([':', safequote(password)])
parts.append('@')
parts.append(safequote(host) if host else '')
if port:
parts.extend([':', port])
parts.extend(['/', path])
return ''.join(str(part) for part in parts if part)
def sanitize_url(url, mask='**'):
return as_url(*_parse_url(url), sanitize=True, mask=mask)
def maybe_sanitize_url(url, mask='**'):
if isinstance(url, string_t) and '://' in url:
return sanitize_url(url, mask)
return url