Updated third-party requirements to latest versions.

This commit is contained in:
Chris Church 2013-08-27 23:20:47 -04:00
parent 415fbc5362
commit 8d16485f7f
205 changed files with 2878 additions and 2374 deletions

View File

@ -1,17 +1,17 @@
Local versions of third-party packages required by AWX. Package names and Local versions of third-party packages required by AWX. Package names and
versions are listed below, along with notes on which files are included. versions are listed below, along with notes on which files are included.
amqp-1.0.11 (amqp/*) amqp-1.0.13 (amqp/*)
anyjson-0.3.3 (anyjson/*) anyjson-0.3.3 (anyjson/*)
billiard-2.7.3.28 (billiard/*, funtests/*, excluded _billiard.so) billiard-2.7.3.32 (billiard/*, funtests/*, excluded _billiard.so)
celery-3.0.19 (celery/*, excluded bin/celery* and bin/camqadm) celery-3.0.22 (celery/*, excluded bin/celery* and bin/camqadm)
django-celery-3.0.17 (djcelery/*, excluded bin/djcelerymon) django-celery-3.0.21 (djcelery/*, excluded bin/djcelerymon)
django-extensions-1.1.1 (django_extensions/*) django-extensions-1.2.0 (django_extensions/*)
django-jsonfield-0.9.10 (jsonfield/*) django-jsonfield-0.9.10 (jsonfield/*)
django-taggit-0.10a1 (taggit/*) django-taggit-0.10 (taggit/*)
djangorestframework-2.3.5 (rest_framework/*) djangorestframework-2.3.7 (rest_framework/*)
importlib-1.0.2 (importlib/*, needed for Python 2.6 support) importlib-1.0.2 (importlib/*, needed for Python 2.6 support)
kombu-2.5.10 (kombu/*) kombu-2.5.14 (kombu/*)
Markdown-2.3.1 (markdown/*, excluded bin/markdown_py) Markdown-2.3.1 (markdown/*, excluded bin/markdown_py)
ordereddict-1.1 (ordereddict.py, needed for Python 2.6 support) ordereddict-1.1 (ordereddict.py, needed for Python 2.6 support)
pexpect-2.4 (pexpect.py, pxssh.py, fdpexpect.py, FSM.py, screen.py, ANSI.py) pexpect-2.4 (pexpect.py, pxssh.py, fdpexpect.py, FSM.py, screen.py, ANSI.py)
@ -19,4 +19,4 @@ python-dateutil-2.1 (dateutil/*)
pytz-2013b (pytz/*) pytz-2013b (pytz/*)
requests-1.2.3 (requests/*) requests-1.2.3 (requests/*)
six-1.3.0 (six.py) six-1.3.0 (six.py)
South-0.8.1 (south/*) South-0.8.2 (south/*)

View File

@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301
from __future__ import absolute_import from __future__ import absolute_import
VERSION = (1, 0, 11) VERSION = (1, 0, 13)
__version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:]) __version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:])
__author__ = 'Barry Pederson' __author__ = 'Barry Pederson'
__maintainer__ = 'Ask Solem' __maintainer__ = 'Ask Solem'

View File

@ -44,7 +44,7 @@ class Message(GenericContent):
('cluster_id', 'shortstr') ('cluster_id', 'shortstr')
] ]
def __init__(self, body='', children=None, **properties): def __init__(self, body='', children=None, channel=None, **properties):
"""Expected arg types """Expected arg types
body: string body: string
@ -107,6 +107,7 @@ class Message(GenericContent):
""" """
super(Message, self).__init__(**properties) super(Message, self).__init__(**properties)
self.body = body self.body = body
self.channel = channel
def __eq__(self, other): def __eq__(self, other):
"""Check if the properties and bodies of this Message and another """Check if the properties and bodies of this Message and another

View File

@ -16,9 +16,8 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301
from __future__ import absolute_import from __future__ import absolute_import
from collections import defaultdict from collections import defaultdict, deque
from struct import pack, unpack from struct import pack, unpack
from Queue import Queue
try: try:
bytes bytes
@ -61,12 +60,17 @@ class _PartialMessage(object):
self.complete = (self.body_size == 0) self.complete = (self.body_size == 0)
def add_payload(self, payload): def add_payload(self, payload):
self.body_parts.append(payload) parts = self.body_parts
self.body_received += len(payload) self.body_received += len(payload)
if self.body_received == self.body_size: if self.body_received == self.body_size:
self.msg.body = bytes().join(self.body_parts) if parts:
parts.append(payload)
self.msg.body = bytes().join(parts)
else:
self.msg.body = payload
self.complete = True self.complete = True
else:
parts.append(payload)
class MethodReader(object): class MethodReader(object):
@ -86,7 +90,7 @@ class MethodReader(object):
def __init__(self, source): def __init__(self, source):
self.source = source self.source = source
self.queue = Queue() self.queue = deque()
self.running = False self.running = False
self.partial_messages = {} self.partial_messages = {}
self.heartbeats = 0 self.heartbeats = 0
@ -94,32 +98,33 @@ class MethodReader(object):
self.expected_types = defaultdict(lambda: 1) self.expected_types = defaultdict(lambda: 1)
# not an actual byte count, just incremented whenever we receive # not an actual byte count, just incremented whenever we receive
self.bytes_recv = 0 self.bytes_recv = 0
self._quick_put = self.queue.append
self._quick_get = self.queue.popleft
def _next_method(self): def _next_method(self):
"""Read the next method from the source, once one complete method has """Read the next method from the source, once one complete method has
been assembled it is placed in the internal queue.""" been assembled it is placed in the internal queue."""
empty = self.queue.empty queue = self.queue
put = self._quick_put
read_frame = self.source.read_frame read_frame = self.source.read_frame
while empty(): while not queue:
try: try:
frame_type, channel, payload = read_frame() frame_type, channel, payload = read_frame()
except Exception, e: except Exception, e:
# #
# Connection was closed? Framing Error? # Connection was closed? Framing Error?
# #
self.queue.put(e) put(e)
break break
self.bytes_recv += 1 self.bytes_recv += 1
if frame_type not in (self.expected_types[channel], 8): if frame_type not in (self.expected_types[channel], 8):
self.queue.put(( put((
channel, channel,
AMQPError( AMQPError(
'Received frame type %s while expecting type: %s' % ( 'Received frame type %s while expecting type: %s' % (
frame_type, self.expected_types[channel]) frame_type, self.expected_types[channel]))))
),
))
elif frame_type == 1: elif frame_type == 1:
self._process_method_frame(channel, payload) self._process_method_frame(channel, payload)
elif frame_type == 2: elif frame_type == 2:
@ -144,7 +149,7 @@ class MethodReader(object):
self.partial_messages[channel] = _PartialMessage(method_sig, args) self.partial_messages[channel] = _PartialMessage(method_sig, args)
self.expected_types[channel] = 2 self.expected_types[channel] = 2
else: else:
self.queue.put((channel, method_sig, args, None)) self._quick_put((channel, method_sig, args, None))
def _process_content_header(self, channel, payload): def _process_content_header(self, channel, payload):
"""Process Content Header frames""" """Process Content Header frames"""
@ -155,8 +160,8 @@ class MethodReader(object):
# #
# a bodyless message, we're done # a bodyless message, we're done
# #
self.queue.put((channel, partial.method_sig, self._quick_put((channel, partial.method_sig,
partial.args, partial.msg)) partial.args, partial.msg))
self.partial_messages.pop(channel, None) self.partial_messages.pop(channel, None)
self.expected_types[channel] = 1 self.expected_types[channel] = 1
else: else:
@ -174,15 +179,15 @@ class MethodReader(object):
# Stick the message in the queue and go back to # Stick the message in the queue and go back to
# waiting for method frames # waiting for method frames
# #
self.queue.put((channel, partial.method_sig, self._quick_put((channel, partial.method_sig,
partial.args, partial.msg)) partial.args, partial.msg))
self.partial_messages.pop(channel, None) self.partial_messages.pop(channel, None)
self.expected_types[channel] = 1 self.expected_types[channel] = 1
def read_method(self): def read_method(self):
"""Read a method from the peer.""" """Read a method from the peer."""
self._next_method() self._next_method()
m = self.queue.get() m = self._quick_get()
if isinstance(m, Exception): if isinstance(m, Exception):
raise m raise m
if isinstance(m, tuple) and isinstance(m[1], AMQPError): if isinstance(m, tuple) and isinstance(m[1], AMQPError):

View File

@ -52,6 +52,8 @@ from .exceptions import AMQPError
AMQP_PORT = 5672 AMQP_PORT = 5672
EMPTY_BUFFER = bytes()
# Yes, Advanced Message Queuing Protocol Protocol is redundant # Yes, Advanced Message Queuing Protocol Protocol is redundant
AMQP_PROTOCOL_HEADER = 'AMQP\x01\x01\x00\x09'.encode('latin_1') AMQP_PROTOCOL_HEADER = 'AMQP\x01\x01\x00\x09'.encode('latin_1')
@ -139,11 +141,12 @@ class _AbstractTransport(object):
self.sock.close() self.sock.close()
self.sock = None self.sock = None
def read_frame(self): def read_frame(self, unpack=unpack):
"""Read an AMQP frame.""" """Read an AMQP frame."""
frame_type, channel, size = unpack('>BHI', self._read(7, True)) read = self._read
payload = self._read(size) frame_type, channel, size = unpack('>BHI', read(7, True))
ch = ord(self._read(1)) payload = read(size)
ch = ord(read(1))
if ch == 206: # '\xce' if ch == 206: # '\xce'
return frame_type, channel, payload return frame_type, channel, payload
else: else:
@ -164,7 +167,7 @@ class SSLTransport(_AbstractTransport):
def __init__(self, host, connect_timeout, ssl): def __init__(self, host, connect_timeout, ssl):
if isinstance(ssl, dict): if isinstance(ssl, dict):
self.sslopts = ssl self.sslopts = ssl
self.sslobj = None self._read_buffer = EMPTY_BUFFER
super(SSLTransport, self).__init__(host, connect_timeout) super(SSLTransport, self).__init__(host, connect_timeout)
def _setup_transport(self): def _setup_transport(self):
@ -173,43 +176,51 @@ class SSLTransport(_AbstractTransport):
lower version.""" lower version."""
if HAVE_PY26_SSL: if HAVE_PY26_SSL:
if hasattr(self, 'sslopts'): if hasattr(self, 'sslopts'):
self.sslobj = ssl.wrap_socket(self.sock, **self.sslopts) self.sock = ssl.wrap_socket(self.sock, **self.sslopts)
else: else:
self.sslobj = ssl.wrap_socket(self.sock) self.sock = ssl.wrap_socket(self.sock)
self.sslobj.do_handshake() self.sock.do_handshake()
else: else:
self.sslobj = socket.ssl(self.sock) self.sock = socket.ssl(self.sock)
self._quick_recv = self.sock.read
def _shutdown_transport(self): def _shutdown_transport(self):
"""Unwrap a Python 2.6 SSL socket, so we can call shutdown()""" """Unwrap a Python 2.6 SSL socket, so we can call shutdown()"""
if HAVE_PY26_SSL and (self.sslobj is not None): if HAVE_PY26_SSL and self.sock is not None:
self.sock = self.sslobj.unwrap()
self.sslobj = None
def _read(self, n, initial=False):
"""It seems that SSL Objects read() method may not supply as much
as you're asking for, at least with extremely large messages.
somewhere > 16K - found this in the test_channel.py test_large
unittest."""
result = ''
while len(result) < n:
try: try:
s = self.sslobj.read(n - len(result)) unwrap = self.sock.unwrap
except AttributeError:
return
self.sock = unwrap()
def _read(self, n, initial=False,
_errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR)):
# According to SSL_read(3), it can at most return 16kb of data.
# Thus, we use an internal read buffer like TCPTransport._read
# to get the exact number of bytes wanted.
recv = self._quick_recv
rbuf = self._read_buffer
while len(rbuf) < n:
try:
s = recv(131072) # see note above
except socket.error, exc: except socket.error, exc:
if not initial and exc.errno in (errno.EAGAIN, errno.EINTR): # ssl.sock.read may cause ENOENT if the
# operation couldn't be performed (Issue celery#1414).
if not initial and exc.errno in _errnos:
continue continue
raise raise exc
if not s: if not s:
raise IOError('Socket closed') raise IOError('Socket closed')
result += s rbuf += s
result, self._read_buffer = rbuf[:n], rbuf[n:]
return result return result
def _write(self, s): def _write(self, s):
"""Write a string out to the SSL socket fully.""" """Write a string out to the SSL socket fully."""
write = self.sock.write
while s: while s:
n = self.sslobj.write(s) n = write(s)
if not n: if not n:
raise IOError('Socket closed') raise IOError('Socket closed')
s = s[n:] s = s[n:]
@ -222,24 +233,25 @@ class TCPTransport(_AbstractTransport):
"""Setup to _write() directly to the socket, and """Setup to _write() directly to the socket, and
do our own buffered reads.""" do our own buffered reads."""
self._write = self.sock.sendall self._write = self.sock.sendall
self._read_buffer = bytes() self._read_buffer = EMPTY_BUFFER
self._quick_recv = self.sock.recv
def _read(self, n, initial=False): def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)):
"""Read exactly n bytes from the socket""" """Read exactly n bytes from the socket"""
while len(self._read_buffer) < n: recv = self._quick_recv
rbuf = self._read_buffer
while len(rbuf) < n:
try: try:
s = self.sock.recv(65536) s = recv(131072)
except socket.error, exc: except socket.error, exc:
if not initial and exc.errno in (errno.EAGAIN, errno.EINTR): if not initial and exc.errno in _errnos:
continue continue
raise raise
if not s: if not s:
raise IOError('Socket closed') raise IOError('Socket closed')
self._read_buffer += s rbuf += s
result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[n:]
result, self._read_buffer = rbuf[:n], rbuf[n:]
return result return result

View File

@ -20,7 +20,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement from __future__ import with_statement
VERSION = (2, 7, 3, 28) VERSION = (2, 7, 3, 32)
__version__ = ".".join(map(str, VERSION[0:4])) + "".join(VERSION[4:]) __version__ = ".".join(map(str, VERSION[0:4])) + "".join(VERSION[4:])
__author__ = 'R Oudkerk / Python Software Foundation' __author__ = 'R Oudkerk / Python Software Foundation'
__author_email__ = 'python-dev@python.org' __author_email__ = 'python-dev@python.org'
@ -232,7 +232,11 @@ def JoinableQueue(maxsize=0):
return JoinableQueue(maxsize) return JoinableQueue(maxsize)
def Pool(processes=None, initializer=None, initargs=(), maxtasksperchild=None): def Pool(processes=None, initializer=None, initargs=(), maxtasksperchild=None,
timeout=None, soft_timeout=None, lost_worker_timeout=None,
max_restarts=None, max_restart_freq=1, on_process_up=None,
on_process_down=None, on_timeout_set=None, on_timeout_cancel=None,
threads=True, semaphore=None, putlocks=False, allow_restart=False):
''' '''
Returns a process pool object Returns a process pool object
''' '''

View File

@ -4,6 +4,11 @@ import sys
supports_exec = True supports_exec = True
try:
import _winapi as win32
except ImportError: # pragma: no cover
win32 = None
if sys.platform.startswith("java"): if sys.platform.startswith("java"):
_billiard = None _billiard = None
else: else:
@ -18,7 +23,8 @@ else:
from multiprocessing.connection import Connection # noqa from multiprocessing.connection import Connection # noqa
PipeConnection = getattr(_billiard, "PipeConnection", None) PipeConnection = getattr(_billiard, "PipeConnection", None)
win32 = getattr(_billiard, "win32", None) if win32 is None:
win32 = getattr(_billiard, "win32", None) # noqa
def ensure_multiprocessing(): def ensure_multiprocessing():

View File

@ -1,12 +1,41 @@
# -*- coding: utf-8 -*-
"""
This module contains utilities added by billiard, to keep
"non-core" functionality out of ``.util``."""
from __future__ import absolute_import from __future__ import absolute_import
import signal import signal
import sys import sys
from time import time from time import time
import pickle as pypickle
try:
import cPickle as cpickle
except ImportError: # pragma: no cover
cpickle = None # noqa
from .exceptions import RestartFreqExceeded from .exceptions import RestartFreqExceeded
if sys.version_info < (2, 6): # pragma: no cover
# cPickle does not use absolute_imports
pickle = pypickle
pickle_load = pypickle.load
pickle_loads = pypickle.loads
else:
pickle = cpickle or pypickle
pickle_load = pickle.load
pickle_loads = pickle.loads
# cPickle.loads does not support buffer() objects,
# but we can just create a StringIO and use load.
if sys.version_info[0] == 3:
from io import BytesIO
else:
try:
from cStringIO import StringIO as BytesIO # noqa
except ImportError:
from StringIO import StringIO as BytesIO # noqa
TERMSIGS = ( TERMSIGS = (
'SIGHUP', 'SIGHUP',
'SIGQUIT', 'SIGQUIT',
@ -30,6 +59,11 @@ TERMSIGS = (
) )
def pickle_loads(s, load=pickle_load):
# used to support buffer objects
return load(BytesIO(s))
def _shutdown_cleanup(signum, frame): def _shutdown_cleanup(signum, frame):
sys.exit(-(256 - signum)) sys.exit(-(256 - signum))

View File

@ -17,7 +17,6 @@ from __future__ import with_statement
import collections import collections
import errno import errno
import itertools import itertools
import logging
import os import os
import platform import platform
import signal import signal
@ -29,7 +28,7 @@ import warnings
from . import Event, Process, cpu_count from . import Event, Process, cpu_count
from . import util from . import util
from .common import reset_signals, restart_state from .common import pickle_loads, reset_signals, restart_state
from .compat import get_errno from .compat import get_errno
from .einfo import ExceptionInfo from .einfo import ExceptionInfo
from .exceptions import ( from .exceptions import (
@ -163,15 +162,6 @@ class LaxBoundedSemaphore(_Semaphore):
if self._value < self._initial_value: if self._value < self._initial_value:
self._value += 1 self._value += 1
cond.notify_all() cond.notify_all()
if __debug__:
self._note(
"%s.release: success, value=%s", self, self._value,
)
else:
if __debug__:
self._note(
"%s.release: success, value=%s (unchanged)" % (
self, self._value))
def clear(self): def clear(self):
while self._value < self._initial_value: while self._value < self._initial_value:
@ -184,14 +174,6 @@ class LaxBoundedSemaphore(_Semaphore):
if self._Semaphore__value < self._initial_value: if self._Semaphore__value < self._initial_value:
self._Semaphore__value += 1 self._Semaphore__value += 1
cond.notifyAll() cond.notifyAll()
if __debug__:
self._note("%s.release: success, value=%s",
self, self._Semaphore__value)
else:
if __debug__:
self._note(
"%s.release: success, value=%s (unchanged)" % (
self, self._Semaphore__value))
def clear(self): # noqa def clear(self): # noqa
while self._Semaphore__value < self._initial_value: while self._Semaphore__value < self._initial_value:
@ -233,28 +215,26 @@ def soft_timeout_sighandler(signum, frame):
def worker(inqueue, outqueue, initializer=None, initargs=(), def worker(inqueue, outqueue, initializer=None, initargs=(),
maxtasks=None, sentinel=None): maxtasks=None, sentinel=None):
# Re-init logging system.
# Workaround for http://bugs.python.org/issue6721#msg140215
# Python logging module uses RLock() objects which are broken after
# fork. This can result in a deadlock (Issue #496).
logger_names = logging.Logger.manager.loggerDict.keys()
logger_names.append(None) # for root logger
for name in logger_names:
for handler in logging.getLogger(name).handlers:
handler.createLock()
logging._lock = threading.RLock()
pid = os.getpid() pid = os.getpid()
assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
put = outqueue.put put = outqueue.put
get = inqueue.get get = inqueue.get
loads = pickle_loads
if hasattr(inqueue, '_reader'): if hasattr(inqueue, '_reader'):
def poll(timeout): if hasattr(inqueue, 'get_payload') and inqueue.get_payload:
if inqueue._reader.poll(timeout): get_payload = inqueue.get_payload
return True, get()
return False, None def poll(timeout):
if inqueue._reader.poll(timeout):
return True, loads(get_payload())
return False, None
else:
def poll(timeout):
if inqueue._reader.poll(timeout):
return True, get()
return False, None
else: else:
def poll(timeout): # noqa def poll(timeout): # noqa
@ -1236,8 +1216,13 @@ class Pool(object):
return result return result
def terminate_job(self, pid, sig=None): def terminate_job(self, pid, sig=None):
self.signalled.add(pid) try:
_kill(pid, sig or signal.SIGTERM) _kill(pid, sig or signal.SIGTERM)
except OSError, exc:
if get_errno(exc) != errno.ESRCH:
raise
else:
self.signalled.add(pid)
def map_async(self, func, iterable, chunksize=None, def map_async(self, func, iterable, chunksize=None,
callback=None, error_callback=None): callback=None, error_callback=None):

View File

@ -19,6 +19,8 @@ import sys
import signal import signal
import itertools import itertools
import binascii import binascii
import logging
import threading
from .compat import bytes from .compat import bytes
try: try:
@ -45,9 +47,17 @@ def current_process():
def _cleanup(): def _cleanup():
# check for processes which have finished # check for processes which have finished
for p in list(_current_process._children): if _current_process is not None:
if p._popen.poll() is not None: for p in list(_current_process._children):
_current_process._children.discard(p) if p._popen.poll() is not None:
_current_process._children.discard(p)
def _maybe_flush(f):
try:
f.flush()
except (AttributeError, EnvironmentError, NotImplementedError):
pass
def active_children(_cleanup=_cleanup): def active_children(_cleanup=_cleanup):
@ -59,7 +69,9 @@ def active_children(_cleanup=_cleanup):
except TypeError: except TypeError:
# called after gc collect so _cleanup does not exist anymore # called after gc collect so _cleanup does not exist anymore
return [] return []
return list(_current_process._children) if _current_process is not None:
return list(_current_process._children)
return []
class Process(object): class Process(object):
@ -242,6 +254,18 @@ class Process(object):
pass pass
old_process = _current_process old_process = _current_process
_current_process = self _current_process = self
# Re-init logging system.
# Workaround for http://bugs.python.org/issue6721#msg140215
# Python logging module uses RLock() objects which are broken after
# fork. This can result in a deadlock (Celery Issue #496).
logger_names = logging.Logger.manager.loggerDict.keys()
logger_names.append(None) # for root logger
for name in logger_names:
for handler in logging.getLogger(name).handlers:
handler.createLock()
logging._lock = threading.RLock()
try: try:
util._finalizer_registry.clear() util._finalizer_registry.clear()
util._run_after_forkers() util._run_after_forkers()
@ -262,7 +286,7 @@ class Process(object):
exitcode = e.args[0] exitcode = e.args[0]
else: else:
sys.stderr.write(str(e.args[0]) + '\n') sys.stderr.write(str(e.args[0]) + '\n')
sys.stderr.flush() _maybe_flush(sys.stderr)
exitcode = 0 if isinstance(e.args[0], str) else 1 exitcode = 0 if isinstance(e.args[0], str) else 1
except: except:
exitcode = 1 exitcode = 1
@ -273,8 +297,8 @@ class Process(object):
finally: finally:
util.info('process %s exiting with exitcode %d', util.info('process %s exiting with exitcode %d',
self.pid, exitcode) self.pid, exitcode)
sys.stdout.flush() _maybe_flush(sys.stdout)
sys.stderr.flush() _maybe_flush(sys.stderr)
return exitcode return exitcode
# #

View File

@ -334,6 +334,10 @@ class SimpleQueue(object):
def _make_methods(self): def _make_methods(self):
recv = self._reader.recv recv = self._reader.recv
try:
recv_payload = self._reader.recv_payload
except AttributeError:
recv_payload = None # C extension not installed
rlock = self._rlock rlock = self._rlock
def get(): def get():
@ -341,6 +345,12 @@ class SimpleQueue(object):
return recv() return recv()
self.get = get self.get = get
if recv_payload is not None:
def get_payload():
with rlock:
return recv_payload()
self.get_payload = get_payload
if self._wlock is None: if self._wlock is None:
# writes to a message oriented win32 pipe are atomic # writes to a message oriented win32 pipe are atomic
self.put = self._writer.send self.put = self._writer.send

View File

@ -244,7 +244,9 @@ class Finalize(object):
return x + '>' return x + '>'
def _run_finalizers(minpriority=None): def _run_finalizers(minpriority=None,
_finalizer_registry=_finalizer_registry,
sub_debug=sub_debug, error=error):
''' '''
Run all finalizers whose exit priority is not None and at least minpriority Run all finalizers whose exit priority is not None and at least minpriority
@ -280,7 +282,9 @@ def is_exiting():
return _exiting or _exiting is None return _exiting or _exiting is None
def _exit_function(): def _exit_function(info=info, debug=debug,
active_children=active_children,
_run_finalizers=_run_finalizers):
''' '''
Clean up on exit Clean up on exit
''' '''

View File

@ -14,7 +14,12 @@ from __future__ import absolute_import
import operator import operator
import sys import sys
from functools import reduce # import fails in python 2.5. fallback to reduce in stdlib
try:
from functools import reduce
except ImportError:
pass
from importlib import import_module from importlib import import_module
from types import ModuleType from types import ModuleType

View File

@ -8,7 +8,7 @@
from __future__ import absolute_import from __future__ import absolute_import
SERIES = 'Chiastic Slide' SERIES = 'Chiastic Slide'
VERSION = (3, 0, 19) VERSION = (3, 0, 22)
__version__ = '.'.join(str(p) for p in VERSION[0:3]) + ''.join(VERSION[3:]) __version__ = '.'.join(str(p) for p in VERSION[0:3]) + ''.join(VERSION[3:])
__author__ = 'Ask Solem' __author__ = 'Ask Solem'
__contact__ = 'ask@celeryproject.org' __contact__ = 'ask@celeryproject.org'

View File

@ -307,8 +307,9 @@ def add_chord_task(app):
accept_magic_kwargs = False accept_magic_kwargs = False
ignore_result = False ignore_result = False
def run(self, header, body, partial_args=(), interval=1, countdown=1, def run(self, header, body, partial_args=(), interval=None,
max_retries=None, propagate=None, eager=False, **kwargs): countdown=1, max_retries=None, propagate=None,
eager=False, **kwargs):
propagate = default_propagate if propagate is None else propagate propagate = default_propagate if propagate is None else propagate
group_id = uuid() group_id = uuid()
AsyncResult = self.app.AsyncResult AsyncResult = self.app.AsyncResult

View File

@ -198,11 +198,11 @@ class Task(object):
serializer = None serializer = None
#: Hard time limit. #: Hard time limit.
#: Defaults to the :setting:`CELERY_TASK_TIME_LIMIT` setting. #: Defaults to the :setting:`CELERYD_TASK_TIME_LIMIT` setting.
time_limit = None time_limit = None
#: Soft time limit. #: Soft time limit.
#: Defaults to the :setting:`CELERY_TASK_SOFT_TIME_LIMIT` setting. #: Defaults to the :setting:`CELERYD_TASK_SOFT_TIME_LIMIT` setting.
soft_time_limit = None soft_time_limit = None
#: The result store backend used for this task. #: The result store backend used for this task.
@ -459,7 +459,8 @@ class Task(object):
args = (self.__self__, ) + tuple(args) args = (self.__self__, ) + tuple(args)
if conf.CELERY_ALWAYS_EAGER: if conf.CELERY_ALWAYS_EAGER:
return self.apply(args, kwargs, task_id=task_id, **options) return self.apply(args, kwargs, task_id=task_id,
link=link, link_error=link_error, **options)
options = dict(extract_exec_options(self), **options) options = dict(extract_exec_options(self), **options)
options = router.route(options, self.name, args, kwargs) options = router.route(options, self.name, args, kwargs)
@ -580,7 +581,8 @@ class Task(object):
raise ret raise ret
return ret return ret
def apply(self, args=None, kwargs=None, **options): def apply(self, args=None, kwargs=None,
link=None, link_error=None, **options):
"""Execute this task locally, by blocking until the task returns. """Execute this task locally, by blocking until the task returns.
:param args: positional arguments passed on to the task. :param args: positional arguments passed on to the task.
@ -614,6 +616,8 @@ class Task(object):
'is_eager': True, 'is_eager': True,
'logfile': options.get('logfile'), 'logfile': options.get('logfile'),
'loglevel': options.get('loglevel', 0), 'loglevel': options.get('loglevel', 0),
'callbacks': maybe_list(link),
'errbacks': maybe_list(link_error),
'delivery_info': {'is_eager': True}} 'delivery_info': {'is_eager': True}}
if self.accept_magic_kwargs: if self.accept_magic_kwargs:
default_kwargs = {'task_name': task.name, default_kwargs = {'task_name': task.name,

View File

@ -251,7 +251,7 @@ class Worker(configurated):
'version': VERSION_BANNER, 'version': VERSION_BANNER,
'conninfo': self.app.connection().as_uri(), 'conninfo': self.app.connection().as_uri(),
'concurrency': concurrency, 'concurrency': concurrency,
'platform': _platform.platform(), 'platform': safe_str(_platform.platform()),
'events': events, 'events': events,
'queues': app.amqp.queues.format(indent=0, indent_first=False), 'queues': app.amqp.queues.format(indent=0, indent_first=False),
}).splitlines() }).splitlines()

View File

@ -114,6 +114,8 @@ class MongoBackend(BaseDictBackend):
if self._connection is not None: if self._connection is not None:
# MongoDB connection will be closed automatically when object # MongoDB connection will be closed automatically when object
# goes out of scope # goes out of scope
del(self.collection)
del(self.database)
self._connection = None self._connection = None
def _store_result(self, task_id, result, status, traceback=None): def _store_result(self, task_id, result, status, traceback=None):
@ -124,7 +126,7 @@ class MongoBackend(BaseDictBackend):
'date_done': datetime.utcnow(), 'date_done': datetime.utcnow(),
'traceback': Binary(self.encode(traceback)), 'traceback': Binary(self.encode(traceback)),
'children': Binary(self.encode(self.current_task_children()))} 'children': Binary(self.encode(self.current_task_children()))}
self.collection.save(meta, safe=True) self.collection.save(meta)
return result return result
@ -151,7 +153,7 @@ class MongoBackend(BaseDictBackend):
meta = {'_id': group_id, meta = {'_id': group_id,
'result': Binary(self.encode(result)), 'result': Binary(self.encode(result)),
'date_done': datetime.utcnow()} 'date_done': datetime.utcnow()}
self.collection.save(meta, safe=True) self.collection.save(meta)
return result return result
@ -183,7 +185,7 @@ class MongoBackend(BaseDictBackend):
# By using safe=True, this will wait until it receives a response from # By using safe=True, this will wait until it receives a response from
# the server. Likewise, it will raise an OperationsError if the # the server. Likewise, it will raise an OperationsError if the
# response was unable to be completed. # response was unable to be completed.
self.collection.remove({'_id': task_id}, safe=True) self.collection.remove({'_id': task_id})
def cleanup(self): def cleanup(self):
"""Delete expired metadata.""" """Delete expired metadata."""

View File

@ -614,7 +614,7 @@ class control(_RemoteControl):
def call(self, method, *args, **options): def call(self, method, *args, **options):
# XXX Python 2.5 doesn't support X(*args, reply=True, **kwargs) # XXX Python 2.5 doesn't support X(*args, reply=True, **kwargs)
return getattr(self.app.control, method)( return getattr(self.app.control, method)(
*args, **dict(options, retry=True)) *args, **dict(options, reply=True))
def pool_grow(self, method, n=1, **kwargs): def pool_grow(self, method, n=1, **kwargs):
"""[N=1]""" """[N=1]"""
@ -866,7 +866,7 @@ class CeleryCommand(BaseCommand):
cls = self.commands.get(command) or self.commands['help'] cls = self.commands.get(command) or self.commands['help']
try: try:
return cls(app=self.app).run_from_argv(self.prog_name, argv) return cls(app=self.app).run_from_argv(self.prog_name, argv)
except (TypeError, Error): except Error:
return self.execute('help', argv) return self.execute('help', argv)
def remove_options_at_beginning(self, argv, index=0): def remove_options_at_beginning(self, argv, index=0):

View File

@ -267,7 +267,8 @@ class Signature(dict):
class chain(Signature): class chain(Signature):
def __init__(self, *tasks, **options): def __init__(self, *tasks, **options):
tasks = tasks[0] if len(tasks) == 1 and is_list(tasks[0]) else tasks tasks = (regen(tasks[0]) if len(tasks) == 1 and is_list(tasks[0])
else tasks)
Signature.__init__( Signature.__init__(
self, 'celery.chain', (), {'tasks': tasks}, **options self, 'celery.chain', (), {'tasks': tasks}, **options
) )
@ -283,7 +284,7 @@ class chain(Signature):
tasks = d['kwargs']['tasks'] tasks = d['kwargs']['tasks']
if d['args'] and tasks: if d['args'] and tasks:
# partial args passed on to first task in chain (Issue #1057). # partial args passed on to first task in chain (Issue #1057).
tasks[0]['args'] = d['args'] + tasks[0]['args'] tasks[0]['args'] = tasks[0]._merge(d['args'])[0]
return chain(*d['kwargs']['tasks'], **kwdict(d['options'])) return chain(*d['kwargs']['tasks'], **kwdict(d['options']))
@property @property
@ -392,7 +393,7 @@ class group(Signature):
if d['args'] and tasks: if d['args'] and tasks:
# partial args passed on to all tasks in the group (Issue #1057). # partial args passed on to all tasks in the group (Issue #1057).
for task in tasks: for task in tasks:
task['args'] = d['args'] + task['args'] task['args'] = task._merge(d['args'])[0]
return group(tasks, **kwdict(d['options'])) return group(tasks, **kwdict(d['options']))
def __call__(self, *partial_args, **options): def __call__(self, *partial_args, **options):

View File

@ -34,7 +34,8 @@ if not EVENTLET_NOPATCH and not PATCHED[0]:
import eventlet import eventlet
import eventlet.debug import eventlet.debug
eventlet.monkey_patch() eventlet.monkey_patch()
eventlet.debug.hub_blocking_detection(EVENTLET_DBLOCK) if EVENTLET_DBLOCK:
eventlet.debug.hub_blocking_detection(EVENTLET_DBLOCK)
from time import time from time import time

View File

@ -236,6 +236,8 @@ def start_filter(app, conn, filter, limit=None, timeout=1.0,
consume_from=None, state=None, **kwargs): consume_from=None, state=None, **kwargs):
state = state or State() state = state or State()
queues = prepare_queues(queues) queues = prepare_queues(queues)
consume_from = [_maybe_queue(app, q)
for q in consume_from or queues.keys()]
if isinstance(tasks, basestring): if isinstance(tasks, basestring):
tasks = set(tasks.split(',')) tasks = set(tasks.split(','))
if tasks is None: if tasks is None:

View File

@ -472,7 +472,10 @@ class LimitedSet(object):
if time.time() < item[0] + self.expires: if time.time() < item[0] + self.expires:
heappush(H, item) heappush(H, item)
break break
self._data.pop(item[1]) try:
self._data.pop(item[1])
except KeyError: # out of sync with heap
pass
i += 1 i += 1
def update(self, other, heappush=heappush): def update(self, other, heappush=heappush):

View File

@ -210,11 +210,20 @@ class State(object):
task_count = 0 task_count = 0
def __init__(self, callback=None, def __init__(self, callback=None,
workers=None, tasks=None, taskheap=None,
max_workers_in_memory=5000, max_tasks_in_memory=10000): max_workers_in_memory=5000, max_tasks_in_memory=10000):
self.workers = LRUCache(limit=max_workers_in_memory)
self.tasks = LRUCache(limit=max_tasks_in_memory)
self.event_callback = callback self.event_callback = callback
self.workers = (LRUCache(max_workers_in_memory)
if workers is None else workers)
self.tasks = (LRUCache(max_tasks_in_memory)
if tasks is None else tasks)
self._taskheap = None # reserved for __reduce__ in 3.1
self.max_workers_in_memory = max_workers_in_memory
self.max_tasks_in_memory = max_tasks_in_memory
self._mutex = threading.Lock() self._mutex = threading.Lock()
self.handlers = {'task': self.task_event,
'worker': self.worker_event}
self._get_handler = self.handlers.__getitem__
def freeze_while(self, fun, *args, **kwargs): def freeze_while(self, fun, *args, **kwargs):
clear_after = kwargs.pop('clear_after', False) clear_after = kwargs.pop('clear_after', False)
@ -295,11 +304,14 @@ class State(object):
with self._mutex: with self._mutex:
return self._dispatch_event(event) return self._dispatch_event(event)
def _dispatch_event(self, event): def _dispatch_event(self, event, kwdict=kwdict):
self.event_count += 1 self.event_count += 1
event = kwdict(event) event = kwdict(event)
group, _, subject = event['type'].partition('-') group, _, subject = event['type'].partition('-')
getattr(self, group + '_event')(subject, event) try:
self._get_handler(group)(subject, event)
except KeyError:
pass
if self.event_callback: if self.event_callback:
self.event_callback(self, event) self.event_callback(self, event)
@ -356,14 +368,10 @@ class State(object):
return '<ClusterState: events=%s tasks=%s>' % (self.event_count, return '<ClusterState: events=%s tasks=%s>' % (self.event_count,
self.task_count) self.task_count)
def __getstate__(self): def __reduce__(self):
d = dict(vars(self)) return self.__class__, (
d.pop('_mutex') self.event_callback, self.workers, self.tasks, None,
return d self.max_workers_in_memory, self.max_tasks_in_memory,
)
def __setstate__(self, state):
self.__dict__ = state
self._mutex = threading.Lock()
state = State() state = State()

View File

@ -11,7 +11,7 @@ from __future__ import absolute_import
from celery._state import current_app from celery._state import current_app
from celery.utils import deprecated from celery.utils import deprecated
from celery.utils.imports import symbol_by_name from celery.utils.imports import symbol_by_name, import_from_cwd
LOADER_ALIASES = {'app': 'celery.loaders.app:AppLoader', LOADER_ALIASES = {'app': 'celery.loaders.app:AppLoader',
'default': 'celery.loaders.default:Loader', 'default': 'celery.loaders.default:Loader',
@ -20,7 +20,7 @@ LOADER_ALIASES = {'app': 'celery.loaders.app:AppLoader',
def get_loader_cls(loader): def get_loader_cls(loader):
"""Get loader class by name/alias""" """Get loader class by name/alias"""
return symbol_by_name(loader, LOADER_ALIASES) return symbol_by_name(loader, LOADER_ALIASES, imp=import_from_cwd)
@deprecated(deprecation='2.5', removal='4.0', @deprecated(deprecation='2.5', removal='4.0',

View File

@ -48,6 +48,12 @@ PIDFILE_MODE = ((os.R_OK | os.W_OK) << 6) | ((os.R_OK) << 3) | ((os.R_OK))
PIDLOCKED = """ERROR: Pidfile (%s) already exists. PIDLOCKED = """ERROR: Pidfile (%s) already exists.
Seems we're already running? (pid: %s)""" Seems we're already running? (pid: %s)"""
try:
from io import UnsupportedOperation
FILENO_ERRORS = (AttributeError, UnsupportedOperation)
except ImportError: # Py2
FILENO_ERRORS = (AttributeError, ) # noqa
def pyimplementation(): def pyimplementation():
"""Returns string identifying the current Python implementation.""" """Returns string identifying the current Python implementation."""
@ -253,17 +259,21 @@ def _create_pidlock(pidfile):
def fileno(f): def fileno(f):
"""Get object fileno, or :const:`None` if not defined.""" if isinstance(f, (int, long)):
if isinstance(f, int):
return f return f
return f.fileno()
def maybe_fileno(f):
"""Get object fileno, or :const:`None` if not defined."""
try: try:
return f.fileno() return fileno(f)
except AttributeError: except FILENO_ERRORS:
pass pass
def close_open_fds(keep=None): def close_open_fds(keep=None):
keep = [fileno(f) for f in keep if fileno(f)] if keep else [] keep = [maybe_fileno(f) for f in keep if maybe_fileno(f)] if keep else []
for fd in reversed(range(get_fdmax(default=2048))): for fd in reversed(range(get_fdmax(default=2048))):
if fd not in keep: if fd not in keep:
with ignore_errno(errno.EBADF): with ignore_errno(errno.EBADF):
@ -299,7 +309,7 @@ class DaemonContext(object):
close_open_fds(self.stdfds) close_open_fds(self.stdfds)
for fd in self.stdfds: for fd in self.stdfds:
self.redirect_to_null(fileno(fd)) self.redirect_to_null(maybe_fileno(fd))
self._is_open = True self._is_open = True
__enter__ = open __enter__ = open

View File

@ -350,7 +350,7 @@ class ResultSet(ResultBase):
def failed(self): def failed(self):
"""Did any of the tasks fail? """Did any of the tasks fail?
:returns: :const:`True` if any of the tasks failed. :returns: :const:`True` if one of the tasks failed.
(i.e., raised an exception) (i.e., raised an exception)
""" """
@ -359,7 +359,7 @@ class ResultSet(ResultBase):
def waiting(self): def waiting(self):
"""Are any of the tasks incomplete? """Are any of the tasks incomplete?
:returns: :const:`True` if any of the tasks is still :returns: :const:`True` if one of the tasks are still
waiting for execution. waiting for execution.
""" """
@ -368,7 +368,7 @@ class ResultSet(ResultBase):
def ready(self): def ready(self):
"""Did all of the tasks complete? (either by success of failure). """Did all of the tasks complete? (either by success of failure).
:returns: :const:`True` if all of the tasks been :returns: :const:`True` if all of the tasks has been
executed. executed.
""" """
@ -435,7 +435,7 @@ class ResultSet(ResultBase):
time.sleep(interval) time.sleep(interval)
elapsed += interval elapsed += interval
if timeout and elapsed >= timeout: if timeout and elapsed >= timeout:
raise TimeoutError("The operation timed out") raise TimeoutError('The operation timed out')
def get(self, timeout=None, propagate=True, interval=0.5): def get(self, timeout=None, propagate=True, interval=0.5):
"""See :meth:`join` """See :meth:`join`
@ -694,7 +694,7 @@ class EagerResult(AsyncResult):
self._state = states.REVOKED self._state = states.REVOKED
def __repr__(self): def __repr__(self):
return "<EagerResult: %s>" % self.id return '<EagerResult: %s>' % self.id
@property @property
def result(self): def result(self):

View File

@ -379,7 +379,11 @@ class crontab(schedule):
flag = (datedata.dom == len(days_of_month) or flag = (datedata.dom == len(days_of_month) or
day_out_of_range(datedata.year, day_out_of_range(datedata.year,
months_of_year[datedata.moy], months_of_year[datedata.moy],
days_of_month[datedata.dom])) days_of_month[datedata.dom]) or
(self.maybe_make_aware(datetime(datedata.year,
months_of_year[datedata.moy],
days_of_month[datedata.dom])) < last_run_at))
if flag: if flag:
datedata.dom = 0 datedata.dom = 0
datedata.moy += 1 datedata.moy += 1
@ -449,10 +453,11 @@ class crontab(schedule):
self._orig_day_of_month, self._orig_day_of_month,
self._orig_month_of_year), None) self._orig_month_of_year), None)
def remaining_estimate(self, last_run_at, tz=None): def remaining_delta(self, last_run_at, tz=None):
"""Returns when the periodic task should run next as a timedelta.""" """Returns when the periodic task should run next as a timedelta."""
tz = tz or self.tz tz = tz or self.tz
last_run_at = self.maybe_make_aware(last_run_at) last_run_at = self.maybe_make_aware(last_run_at)
now = self.maybe_make_aware(self.now())
dow_num = last_run_at.isoweekday() % 7 # Sunday is day 0, not day 7 dow_num = last_run_at.isoweekday() % 7 # Sunday is day 0, not day 7
execute_this_date = (last_run_at.month in self.month_of_year and execute_this_date = (last_run_at.month in self.month_of_year and
@ -460,6 +465,9 @@ class crontab(schedule):
dow_num in self.day_of_week) dow_num in self.day_of_week)
execute_this_hour = (execute_this_date and execute_this_hour = (execute_this_date and
last_run_at.day == now.day and
last_run_at.month == now.month and
last_run_at.year == now.year and
last_run_at.hour in self.hour and last_run_at.hour in self.hour and
last_run_at.minute < max(self.minute)) last_run_at.minute < max(self.minute))
@ -499,10 +507,11 @@ class crontab(schedule):
else: else:
delta = self._delta_to_next(last_run_at, delta = self._delta_to_next(last_run_at,
next_hour, next_minute) next_hour, next_minute)
return self.to_local(last_run_at), delta, self.to_local(now)
now = self.maybe_make_aware(self.now()) def remaining_estimate(self, last_run_at):
return remaining(self.to_local(last_run_at), delta, """Returns when the periodic task should run next as a timedelta."""
self.to_local(now)) return remaining(*self.remaining_delta(last_run_at))
def is_due(self, last_run_at): def is_due(self, last_run_at):
"""Returns tuple of two items `(is_due, next_time_to_run)`, """Returns tuple of two items `(is_due, next_time_to_run)`,

View File

@ -30,7 +30,10 @@ from celery.app import set_default_app
from celery.app.task import Task as BaseTask, Context from celery.app.task import Task as BaseTask, Context
from celery.datastructures import ExceptionInfo from celery.datastructures import ExceptionInfo
from celery.exceptions import Ignore, RetryTaskError from celery.exceptions import Ignore, RetryTaskError
from celery.utils.serialization import get_pickleable_exception from celery.utils.serialization import (
get_pickleable_exception,
get_pickleable_etype,
)
from celery.utils.log import get_logger from celery.utils.log import get_logger
_logger = get_logger(__name__) _logger = get_logger(__name__)
@ -128,7 +131,9 @@ class TraceInfo(object):
type_, _, tb = sys.exc_info() type_, _, tb = sys.exc_info()
try: try:
exc = self.retval exc = self.retval
einfo = ExceptionInfo((type_, get_pickleable_exception(exc), tb)) einfo = ExceptionInfo()
einfo.exception = get_pickleable_exception(einfo.exception)
einfo.type = get_pickleable_etype(einfo.type)
if store_errors: if store_errors:
task.backend.mark_as_failure(req.id, exc, einfo.traceback) task.backend.mark_as_failure(req.id, exc, einfo.traceback)
task.on_failure(exc, req.id, req.args, req.kwargs, einfo) task.on_failure(exc, req.id, req.args, req.kwargs, einfo)

View File

@ -11,8 +11,7 @@ from celery import current_app
from celery.result import AsyncResult, GroupResult from celery.result import AsyncResult, GroupResult
from celery.utils import serialization from celery.utils import serialization
from celery.utils.serialization import subclass_exception from celery.utils.serialization import subclass_exception
from celery.utils.serialization import \ from celery.utils.serialization import find_pickleable_exception as fnpe
find_nearest_pickleable_exception as fnpe
from celery.utils.serialization import UnpickleableExceptionWrapper from celery.utils.serialization import UnpickleableExceptionWrapper
from celery.utils.serialization import get_pickleable_exception as gpe from celery.utils.serialization import get_pickleable_exception as gpe

View File

@ -287,7 +287,7 @@ class test_MongoBackend(AppCase):
mock_database.__getitem__.assert_called_once_with( mock_database.__getitem__.assert_called_once_with(
MONGODB_COLLECTION) MONGODB_COLLECTION)
mock_collection.remove.assert_called_once_with( mock_collection.remove.assert_called_once_with(
{'_id': sentinel.task_id}, safe=True) {'_id': sentinel.task_id})
@patch('celery.backends.mongodb.MongoBackend._get_database') @patch('celery.backends.mongodb.MongoBackend._get_database')
def test_cleanup(self, mock_get_database): def test_cleanup(self, mock_get_database):

View File

@ -134,6 +134,11 @@ class test_chain(Case):
self.assertEqual(res.parent.parent.get(), 8) self.assertEqual(res.parent.parent.get(), 8)
self.assertIsNone(res.parent.parent.parent) self.assertIsNone(res.parent.parent.parent)
def test_accepts_generator_argument(self):
x = chain(add.s(i) for i in range(10))
self.assertTrue(x.tasks[0].type, add)
self.assertTrue(x.type)
class test_group(Case): class test_group(Case):

View File

@ -1,6 +1,8 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement from __future__ import with_statement
import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import wraps from functools import wraps
from mock import patch from mock import patch
@ -616,6 +618,14 @@ def monthly():
pass pass
@periodic_task(run_every=crontab(hour=22,
day_of_week='*',
month_of_year='2',
day_of_month='26,27,28'))
def monthly_moy():
pass
@periodic_task(run_every=crontab(hour=7, minute=30, @periodic_task(run_every=crontab(hour=7, minute=30,
day_of_week='thursday', day_of_week='thursday',
day_of_month='8-14', day_of_month='8-14',
@ -1212,6 +1222,40 @@ class test_crontab_is_due(Case):
self.assertFalse(due) self.assertFalse(due)
self.assertEqual(remaining, 4 * 24 * 60 * 60 - 3 * 60 * 60) self.assertEqual(remaining, 4 * 24 * 60 * 60 - 3 * 60 * 60)
@patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 22, 0))
def test_monthly_moy_execution_is_due(self):
due, remaining = monthly_moy.run_every.is_due(
datetime(2013, 7, 4, 10, 0))
self.assertTrue(due)
self.assertEqual(remaining, 60.)
@patch_crontab_nowfun(monthly_moy, datetime(2013, 6, 28, 14, 30))
def test_monthly_moy_execution_is_not_due(self):
due, remaining = monthly_moy.run_every.is_due(
datetime(2013, 6, 28, 22, 14))
self.assertFalse(due)
attempt = (
time.mktime(datetime(2014, 2, 26, 22, 0).timetuple()) -
time.mktime(datetime(2013, 6, 28, 14, 30).timetuple()) -
60 * 60
)
self.assertEqual(remaining, attempt)
@patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 22, 0))
def test_monthly_moy_execution_is_due2(self):
due, remaining = monthly_moy.run_every.is_due(
datetime(2013, 2, 28, 10, 0))
self.assertTrue(due)
self.assertEqual(remaining, 60.)
@patch_crontab_nowfun(monthly_moy, datetime(2014, 2, 26, 21, 0))
def test_monthly_moy_execution_is_not_due2(self):
due, remaining = monthly_moy.run_every.is_due(
datetime(2013, 6, 28, 22, 14))
self.assertFalse(due)
attempt = 60 * 60
self.assertEqual(remaining, attempt)
@patch_crontab_nowfun(yearly, datetime(2010, 3, 11, 7, 30)) @patch_crontab_nowfun(yearly, datetime(2010, 3, 11, 7, 30))
def test_yearly_execution_is_due(self): def test_yearly_execution_is_due(self):
due, remaining = yearly.run_every.is_due( due, remaining = yearly.run_every.is_due(

View File

@ -1,9 +1,5 @@
from __future__ import absolute_import from __future__ import absolute_import
import sys
from nose import SkipTest
from celery.utils import encoding from celery.utils import encoding
from celery.tests.utils import Case from celery.tests.utils import Case
@ -15,17 +11,6 @@ class test_encoding(Case):
self.assertTrue(encoding.safe_str('foo')) self.assertTrue(encoding.safe_str('foo'))
self.assertTrue(encoding.safe_str(u'foo')) self.assertTrue(encoding.safe_str(u'foo'))
def test_safe_str_UnicodeDecodeError(self):
if sys.version_info >= (3, 0):
raise SkipTest('py3k: not relevant')
class foo(unicode):
def encode(self, *args, **kwargs):
raise UnicodeDecodeError('foo')
self.assertIn('<Unrepresentable', encoding.safe_str(foo()))
def test_safe_repr(self): def test_safe_repr(self):
self.assertTrue(encoding.safe_repr(object())) self.assertTrue(encoding.safe_repr(object()))

View File

@ -1,6 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import from __future__ import absolute_import
from kombu.utils.encoding import str_t
from celery.utils import term from celery.utils import term
from celery.utils.term import colored, fg from celery.utils.term import colored, fg
@ -38,7 +40,7 @@ class test_colored(Case):
self.assertTrue(str(colored().iwhite('f'))) self.assertTrue(str(colored().iwhite('f')))
self.assertTrue(str(colored().reset('f'))) self.assertTrue(str(colored().reset('f')))
self.assertTrue(str(colored().green(u'∂bar'))) self.assertTrue(str_t(colored().green(u'∂bar')))
self.assertTrue( self.assertTrue(
colored().red(u'éefoo') + colored().green(u'∂bar')) colored().red(u'éefoo') + colored().green(u'∂bar'))

View File

@ -967,10 +967,12 @@ class test_WorkController(AppCase):
except ImportError: except ImportError:
raise SkipTest('multiprocessing not supported') raise SkipTest('multiprocessing not supported')
self.assertIsInstance(worker.ready_queue, AsyncTaskBucket) self.assertIsInstance(worker.ready_queue, AsyncTaskBucket)
self.assertFalse(worker.mediator) # XXX disabled until 3.1
self.assertNotEqual(worker.ready_queue.put, worker.process_task) #self.assertFalse(worker.mediator)
#self.assertNotEqual(worker.ready_queue.put, worker.process_task)
def test_disable_rate_limits_processes(self): def test_disable_rate_limits_processes(self):
raise SkipTest('disabled until v3.1')
try: try:
worker = self.create_worker(disable_rate_limits=True, worker = self.create_worker(disable_rate_limits=True,
use_eventloop=False, use_eventloop=False,
@ -1058,6 +1060,7 @@ class test_WorkController(AppCase):
self.assertTrue(w.disable_rate_limits) self.assertTrue(w.disable_rate_limits)
def test_Queues_pool_no_sem(self): def test_Queues_pool_no_sem(self):
raise SkipTest('disabled until v3.1')
w = Mock() w = Mock()
w.pool_cls.uses_semaphore = False w.pool_cls.uses_semaphore = False
Queues(w).create(w) Queues(w).create(w)
@ -1086,6 +1089,7 @@ class test_WorkController(AppCase):
w.hub.on_init = [] w.hub.on_init = []
w.pool_cls = Mock() w.pool_cls = Mock()
P = w.pool_cls.return_value = Mock() P = w.pool_cls.return_value = Mock()
P._cache = {}
P.timers = {Mock(): 30} P.timers = {Mock(): 30}
w.use_eventloop = True w.use_eventloop = True
w.consumer.restart_count = -1 w.consumer.restart_count = -1
@ -1105,23 +1109,13 @@ class test_WorkController(AppCase):
cbs['on_process_down'](w) cbs['on_process_down'](w)
hub.remove.assert_called_with(w.sentinel) hub.remove.assert_called_with(w.sentinel)
w.pool._tref_for_id = {}
result = Mock() result = Mock()
tref = result._tref
cbs['on_timeout_cancel'](result) cbs['on_timeout_cancel'](result)
tref.cancel.assert_called_with()
cbs['on_timeout_cancel'](result) # no more tref cbs['on_timeout_cancel'](result) # no more tref
cbs['on_timeout_set'](result, 10, 20)
tsoft, callback = hub.timer.apply_after.call_args[0]
callback()
cbs['on_timeout_set'](result, 10, None)
tsoft, callback = hub.timer.apply_after.call_args[0]
callback()
cbs['on_timeout_set'](result, None, 10)
cbs['on_timeout_set'](result, None, None)
with self.assertRaises(WorkerLostError): with self.assertRaises(WorkerLostError):
P.did_start_ok.return_value = False P.did_start_ok.return_value = False
w.consumer.restart_count = 0 w.consumer.restart_count = 0

View File

@ -28,15 +28,18 @@ class NotAPackage(Exception):
if sys.version_info >= (3, 3): # pragma: no cover if sys.version_info >= (3, 3): # pragma: no cover
def qualname(obj): def qualname(obj):
return obj.__qualname__ if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
obj = obj.__class__
q = getattr(obj, '__qualname__', None)
if '.' not in q:
q = '.'.join((obj.__module__, q))
return q
else: else:
def qualname(obj): # noqa def qualname(obj): # noqa
if not hasattr(obj, '__name__') and hasattr(obj, '__class__'): if not hasattr(obj, '__name__') and hasattr(obj, '__class__'):
return qualname(obj.__class__) return qualname(obj.__class__)
return '.'.join((obj.__module__, obj.__name__))
return '%s.%s' % (obj.__module__, obj.__name__)
def instantiate(name, *args, **kwargs): def instantiate(name, *args, **kwargs):

View File

@ -50,7 +50,8 @@ else:
return type(name, (parent,), {'__module__': module}) return type(name, (parent,), {'__module__': module})
def find_nearest_pickleable_exception(exc): def find_pickleable_exception(exc, loads=pickle.loads,
dumps=pickle.dumps):
"""With an exception instance, iterate over its super classes (by mro) """With an exception instance, iterate over its super classes (by mro)
and find the first super exception that is pickleable. It does and find the first super exception that is pickleable. It does
not go below :exc:`Exception` (i.e. it skips :exc:`Exception`, not go below :exc:`Exception` (i.e. it skips :exc:`Exception`,
@ -65,7 +66,19 @@ def find_nearest_pickleable_exception(exc):
:rtype :exc:`Exception`: :rtype :exc:`Exception`:
""" """
cls = exc.__class__ exc_args = getattr(exc, 'args', [])
for supercls in itermro(exc.__class__, unwanted_base_classes):
try:
superexc = supercls(*exc_args)
loads(dumps(superexc))
except:
pass
else:
return superexc
find_nearest_pickleable_exception = find_pickleable_exception # XXX compat
def itermro(cls, stop):
getmro_ = getattr(cls, 'mro', None) getmro_ = getattr(cls, 'mro', None)
# old-style classes doesn't have mro() # old-style classes doesn't have mro()
@ -77,18 +90,11 @@ def find_nearest_pickleable_exception(exc):
getmro_ = lambda: inspect.getmro(cls) getmro_ = lambda: inspect.getmro(cls)
for supercls in getmro_(): for supercls in getmro_():
if supercls in unwanted_base_classes: if supercls in stop:
# only BaseException and object, from here on down, # only BaseException and object, from here on down,
# we don't care about these. # we don't care about these.
return return
try: yield supercls
exc_args = getattr(exc, 'args', [])
superexc = supercls(*exc_args)
pickle.loads(pickle.dumps(superexc))
except:
pass
else:
return superexc
def create_exception_cls(name, module, parent=None): def create_exception_cls(name, module, parent=None):
@ -165,12 +171,19 @@ def get_pickleable_exception(exc):
pass pass
else: else:
return exc return exc
nearest = find_nearest_pickleable_exception(exc) nearest = find_pickleable_exception(exc)
if nearest: if nearest:
return nearest return nearest
return UnpickleableExceptionWrapper.from_exception(exc) return UnpickleableExceptionWrapper.from_exception(exc)
def get_pickleable_etype(cls, loads=pickle.loads, dumps=pickle.dumps):
try:
loads(dumps(cls))
except:
return Exception
def get_pickled_exception(exc): def get_pickled_exception(exc):
"""Get original exception from exception pickled using """Get original exception from exception pickled using
:meth:`get_pickleable_exception`.""" :meth:`get_pickleable_exception`."""

View File

@ -18,29 +18,26 @@ from celery.utils.compat import THREAD_TIMEOUT_MAX
USE_FAST_LOCALS = os.environ.get('USE_FAST_LOCALS') USE_FAST_LOCALS = os.environ.get('USE_FAST_LOCALS')
PY3 = sys.version_info[0] == 3 PY3 = sys.version_info[0] == 3
NEW_EVENT = (sys.version_info[0] == 3) and (sys.version_info[1] >= 3)
_Thread = threading.Thread _Thread = threading.Thread
_Event = threading.Event if PY3 else threading._Event _Event = threading.Event if NEW_EVENT else threading._Event
active_count = (getattr(threading, 'active_count', None) or active_count = (getattr(threading, 'active_count', None) or
threading.activeCount) threading.activeCount)
class Event(_Event): if sys.version_info < (2, 6):
if not hasattr(_Event, 'is_set'): # pragma: no cover class Event(_Event): # pragma: no cover
is_set = _Event.isSet is_set = _Event.isSet
class Thread(_Thread): # pragma: no cover
class Thread(_Thread):
if not hasattr(_Thread, 'is_alive'): # pragma: no cover
is_alive = _Thread.isAlive is_alive = _Thread.isAlive
if not hasattr(_Thread, 'daemon'): # pragma: no cover
daemon = property(_Thread.isDaemon, _Thread.setDaemon) daemon = property(_Thread.isDaemon, _Thread.setDaemon)
if not hasattr(_Thread, 'name'): # pragma: no cover
name = property(_Thread.getName, _Thread.setName) name = property(_Thread.getName, _Thread.setName)
else:
Event = _Event
Thread = _Thread
class bgThread(Thread): class bgThread(Thread):

View File

@ -19,6 +19,7 @@ import time
import traceback import traceback
from functools import partial from functools import partial
from weakref import WeakValueDictionary
from billiard.exceptions import WorkerLostError from billiard.exceptions import WorkerLostError
from billiard.util import Finalize from billiard.util import Finalize
@ -26,6 +27,7 @@ from kombu.syn import detect_environment
from celery import concurrency as _concurrency from celery import concurrency as _concurrency
from celery import platforms from celery import platforms
from celery import signals
from celery.app import app_or_default from celery.app import app_or_default
from celery.app.abstract import configurated, from_config from celery.app.abstract import configurated, from_config
from celery.exceptions import SystemTerminate, TaskRevokedError from celery.exceptions import SystemTerminate, TaskRevokedError
@ -105,6 +107,7 @@ class Pool(bootsteps.StartStopComponent):
add_reader = hub.add_reader add_reader = hub.add_reader
remove = hub.remove remove = hub.remove
now = time.time now = time.time
cache = pool._pool._cache
# did_start_ok will verify that pool processes were able to start, # did_start_ok will verify that pool processes were able to start,
# but this will only work the first time we start, as # but this will only work the first time we start, as
@ -120,25 +123,58 @@ class Pool(bootsteps.StartStopComponent):
for handler, interval in pool.timers.iteritems(): for handler, interval in pool.timers.iteritems():
hub.timer.apply_interval(interval * 1000.0, handler) hub.timer.apply_interval(interval * 1000.0, handler)
def on_timeout_set(R, soft, hard): trefs = pool._tref_for_id = WeakValueDictionary()
def _on_soft_timeout(): def _discard_tref(job):
if hard:
R._tref = apply_at(now() + (hard - soft),
on_hard_timeout, (R, ))
on_soft_timeout(R)
if soft:
R._tref = apply_after(soft * 1000.0, _on_soft_timeout)
elif hard:
R._tref = apply_after(hard * 1000.0,
on_hard_timeout, (R, ))
def on_timeout_cancel(result):
try: try:
result._tref.cancel() tref = trefs.pop(job)
delattr(result, '_tref') tref.cancel()
except AttributeError: del(tref)
pass except (KeyError, AttributeError):
pass # out of scope
def _on_hard_timeout(job):
try:
result = cache[job]
except KeyError:
pass # job ready
else:
on_hard_timeout(result)
finally:
# remove tref
_discard_tref(job)
def _on_soft_timeout(job, soft, hard, hub):
if hard:
trefs[job] = apply_at(
now() + (hard - soft),
_on_hard_timeout, (job, ),
)
try:
result = cache[job]
except KeyError:
pass # job ready
else:
on_soft_timeout(result)
finally:
if not hard:
# remove tref
_discard_tref(job)
def on_timeout_set(R, soft, hard):
if soft:
trefs[R._job] = apply_after(
soft * 1000.0,
_on_soft_timeout, (R._job, soft, hard, hub),
)
elif hard:
trefs[R._job] = apply_after(
hard * 1000.0,
_on_hard_timeout, (R._job, )
)
def on_timeout_cancel(R):
_discard_tref(R._job)
pool.init_callbacks( pool.init_callbacks(
on_process_up=lambda w: add_reader(w.sentinel, maintain_pool), on_process_up=lambda w: add_reader(w.sentinel, maintain_pool),
@ -208,19 +244,18 @@ class Queues(bootsteps.Component):
def create(self, w): def create(self, w):
BucketType = TaskBucket BucketType = TaskBucket
w.start_mediator = not w.disable_rate_limits w.start_mediator = w.pool_cls.requires_mediator
if not w.pool_cls.rlimit_safe: if not w.pool_cls.rlimit_safe:
w.start_mediator = False
BucketType = AsyncTaskBucket BucketType = AsyncTaskBucket
process_task = w.process_task process_task = w.process_task
if w.use_eventloop: if w.use_eventloop:
w.start_mediator = False
BucketType = AsyncTaskBucket BucketType = AsyncTaskBucket
if w.pool_putlocks and w.pool_cls.uses_semaphore: if w.pool_putlocks and w.pool_cls.uses_semaphore:
process_task = w.process_task_sem process_task = w.process_task_sem
if w.disable_rate_limits: if w.disable_rate_limits or not w.start_mediator:
w.ready_queue = FastQueue() w.ready_queue = FastQueue()
w.ready_queue.put = process_task if not w.start_mediator:
w.ready_queue.put = process_task
else: else:
w.ready_queue = BucketType( w.ready_queue = BucketType(
task_registry=w.app.tasks, callback=process_task, worker=w, task_registry=w.app.tasks, callback=process_task, worker=w,
@ -327,7 +362,10 @@ class WorkController(configurated):
self.loglevel = loglevel or self.loglevel self.loglevel = loglevel or self.loglevel
self.hostname = hostname or socket.gethostname() self.hostname = hostname or socket.gethostname()
self.ready_callback = ready_callback self.ready_callback = ready_callback
self._finalize = Finalize(self, self.stop, exitpriority=1) self._finalize = [
Finalize(self, self.stop, exitpriority=1),
Finalize(self, self._send_worker_shutdown, exitpriority=10),
]
self.pidfile = pidfile self.pidfile = pidfile
self.pidlock = None self.pidlock = None
# this connection is not established, only used for params # this connection is not established, only used for params
@ -350,6 +388,9 @@ class WorkController(configurated):
self.components = [] self.components = []
self.namespace = Namespace(app=self.app).apply(self, **kwargs) self.namespace = Namespace(app=self.app).apply(self, **kwargs)
def _send_worker_shutdown(self):
signals.worker_shutdown.send(sender=self)
def start(self): def start(self):
"""Starts the workers main loop.""" """Starts the workers main loop."""
self._state = self.RUN self._state = self.RUN

View File

@ -39,6 +39,12 @@ class AsyncTaskBucket(object):
self.worker = worker self.worker = worker
self.buckets = {} self.buckets = {}
self.refresh() self.refresh()
self._queue = Queue()
self._quick_put = self._queue.put
self.get = self._queue.get
def get(self, *args, **kwargs):
return self._queue.get(*args, **kwargs)
def cont(self, request, bucket, tokens): def cont(self, request, bucket, tokens):
if not bucket.can_consume(tokens): if not bucket.can_consume(tokens):
@ -47,7 +53,7 @@ class AsyncTaskBucket(object):
hold * 1000.0, self.cont, (request, bucket, tokens), hold * 1000.0, self.cont, (request, bucket, tokens),
) )
else: else:
self.callback(request) self._quick_put(request)
def put(self, request): def put(self, request):
name = request.name name = request.name
@ -56,7 +62,7 @@ class AsyncTaskBucket(object):
except KeyError: except KeyError:
bucket = self.add_bucket_for_type(name) bucket = self.add_bucket_for_type(name)
if not bucket: if not bucket:
return self.callback(request) return self._quick_put(request)
return self.cont(request, bucket, 1) return self.cont(request, bucket, 1)
def add_task_type(self, name): def add_task_type(self, name):

View File

@ -80,6 +80,8 @@ import threading
from time import sleep from time import sleep
from Queue import Empty from Queue import Empty
from billiard.common import restart_state
from billiard.exceptions import RestartFreqExceeded
from kombu.syn import _detect_environment from kombu.syn import _detect_environment
from kombu.utils.encoding import safe_repr, safe_str, bytes_t from kombu.utils.encoding import safe_repr, safe_str, bytes_t
from kombu.utils.eventio import READ, WRITE, ERR from kombu.utils.eventio import READ, WRITE, ERR
@ -100,6 +102,13 @@ from .bootsteps import StartStopComponent
from .control import Panel from .control import Panel
from .heartbeat import Heart from .heartbeat import Heart
try:
buffer_t = buffer
except NameError: # pragma: no cover
class buffer_t(object): # noqa
pass
RUN = 0x1 RUN = 0x1
CLOSE = 0x2 CLOSE = 0x2
@ -171,7 +180,7 @@ def debug(msg, *args, **kwargs):
def dump_body(m, body): def dump_body(m, body):
if isinstance(body, buffer): if isinstance(body, buffer_t):
body = bytes_t(body) body = bytes_t(body)
return "%s (%sb)" % (text.truncate(safe_repr(body), 1024), len(m.body)) return "%s (%sb)" % (text.truncate(safe_repr(body), 1024), len(m.body))
@ -348,6 +357,7 @@ class Consumer(object):
conninfo = self.app.connection() conninfo = self.app.connection()
self.connection_errors = conninfo.connection_errors self.connection_errors = conninfo.connection_errors
self.channel_errors = conninfo.channel_errors self.channel_errors = conninfo.channel_errors
self._restart_state = restart_state(maxR=5, maxT=1)
self._does_info = logger.isEnabledFor(logging.INFO) self._does_info = logger.isEnabledFor(logging.INFO)
self.strategies = {} self.strategies = {}
@ -390,6 +400,11 @@ class Consumer(object):
while self._state != CLOSE: while self._state != CLOSE:
self.restart_count += 1 self.restart_count += 1
self.maybe_shutdown() self.maybe_shutdown()
try:
self._restart_state.step()
except RestartFreqExceeded as exc:
crit('Frequent restarts detected: %r', exc, exc_info=1)
sleep(1)
try: try:
self.reset_connection() self.reset_connection()
self.consume_messages() self.consume_messages()
@ -736,6 +751,7 @@ class Consumer(object):
# to the current channel. # to the current channel.
self.ready_queue.clear() self.ready_queue.clear()
self.timer.clear() self.timer.clear()
state.reserved_requests.clear()
# Re-establish the broker connection and setup the task consumer. # Re-establish the broker connection and setup the task consumer.
self.connection = self._open_connection() self.connection = self._open_connection()

View File

@ -36,7 +36,7 @@ class WorkerComponent(StartStopComponent):
w.mediator = None w.mediator = None
def include_if(self, w): def include_if(self, w):
return w.start_mediator and not w.use_eventloop return w.start_mediator
def create(self, w): def create(self, w):
m = w.mediator = self.instantiate(w.mediator_cls, w.ready_queue, m = w.mediator = self.instantiate(w.mediator_cls, w.ready_queue,

View File

@ -1,5 +1,5 @@
VERSION = (1, 1, 1) VERSION = (1, 2, 0)
# Dynamically calculate the version based on VERSION tuple # Dynamically calculate the version based on VERSION tuple
if len(VERSION) > 2 and VERSION[2] is not None: if len(VERSION) > 2 and VERSION[2] is not None:

View File

@ -9,6 +9,7 @@
# (Michal Salaban) # (Michal Salaban)
# #
import six
import operator import operator
from six.moves import reduce from six.moves import reduce
from django.http import HttpResponse, HttpResponseNotFound from django.http import HttpResponse, HttpResponseNotFound
@ -108,8 +109,7 @@ class ForeignKeyAutocompleteAdmin(ModelAdmin):
other_qs.dup_select_related(queryset) other_qs.dup_select_related(queryset)
other_qs = other_qs.filter(reduce(operator.or_, or_queries)) other_qs = other_qs.filter(reduce(operator.or_, or_queries))
queryset = queryset & other_qs queryset = queryset & other_qs
data = ''.join([u'%s|%s\n' % ( data = ''.join([six.u('%s|%s\n' % (to_string_function(f), f.pk)) for f in queryset])
to_string_function(f), f.pk) for f in queryset])
elif object_pk: elif object_pk:
try: try:
obj = queryset.get(pk=object_pk) obj = queryset.get(pk=object_pk)
@ -139,7 +139,7 @@ class ForeignKeyAutocompleteAdmin(ModelAdmin):
model_name = db_field.rel.to._meta.object_name model_name = db_field.rel.to._meta.object_name
help_text = self.get_help_text(db_field.name, model_name) help_text = self.get_help_text(db_field.name, model_name)
if kwargs.get('help_text'): if kwargs.get('help_text'):
help_text = u'%s %s' % (kwargs['help_text'], help_text) help_text = six.u('%s %s' % (kwargs['help_text'], help_text))
kwargs['widget'] = ForeignKeySearchInput(db_field.rel, self.related_search_fields[db_field.name]) kwargs['widget'] = ForeignKeySearchInput(db_field.rel, self.related_search_fields[db_field.name])
kwargs['help_text'] = help_text kwargs['help_text'] = help_text
return super(ForeignKeyAutocompleteAdmin, self).formfield_for_dbfield(db_field, **kwargs) return super(ForeignKeyAutocompleteAdmin, self).formfield_for_dbfield(db_field, **kwargs)

View File

@ -1,3 +1,4 @@
import six
import django import django
from django import forms from django import forms
from django.conf import settings from django.conf import settings
@ -67,7 +68,7 @@ class ForeignKeySearchInput(ForeignKeyRawIdWidget):
if value: if value:
label = self.label_for_value(value) label = self.label_for_value(value)
else: else:
label = u'' label = six.u('')
try: try:
admin_media_prefix = settings.ADMIN_MEDIA_PREFIX admin_media_prefix = settings.ADMIN_MEDIA_PREFIX
@ -92,4 +93,4 @@ class ForeignKeySearchInput(ForeignKeyRawIdWidget):
'django_extensions/widgets/foreignkey_searchinput.html', 'django_extensions/widgets/foreignkey_searchinput.html',
), context)) ), context))
output.reverse() output.reverse()
return mark_safe(u''.join(output)) return mark_safe(six.u(''.join(output)))

View File

@ -3,13 +3,13 @@ Django Extensions additional model fields
""" """
import re import re
import six import six
try: try:
import uuid import uuid
assert uuid HAS_UUID = True
except ImportError: except ImportError:
from django_extensions.utils import uuid HAS_UUID = False
from django.core.exceptions import ImproperlyConfigured
from django.template.defaultfilters import slugify from django.template.defaultfilters import slugify
from django.db.models import DateTimeField, CharField, SlugField from django.db.models import DateTimeField, CharField, SlugField
@ -56,7 +56,7 @@ class AutoSlugField(SlugField):
raise ValueError("missing 'populate_from' argument") raise ValueError("missing 'populate_from' argument")
else: else:
self._populate_from = populate_from self._populate_from = populate_from
self.separator = kwargs.pop('separator', u'-') self.separator = kwargs.pop('separator', six.u('-'))
self.overwrite = kwargs.pop('overwrite', False) self.overwrite = kwargs.pop('overwrite', False)
self.allow_duplicates = kwargs.pop('allow_duplicates', False) self.allow_duplicates = kwargs.pop('allow_duplicates', False)
super(AutoSlugField, self).__init__(*args, **kwargs) super(AutoSlugField, self).__init__(*args, **kwargs)
@ -221,13 +221,15 @@ class UUIDVersionError(Exception):
class UUIDField(CharField): class UUIDField(CharField):
""" UUIDField """ UUIDField
By default uses UUID version 4 (generate from host ID, sequence number and current time) By default uses UUID version 4 (randomly generated UUID).
The field support all uuid versions which are natively supported by the uuid python module. The field support all uuid versions which are natively supported by the uuid python module, except version 2.
For more information see: http://docs.python.org/lib/module-uuid.html For more information see: http://docs.python.org/lib/module-uuid.html
""" """
def __init__(self, verbose_name=None, name=None, auto=True, version=1, node=None, clock_seq=None, namespace=None, **kwargs): def __init__(self, verbose_name=None, name=None, auto=True, version=4, node=None, clock_seq=None, namespace=None, **kwargs):
if not HAS_UUID:
raise ImproperlyConfigured("'uuid' module is required for UUIDField. (Do you have Python 2.5 or higher installed ?)")
kwargs.setdefault('max_length', 36) kwargs.setdefault('max_length', 36)
if auto: if auto:
self.empty_strings_allowed = False self.empty_strings_allowed = False
@ -244,17 +246,6 @@ class UUIDField(CharField):
def get_internal_type(self): def get_internal_type(self):
return CharField.__name__ return CharField.__name__
def contribute_to_class(self, cls, name):
if self.primary_key:
assert not cls._meta.has_auto_field, "A model can't have more than one AutoField: %s %s %s; have %s" % (
self, cls, name, cls._meta.auto_field
)
super(UUIDField, self).contribute_to_class(cls, name)
cls._meta.has_auto_field = True
cls._meta.auto_field = self
else:
super(UUIDField, self).contribute_to_class(cls, name)
def create_uuid(self): def create_uuid(self):
if not self.version or self.version == 4: if not self.version or self.version == 4:
return uuid.uuid4() return uuid.uuid4()
@ -277,7 +268,7 @@ class UUIDField(CharField):
return value return value
else: else:
if self.auto and not value: if self.auto and not value:
value = six.u(self.create_uuid()) value = force_unicode(self.create_uuid())
setattr(model_instance, self.attname, value) setattr(model_instance, self.attname, value)
return value return value

View File

@ -21,8 +21,11 @@ class BaseEncryptedField(models.Field):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if not hasattr(settings, 'ENCRYPTED_FIELD_KEYS_DIR'): if not hasattr(settings, 'ENCRYPTED_FIELD_KEYS_DIR'):
raise ImproperlyConfigured('You must set the ENCRYPTED_FIELD_KEYS_DIR setting to your Keyczar keys directory.') raise ImproperlyConfigured('You must set the ENCRYPTED_FIELD_KEYS_DIR '
self.crypt = keyczar.Crypter.Read(settings.ENCRYPTED_FIELD_KEYS_DIR) 'setting to your Keyczar keys directory.')
crypt_class = self.get_crypt_class()
self.crypt = crypt_class.Read(settings.ENCRYPTED_FIELD_KEYS_DIR)
# Encrypted size is larger than unencrypted # Encrypted size is larger than unencrypted
self.unencrypted_length = max_length = kwargs.get('max_length', None) self.unencrypted_length = max_length = kwargs.get('max_length', None)
@ -34,6 +37,32 @@ class BaseEncryptedField(models.Field):
super(BaseEncryptedField, self).__init__(*args, **kwargs) super(BaseEncryptedField, self).__init__(*args, **kwargs)
def get_crypt_class(self):
"""
Get the Keyczar class to use.
The class can be customized with the ENCRYPTED_FIELD_MODE setting. By default,
this setting is DECRYPT_AND_ENCRYPT. Set this to ENCRYPT to disable decryption.
This is necessary if you are only providing public keys to Keyczar.
Returns:
keyczar.Encrypter if ENCRYPTED_FIELD_MODE is ENCRYPT.
keyczar.Crypter if ENCRYPTED_FIELD_MODE is DECRYPT_AND_ENCRYPT.
Override this method to customize the type of Keyczar class returned.
"""
crypt_type = getattr(settings, 'ENCRYPTED_FIELD_MODE', 'DECRYPT_AND_ENCRYPT')
if crypt_type == 'ENCRYPT':
crypt_class_name = 'Encrypter'
elif crypt_type == 'DECRYPT_AND_ENCRYPT':
crypt_class_name = 'Crypter'
else:
raise ImproperlyConfigured(
'ENCRYPTED_FIELD_MODE must be either DECRYPT_AND_ENCRYPT '
'or ENCRYPT, not %s.' % crypt_type)
return getattr(keyczar, crypt_class_name)
def to_python(self, value): def to_python(self, value):
if isinstance(self.crypt.primary_key, keyczar.keys.RsaPublicKey): if isinstance(self.crypt.primary_key, keyczar.keys.RsaPublicKey):
retval = value retval = value
@ -64,9 +93,8 @@ class BaseEncryptedField(models.Field):
return value return value
class EncryptedTextField(BaseEncryptedField): class EncryptedTextField(six.with_metaclass(models.SubfieldBase,
__metaclass__ = models.SubfieldBase BaseEncryptedField)):
def get_internal_type(self): def get_internal_type(self):
return 'TextField' return 'TextField'
@ -85,9 +113,8 @@ class EncryptedTextField(BaseEncryptedField):
return (field_class, args, kwargs) return (field_class, args, kwargs)
class EncryptedCharField(BaseEncryptedField): class EncryptedCharField(six.with_metaclass(models.SubfieldBase,
__metaclass__ = models.SubfieldBase BaseEncryptedField)):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(EncryptedCharField, self).__init__(*args, **kwargs) super(EncryptedCharField, self).__init__(*args, **kwargs)
@ -107,4 +134,3 @@ class EncryptedCharField(BaseEncryptedField):
args, kwargs = introspector(self) args, kwargs = introspector(self)
# That's our definition! # That's our definition!
return (field_class, args, kwargs) return (field_class, args, kwargs)

View File

@ -58,16 +58,13 @@ class JSONList(list):
return dumps(self) return dumps(self)
class JSONField(models.TextField): class JSONField(six.with_metaclass(models.SubfieldBase, models.TextField)):
"""JSONField is a generic textfield that neatly serializes/unserializes """JSONField is a generic textfield that neatly serializes/unserializes
JSON objects seamlessly. Main thingy must be a dict object.""" JSON objects seamlessly. Main thingy must be a dict object."""
# Used so to_python() is called
__metaclass__ = models.SubfieldBase
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
default = kwargs.get('default') default = kwargs.get('default', None)
if not default: if default is None:
kwargs['default'] = '{}' kwargs['default'] = '{}'
elif isinstance(default, (list, dict)): elif isinstance(default, (list, dict)):
kwargs['default'] = dumps(default) kwargs['default'] = dumps(default)

View File

@ -0,0 +1,16 @@
"""
A forwards compatibility module.
Implements some features of Django 1.5 related to the 'Custom User Model' feature
when the application is run with a lower version of Django.
"""
from __future__ import unicode_literals
from django.contrib.auth.models import User
User.USERNAME_FIELD = "username"
User.get_username = lambda self: self.username
def get_user_model():
return User

View File

@ -71,9 +71,9 @@ def orm_item_locator(orm_obj):
for key in clean_dict: for key in clean_dict:
v = clean_dict[key] v = clean_dict[key]
if v is not None and not isinstance(v, (six.string_types, six.integer_types, float, datetime.datetime)): if v is not None and not isinstance(v, (six.string_types, six.integer_types, float, datetime.datetime)):
clean_dict[key] = u"%s" % v clean_dict[key] = six.u("%s" % v)
output = """ locate_object(%s, "%s", %s, "%s", %s, %s ) """ % ( output = """ importer.locate_object(%s, "%s", %s, "%s", %s, %s ) """ % (
original_class, original_pk_name, original_class, original_pk_name,
the_class, pk_name, pk_value, clean_dict the_class, pk_name, pk_value, clean_dict
) )
@ -264,7 +264,7 @@ class InstanceCode(Code):
# Print the save command for our new object # Print the save command for our new object
# e.g. model_name_35.save() # e.g. model_name_35.save()
if code_lines: if code_lines:
code_lines.append("%s = save_or_locate(%s)\n" % (self.variable_name, self.variable_name)) code_lines.append("%s = importer.save_or_locate(%s)\n" % (self.variable_name, self.variable_name))
code_lines += self.get_many_to_many_lines(force=force) code_lines += self.get_many_to_many_lines(force=force)
@ -499,7 +499,6 @@ class Script(Code):
code.insert(2, "") code.insert(2, "")
for key, value in self.context["__extra_imports"].items(): for key, value in self.context["__extra_imports"].items():
code.insert(2, " from %s import %s" % (value, key)) code.insert(2, " from %s import %s" % (value, key))
code.insert(2 + len(self.context["__extra_imports"]), self.locate_object_function)
return code return code
@ -513,11 +512,17 @@ class Script(Code):
# This file has been automatically generated. # This file has been automatically generated.
# Instead of changing it, create a file called import_helper.py # Instead of changing it, create a file called import_helper.py
# which this script has hooks to. # and put there a class called ImportHelper(object) in it.
# #
# On that file, don't forget to add the necessary Django imports # This class will be specially casted so that instead of extending object,
# and take a look at how locate_object() and save_or_locate() # it will actually extend the class BasicImportHelper()
# are implemented here and expected to behave. #
# That means you just have to overload the methods you want to
# change, leaving the other ones inteact.
#
# Something that you might want to do is use transactions, for example.
#
# Also, don't forget to add the necessary Django imports.
# #
# This file was generated with the following command: # This file was generated with the following command:
# %s # %s
@ -530,24 +535,31 @@ class Script(Code):
# you must make sure ./some_folder/__init__.py exists # you must make sure ./some_folder/__init__.py exists
# and run ./manage.py runscript some_folder.some_script # and run ./manage.py runscript some_folder.some_script
from django.db import transaction
IMPORT_HELPER_AVAILABLE = False class BasicImportHelper(object):
try:
import import_helper
IMPORT_HELPER_AVAILABLE = True
except ImportError:
pass
import datetime def pre_import(self):
from decimal import Decimal pass
from django.contrib.contenttypes.models import ContentType
def run(): # You probably want to uncomment on of these two lines
# @transaction.atomic # Django 1.6
# @transaction.commit_on_success # Django <1.6
def run_import(self, import_data):
import_data()
""" % " ".join(sys.argv) def post_import(self):
pass
locate_object_function = """ def locate_similar(self, current_object, search_data):
def locate_object(original_class, original_pk_name, the_class, pk_name, pk_value, obj_content): #you will probably want to call this method from save_or_locate()
#example:
#new_obj = self.locate_similar(the_obj, {"national_id": the_obj.national_id } )
the_obj = current_object.__class__.objects.get(**search_data)
return the_obj
def locate_object(self, original_class, original_pk_name, the_class, pk_name, pk_value, obj_content):
#You may change this function to do specific lookup for specific objects #You may change this function to do specific lookup for specific objects
# #
#original_class class of the django orm's object that needs to be located #original_class class of the django orm's object that needs to be located
@ -571,22 +583,55 @@ def run():
#if the_class == StaffGroup: #if the_class == StaffGroup:
# pk_value=8 # pk_value=8
if IMPORT_HELPER_AVAILABLE and hasattr(import_helper, "locate_object"):
return import_helper.locate_object(original_class, original_pk_name, the_class, pk_name, pk_value, obj_content)
search_data = { pk_name: pk_value } search_data = { pk_name: pk_value }
the_obj =the_class.objects.get(**search_data) the_obj = the_class.objects.get(**search_data)
#print(the_obj)
return the_obj return the_obj
def save_or_locate(the_obj):
if IMPORT_HELPER_AVAILABLE and hasattr(import_helper, "save_or_locate"): def save_or_locate(self, the_obj):
the_obj = import_helper.save_or_locate(the_obj) #change this if you want to locate the object in the database
else: try:
the_obj.save() the_obj.save()
except:
print("---------------")
print("Error saving the following object:")
print(the_obj.__class__)
print(" ")
print(the_obj.__dict__)
print(" ")
print(the_obj)
print(" ")
print("---------------")
raise
return the_obj return the_obj
"""
importer = None
try:
import import_helper
#we need this so ImportHelper can extend BasicImportHelper, although import_helper.py
#has no knowlodge of this class
importer = type("DynamicImportHelper", (import_helper.ImportHelper, BasicImportHelper ) , {} )()
except ImportError as e:
if str(e) == "No module named import_helper":
importer = BasicImportHelper()
else:
raise
import datetime
from decimal import Decimal
from django.contrib.contenttypes.models import ContentType
def run():
importer.pre_import()
importer.run_import(import_data)
importer.post_import()
def import_data():
""" % " ".join(sys.argv)
# HELPER FUNCTIONS # HELPER FUNCTIONS

View File

@ -1,8 +1,13 @@
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
from django.contrib.auth.models import User, Group try:
from django.contrib.auth import get_user_model # Django 1.5
except ImportError:
from django_extensions.future_1_5 import get_user_model
from django.contrib.auth.models import Group
from optparse import make_option from optparse import make_option
from sys import stdout from sys import stdout
from csv import writer from csv import writer
import six
FORMATS = [ FORMATS = [
'address', 'address',
@ -15,7 +20,7 @@ FORMATS = [
def full_name(first_name, last_name, username, **extra): def full_name(first_name, last_name, username, **extra):
name = u" ".join(n for n in [first_name, last_name] if n) name = six.u(" ").join(n for n in [first_name, last_name] if n)
if not name: if not name:
return username return username
return name return name
@ -42,7 +47,7 @@ class Command(BaseCommand):
raise CommandError("extra arguments supplied") raise CommandError("extra arguments supplied")
group = options['group'] group = options['group']
if group and not Group.objects.filter(name=group).count() == 1: if group and not Group.objects.filter(name=group).count() == 1:
names = u"', '".join(g['name'] for g in Group.objects.values('name')).encode('utf-8') names = six.u("', '").join(g['name'] for g in Group.objects.values('name')).encode('utf-8')
if names: if names:
names = "'" + names + "'." names = "'" + names + "'."
raise CommandError("Unknown group '" + group + "'. Valid group names are: " + names) raise CommandError("Unknown group '" + group + "'. Valid group names are: " + names)
@ -51,6 +56,7 @@ class Command(BaseCommand):
else: else:
outfile = stdout outfile = stdout
User = get_user_model()
qs = User.objects.all().order_by('last_name', 'first_name', 'username', 'email') qs = User.objects.all().order_by('last_name', 'first_name', 'username', 'email')
if group: if group:
qs = qs.filter(group__name=group).distinct() qs = qs.filter(group__name=group).distinct()
@ -61,15 +67,15 @@ class Command(BaseCommand):
"""simple single entry per line in the format of: """simple single entry per line in the format of:
"full name" <my@address.com>; "full name" <my@address.com>;
""" """
out.write(u"\n".join(u'"%s" <%s>;' % (full_name(**ent), ent['email']) out.write(six.u("\n").join(six.u('"%s" <%s>;' % (full_name(**ent), ent['email']))
for ent in qs).encode(self.encoding)) for ent in qs).encode(self.encoding))
out.write("\n") out.write("\n")
def emails(self, qs, out): def emails(self, qs, out):
"""simpler single entry with email only in the format of: """simpler single entry with email only in the format of:
my@address.com, my@address.com,
""" """
out.write(u",\n".join(u'%s' % (ent['email']) for ent in qs).encode(self.encoding)) out.write(six.u(",\n").join(six.u('%s' % (ent['email'])) for ent in qs).encode(self.encoding))
out.write("\n") out.write("\n")
def google(self, qs, out): def google(self, qs, out):

View File

@ -56,7 +56,7 @@ class Command(BaseCommand):
vizdata = ' '.join(dotdata.split("\n")).strip().encode('utf-8') vizdata = ' '.join(dotdata.split("\n")).strip().encode('utf-8')
version = pygraphviz.__version__.rstrip("-svn") version = pygraphviz.__version__.rstrip("-svn")
try: try:
if [int(v) for v in version.split('.')] < (0, 36): if tuple(int(v) for v in version.split('.')) < (0, 36):
# HACK around old/broken AGraph before version 0.36 (ubuntu ships with this old version) # HACK around old/broken AGraph before version 0.36 (ubuntu ships with this old version)
import tempfile import tempfile
tmpfile = tempfile.NamedTemporaryFile() tmpfile = tempfile.NamedTemporaryFile()

View File

@ -1,5 +1,8 @@
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
from django.contrib.auth.models import User try:
from django.contrib.auth import get_user_model # Django 1.5
except ImportError:
from django_extensions.future_1_5 import get_user_model
import getpass import getpass
@ -17,6 +20,7 @@ class Command(BaseCommand):
else: else:
username = getpass.getuser() username = getpass.getuser()
User = get_user_model()
try: try:
u = User.objects.get(username=username) u = User.objects.get(username=username)
except User.DoesNotExist: except User.DoesNotExist:

View File

@ -8,6 +8,12 @@ import urlparse
import xmlrpclib import xmlrpclib
from distutils.version import LooseVersion from distutils.version import LooseVersion
try:
import requests
except ImportError:
print("""The requests library is not installed. To continue:
pip install requests""")
from optparse import make_option from optparse import make_option
from django.core.management.base import NoArgsCommand from django.core.management.base import NoArgsCommand
@ -166,13 +172,25 @@ class Command(NoArgsCommand):
} }
if self.github_api_token: if self.github_api_token:
headers["Authorization"] = "token {0}".format(self.github_api_token) headers["Authorization"] = "token {0}".format(self.github_api_token)
user, repo = urlparse.urlparse(req_url).path.split("#")[0].strip("/").rstrip("/").split("/") try:
user, repo = urlparse.urlparse(req_url).path.split("#")[0].strip("/").rstrip("/").split("/")
except (ValueError, IndexError) as e:
print("\nFailed to parse %r: %s\n" % (req_url, e))
continue
try:
#test_auth = self._urlopen_as_json("https://api.github.com/django/", headers=headers)
test_auth = requests.get("https://api.github.com/django/", headers=headers).json()
except urllib2.HTTPError as e:
print("\n%s\n" % str(e))
return
test_auth = self._urlopen_as_json("https://api.github.com/django/", headers=headers)
if "message" in test_auth and test_auth["message"] == "Bad credentials": if "message" in test_auth and test_auth["message"] == "Bad credentials":
sys.exit("\nGithub API: Bad credentials. Aborting!\n") print("\nGithub API: Bad credentials. Aborting!\n")
return
elif "message" in test_auth and test_auth["message"].startswith("API Rate Limit Exceeded"): elif "message" in test_auth and test_auth["message"].startswith("API Rate Limit Exceeded"):
sys.exit("\nGithub API: Rate Limit Exceeded. Aborting!\n") print("\nGithub API: Rate Limit Exceeded. Aborting!\n")
return
if ".git" in repo: if ".git" in repo:
repo_name, frozen_commit_full = repo.split(".git") repo_name, frozen_commit_full = repo.split(".git")
@ -186,11 +204,14 @@ class Command(NoArgsCommand):
if frozen_commit_sha: if frozen_commit_sha:
branch_url = "https://api.github.com/repos/{0}/{1}/branches".format(user, repo_name) branch_url = "https://api.github.com/repos/{0}/{1}/branches".format(user, repo_name)
branch_data = self._urlopen_as_json(branch_url, headers=headers) #branch_data = self._urlopen_as_json(branch_url, headers=headers)
branch_data = requests.get(branch_url, headers=headers).json()
frozen_commit_url = "https://api.github.com/repos/{0}/{1}/commits/{2}" \ frozen_commit_url = "https://api.github.com/repos/{0}/{1}/commits/{2}".format(
.format(user, repo_name, frozen_commit_sha) user, repo_name, frozen_commit_sha
frozen_commit_data = self._urlopen_as_json(frozen_commit_url, headers=headers) )
#frozen_commit_data = self._urlopen_as_json(frozen_commit_url, headers=headers)
frozen_commit_data = requests.get(frozen_commit_url, headers=headers).json()
if "message" in frozen_commit_data and frozen_commit_data["message"] == "Not Found": if "message" in frozen_commit_data and frozen_commit_data["message"] == "Not Found":
msg = "{0} not found in {1}. Repo may be private.".format(frozen_commit_sha[:10], name) msg = "{0} not found in {1}. Repo may be private.".format(frozen_commit_sha[:10], name)

View File

@ -1,5 +1,8 @@
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
from django.contrib.auth.models import User try:
from django.contrib.auth import get_user_model # Django 1.5
except ImportError:
from django_extensions.future_1_5 import get_user_model
from django.contrib.sessions.models import Session from django.contrib.sessions.models import Session
import re import re
@ -38,6 +41,7 @@ class Command(BaseCommand):
print('No user associated with session') print('No user associated with session')
return return
print("User id: %s" % uid) print("User id: %s" % uid)
User = get_user_model()
try: try:
user = User.objects.get(pk=uid) user = User.objects.get(pk=uid)
except User.DoesNotExist: except User.DoesNotExist:

View File

@ -92,10 +92,10 @@ Type 'yes' to continue, or 'no' to cancel: """ % (settings.DATABASE_NAME,))
if password is None: if password is None:
password = settings.DATABASE_PASSWORD password = settings.DATABASE_PASSWORD
if engine == 'sqlite3': if engine in ('sqlite3', 'spatialite'):
import os import os
try: try:
logging.info("Unlinking sqlite3 database") logging.info("Unlinking %s database" % engine)
os.unlink(settings.DATABASE_NAME) os.unlink(settings.DATABASE_NAME)
except OSError: except OSError:
pass pass

View File

@ -116,6 +116,8 @@ class Command(BaseCommand):
help='Specifies the directory from which to serve admin media.'), help='Specifies the directory from which to serve admin media.'),
make_option('--prof-path', dest='prof_path', default='/tmp', make_option('--prof-path', dest='prof_path', default='/tmp',
help='Specifies the directory which to save profile information in.'), help='Specifies the directory which to save profile information in.'),
make_option('--prof-file', dest='prof_file', default='{path}.{duration:06d}ms.{time}',
help='Set filename format, default if "{path}.{duration:06d}ms.{time}".'),
make_option('--nomedia', action='store_true', dest='no_media', default=False, make_option('--nomedia', action='store_true', dest='no_media', default=False,
help='Do not profile MEDIA_URL and ADMIN_MEDIA_URL'), help='Do not profile MEDIA_URL and ADMIN_MEDIA_URL'),
make_option('--use-cprofile', action='store_true', dest='use_cprofile', default=False, make_option('--use-cprofile', action='store_true', dest='use_cprofile', default=False,
@ -186,6 +188,11 @@ class Command(BaseCommand):
raise SystemExit("Kcachegrind compatible output format required cProfile from Python 2.5") raise SystemExit("Kcachegrind compatible output format required cProfile from Python 2.5")
prof_path = options.get('prof_path', '/tmp') prof_path = options.get('prof_path', '/tmp')
prof_file = options.get('prof_file', '{path}.{duration:06d}ms.{time}')
if not prof_file.format(path='1', duration=2, time=3):
prof_file = '{path}.{duration:06d}ms.{time}'
print("Filename format is wrong. Default format used: '{path}.{duration:06d}ms.{time}'.")
def get_exclude_paths(): def get_exclude_paths():
exclude_paths = [] exclude_paths = []
media_url = getattr(settings, 'MEDIA_URL', None) media_url = getattr(settings, 'MEDIA_URL', None)
@ -225,8 +232,8 @@ class Command(BaseCommand):
kg.output(open(profname, 'w')) kg.output(open(profname, 'w'))
elif USE_CPROFILE: elif USE_CPROFILE:
prof.dump_stats(profname) prof.dump_stats(profname)
profname2 = "%s.%06dms.%d.prof" % (path_name, elapms, time.time()) profname2 = prof_file.format(path=path_name, duration=int(elapms), time=int(time.time()))
profname2 = os.path.join(prof_path, profname2) profname2 = os.path.join(prof_path, "%s.prof" % profname2)
if not USE_CPROFILE: if not USE_CPROFILE:
prof.close() prof.close()
os.rename(profname, profname2) os.rename(profname, profname2)
@ -278,4 +285,3 @@ class Command(BaseCommand):
autoreload.main(inner_run) autoreload.main(inner_run)
else: else:
inner_run() inner_run()

View File

@ -3,15 +3,31 @@ from django.core.management.base import BaseCommand, CommandError
from django_extensions.management.utils import setup_logger, RedirectHandler from django_extensions.management.utils import setup_logger, RedirectHandler
from optparse import make_option from optparse import make_option
import os import os
import re
import socket
import sys import sys
import time import time
try: try:
from django.contrib.staticfiles.handlers import StaticFilesHandler if 'django.contrib.staticfiles' in settings.INSTALLED_APPS:
USE_STATICFILES = 'django.contrib.staticfiles' in settings.INSTALLED_APPS from django.contrib.staticfiles.handlers import StaticFilesHandler
USE_STATICFILES = True
elif 'staticfiles' in settings.INSTALLED_APPS:
from staticfiles.handlers import StaticFilesHandler # noqa
USE_STATICFILES = True
else:
USE_STATICFILES = False
except ImportError: except ImportError:
USE_STATICFILES = False USE_STATICFILES = False
naiveip_re = re.compile(r"""^(?:
(?P<addr>
(?P<ipv4>\d{1,3}(?:\.\d{1,3}){3}) | # IPv4 address
(?P<ipv6>\[[a-fA-F0-9:]+\]) | # IPv6 address
(?P<fqdn>[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*) # FQDN
):)?(?P<port>\d+)$""", re.X)
DEFAULT_PORT = "8000"
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,6 +36,8 @@ from django_extensions.management.technical_response import null_technical_500_r
class Command(BaseCommand): class Command(BaseCommand):
option_list = BaseCommand.option_list + ( option_list = BaseCommand.option_list + (
make_option('--ipv6', '-6', action='store_true', dest='use_ipv6', default=False,
help='Tells Django to use a IPv6 address.'),
make_option('--noreload', action='store_false', dest='use_reloader', default=True, make_option('--noreload', action='store_false', dest='use_reloader', default=True,
help='Tells Django to NOT use the auto-reloader.'), help='Tells Django to NOT use the auto-reloader.'),
make_option('--browser', action='store_true', dest='open_browser', make_option('--browser', action='store_true', dest='open_browser',
@ -103,33 +121,51 @@ class Command(BaseCommand):
from django.views import debug from django.views import debug
debug.technical_500_response = null_technical_500_response debug.technical_500_response = null_technical_500_response
if args: self.use_ipv6 = options.get('use_ipv6')
raise CommandError('Usage is runserver %s' % self.args) if self.use_ipv6 and not socket.has_ipv6:
raise CommandError('Your Python does not support IPv6.')
self._raw_ipv6 = False
if not addrport: if not addrport:
addr = ''
port = '8000'
else:
try: try:
addr, port = addrport.split(':') addrport = settings.RUNSERVERPLUS_SERVER_ADDRESS_PORT
except ValueError: except AttributeError:
addr, port = '', addrport pass
if not addr: if not addrport:
addr = '127.0.0.1' self.addr = ''
self.port = DEFAULT_PORT
if not port.isdigit(): else:
raise CommandError("%r is not a valid port number." % port) m = re.match(naiveip_re, addrport)
if m is None:
raise CommandError('"%s" is not a valid port number '
'or address:port pair.' % addrport)
self.addr, _ipv4, _ipv6, _fqdn, self.port = m.groups()
if not self.port.isdigit():
raise CommandError("%r is not a valid port number." %
self.port)
if self.addr:
if _ipv6:
self.addr = self.addr[1:-1]
self.use_ipv6 = True
self._raw_ipv6 = True
elif self.use_ipv6 and not _fqdn:
raise CommandError('"%s" is not a valid IPv6 address.'
% self.addr)
if not self.addr:
self.addr = '::1' if self.use_ipv6 else '127.0.0.1'
threaded = options.get('threaded', False) threaded = options.get('threaded', False)
use_reloader = options.get('use_reloader', True) use_reloader = options.get('use_reloader', True)
open_browser = options.get('open_browser', False) open_browser = options.get('open_browser', False)
cert_path = options.get("cert_path") cert_path = options.get("cert_path")
quit_command = (sys.platform == 'win32') and 'CTRL-BREAK' or 'CONTROL-C' quit_command = (sys.platform == 'win32') and 'CTRL-BREAK' or 'CONTROL-C'
bind_url = "http://%s:%s/" % (
self.addr if not self._raw_ipv6 else '[%s]' % self.addr, self.port)
def inner_run(): def inner_run():
print("Validating models...") print("Validating models...")
self.validate(display_num_errors=True) self.validate(display_num_errors=True)
print("\nDjango version %s, using settings %r" % (django.get_version(), settings.SETTINGS_MODULE)) print("\nDjango version %s, using settings %r" % (django.get_version(), settings.SETTINGS_MODULE))
print("Development server is running at http://%s:%s/" % (addr, port)) print("Development server is running at %s" % (bind_url,))
print("Using the Werkzeug debugger (http://werkzeug.pocoo.org/)") print("Using the Werkzeug debugger (http://werkzeug.pocoo.org/)")
print("Quit the server with %s." % quit_command) print("Quit the server with %s." % quit_command)
path = options.get('admin_media_path', '') path = options.get('admin_media_path', '')
@ -149,8 +185,7 @@ class Command(BaseCommand):
handler = StaticFilesHandler(handler) handler = StaticFilesHandler(handler)
if open_browser: if open_browser:
import webbrowser import webbrowser
url = "http://%s:%s/" % (addr, port) webbrowser.open(bind_url)
webbrowser.open(url)
if cert_path: if cert_path:
""" """
OpenSSL is needed for SSL support. OpenSSL is needed for SSL support.
@ -189,8 +224,8 @@ class Command(BaseCommand):
else: else:
ssl_context = None ssl_context = None
run_simple( run_simple(
addr, self.addr,
int(port), int(self.port),
DebuggedApplication(handler, True), DebuggedApplication(handler, True),
use_reloader=use_reloader, use_reloader=use_reloader,
use_debugger=True, use_debugger=True,

View File

@ -38,7 +38,11 @@ class Command(NoArgsCommand):
if not settings.DEBUG: if not settings.DEBUG:
raise CommandError('Only available in debug mode') raise CommandError('Only available in debug mode')
from django.contrib.auth.models import User, Group try:
from django.contrib.auth import get_user_model # Django 1.5
except ImportError:
from django_extensions.future_1_5 import get_user_model
from django.contrib.auth.models import Group
email = options.get('default_email', DEFAULT_FAKE_EMAIL) email = options.get('default_email', DEFAULT_FAKE_EMAIL)
include_regexp = options.get('include_regexp', None) include_regexp = options.get('include_regexp', None)
exclude_regexp = options.get('exclude_regexp', None) exclude_regexp = options.get('exclude_regexp', None)
@ -47,6 +51,7 @@ class Command(NoArgsCommand):
no_admin = options.get('no_admin', False) no_admin = options.get('no_admin', False)
no_staff = options.get('no_staff', False) no_staff = options.get('no_staff', False)
User = get_user_model()
users = User.objects.all() users = User.objects.all()
if no_admin: if no_admin:
users = users.exclude(is_superuser=True) users = users.exclude(is_superuser=True)

View File

@ -28,7 +28,11 @@ class Command(NoArgsCommand):
if not settings.DEBUG: if not settings.DEBUG:
raise CommandError('Only available in debug mode') raise CommandError('Only available in debug mode')
from django.contrib.auth.models import User try:
from django.contrib.auth import get_user_model # Django 1.5
except ImportError:
from django_extensions.future_1_5 import get_user_model
if options.get('prompt_passwd', False): if options.get('prompt_passwd', False):
from getpass import getpass from getpass import getpass
passwd = getpass('Password: ') passwd = getpass('Password: ')
@ -37,6 +41,7 @@ class Command(NoArgsCommand):
else: else:
passwd = options.get('default_passwd', DEFAULT_FAKE_PASSWORD) passwd = options.get('default_passwd', DEFAULT_FAKE_PASSWORD)
User = get_user_model()
user = User() user = User()
user.set_password(passwd) user.set_password(passwd)
count = User.objects.all().update(password=user.password) count = User.objects.all().update(password=user.password)

View File

@ -1,5 +1,6 @@
from optparse import make_option from optparse import make_option
import sys import sys
import socket
import django import django
from django.core.management.base import CommandError, BaseCommand from django.core.management.base import CommandError, BaseCommand
@ -57,6 +58,7 @@ The envisioned use case is something like this:
dbuser = settings.DATABASE_USER dbuser = settings.DATABASE_USER
dbpass = settings.DATABASE_PASSWORD dbpass = settings.DATABASE_PASSWORD
dbhost = settings.DATABASE_HOST dbhost = settings.DATABASE_HOST
dbclient = socket.gethostname()
# django settings file tells you that localhost should be specified by leaving # django settings file tells you that localhost should be specified by leaving
# the DATABASE_HOST blank # the DATABASE_HOST blank
@ -69,7 +71,7 @@ The envisioned use case is something like this:
""") """)
print("CREATE DATABASE %s CHARACTER SET utf8 COLLATE utf8_bin;" % dbname) print("CREATE DATABASE %s CHARACTER SET utf8 COLLATE utf8_bin;" % dbname)
print("GRANT ALL PRIVILEGES ON %s.* to '%s'@'%s' identified by '%s';" % ( print("GRANT ALL PRIVILEGES ON %s.* to '%s'@'%s' identified by '%s';" % (
dbname, dbuser, dbhost, dbpass dbname, dbuser, dbclient, dbpass
)) ))
elif engine == 'postgresql_psycopg2': elif engine == 'postgresql_psycopg2':
if options.get('drop'): if options.get('drop'):

View File

@ -228,7 +228,7 @@ class SQLDiff(object):
def strip_parameters(self, field_type): def strip_parameters(self, field_type):
if field_type and field_type != 'double precision': if field_type and field_type != 'double precision':
return field_type.split(" ")[0].split("(")[0] return field_type.split(" ")[0].split("(")[0].lower()
return field_type return field_type
def find_unique_missing_in_db(self, meta, table_indexes, table_name): def find_unique_missing_in_db(self, meta, table_indexes, table_name):
@ -289,14 +289,14 @@ class SQLDiff(object):
continue continue
description = db_fields[field.name] description = db_fields[field.name]
model_type = self.strip_parameters(self.get_field_model_type(field)) model_type = self.get_field_model_type(field)
db_type = self.strip_parameters(self.get_field_db_type(description, field)) db_type = self.get_field_db_type(description, field)
# use callback function if defined # use callback function if defined
if func: if func:
model_type, db_type = func(field, description, model_type, db_type) model_type, db_type = func(field, description, model_type, db_type)
if not model_type == db_type: if not self.strip_parameters(db_type) == self.strip_parameters(model_type):
self.add_difference('field-type-differ', table_name, field.name, model_type, db_type) self.add_difference('field-type-differ', table_name, field.name, model_type, db_type)
def find_field_parameter_differ(self, meta, table_description, table_name, func=None): def find_field_parameter_differ(self, meta, table_description, table_name, func=None):

View File

@ -210,9 +210,9 @@ def generate_dot(app_labels, **kwargs):
if skip_field(field): if skip_field(field):
continue continue
if isinstance(field, OneToOneField): if isinstance(field, OneToOneField):
add_relation(field, '[arrowhead=none, arrowtail=none]') add_relation(field, '[arrowhead=none, arrowtail=none, dir=both]')
elif isinstance(field, ForeignKey): elif isinstance(field, ForeignKey):
add_relation(field, '[arrowhead=none, arrowtail=dot]') add_relation(field, '[arrowhead=none, arrowtail=dot, dir=both]')
for field in appmodel._meta.local_many_to_many: for field in appmodel._meta.local_many_to_many:
if skip_field(field): if skip_field(field):
@ -240,7 +240,7 @@ def generate_dot(app_labels, **kwargs):
'type': "inheritance", 'type': "inheritance",
'name': "inheritance", 'name': "inheritance",
'label': l, 'label': l,
'arrows': '[arrowhead=empty, arrowtail=none]', 'arrows': '[arrowhead=empty, arrowtail=none, dir=both]',
'needs_node': True 'needs_node': True
} }
# TODO: seems as if abstract models aren't part of models.getModels, which is why they are printed by this without any attributes. # TODO: seems as if abstract models aren't part of models.getModels, which is why they are printed by this without any attributes.

View File

@ -69,7 +69,7 @@ class AutoSlugField(SlugField):
raise ValueError("missing 'populate_from' argument") raise ValueError("missing 'populate_from' argument")
else: else:
self._populate_from = populate_from self._populate_from = populate_from
self.separator = kwargs.pop('separator', u'-') self.separator = kwargs.pop('separator', six.u('-'))
self.overwrite = kwargs.pop('overwrite', False) self.overwrite = kwargs.pop('overwrite', False)
super(AutoSlugField, self).__init__(*args, **kwargs) super(AutoSlugField, self).__init__(*args, **kwargs)

View File

@ -1,6 +1,7 @@
from django.template import Library from django.template import Library
from django.utils.encoding import force_unicode from django.utils.encoding import force_unicode
import re import re
import six
register = Library() register = Library()
re_widont = re.compile(r'\s+(\S+\s*)$') re_widont = re.compile(r'\s+(\S+\s*)$')
@ -24,7 +25,7 @@ def widont(value, count=1):
NoEffect NoEffect
""" """
def replace(matchobj): def replace(matchobj):
return u'&nbsp;%s' % matchobj.group(1) return six.u('&nbsp;%s' % matchobj.group(1))
for i in range(count): for i in range(count):
value = re_widont.sub(replace, force_unicode(value)) value = re_widont.sub(replace, force_unicode(value))
return value return value
@ -48,7 +49,7 @@ def widont_html(value):
leading&nbsp;text <p>test me&nbsp;out</p> trailing&nbsp;text leading&nbsp;text <p>test me&nbsp;out</p> trailing&nbsp;text
""" """
def replace(matchobj): def replace(matchobj):
return u'%s&nbsp;%s%s' % matchobj.groups() return six.u('%s&nbsp;%s%s' % matchobj.groups())
return re_widont_html.sub(replace, force_unicode(value)) return re_widont_html.sub(replace, force_unicode(value))
register.filter(widont) register.filter(widont)

View File

@ -25,9 +25,13 @@ class JsonFieldTest(unittest.TestCase):
def testCharFieldCreate(self): def testCharFieldCreate(self):
j = TestModel.objects.create(a=6, j_field=dict(foo='bar')) j = TestModel.objects.create(a=6, j_field=dict(foo='bar'))
self.assertEquals(j.a, 6) self.assertEqual(j.a, 6)
def testDefault(self):
j = TestModel.objects.create(a=1)
self.assertEqual(j.j_field, {})
def testEmptyList(self): def testEmptyList(self):
j = TestModel.objects.create(a=6, j_field=[]) j = TestModel.objects.create(a=6, j_field=[])
self.assertTrue(isinstance(j.j_field, list)) self.assertTrue(isinstance(j.j_field, list))
self.assertEquals(j.j_field, []) self.assertEqual(j.j_field, [])

View File

@ -52,7 +52,7 @@ class DumpScriptTests(TestCase):
tmp_out = StringIO() tmp_out = StringIO()
call_command('dumpscript', 'tests', stdout=tmp_out) call_command('dumpscript', 'tests', stdout=tmp_out)
self.assertTrue('Mike' in tmp_out.getvalue()) # script should go to tmp_out self.assertTrue('Mike' in tmp_out.getvalue()) # script should go to tmp_out
self.assertEquals(0, len(sys.stdout.getvalue())) # there should not be any output to sys.stdout self.assertEqual(0, len(sys.stdout.getvalue())) # there should not be any output to sys.stdout
tmp_out.close() tmp_out.close()
#---------------------------------------------------------------------- #----------------------------------------------------------------------
@ -65,7 +65,7 @@ class DumpScriptTests(TestCase):
call_command('dumpscript', 'tests', stderr=tmp_err) call_command('dumpscript', 'tests', stderr=tmp_err)
self.assertTrue('Fred' in sys.stdout.getvalue()) # script should still go to stdout self.assertTrue('Fred' in sys.stdout.getvalue()) # script should still go to stdout
self.assertTrue('Name' in tmp_err.getvalue()) # error output should go to tmp_err self.assertTrue('Name' in tmp_err.getvalue()) # error output should go to tmp_err
self.assertEquals(0, len(sys.stderr.getvalue())) # there should not be any output to sys.stderr self.assertEqual(0, len(sys.stderr.getvalue())) # there should not be any output to sys.stderr
tmp_err.close() tmp_err.close()
#---------------------------------------------------------------------- #----------------------------------------------------------------------

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import sys import sys
import six
from django.test import TestCase from django.test import TestCase
from django.utils.unittest import skipIf from django.utils.unittest import skipIf
@ -14,21 +15,21 @@ except ImportError:
class TruncateLetterTests(TestCase): class TruncateLetterTests(TestCase):
def test_truncate_more_than_text_length(self): def test_truncate_more_than_text_length(self):
self.assertEquals(u"hello tests", truncate_letters("hello tests", 100)) self.assertEqual(six.u("hello tests"), truncate_letters("hello tests", 100))
def test_truncate_text(self): def test_truncate_text(self):
self.assertEquals(u"hello...", truncate_letters("hello tests", 5)) self.assertEqual(six.u("hello..."), truncate_letters("hello tests", 5))
def test_truncate_with_range(self): def test_truncate_with_range(self):
for i in range(10, -1, -1): for i in range(10, -1, -1):
self.assertEqual( self.assertEqual(
u'hello tests'[:i] + '...', six.u('hello tests'[:i]) + '...',
truncate_letters("hello tests", i) truncate_letters("hello tests", i)
) )
def test_with_non_ascii_characters(self): def test_with_non_ascii_characters(self):
self.assertEquals( self.assertEqual(
u'\u5ce0 (\u3068\u3046\u3052 t\u014dg...', six.u('\u5ce0 (\u3068\u3046\u3052 t\u014dg...'),
truncate_letters("峠 (とうげ tōge - mountain pass)", 10) truncate_letters("峠 (とうげ tōge - mountain pass)", 10)
) )
@ -37,7 +38,7 @@ class UUIDTests(TestCase):
@skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib') @skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib')
def test_uuid3(self): def test_uuid3(self):
# make a UUID using an MD5 hash of a namespace UUID and a name # make a UUID using an MD5 hash of a namespace UUID and a name
self.assertEquals( self.assertEqual(
uuid.UUID('6fa459ea-ee8a-3ca4-894e-db77e160355e'), uuid.UUID('6fa459ea-ee8a-3ca4-894e-db77e160355e'),
uuid.uuid3(uuid.NAMESPACE_DNS, 'python.org') uuid.uuid3(uuid.NAMESPACE_DNS, 'python.org')
) )
@ -45,7 +46,7 @@ class UUIDTests(TestCase):
@skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib') @skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib')
def test_uuid5(self): def test_uuid5(self):
# make a UUID using a SHA-1 hash of a namespace UUID and a name # make a UUID using a SHA-1 hash of a namespace UUID and a name
self.assertEquals( self.assertEqual(
uuid.UUID('886313e1-3b8a-5372-9b90-0c9aee199e5d'), uuid.UUID('886313e1-3b8a-5372-9b90-0c9aee199e5d'),
uuid.uuid5(uuid.NAMESPACE_DNS, 'python.org') uuid.uuid5(uuid.NAMESPACE_DNS, 'python.org')
) )
@ -55,21 +56,21 @@ class UUIDTests(TestCase):
# make a UUID from a string of hex digits (braces and hyphens ignored) # make a UUID from a string of hex digits (braces and hyphens ignored)
x = uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}') x = uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}')
# convert a UUID to a string of hex digits in standard form # convert a UUID to a string of hex digits in standard form
self.assertEquals('00010203-0405-0607-0809-0a0b0c0d0e0f', str(x)) self.assertEqual('00010203-0405-0607-0809-0a0b0c0d0e0f', str(x))
@skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib') @skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib')
def test_uuid_bytes(self): def test_uuid_bytes(self):
# make a UUID from a string of hex digits (braces and hyphens ignored) # make a UUID from a string of hex digits (braces and hyphens ignored)
x = uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}') x = uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}')
# get the raw 16 bytes of the UUID # get the raw 16 bytes of the UUID
self.assertEquals( self.assertEqual(
'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f', '\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f',
x.bytes x.bytes
) )
@skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib') @skipIf(sys.version_info >= (2, 5, 0), 'uuid already in stdlib')
def test_make_uuid_from_byte_string(self): def test_make_uuid_from_byte_string(self):
self.assertEquals( self.assertEqual(
uuid.UUID(bytes='\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f'), uuid.UUID(bytes='\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f'),
uuid.UUID('00010203-0405-0607-0809-0a0b0c0d0e0f') uuid.UUID('00010203-0405-0607-0809-0a0b0c0d0e0f')
) )

View File

@ -1,3 +1,4 @@
import six
from django.conf import settings from django.conf import settings
from django.core.management import call_command from django.core.management import call_command
from django.db.models import loading from django.db.models import loading
@ -36,20 +37,22 @@ class UUIDFieldTest(unittest.TestCase):
settings.INSTALLED_APPS = self.old_installed_apps settings.INSTALLED_APPS = self.old_installed_apps
def testUUIDFieldCreate(self): def testUUIDFieldCreate(self):
j = TestModel_field.objects.create(a=6, uuid_field=u'550e8400-e29b-41d4-a716-446655440000') j = TestModel_field.objects.create(a=6, uuid_field=six.u('550e8400-e29b-41d4-a716-446655440000'))
self.assertEquals(j.uuid_field, u'550e8400-e29b-41d4-a716-446655440000') self.assertEqual(j.uuid_field, six.u('550e8400-e29b-41d4-a716-446655440000'))
def testUUIDField_pkCreate(self): def testUUIDField_pkCreate(self):
j = TestModel_pk.objects.create(uuid_field=u'550e8400-e29b-41d4-a716-446655440000') j = TestModel_pk.objects.create(uuid_field=six.u('550e8400-e29b-41d4-a716-446655440000'))
self.assertEquals(j.uuid_field, u'550e8400-e29b-41d4-a716-446655440000') self.assertEqual(j.uuid_field, six.u('550e8400-e29b-41d4-a716-446655440000'))
self.assertEquals(j.pk, u'550e8400-e29b-41d4-a716-446655440000') self.assertEqual(j.pk, six.u('550e8400-e29b-41d4-a716-446655440000'))
def testUUIDField_pkAgregateCreate(self): def testUUIDField_pkAgregateCreate(self):
j = TestAgregateModel.objects.create(a=6) j = TestAgregateModel.objects.create(a=6, uuid_field=six.u('550e8400-e29b-41d4-a716-446655440001'))
self.assertEquals(j.a, 6) self.assertEqual(j.a, 6)
self.assertIsInstance(j.pk, six.string_types)
self.assertEqual(len(j.pk), 36)
def testUUIDFieldManyToManyCreate(self): def testUUIDFieldManyToManyCreate(self):
j = TestManyToManyModel.objects.create(uuid_field=u'550e8400-e29b-41d4-a716-446655440010') j = TestManyToManyModel.objects.create(uuid_field=six.u('550e8400-e29b-41d4-a716-446655440010'))
self.assertEquals(j.uuid_field, u'550e8400-e29b-41d4-a716-446655440010') self.assertEqual(j.uuid_field, six.u('550e8400-e29b-41d4-a716-446655440010'))
self.assertEquals(j.pk, u'550e8400-e29b-41d4-a716-446655440010') self.assertEqual(j.pk, six.u('550e8400-e29b-41d4-a716-446655440010'))

View File

@ -17,6 +17,7 @@ import sys
import gzip import gzip
from xml.dom.minidom import * # NOQA from xml.dom.minidom import * # NOQA
import re import re
import six
#Type dictionary translation types SQL -> Django #Type dictionary translation types SQL -> Django
tsd = { tsd = {
@ -75,7 +76,7 @@ def dia2django(archivo):
datos = ppal.getElementsByTagName("dia:diagram")[0].getElementsByTagName("dia:layer")[0].getElementsByTagName("dia:object") datos = ppal.getElementsByTagName("dia:diagram")[0].getElementsByTagName("dia:layer")[0].getElementsByTagName("dia:object")
clases = {} clases = {}
herit = [] herit = []
imports = u"" imports = six.u("")
for i in datos: for i in datos:
#Look for the classes #Look for the classes
if i.getAttribute("type") == "UML - Class": if i.getAttribute("type") == "UML - Class":
@ -165,7 +166,7 @@ def dia2django(archivo):
a = i.getElementsByTagName("dia:string") a = i.getElementsByTagName("dia:string")
for j in a: for j in a:
if len(j.childNodes[0].data[1:-1]): if len(j.childNodes[0].data[1:-1]):
imports += u"from %s.models import *" % j.childNodes[0].data[1:-1] imports += six.u("from %s.models import *" % j.childNodes[0].data[1:-1])
addparentstofks(herit, clases) addparentstofks(herit, clases)
#Ordering the appearance of classes #Ordering the appearance of classes

View File

@ -1,566 +0,0 @@
# flake8:noqa
r"""UUID objects (universally unique identifiers) according to RFC 4122.
This module provides immutable UUID objects (class UUID) and the functions
uuid1(), uuid3(), uuid4(), uuid5() for generating version 1, 3, 4, and 5
UUIDs as specified in RFC 4122.
If all you want is a unique ID, you should probably call uuid1() or uuid4().
Note that uuid1() may compromise privacy since it creates a UUID containing
the computer's network address. uuid4() creates a random UUID.
Typical usage:
>>> import uuid
# make a UUID based on the host ID and current time
>>> uuid.uuid1()
UUID('a8098c1a-f86e-11da-bd1a-00112444be1e')
# make a UUID using an MD5 hash of a namespace UUID and a name
>>> uuid.uuid3(uuid.NAMESPACE_DNS, 'python.org')
UUID('6fa459ea-ee8a-3ca4-894e-db77e160355e')
# make a random UUID
>>> uuid.uuid4()
UUID('16fd2706-8baf-433b-82eb-8c7fada847da')
# make a UUID using a SHA-1 hash of a namespace UUID and a name
>>> uuid.uuid5(uuid.NAMESPACE_DNS, 'python.org')
UUID('886313e1-3b8a-5372-9b90-0c9aee199e5d')
# make a UUID from a string of hex digits (braces and hyphens ignored)
>>> x = uuid.UUID('{00010203-0405-0607-0809-0a0b0c0d0e0f}')
# convert a UUID to a string of hex digits in standard form
>>> str(x)
'00010203-0405-0607-0809-0a0b0c0d0e0f'
# get the raw 16 bytes of the UUID
>>> x.bytes
'\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f'
# make a UUID from a 16-byte string
>>> uuid.UUID(bytes=x.bytes)
UUID('00010203-0405-0607-0809-0a0b0c0d0e0f')
"""
__author__ = 'Ka-Ping Yee <ping@zesty.ca>'
RESERVED_NCS, RFC_4122, RESERVED_MICROSOFT, RESERVED_FUTURE = [
'reserved for NCS compatibility', 'specified in RFC 4122',
'reserved for Microsoft compatibility', 'reserved for future definition'
]
class UUID(object):
"""Instances of the UUID class represent UUIDs as specified in RFC 4122.
UUID objects are immutable, hashable, and usable as dictionary keys.
Converting a UUID to a string with str() yields something in the form
'12345678-1234-1234-1234-123456789abc'. The UUID constructor accepts
five possible forms: a similar string of hexadecimal digits, or a tuple
of six integer fields (with 32-bit, 16-bit, 16-bit, 8-bit, 8-bit, and
48-bit values respectively) as an argument named 'fields', or a string
of 16 bytes (with all the integer fields in big-endian order) as an
argument named 'bytes', or a string of 16 bytes (with the first three
fields in little-endian order) as an argument named 'bytes_le', or a
single 128-bit integer as an argument named 'int'.
UUIDs have these read-only attributes:
bytes the UUID as a 16-byte string (containing the six
integer fields in big-endian byte order)
bytes_le the UUID as a 16-byte string (with time_low, time_mid,
and time_hi_version in little-endian byte order)
fields a tuple of the six integer fields of the UUID,
which are also available as six individual attributes
and two derived attributes:
time_low the first 32 bits of the UUID
time_mid the next 16 bits of the UUID
time_hi_version the next 16 bits of the UUID
clock_seq_hi_variant the next 8 bits of the UUID
clock_seq_low the next 8 bits of the UUID
node the last 48 bits of the UUID
time the 60-bit timestamp
clock_seq the 14-bit sequence number
hex the UUID as a 32-character hexadecimal string
int the UUID as a 128-bit integer
urn the UUID as a URN as specified in RFC 4122
variant the UUID variant (one of the constants RESERVED_NCS,
RFC_4122, RESERVED_MICROSOFT, or RESERVED_FUTURE)
version the UUID version number (1 through 5, meaningful only
when the variant is RFC_4122)
"""
def __init__(self, hex=None, bytes=None, bytes_le=None, fields=None, int=None, version=None):
r"""Create a UUID from either a string of 32 hexadecimal digits,
a string of 16 bytes as the 'bytes' argument, a string of 16 bytes
in little-endian order as the 'bytes_le' argument, a tuple of six
integers (32-bit time_low, 16-bit time_mid, 16-bit time_hi_version,
8-bit clock_seq_hi_variant, 8-bit clock_seq_low, 48-bit node) as
the 'fields' argument, or a single 128-bit integer as the 'int'
argument. When a string of hex digits is given, curly braces,
hyphens, and a URN prefix are all optional. For example, these
expressions all yield the same UUID:
UUID('{12345678-1234-5678-1234-567812345678}')
UUID('12345678123456781234567812345678')
UUID('urn:uuid:12345678-1234-5678-1234-567812345678')
UUID(bytes='\x12\x34\x56\x78'*4)
UUID(bytes_le='\x78\x56\x34\x12\x34\x12\x78\x56' +
'\x12\x34\x56\x78\x12\x34\x56\x78')
UUID(fields=(0x12345678, 0x1234, 0x5678, 0x12, 0x34, 0x567812345678))
UUID(int=0x12345678123456781234567812345678)
Exactly one of 'hex', 'bytes', 'bytes_le', 'fields', or 'int' must
be given. The 'version' argument is optional; if given, the resulting
UUID will have its variant and version set according to RFC 4122,
overriding the given 'hex', 'bytes', 'bytes_le', 'fields', or 'int'.
"""
if [hex, bytes, bytes_le, fields, int].count(None) != 4:
raise TypeError('need one of hex, bytes, bytes_le, fields, or int')
if hex is not None:
hex = hex.replace('urn:', '').replace('uuid:', '')
hex = hex.strip('{}').replace('-', '')
if len(hex) != 32:
raise ValueError('badly formed hexadecimal UUID string')
int = long(hex, 16)
if bytes_le is not None:
if len(bytes_le) != 16:
raise ValueError('bytes_le is not a 16-char string')
bytes = (bytes_le[3] + bytes_le[2] + bytes_le[1] + bytes_le[0] +
bytes_le[5] + bytes_le[4] + bytes_le[7] + bytes_le[6] +
bytes_le[8:])
if bytes is not None:
if len(bytes) != 16:
raise ValueError('bytes is not a 16-char string')
int = long(('%02x' * 16) % tuple(map(ord, bytes)), 16)
if fields is not None:
if len(fields) != 6:
raise ValueError('fields is not a 6-tuple')
(time_low, time_mid, time_hi_version,
clock_seq_hi_variant, clock_seq_low, node) = fields
if not 0 <= time_low < 1 << 32L:
raise ValueError('field 1 out of range (need a 32-bit value)')
if not 0 <= time_mid < 1 << 16L:
raise ValueError('field 2 out of range (need a 16-bit value)')
if not 0 <= time_hi_version < 1 << 16L:
raise ValueError('field 3 out of range (need a 16-bit value)')
if not 0 <= clock_seq_hi_variant < 1 << 8L:
raise ValueError('field 4 out of range (need an 8-bit value)')
if not 0 <= clock_seq_low < 1 << 8L:
raise ValueError('field 5 out of range (need an 8-bit value)')
if not 0 <= node < 1 << 48L:
raise ValueError('field 6 out of range (need a 48-bit value)')
clock_seq = (clock_seq_hi_variant << 8L) | clock_seq_low
int = ((time_low << 96L) | (time_mid << 80L) |
(time_hi_version << 64L) | (clock_seq << 48L) | node)
if int is not None:
if not 0 <= int < 1 << 128L:
raise ValueError('int is out of range (need a 128-bit value)')
if version is not None:
if not 1 <= version <= 5:
raise ValueError('illegal version number')
# Set the variant to RFC 4122.
int &= ~(0xc000 << 48L)
int |= 0x8000 << 48L
# Set the version number.
int &= ~(0xf000 << 64L)
int |= version << 76L
self.__dict__['int'] = int
def __cmp__(self, other):
if isinstance(other, UUID):
return cmp(self.int, other.int)
return NotImplemented
def __hash__(self):
return hash(self.int)
def __int__(self):
return self.int
def __repr__(self):
return 'UUID(%r)' % str(self)
def __setattr__(self, name, value):
raise TypeError('UUID objects are immutable')
def __str__(self):
hex = '%032x' % self.int
return '%s-%s-%s-%s-%s' % (
hex[:8], hex[8:12], hex[12:16], hex[16:20], hex[20:])
def get_bytes(self):
bytes = ''
for shift in range(0, 128, 8):
bytes = chr((self.int >> shift) & 0xff) + bytes
return bytes
bytes = property(get_bytes)
def get_bytes_le(self):
bytes = self.bytes
return (bytes[3] + bytes[2] + bytes[1] + bytes[0] +
bytes[5] + bytes[4] + bytes[7] + bytes[6] + bytes[8:])
bytes_le = property(get_bytes_le)
def get_fields(self):
return (self.time_low, self.time_mid, self.time_hi_version,
self.clock_seq_hi_variant, self.clock_seq_low, self.node)
fields = property(get_fields)
def get_time_low(self):
return self.int >> 96L
time_low = property(get_time_low)
def get_time_mid(self):
return (self.int >> 80L) & 0xffff
time_mid = property(get_time_mid)
def get_time_hi_version(self):
return (self.int >> 64L) & 0xffff
time_hi_version = property(get_time_hi_version)
def get_clock_seq_hi_variant(self):
return (self.int >> 56L) & 0xff
clock_seq_hi_variant = property(get_clock_seq_hi_variant)
def get_clock_seq_low(self):
return (self.int >> 48L) & 0xff
clock_seq_low = property(get_clock_seq_low)
def get_time(self):
return (((self.time_hi_version & 0x0fffL) << 48L) |
(self.time_mid << 32L) | self.time_low)
time = property(get_time)
def get_clock_seq(self):
return (((self.clock_seq_hi_variant & 0x3fL) << 8L) |
self.clock_seq_low)
clock_seq = property(get_clock_seq)
def get_node(self):
return self.int & 0xffffffffffff
node = property(get_node)
def get_hex(self):
return '%032x' % self.int
hex = property(get_hex)
def get_urn(self):
return 'urn:uuid:' + str(self)
urn = property(get_urn)
def get_variant(self):
if not self.int & (0x8000 << 48L):
return RESERVED_NCS
elif not self.int & (0x4000 << 48L):
return RFC_4122
elif not self.int & (0x2000 << 48L):
return RESERVED_MICROSOFT
else:
return RESERVED_FUTURE
variant = property(get_variant)
def get_version(self):
# The version bits are only meaningful for RFC 4122 UUIDs.
if self.variant == RFC_4122:
return int((self.int >> 76L) & 0xf)
version = property(get_version)
def _find_mac(command, args, hw_identifiers, get_index):
import os
for dir in ['', '/sbin/', '/usr/sbin']:
executable = os.path.join(dir, command)
if not os.path.exists(executable):
continue
try:
# LC_ALL to get English output, 2>/dev/null to
# prevent output on stderr
cmd = 'LC_ALL=C %s %s 2>/dev/null' % (executable, args)
pipe = os.popen(cmd)
except IOError:
continue
for line in pipe:
words = line.lower().split()
for i in range(len(words)):
if words[i] in hw_identifiers:
return int(words[get_index(i)].replace(':', ''), 16)
return None
def _ifconfig_getnode():
"""Get the hardware address on Unix by running ifconfig."""
# This works on Linux ('' or '-a'), Tru64 ('-av'), but not all Unixes.
for args in ('', '-a', '-av'):
mac = _find_mac('ifconfig', args, ['hwaddr', 'ether'], lambda i: i + 1)
if mac:
return mac
import socket
ip_addr = socket.gethostbyname(socket.gethostname())
# Try getting the MAC addr from arp based on our IP address (Solaris).
mac = _find_mac('arp', '-an', [ip_addr], lambda i: -1)
if mac:
return mac
# This might work on HP-UX.
mac = _find_mac('lanscan', '-ai', ['lan0'], lambda i: 0)
if mac:
return mac
return None
def _ipconfig_getnode():
"""Get the hardware address on Windows by running ipconfig.exe."""
import os
import re
dirs = ['', r'c:\windows\system32', r'c:\winnt\system32']
try:
import ctypes
buffer = ctypes.create_string_buffer(300)
ctypes.windll.kernel32.GetSystemDirectoryA(buffer, 300)
dirs.insert(0, buffer.value.decode('mbcs'))
except:
pass
for dir in dirs:
try:
pipe = os.popen(os.path.join(dir, 'ipconfig') + ' /all')
except IOError:
continue
for line in pipe:
value = line.split(':')[-1].strip().lower()
if re.match('([0-9a-f][0-9a-f]-){5}[0-9a-f][0-9a-f]', value):
return int(value.replace('-', ''), 16)
def _netbios_getnode():
"""Get the hardware address on Windows using NetBIOS calls.
See http://support.microsoft.com/kb/118623 for details."""
import win32wnet
import netbios
ncb = netbios.NCB()
ncb.Command = netbios.NCBENUM
ncb.Buffer = adapters = netbios.LANA_ENUM()
adapters._pack()
if win32wnet.Netbios(ncb) != 0:
return
adapters._unpack()
for i in range(adapters.length):
ncb.Reset()
ncb.Command = netbios.NCBRESET
ncb.Lana_num = ord(adapters.lana[i])
if win32wnet.Netbios(ncb) != 0:
continue
ncb.Reset()
ncb.Command = netbios.NCBASTAT
ncb.Lana_num = ord(adapters.lana[i])
ncb.Callname = '*'.ljust(16)
ncb.Buffer = status = netbios.ADAPTER_STATUS()
if win32wnet.Netbios(ncb) != 0:
continue
status._unpack()
bytes = map(ord, status.adapter_address)
return ((bytes[0] << 40L) + (bytes[1] << 32L) + (bytes[2] << 24L) +
(bytes[3] << 16L) + (bytes[4] << 8L) + bytes[5])
# Thanks to Thomas Heller for ctypes and for his help with its use here.
# If ctypes is available, use it to find system routines for UUID generation.
_uuid_generate_random = _uuid_generate_time = _UuidCreate = None
try:
import ctypes
import ctypes.util
_buffer = ctypes.create_string_buffer(16)
# The uuid_generate_* routines are provided by libuuid on at least
# Linux and FreeBSD, and provided by libc on Mac OS X.
for libname in ['uuid', 'c']:
try:
lib = ctypes.CDLL(ctypes.util.find_library(libname))
except:
continue
if hasattr(lib, 'uuid_generate_random'):
_uuid_generate_random = lib.uuid_generate_random
if hasattr(lib, 'uuid_generate_time'):
_uuid_generate_time = lib.uuid_generate_time
# On Windows prior to 2000, UuidCreate gives a UUID containing the
# hardware address. On Windows 2000 and later, UuidCreate makes a
# random UUID and UuidCreateSequential gives a UUID containing the
# hardware address. These routines are provided by the RPC runtime.
# NOTE: at least on Tim's WinXP Pro SP2 desktop box, while the last
# 6 bytes returned by UuidCreateSequential are fixed, they don't appear
# to bear any relationship to the MAC address of any network device
# on the box.
try:
lib = ctypes.windll.rpcrt4
except:
lib = None
_UuidCreate = getattr(lib, 'UuidCreateSequential',
getattr(lib, 'UuidCreate', None))
except:
pass
def _unixdll_getnode():
"""Get the hardware address on Unix using ctypes."""
_uuid_generate_time(_buffer)
return UUID(bytes=_buffer.raw).node
def _windll_getnode():
"""Get the hardware address on Windows using ctypes."""
if _UuidCreate(_buffer) == 0:
return UUID(bytes=_buffer.raw).node
def _random_getnode():
"""Get a random node ID, with eighth bit set as suggested by RFC 4122."""
import random
return random.randrange(0, 1 << 48L) | 0x010000000000L
_node = None
def getnode():
"""Get the hardware address as a 48-bit positive integer.
The first time this runs, it may launch a separate program, which could
be quite slow. If all attempts to obtain the hardware address fail, we
choose a random 48-bit number with its eighth bit set to 1 as recommended
in RFC 4122.
"""
global _node
if _node is not None:
return _node
import sys
if sys.platform == 'win32':
getters = [_windll_getnode, _netbios_getnode, _ipconfig_getnode]
else:
getters = [_unixdll_getnode, _ifconfig_getnode]
for getter in getters + [_random_getnode]:
try:
_node = getter()
except:
continue
if _node is not None:
return _node
_last_timestamp = None
def uuid1(node=None, clock_seq=None):
"""Generate a UUID from a host ID, sequence number, and the current time.
If 'node' is not given, getnode() is used to obtain the hardware
address. If 'clock_seq' is given, it is used as the sequence number;
otherwise a random 14-bit sequence number is chosen."""
# When the system provides a version-1 UUID generator, use it (but don't
# use UuidCreate here because its UUIDs don't conform to RFC 4122).
if _uuid_generate_time and node is clock_seq is None:
_uuid_generate_time(_buffer)
return UUID(bytes=_buffer.raw)
global _last_timestamp
import time
nanoseconds = int(time.time() * 1e9)
# 0x01b21dd213814000 is the number of 100-ns intervals between the
# UUID epoch 1582-10-15 00:00:00 and the Unix epoch 1970-01-01 00:00:00.
timestamp = int(nanoseconds / 100) + 0x01b21dd213814000L
if timestamp <= _last_timestamp:
timestamp = _last_timestamp + 1
_last_timestamp = timestamp
if clock_seq is None:
import random
clock_seq = random.randrange(1 << 14L) # instead of stable storage
time_low = timestamp & 0xffffffffL
time_mid = (timestamp >> 32L) & 0xffffL
time_hi_version = (timestamp >> 48L) & 0x0fffL
clock_seq_low = clock_seq & 0xffL
clock_seq_hi_variant = (clock_seq >> 8L) & 0x3fL
if node is None:
node = getnode()
return UUID(fields=(time_low, time_mid, time_hi_version,
clock_seq_hi_variant, clock_seq_low, node), version=1)
def uuid3(namespace, name):
"""Generate a UUID from the MD5 hash of a namespace UUID and a name."""
try:
import hashlib
md5 = hashlib.md5
except ImportError:
from md5 import md5 # NOQA
hash = md5(namespace.bytes + name).digest()
return UUID(bytes=hash[:16], version=3)
def uuid4():
"""Generate a random UUID."""
# When the system provides a version-4 UUID generator, use it.
if _uuid_generate_random:
_uuid_generate_random(_buffer)
return UUID(bytes=_buffer.raw)
# Otherwise, get randomness from urandom or the 'random' module.
try:
import os
return UUID(bytes=os.urandom(16), version=4)
except:
import random
bytes = [chr(random.randrange(256)) for i in range(16)]
return UUID(bytes=bytes, version=4)
def uuid5(namespace, name):
"""Generate a UUID from the SHA-1 hash of a namespace UUID and a name."""
try:
import hashlib
sha = hashlib.sha1
except ImportError:
from sha import sha # NOQA
hash = sha(namespace.bytes + name).digest()
return UUID(bytes=hash[:16], version=5)
# The following standard UUIDs are for use with uuid3() or uuid5().
NAMESPACE_DNS = UUID('6ba7b810-9dad-11d1-80b4-00c04fd430c8')
NAMESPACE_URL = UUID('6ba7b811-9dad-11d1-80b4-00c04fd430c8')
NAMESPACE_OID = UUID('6ba7b812-9dad-11d1-80b4-00c04fd430c8')
NAMESPACE_X500 = UUID('6ba7b814-9dad-11d1-80b4-00c04fd430c8')

View File

@ -5,7 +5,7 @@ from __future__ import absolute_import
import os import os
VERSION = (3, 0, 17) VERSION = (3, 0, 21)
__version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:]) __version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:])
__author__ = 'Ask Solem' __author__ = 'Ask Solem'
__contact__ = 'ask@celeryproject.org' __contact__ = 'ask@celeryproject.org'

View File

@ -8,7 +8,6 @@ from django.contrib.admin import helpers
from django.contrib.admin.views import main as main_views from django.contrib.admin.views import main as main_views
from django.shortcuts import render_to_response from django.shortcuts import render_to_response
from django.template import RequestContext from django.template import RequestContext
from django.utils.encoding import force_unicode
from django.utils.html import escape from django.utils.html import escape
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -22,6 +21,11 @@ from .models import (TaskState, WorkerState,
PeriodicTask, IntervalSchedule, CrontabSchedule) PeriodicTask, IntervalSchedule, CrontabSchedule)
from .humanize import naturaldate from .humanize import naturaldate
try:
from django.utils.encoding import force_text
except ImportError:
from django.utils.encoding import force_unicode as force_text
TASK_STATE_COLORS = {states.SUCCESS: 'green', TASK_STATE_COLORS = {states.SUCCESS: 'green',
states.FAILURE: 'red', states.FAILURE: 'red',
@ -175,7 +179,7 @@ class TaskMonitor(ModelMonitor):
context = { context = {
'title': _('Rate limit selection'), 'title': _('Rate limit selection'),
'queryset': queryset, 'queryset': queryset,
'object_name': force_unicode(opts.verbose_name), 'object_name': force_text(opts.verbose_name),
'action_checkbox_name': helpers.ACTION_CHECKBOX_NAME, 'action_checkbox_name': helpers.ACTION_CHECKBOX_NAME,
'opts': opts, 'opts': opts,
'app_label': app_label, 'app_label': app_label,

View File

@ -20,7 +20,11 @@ from zlib import compress, decompress
from celery.utils.serialization import pickle from celery.utils.serialization import pickle
from django.db import models from django.db import models
from django.utils.encoding import force_unicode
try:
from django.utils.encoding import force_text
except ImportError:
from django.utils.encoding import force_unicode as force_text
DEFAULT_PROTOCOL = 2 DEFAULT_PROTOCOL = 2
@ -77,7 +81,7 @@ class PickledObjectField(models.Field):
def get_db_prep_value(self, value, **kwargs): def get_db_prep_value(self, value, **kwargs):
if value is not None and not isinstance(value, PickledObject): if value is not None and not isinstance(value, PickledObject):
return force_unicode(encode(value, self.compress, self.protocol)) return force_text(encode(value, self.compress, self.protocol))
return value return value
def value_to_string(self, obj): def value_to_string(self, obj):

View File

@ -68,12 +68,12 @@ class ModelEntry(ScheduleEntry):
def _default_now(self): def _default_now(self):
return self.app.now() return self.app.now()
def next(self): def __next__(self):
self.model.last_run_at = self.app.now() self.model.last_run_at = self.app.now()
self.model.total_run_count += 1 self.model.total_run_count += 1
self.model.no_changes = True self.model.no_changes = True
return self.__class__(self.model) return self.__class__(self.model)
__next__ = next # for 2to3 next = __next__ # for 2to3
def save(self): def save(self):
# Object may not be synchronized, so only # Object may not be synchronized, so only

View File

@ -10,7 +10,7 @@ from django.conf import settings
from celery import states from celery import states
from celery.events.state import Task from celery.events.state import Task
from celery.events.snapshot import Polaroid from celery.events.snapshot import Polaroid
from celery.utils.timeutils import maybe_iso8601, timezone from celery.utils.timeutils import maybe_iso8601
from .models import WorkerState, TaskState from .models import WorkerState, TaskState
from .utils import maybe_make_aware from .utils import maybe_make_aware
@ -31,7 +31,7 @@ NOT_SAVED_ATTRIBUTES = frozenset(['name', 'args', 'kwargs', 'eta'])
def aware_tstamp(secs): def aware_tstamp(secs):
"""Event timestamps uses the local timezone.""" """Event timestamps uses the local timezone."""
return timezone.to_local_fallback(datetime.fromtimestamp(secs)) return maybe_make_aware(datetime.fromtimestamp(secs))
class Camera(Polaroid): class Camera(Polaroid):

View File

@ -54,7 +54,8 @@ try:
def make_aware(value): def make_aware(value):
if getattr(settings, 'USE_TZ', False): if getattr(settings, 'USE_TZ', False):
# naive datetimes are assumed to be in UTC. # naive datetimes are assumed to be in UTC.
value = timezone.make_aware(value, timezone.utc) if timezone.is_naive(value):
value = timezone.make_aware(value, timezone.utc)
# then convert to the Django configured timezone. # then convert to the Django configured timezone.
default_tz = timezone.get_default_timezone() default_tz = timezone.get_default_timezone()
value = timezone.localtime(value, default_tz) value = timezone.localtime(value, default_tz)

View File

@ -1,7 +1,7 @@
"""Messaging Framework for Python""" """Messaging Framework for Python"""
from __future__ import absolute_import from __future__ import absolute_import
VERSION = (2, 5, 10) VERSION = (2, 5, 14)
__version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:]) __version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:])
__author__ = 'Ask Solem' __author__ = 'Ask Solem'
__contact__ = 'ask@celeryproject.org' __contact__ = 'ask@celeryproject.org'

View File

@ -1,6 +1,6 @@
""" """
kombu.compression kombu.abstract
================= ==============
Object utilities. Object utilities.

View File

@ -151,8 +151,9 @@ class Connection(object):
password=None, virtual_host=None, port=None, insist=False, password=None, virtual_host=None, port=None, insist=False,
ssl=False, transport=None, connect_timeout=5, ssl=False, transport=None, connect_timeout=5,
transport_options=None, login_method=None, uri_prefix=None, transport_options=None, login_method=None, uri_prefix=None,
heartbeat=0, failover_strategy='round-robin', **kwargs): heartbeat=0, failover_strategy='round-robin',
alt = [] alternates=None, **kwargs):
alt = [] if alternates is None else alternates
# have to spell the args out, just to get nice docstrings :( # have to spell the args out, just to get nice docstrings :(
params = self._initial_params = { params = self._initial_params = {
'hostname': hostname, 'userid': userid, 'hostname': hostname, 'userid': userid,
@ -328,6 +329,29 @@ class Connection(object):
self._debug('closed') self._debug('closed')
self._closed = True self._closed = True
def collect(self, socket_timeout=None):
# amqp requires communication to close, we don't need that just
# to clear out references, Transport._collect can also be implemented
# by other transports that want fast after fork
try:
gc_transport = self._transport._collect
except AttributeError:
_timeo = socket.getdefaulttimeout()
socket.setdefaulttimeout(socket_timeout)
try:
self._close()
except socket.timeout:
pass
finally:
socket.setdefaulttimeout(_timeo)
else:
gc_transport(self._connection)
if self._transport:
self._transport.client = None
self._transport = None
self.declared_entities.clear()
self._connection = None
def release(self): def release(self):
"""Close the connection (if open).""" """Close the connection (if open)."""
self._close() self._close()
@ -522,12 +546,9 @@ class Connection(object):
transport_cls = RESOLVE_ALIASES.get(transport_cls, transport_cls) transport_cls = RESOLVE_ALIASES.get(transport_cls, transport_cls)
D = self.transport.default_connection_params D = self.transport.default_connection_params
if self.alt: hostname = self.hostname or D.get('hostname')
hostname = ";".join(self.alt) if self.uri_prefix:
else: hostname = '%s+%s' % (self.uri_prefix, hostname)
hostname = self.hostname or D.get('hostname')
if self.uri_prefix:
hostname = '%s+%s' % (self.uri_prefix, hostname)
info = (('hostname', hostname), info = (('hostname', hostname),
('userid', self.userid or D.get('userid')), ('userid', self.userid or D.get('userid')),
@ -542,6 +563,10 @@ class Connection(object):
('login_method', self.login_method or D.get('login_method')), ('login_method', self.login_method or D.get('login_method')),
('uri_prefix', self.uri_prefix), ('uri_prefix', self.uri_prefix),
('heartbeat', self.heartbeat)) ('heartbeat', self.heartbeat))
if self.alt:
info += (('alternates', self.alt),)
return info return info
def info(self): def info(self):
@ -910,6 +935,9 @@ class Resource(object):
else: else:
self.close_resource(resource) self.close_resource(resource)
def collect_resource(self, resource):
pass
def force_close_all(self): def force_close_all(self):
"""Closes and removes all resources in the pool (also those in use). """Closes and removes all resources in the pool (also those in use).
@ -919,32 +947,27 @@ class Resource(object):
""" """
dirty = self._dirty dirty = self._dirty
resource = self._resource resource = self._resource
while 1: while 1: # - acquired
try: try:
dres = dirty.pop() dres = dirty.pop()
except KeyError: except KeyError:
break break
try: try:
self.close_resource(dres) self.collect_resource(dres)
except AttributeError: # Issue #78 except AttributeError: # Issue #78
pass pass
while 1: # - available
mutex = getattr(resource, 'mutex', None) # deque supports '.clear', but lists do not, so for that
if mutex: # reason we use pop here, so that the underlying object can
mutex.acquire() # be any object supporting '.pop' and '.append'.
try: try:
while 1: res = resource.queue.pop()
try: except IndexError:
res = resource.queue.pop() break
except IndexError: try:
break self.collect_resource(res)
try: except AttributeError:
self.close_resource(res) pass # Issue #78
except AttributeError:
pass # Issue #78
finally:
if mutex: # pragma: no cover
mutex.release()
if os.environ.get('KOMBU_DEBUG_POOL'): # pragma: no cover if os.environ.get('KOMBU_DEBUG_POOL'): # pragma: no cover
_orig_acquire = acquire _orig_acquire = acquire
@ -993,6 +1016,9 @@ class ConnectionPool(Resource):
def close_resource(self, resource): def close_resource(self, resource):
resource._close() resource._close()
def collect_resource(self, resource, socket_timeout=0.1):
return resource.collect(socket_timeout)
@contextmanager @contextmanager
def acquire_channel(self, block=False): def acquire_channel(self, block=False):
with self.acquire(block=block) as connection: with self.acquire(block=block) as connection:

View File

@ -123,6 +123,7 @@ class Exchange(MaybeChannelBound):
type = 'direct' type = 'direct'
durable = True durable = True
auto_delete = False auto_delete = False
passive = False
delivery_mode = PERSISTENT_DELIVERY_MODE delivery_mode = PERSISTENT_DELIVERY_MODE
attrs = ( attrs = (
@ -130,6 +131,7 @@ class Exchange(MaybeChannelBound):
('type', None), ('type', None),
('arguments', None), ('arguments', None),
('durable', bool), ('durable', bool),
('passive', bool),
('auto_delete', bool), ('auto_delete', bool),
('delivery_mode', lambda m: DELIVERY_MODES.get(m) or m), ('delivery_mode', lambda m: DELIVERY_MODES.get(m) or m),
) )
@ -143,7 +145,7 @@ class Exchange(MaybeChannelBound):
def __hash__(self): def __hash__(self):
return hash('E|%s' % (self.name, )) return hash('E|%s' % (self.name, ))
def declare(self, nowait=False, passive=False): def declare(self, nowait=False, passive=None):
"""Declare the exchange. """Declare the exchange.
Creates the exchange on the broker. Creates the exchange on the broker.
@ -152,6 +154,7 @@ class Exchange(MaybeChannelBound):
response will not be waited for. Default is :const:`False`. response will not be waited for. Default is :const:`False`.
""" """
passive = self.passive if passive is None else passive
if self.name: if self.name:
return self.channel.exchange_declare( return self.channel.exchange_declare(
exchange=self.name, type=self.type, durable=self.durable, exchange=self.name, type=self.type, durable=self.durable,
@ -489,7 +492,7 @@ class Queue(MaybeChannelBound):
self.exchange.declare(nowait) self.exchange.declare(nowait)
self.queue_declare(nowait, passive=False) self.queue_declare(nowait, passive=False)
if self.exchange is not None: if self.exchange and self.exchange.name:
self.queue_bind(nowait) self.queue_bind(nowait)
# - declare extra/multi-bindings. # - declare extra/multi-bindings.
@ -541,8 +544,8 @@ class Queue(MaybeChannelBound):
Returns the message instance if a message was available, Returns the message instance if a message was available,
or :const:`None` otherwise. or :const:`None` otherwise.
:keyword no_ack: If set messages received does not have to :keyword no_ack: If enabled the broker will automatically
be acknowledged. ack messages.
This method provides direct access to the messages in a This method provides direct access to the messages in a
queue using a synchronous dialogue, designed for queue using a synchronous dialogue, designed for
@ -575,8 +578,8 @@ class Queue(MaybeChannelBound):
can use the same consumer tags. If this field is empty can use the same consumer tags. If this field is empty
the server will generate a unique tag. the server will generate a unique tag.
:keyword no_ack: If set messages received does not have to :keyword no_ack: If enabled the broker will automatically ack
be acknowledged. messages.
:keyword nowait: Do not wait for a reply. :keyword nowait: Do not wait for a reply.

View File

@ -276,8 +276,12 @@ class Consumer(object):
#: consume from. #: consume from.
queues = None queues = None
#: Flag for message acknowledgment disabled/enabled. #: Flag for automatic message acknowledgment.
#: Enabled by default. #: If enabled the messages are automatically acknowledged by the
#: broker. This can increase performance but means that you
#: have no control of when the message is removed.
#:
#: Disabled by default.
no_ack = None no_ack = None
#: By default all entities will be declared at instantiation, if you #: By default all entities will be declared at instantiation, if you
@ -399,6 +403,12 @@ class Consumer(object):
pass pass
def add_queue(self, queue): def add_queue(self, queue):
"""Add a queue to the list of queues to consume from.
This will not start consuming from the queue,
for that you will have to call :meth:`consume` after.
"""
queue = queue(self.channel) queue = queue(self.channel)
if self.auto_declare: if self.auto_declare:
queue.declare() queue.declare()
@ -406,9 +416,26 @@ class Consumer(object):
return queue return queue
def add_queue_from_dict(self, queue, **options): def add_queue_from_dict(self, queue, **options):
"""This method is deprecated.
Instead please use::
consumer.add_queue(Queue.from_dict(d))
"""
return self.add_queue(Queue.from_dict(queue, **options)) return self.add_queue(Queue.from_dict(queue, **options))
def consume(self, no_ack=None): def consume(self, no_ack=None):
"""Start consuming messages.
Can be called multiple times, but note that while it
will consume from new queues added since the last call,
it will not cancel consuming from removed queues (
use :meth:`cancel_by_queue`).
:param no_ack: See :attr:`no_ack`.
"""
if self.queues: if self.queues:
no_ack = self.no_ack if no_ack is None else no_ack no_ack = self.no_ack if no_ack is None else no_ack
@ -441,10 +468,12 @@ class Consumer(object):
self.channel.basic_cancel(tag) self.channel.basic_cancel(tag)
def consuming_from(self, queue): def consuming_from(self, queue):
"""Returns :const:`True` if the consumer is currently
consuming from queue'."""
name = queue name = queue
if isinstance(queue, Queue): if isinstance(queue, Queue):
name = queue.name name = queue.name
return any(q.name == name for q in self.queues) return name in self._active_tags
def purge(self): def purge(self):
"""Purge messages from all queues. """Purge messages from all queues.

View File

@ -13,6 +13,7 @@ import socket
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from itertools import count from itertools import count
from time import sleep
from .common import ignore_errors from .common import ignore_errors
from .messaging import Consumer from .messaging import Consumer
@ -158,13 +159,18 @@ class ConsumerMixin(object):
def extra_context(self, connection, channel): def extra_context(self, connection, channel):
yield yield
def run(self): def run(self, _tokens=1):
restart_limit = self.restart_limit
errors = (self.connection.connection_errors +
self.connection.channel_errors)
while not self.should_stop: while not self.should_stop:
try: try:
if self.restart_limit.can_consume(1): if restart_limit.can_consume(_tokens):
for _ in self.consume(limit=None): for _ in self.consume(limit=None):
pass pass
except self.connection.connection_errors: else:
sleep(restart_limit.expected_time(_tokens))
except errors:
warn('Connection to broker lost. ' warn('Connection to broker lost. '
'Trying to re-establish the connection...') 'Trying to re-establish the connection...')

View File

@ -20,6 +20,7 @@ from time import time
from . import Exchange, Queue, Consumer, Producer from . import Exchange, Queue, Consumer, Producer
from .clocks import LamportClock from .clocks import LamportClock
from .common import maybe_declare, oid_from from .common import maybe_declare, oid_from
from .exceptions import InconsistencyError
from .utils import cached_property, kwdict, uuid from .utils import cached_property, kwdict, uuid
REPLY_QUEUE_EXPIRES = 10 REPLY_QUEUE_EXPIRES = 10
@ -215,9 +216,15 @@ class Mailbox(object):
delivery_mode='transient', delivery_mode='transient',
durable=False) durable=False)
producer = Producer(chan, auto_declare=False) producer = Producer(chan, auto_declare=False)
producer.publish(reply, exchange=exchange, routing_key=routing_key, try:
declare=[exchange], headers={ producer.publish(
'ticket': ticket, 'clock': self.clock.forward()}) reply, exchange=exchange, routing_key=routing_key,
declare=[exchange], headers={
'ticket': ticket, 'clock': self.clock.forward(),
},
)
except InconsistencyError:
pass # queue probably deleted and no one is expecting a reply.
def _publish(self, type, arguments, destination=None, def _publish(self, type, arguments, destination=None,
reply_ticket=None, channel=None, timeout=None): reply_ticket=None, channel=None, timeout=None):

View File

@ -1,6 +1,7 @@
from __future__ import absolute_import from __future__ import absolute_import
import anyjson import anyjson
import atexit
import os import os
import sys import sys
@ -14,6 +15,24 @@ except ImportError:
anyjson.force_implementation('simplejson') anyjson.force_implementation('simplejson')
def teardown():
# Workaround for multiprocessing bug where logging
# is attempted after global already collected at shutdown.
cancelled = set()
try:
import multiprocessing.util
cancelled.add(multiprocessing.util._exit_function)
except (AttributeError, ImportError):
pass
try:
atexit._exithandlers[:] = [
e for e in atexit._exithandlers if e[0] not in cancelled
]
except AttributeError: # pragma: no cover
pass # Py3 missing _exithandlers
def find_distribution_modules(name=__name__, file=__file__): def find_distribution_modules(name=__name__, file=__file__):
current_dist_depth = len(name.split('.')) - 1 current_dist_depth = len(name.split('.')) - 1
current_dist = os.path.join(os.path.dirname(file), current_dist = os.path.join(os.path.dirname(file),

View File

@ -130,6 +130,10 @@ class test_Exchange(TestCase):
exc = Exchange('foo', 'direct', delivery_mode='transient') exc = Exchange('foo', 'direct', delivery_mode='transient')
self.assertEqual(exc.delivery_mode, Exchange.TRANSIENT_DELIVERY_MODE) self.assertEqual(exc.delivery_mode, Exchange.TRANSIENT_DELIVERY_MODE)
def test_set_passive_mode(self):
exc = Exchange('foo', 'direct', passive=True)
self.assertTrue(exc.passive)
def test_set_persistent_delivery_mode(self): def test_set_persistent_delivery_mode(self):
exc = Exchange('foo', 'direct', delivery_mode='persistent') exc = Exchange('foo', 'direct', delivery_mode='persistent')
self.assertEqual(exc.delivery_mode, Exchange.PERSISTENT_DELIVERY_MODE) self.assertEqual(exc.delivery_mode, Exchange.PERSISTENT_DELIVERY_MODE)

View File

@ -258,9 +258,13 @@ class test_Consumer(TestCase):
def test_consuming_from(self): def test_consuming_from(self):
consumer = self.connection.Consumer() consumer = self.connection.Consumer()
consumer.queues[:] = [Queue('a'), Queue('b')] consumer.queues[:] = [Queue('a'), Queue('b'), Queue('d')]
consumer._active_tags = {'a': 1, 'b': 2}
self.assertFalse(consumer.consuming_from(Queue('c'))) self.assertFalse(consumer.consuming_from(Queue('c')))
self.assertFalse(consumer.consuming_from('c')) self.assertFalse(consumer.consuming_from('c'))
self.assertFalse(consumer.consuming_from(Queue('d')))
self.assertFalse(consumer.consuming_from('d'))
self.assertTrue(consumer.consuming_from(Queue('a'))) self.assertTrue(consumer.consuming_from(Queue('a')))
self.assertTrue(consumer.consuming_from(Queue('b'))) self.assertTrue(consumer.consuming_from(Queue('b')))
self.assertTrue(consumer.consuming_from('b')) self.assertTrue(consumer.consuming_from('b'))

View File

@ -5,6 +5,8 @@ from __future__ import with_statement
import sys import sys
from base64 import b64decode
from kombu.serialization import (registry, register, SerializerNotInstalled, from kombu.serialization import (registry, register, SerializerNotInstalled,
raw_encode, register_yaml, register_msgpack, raw_encode, register_yaml, register_msgpack,
decode, bytes_t, pickle, pickle_protocol, decode, bytes_t, pickle, pickle_protocol,
@ -53,16 +55,13 @@ unicode: "Th\\xE9 quick brown fox jumps over th\\xE9 lazy dog"
msgpack_py_data = dict(py_data) msgpack_py_data = dict(py_data)
# msgpack only supports tuples
msgpack_py_data['list'] = tuple(msgpack_py_data['list'])
# Unicode chars are lost in transmit :( # Unicode chars are lost in transmit :(
msgpack_py_data['unicode'] = 'Th quick brown fox jumps over th lazy dog' msgpack_py_data['unicode'] = 'Th quick brown fox jumps over th lazy dog'
msgpack_data = """\ msgpack_data = b64decode("""\
\x85\xa3int\n\xa5float\xcb@\t!\xfbS\xc8\xd4\xf1\xa4list\ haNpbnQKpWZsb2F0y0AJIftTyNTxpGxpc3SUpmdlb3JnZaVqZXJyeaZlbGFpbmWlY29zbW+mc3Rya\
\x94\xa6george\xa5jerry\xa6elaine\xa5cosmo\xa6string\xda\ W5n2gArVGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIHRoZSBsYXp5IGRvZ6d1bmljb2Rl2g\
\x00+The quick brown fox jumps over the lazy dog\xa7unicode\ ApVGggcXVpY2sgYnJvd24gZm94IGp1bXBzIG92ZXIgdGggbGF6eSBkb2c=\
\xda\x00)Th quick brown fox jumps over th lazy dog\ """)
"""
def say(m): def say(m):

View File

@ -215,7 +215,7 @@ class test_retry_over_time(TestCase):
self.myfun, self.Predicate, self.myfun, self.Predicate,
max_retries=1, errback=self.errback, interval_max=14, max_retries=1, errback=self.errback, interval_max=14,
) )
self.assertEqual(self.index, 2) self.assertEqual(self.index, 1)
# no errback # no errback
self.assertRaises( self.assertRaises(
self.Predicate, utils.retry_over_time, self.Predicate, utils.retry_over_time,
@ -230,7 +230,7 @@ class test_retry_over_time(TestCase):
self.myfun, self.Predicate, self.myfun, self.Predicate,
max_retries=0, errback=self.errback, interval_max=14, max_retries=0, errback=self.errback, interval_max=14,
) )
self.assertEqual(self.index, 1) self.assertEqual(self.index, 0)
class test_cached_property(TestCase): class test_cached_property(TestCase):

View File

@ -3,8 +3,10 @@ from __future__ import with_statement
import sys import sys
from functools import partial
from mock import patch from mock import patch
from nose import SkipTest from nose import SkipTest
from itertools import count
try: try:
import amqp # noqa import amqp # noqa
@ -43,6 +45,7 @@ class test_Channel(TestCase):
pass pass
self.conn = Mock() self.conn = Mock()
self.conn._get_free_channel_id.side_effect = partial(next, count(0))
self.conn.channels = {} self.conn.channels = {}
self.channel = Channel(self.conn, 0) self.channel = Channel(self.conn, 0)

View File

@ -8,7 +8,7 @@ from contextlib import contextmanager
from mock import patch from mock import patch
from nose import SkipTest from nose import SkipTest
from kombu.utils.encoding import safe_str from kombu.utils.encoding import bytes_t, safe_str, default_encoding
from kombu.tests.utils import TestCase from kombu.tests.utils import TestCase
@ -26,16 +26,16 @@ def clean_encoding():
class test_default_encoding(TestCase): class test_default_encoding(TestCase):
@patch('sys.getfilesystemencoding') @patch('sys.getdefaultencoding')
def test_default(self, getfilesystemencoding): def test_default(self, getdefaultencoding):
getfilesystemencoding.return_value = 'ascii' getdefaultencoding.return_value = 'ascii'
with clean_encoding() as encoding: with clean_encoding() as encoding:
enc = encoding.default_encoding() enc = encoding.default_encoding()
if sys.platform.startswith('java'): if sys.platform.startswith('java'):
self.assertEqual(enc, 'utf-8') self.assertEqual(enc, 'utf-8')
else: else:
self.assertEqual(enc, 'ascii') self.assertEqual(enc, 'ascii')
getfilesystemencoding.assert_called_with() getdefaultencoding.assert_called_with()
class test_encoding_utils(TestCase): class test_encoding_utils(TestCase):
@ -60,16 +60,36 @@ class test_encoding_utils(TestCase):
class test_safe_str(TestCase): class test_safe_str(TestCase):
def test_when_str(self): def setUp(self):
self._cencoding = patch('sys.getdefaultencoding')
self._encoding = self._cencoding.__enter__()
self._encoding.return_value = 'ascii'
def tearDown(self):
self._cencoding.__exit__()
def test_when_bytes(self):
self.assertEqual(safe_str('foo'), 'foo') self.assertEqual(safe_str('foo'), 'foo')
def test_when_unicode(self): def test_when_unicode(self):
self.assertIsInstance(safe_str(u'foo'), str) self.assertIsInstance(safe_str(u'foo'), bytes_t)
def test_when_encoding_utf8(self):
with patch('sys.getdefaultencoding') as encoding:
encoding.return_value = 'utf-8'
self.assertEqual(default_encoding(), 'utf-8')
s = u'The quiæk fåx jømps øver the lazy dåg'
res = safe_str(s)
self.assertIsInstance(res, bytes_t)
self.assertGreater(len(res), len(s))
def test_when_containing_high_chars(self): def test_when_containing_high_chars(self):
s = u'The quiæk fåx jømps øver the lazy dåg' with patch('sys.getdefaultencoding') as encoding:
res = safe_str(s) encoding.return_value = 'ascii'
self.assertIsInstance(res, str) s = u'The quiæk fåx jømps øver the lazy dåg'
res = safe_str(s)
self.assertIsInstance(res, bytes_t)
self.assertEqual(len(s), len(res))
def test_when_not_string(self): def test_when_not_string(self):
o = object() o = object()

View File

@ -9,6 +9,7 @@ kombu.transport.librabbitmq
""" """
from __future__ import absolute_import from __future__ import absolute_import
import os
import socket import socket
try: try:
@ -28,6 +29,10 @@ from . import base
DEFAULT_PORT = 5672 DEFAULT_PORT = 5672
NO_SSL_ERROR = """\
ssl not supported by librabbitmq, please use pyamqp:// or stunnel\
"""
class Message(base.Message): class Message(base.Message):
@ -98,22 +103,41 @@ class Transport(base.Transport):
for name, default_value in self.default_connection_params.items(): for name, default_value in self.default_connection_params.items():
if not getattr(conninfo, name, None): if not getattr(conninfo, name, None):
setattr(conninfo, name, default_value) setattr(conninfo, name, default_value)
conn = self.Connection(host=conninfo.host, if conninfo.ssl:
userid=conninfo.userid, raise NotImplementedError(NO_SSL_ERROR)
password=conninfo.password, opts = dict({
virtual_host=conninfo.virtual_host, 'host': conninfo.host,
login_method=conninfo.login_method, 'userid': conninfo.userid,
insist=conninfo.insist, 'password': conninfo.password,
ssl=conninfo.ssl, 'virtual_host': conninfo.virtual_host,
connect_timeout=conninfo.connect_timeout) 'login_method': conninfo.login_method,
'insist': conninfo.insist,
'ssl': conninfo.ssl,
'connect_timeout': conninfo.connect_timeout,
}, **conninfo.transport_options or {})
conn = self.Connection(**opts)
conn.client = self.client conn.client = self.client
self.client.drain_events = conn.drain_events self.client.drain_events = conn.drain_events
return conn return conn
def close_connection(self, connection): def close_connection(self, connection):
"""Close the AMQP broker connection.""" """Close the AMQP broker connection."""
self.client.drain_events = None
connection.close() connection.close()
def _collect(self, connection):
if connection is not None:
for channel in connection.channels.itervalues():
channel.connection = None
try:
os.close(connection.fileno())
except OSError:
pass
connection.channels.clear()
connection.callbacks.clear()
self.client.drain_events = None
self.client = None
def verify_connection(self, connection): def verify_connection(self, connection):
return connection.connected return connection.connected

View File

@ -101,57 +101,42 @@ class Channel(virtual.Channel):
See mongodb uri documentation: See mongodb uri documentation:
http://www.mongodb.org/display/DOCS/Connections http://www.mongodb.org/display/DOCS/Connections
""" """
conninfo = self.connection.client client = self.connection.client
hostname = client.hostname or DEFAULT_HOST
authdb = dbname = client.virtual_host
dbname = None if dbname in ["/", None]:
hostname = None
if not conninfo.hostname:
conninfo.hostname = DEFAULT_HOST
for part in conninfo.hostname.split('/'):
if not hostname:
hostname = 'mongodb://' + part
continue
dbname = part
if '?' in part:
# In case someone is passing options
# to the mongodb connection. Right now
# it is not permitted by kombu
dbname, options = part.split('?')
hostname += '/?' + options
hostname = "%s/%s" % (
hostname, dbname in [None, "/"] and "admin" or dbname,
)
if not dbname or dbname == "/":
dbname = "kombu_default" dbname = "kombu_default"
authdb = "admin"
if not client.userid:
hostname = hostname.replace('/' + client.virtual_host, '/')
else:
hostname = hostname.replace('/' + client.virtual_host,
'/' + authdb)
mongo_uri = 'mongodb://' + hostname
# At this point we expect the hostname to be something like # At this point we expect the hostname to be something like
# (considering replica set form too): # (considering replica set form too):
# #
# mongodb://[username:password@]host1[:port1][,host2[:port2], # mongodb://[username:password@]host1[:port1][,host2[:port2],
# ...[,hostN[:portN]]][/[?options]] # ...[,hostN[:portN]]][/[?options]]
mongoconn = Connection(host=hostname, ssl=conninfo.ssl) mongoconn = Connection(host=mongo_uri, ssl=client.ssl)
database = getattr(mongoconn, dbname)
version = mongoconn.server_info()['version'] version = mongoconn.server_info()['version']
if tuple(map(int, version.split('.')[:2])) < (1, 3): if tuple(map(int, version.split('.')[:2])) < (1, 3):
raise NotImplementedError( raise NotImplementedError(
'Kombu requires MongoDB version 1.3+, but connected to %s' % ( 'Kombu requires MongoDB version 1.3+, but connected to %s' % (
version, )) version, ))
database = getattr(mongoconn, dbname)
# This is done by the connection uri
# if conninfo.userid:
# database.authenticate(conninfo.userid, conninfo.password)
self.db = database self.db = database
col = database.messages col = database.messages
col.ensure_index([('queue', 1), ('_id', 1)], background=True) col.ensure_index([('queue', 1), ('_id', 1)], background=True)
if 'messages.broadcast' not in database.collection_names(): if 'messages.broadcast' not in database.collection_names():
capsize = conninfo.transport_options.get( capsize = (client.transport_options.get('capped_queue_size')
'capped_queue_size') or 100000 or 100000)
database.create_collection('messages.broadcast', database.create_collection('messages.broadcast',
size=capsize, capped=True) size=capsize, capped=True)

View File

@ -75,14 +75,17 @@ class Transport(base.Transport):
channel_errors = (StdChannelError, ) + amqp.Connection.channel_errors channel_errors = (StdChannelError, ) + amqp.Connection.channel_errors
nb_keep_draining = True nb_keep_draining = True
driver_name = "py-amqp" driver_name = 'py-amqp'
driver_type = "amqp" driver_type = 'amqp'
supports_heartbeats = True supports_heartbeats = True
supports_ev = True supports_ev = True
def __init__(self, client, **kwargs): def __init__(self, client, **kwargs):
self.client = client self.client = client
self.default_port = kwargs.get("default_port") or self.default_port self.default_port = kwargs.get('default_port') or self.default_port
def driver_version(self):
return amqp.__version__
def create_channel(self, connection): def create_channel(self, connection):
return connection.channel() return connection.channel()
@ -98,15 +101,18 @@ class Transport(base.Transport):
setattr(conninfo, name, default_value) setattr(conninfo, name, default_value)
if conninfo.hostname == 'localhost': if conninfo.hostname == 'localhost':
conninfo.hostname = '127.0.0.1' conninfo.hostname = '127.0.0.1'
conn = self.Connection(host=conninfo.host, opts = dict({
userid=conninfo.userid, 'host': conninfo.host,
password=conninfo.password, 'userid': conninfo.userid,
login_method=conninfo.login_method, 'password': conninfo.password,
virtual_host=conninfo.virtual_host, 'login_method': conninfo.login_method,
insist=conninfo.insist, 'virtual_host': conninfo.virtual_host,
ssl=conninfo.ssl, 'insist': conninfo.insist,
connect_timeout=conninfo.connect_timeout, 'ssl': conninfo.ssl,
heartbeat=conninfo.heartbeat) 'connect_timeout': conninfo.connect_timeout,
'heartbeat': conninfo.heartbeat,
}, **conninfo.transport_options or {})
conn = self.Connection(**opts)
conn.client = self.client conn.client = self.client
return conn return conn

View File

@ -206,7 +206,7 @@ class MultiChannelPoller(object):
for fd in self._chan_to_sock.itervalues(): for fd in self._chan_to_sock.itervalues():
try: try:
self.poller.unregister(fd) self.poller.unregister(fd)
except KeyError: except (KeyError, ValueError):
pass pass
self._channels.clear() self._channels.clear()
self._fd_to_chan.clear() self._fd_to_chan.clear()
@ -707,11 +707,12 @@ class Channel(virtual.Channel):
return self._queue_cycle[0:active] return self._queue_cycle[0:active]
def _rotate_cycle(self, used): def _rotate_cycle(self, used):
""" """Move most recently used queue to end of list."""
Move most recently used queue to end of list cycle = self._queue_cycle
""" try:
index = self._queue_cycle.index(used) cycle.append(cycle.pop(cycle.index(used)))
self._queue_cycle.append(self._queue_cycle.pop(index)) except ValueError:
pass
def _get_response_error(self): def _get_response_error(self):
from redis import exceptions from redis import exceptions

View File

@ -216,7 +216,7 @@ def retry_over_time(fun, catch, args=[], kwargs={}, errback=None,
try: try:
return fun(*args, **kwargs) return fun(*args, **kwargs)
except catch, exc: except catch, exc:
if max_retries is not None and retries > max_retries: if max_retries is not None and retries >= max_retries:
raise raise
if callback: if callback:
callback() callback()

View File

@ -22,7 +22,7 @@ if sys.platform.startswith('java'): # pragma: no cover
else: else:
def default_encoding(): # noqa def default_encoding(): # noqa
return sys.getfilesystemencoding() return sys.getdefaultencoding()
if is_py3k: # pragma: no cover if is_py3k: # pragma: no cover

View File

@ -10,7 +10,7 @@ from __future__ import absolute_import
import errno import errno
import socket import socket
from select import select as _selectf from select import select as _selectf, error as _selecterr
try: try:
from select import epoll from select import epoll
@ -53,6 +53,11 @@ READ = POLL_READ = 0x001
WRITE = POLL_WRITE = 0x004 WRITE = POLL_WRITE = 0x004
ERR = POLL_ERR = 0x008 | 0x010 ERR = POLL_ERR = 0x008 | 0x010
try:
SELECT_BAD_FD = set((errno.EBADF, errno.WSAENOTSOCK))
except AttributeError:
SELECT_BAD_FD = set((errno.EBADF,))
class Poller(object): class Poller(object):
@ -79,11 +84,9 @@ class _epoll(Poller):
def unregister(self, fd): def unregister(self, fd):
try: try:
self._epoll.unregister(fd) self._epoll.unregister(fd)
except socket.error: except (socket.error, ValueError, KeyError):
pass pass
except ValueError: except (IOError, OSError), exc:
pass
except IOError, exc:
if get_errno(exc) != errno.ENOENT: if get_errno(exc) != errno.ENOENT:
raise raise
@ -191,13 +194,31 @@ class _select(Poller):
if events & READ: if events & READ:
self._rfd.add(fd) self._rfd.add(fd)
def _remove_bad(self):
for fd in self._rfd | self._wfd | self._efd:
try:
_selectf([fd], [], [], 0)
except (_selecterr, socket.error), exc:
if get_errno(exc) in SELECT_BAD_FD:
self.unregister(fd)
def unregister(self, fd): def unregister(self, fd):
self._rfd.discard(fd) self._rfd.discard(fd)
self._wfd.discard(fd) self._wfd.discard(fd)
self._efd.discard(fd) self._efd.discard(fd)
def _poll(self, timeout): def _poll(self, timeout):
read, write, error = _selectf(self._rfd, self._wfd, self._efd, timeout) try:
read, write, error = _selectf(
self._rfd, self._wfd, self._efd, timeout,
)
except (_selecterr, socket.error), exc:
if get_errno(exc) == errno.EINTR:
return
elif get_errno(exc) in SELECT_BAD_FD:
return self._remove_bad()
raise
events = {} events = {}
for fd in read: for fd in read:
if not isinstance(fd, int): if not isinstance(fd, int):

View File

@ -1,4 +1,4 @@
__version__ = '2.3.5' __version__ = '2.3.7'
VERSION = __version__ # synonym VERSION = __version__ # synonym

Some files were not shown because too many files have changed in this diff Show More