Updated all vendored third-party packages.

This commit is contained in:
Chris Church
2013-11-14 22:55:03 -05:00
parent cbd6132d4b
commit 7cd2707713
767 changed files with 45175 additions and 28364 deletions

View File

@@ -1,50 +1,51 @@
Local versions of third-party packages required by AWX. Package names and Local versions of third-party packages required by AWX. Package names and
versions are listed below, along with notes on which files are included. versions are listed below, along with notes on which files are included.
amqp==1.2.1 (amqp/*) amqp==1.3.3 (amqp/*)
anyjson==0.3.3 (anyjson/*) anyjson==0.3.3 (anyjson/*)
argparse==1.2.1 (argparse.py, needed for Python 2.6 support) argparse==1.2.1 (argparse.py, needed for Python 2.6 support)
Babel==1.3 (babel/*, excluded bin/pybabel) Babel==1.3 (babel/*, excluded bin/pybabel)
billiard==2.7.3.32 (billiard/*, funtests/*, excluded _billiard.so) billiard==3.3.0.6 (billiard/*, funtests/*, excluded _billiard.so)
boto==2.13.3 (boto/*, excluded bin/asadmin, bin/bundle_image, bin/cfadmin, boto==2.17.0 (boto/*, excluded bin/asadmin, bin/bundle_image, bin/cfadmin,
bin/cq, bin/cwutil, bin/dynamodb_dump, bin/dynamodb_load, bin/elbadmin, bin/cq, bin/cwutil, bin/dynamodb_dump, bin/dynamodb_load, bin/elbadmin,
bin/fetch_file, bin/glacier, bin/instance_events, bin/kill_instance, bin/fetch_file, bin/glacier, bin/instance_events, bin/kill_instance,
bin/launch_instance, bin/list_instances, bin/lss3, bin/mturk, bin/launch_instance, bin/list_instances, bin/lss3, bin/mturk,
bin/pyami_sendmail, bin/route53, bin/s3put, bin/sdbadmin, bin/taskadmin) bin/pyami_sendmail, bin/route53, bin/s3put, bin/sdbadmin, bin/taskadmin)
celery==3.0.23 (celery/*, excluded bin/celery* and bin/camqadm) celery==3.1.3 (celery/*, excluded bin/celery*)
d2to1==0.2.11 (d2to1/*) d2to1==0.2.11 (d2to1/*)
distribute==0.7.3 (no files) distribute==0.7.3 (no files)
django-auth-ldap==1.1.4 (django_auth_ldap/*) django-auth-ldap==1.1.6 (django_auth_ldap/*)
django-celery==3.0.23 (djcelery/*, excluded bin/djcelerymon) django-celery==3.1.1 (djcelery/*)
django-extensions==1.2.2 (django_extensions/*) django-extensions==1.2.5 (django_extensions/*)
django-jsonfield==0.9.10 (jsonfield/*) django-jsonfield==0.9.11 (jsonfield/*)
django-taggit==0.10 (taggit/*) django-taggit==0.10 (taggit/*)
djangorestframework==2.3.8 (rest_framework/*) djangorestframework==2.3.8 (rest_framework/*)
httplib2==0.8 (httplib2/*) httplib2==0.8 (httplib2/*)
importlib==1.0.2 (importlib/*, needed for Python 2.6 support) importlib==1.0.2 (importlib/*, needed for Python 2.6 support)
iso8601==0.1.4 (iso8601/*) iso8601==0.1.8 (iso8601/*)
keyring==3.0.5 (keyring/*, excluded bin/keyring) keyring==3.2 (keyring/*, excluded bin/keyring)
kombu==2.5.14 (kombu/*) kombu==3.0.4 (kombu/*)
Markdown==2.3.1 (markdown/*, excluded bin/markdown_py) Markdown==2.3.1 (markdown/*, excluded bin/markdown_py)
mock==1.0.1 (mock.py) mock==1.0.1 (mock.py)
ordereddict==1.1 (ordereddict.py, needed for Python 2.6 support) ordereddict==1.1 (ordereddict.py, needed for Python 2.6 support)
os-diskconfig-python-novaclient-ext==0.1.1 (os_diskconfig_python_novaclient_ext/*) os-diskconfig-python-novaclient-ext==0.1.1 (os_diskconfig_python_novaclient_ext/*)
os-networksv2-python-novaclient-ext==0.21 (os_networksv2_python_novaclient_ext.py) os-networksv2-python-novaclient-ext==0.21 (os_networksv2_python_novaclient_ext.py)
pbr==0.5.21 (pbr/*) pbr==0.5.23 (pbr/*)
pexpect==2.4 (pexpect.py, pxssh.py, fdpexpect.py, FSM.py, screen.py, ANSI.py) pexpect==3.0 (pexpect/*, excluded pxssh.py, fdpexpect.py, FSM.py, screen.py,
ANSI.py)
pip==1.4.1 (pip/*, excluded bin/pip*) pip==1.4.1 (pip/*, excluded bin/pip*)
prettytable==0.7.2 (prettytable.py) prettytable==0.7.2 (prettytable.py)
pyrax==1.5.0 (pyrax/*) pyrax==1.6.2 (pyrax/*)
python-dateutil==2.1 (dateutil/*) python-dateutil==2.2 (dateutil/*)
python-novaclient==2.15.0 (novaclient/*, excluded bin/nova) python-novaclient==2.15.0 (novaclient/*, excluded bin/nova)
python-swiftclient==1.6.0 (swiftclient/*, excluded bin/swift) python-swiftclient==1.8.0 (swiftclient/*, excluded bin/swift)
pytz==2013d (pytz/*) pytz==2013.8 (pytz/*)
rackspace-auth-openstack==1.0 (rackspace_auth_openstack/*) rackspace-auth-openstack==1.1 (rackspace_auth_openstack/*)
rackspace-novaclient==1.3 (no files) rackspace-novaclient==1.3 (no files)
rax-default-network-flags-python-novaclient-ext==0.1.3 (rax_default_network_flags_python_novaclient_ext/*) rax-default-network-flags-python-novaclient-ext==0.1.3 (rax_default_network_flags_python_novaclient_ext/*)
rax-scheduled-images-python-novaclient-ext==0.2.1 (rax_scheduled_images_python_novaclient_ext/*) rax-scheduled-images-python-novaclient-ext==0.2.1 (rax_scheduled_images_python_novaclient_ext/*)
requests==2.0.0 (requests/*) requests==2.0.1 (requests/*)
setuptools==1.1.6 (setuptools/*, _markerlib/*, pkg_resources.py, easy_install.py, excluded bin/easy_install*) setuptools==1.3.2 (setuptools/*, _markerlib/*, pkg_resources.py, easy_install.py, excluded bin/easy_install*)
simplejson==3.3.0 (simplejson/*, excluded simplejson/_speedups.so) simplejson==3.3.1 (simplejson/*, excluded simplejson/_speedups.so)
six==1.4.1 (six.py) six==1.4.1 (six.py)
South==0.8.2 (south/*) South==0.8.3 (south/*)

View File

@@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301
from __future__ import absolute_import from __future__ import absolute_import
VERSION = (1, 2, 1) VERSION = (1, 3, 3)
__version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:]) __version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:])
__author__ = 'Barry Pederson' __author__ = 'Barry Pederson'
__maintainer__ = 'Ask Solem' __maintainer__ = 'Ask Solem'
@@ -61,6 +61,7 @@ from .exceptions import ( # noqa
error_for_code, error_for_code,
__all__ as _all_exceptions, __all__ as _all_exceptions,
) )
from .utils import promise # noqa
__all__ = [ __all__ = [
'Connection', 'Connection',

View File

@@ -24,6 +24,7 @@ from warnings import warn
from .abstract_channel import AbstractChannel from .abstract_channel import AbstractChannel
from .exceptions import ChannelError, ConsumerCancelled, error_for_code from .exceptions import ChannelError, ConsumerCancelled, error_for_code
from .five import Queue from .five import Queue
from .protocol import basic_return_t, queue_declare_ok_t
from .serialization import AMQPWriter from .serialization import AMQPWriter
__all__ = ['Channel'] __all__ = ['Channel']
@@ -80,6 +81,12 @@ class Channel(AbstractChannel):
self.events = defaultdict(set) self.events = defaultdict(set)
self.no_ack_consumers = set() self.no_ack_consumers = set()
# set first time basic_publish_confirm is called
# and publisher confirms are enabled for this channel.
self._confirm_selected = False
if self.connection.confirm_publish:
self.basic_publish = self.basic_publish_confirm
self._x_open() self._x_open()
def _do_close(self): def _do_close(self):
@@ -1272,10 +1279,11 @@ class Channel(AbstractChannel):
this count. this count.
""" """
queue = args.read_shortstr() return queue_declare_ok_t(
message_count = args.read_long() args.read_shortstr(),
consumer_count = args.read_long() args.read_long(),
return queue, message_count, consumer_count args.read_long(),
)
def queue_delete(self, queue='', def queue_delete(self, queue='',
if_unused=False, if_empty=False, nowait=False): if_unused=False, if_empty=False, nowait=False):
@@ -1875,6 +1883,7 @@ class Channel(AbstractChannel):
exchange = args.read_shortstr() exchange = args.read_shortstr()
routing_key = args.read_shortstr() routing_key = args.read_shortstr()
msg.channel = self
msg.delivery_info = { msg.delivery_info = {
'consumer_tag': consumer_tag, 'consumer_tag': consumer_tag,
'delivery_tag': delivery_tag, 'delivery_tag': delivery_tag,
@@ -1883,8 +1892,11 @@ class Channel(AbstractChannel):
'routing_key': routing_key, 'routing_key': routing_key,
} }
fun = self.callbacks.get(consumer_tag, None) try:
if fun is not None: fun = self.callbacks[consumer_tag]
except KeyError:
pass
else:
fun(msg) fun(msg)
def basic_get(self, queue='', no_ack=False): def basic_get(self, queue='', no_ack=False):
@@ -2015,6 +2027,7 @@ class Channel(AbstractChannel):
routing_key = args.read_shortstr() routing_key = args.read_shortstr()
message_count = args.read_long() message_count = args.read_long()
msg.channel = self
msg.delivery_info = { msg.delivery_info = {
'delivery_tag': delivery_tag, 'delivery_tag': delivery_tag,
'redelivered': redelivered, 'redelivered': redelivered,
@@ -2024,8 +2037,8 @@ class Channel(AbstractChannel):
} }
return msg return msg
def basic_publish(self, msg, exchange='', routing_key='', def _basic_publish(self, msg, exchange='', routing_key='',
mandatory=False, immediate=False): mandatory=False, immediate=False):
"""Publish a message """Publish a message
This method publishes a message to a specific exchange. The This method publishes a message to a specific exchange. The
@@ -2099,6 +2112,15 @@ class Channel(AbstractChannel):
args.write_bit(immediate) args.write_bit(immediate)
self._send_method((60, 40), args, msg) self._send_method((60, 40), args, msg)
basic_publish = _basic_publish
def basic_publish_confirm(self, *args, **kwargs):
if not self._confirm_selected:
self._confirm_selected = True
self.confirm_select()
ret = self._basic_publish(*args, **kwargs)
self.wait([(60, 80)])
return ret
def basic_qos(self, prefetch_size, prefetch_count, a_global): def basic_qos(self, prefetch_size, prefetch_count, a_global):
"""Specify quality of service """Specify quality of service
@@ -2334,14 +2356,13 @@ class Channel(AbstractChannel):
message was published. message was published.
""" """
reply_code = args.read_short() self.returned_messages.put(basic_return_t(
reply_text = args.read_shortstr() args.read_short(),
exchange = args.read_shortstr() args.read_shortstr(),
routing_key = args.read_shortstr() args.read_shortstr(),
args.read_shortstr(),
self.returned_messages.put( msg,
(reply_code, reply_text, exchange, routing_key, msg) ))
)
############# #############
# #

View File

@@ -89,7 +89,7 @@ class Connection(AbstractChannel):
virtual_host='/', locale='en_US', client_properties=None, virtual_host='/', locale='en_US', client_properties=None,
ssl=False, connect_timeout=None, channel_max=None, ssl=False, connect_timeout=None, channel_max=None,
frame_max=None, heartbeat=0, on_blocked=None, frame_max=None, heartbeat=0, on_blocked=None,
on_unblocked=None, **kwargs): on_unblocked=None, confirm_publish=False, **kwargs):
"""Create a connection to the specified host, which should be """Create a connection to the specified host, which should be
a 'host[:port]', such as 'localhost', or '1.2.3.4:5672' a 'host[:port]', such as 'localhost', or '1.2.3.4:5672'
(defaults to 'localhost', if a port is not specified then (defaults to 'localhost', if a port is not specified then
@@ -127,6 +127,8 @@ class Connection(AbstractChannel):
self.frame_max = frame_max self.frame_max = frame_max
self.heartbeat = heartbeat self.heartbeat = heartbeat
self.confirm_publish = confirm_publish
# Callbacks # Callbacks
self.on_blocked = on_blocked self.on_blocked = on_blocked
self.on_unblocked = on_unblocked self.on_unblocked = on_unblocked
@@ -163,6 +165,10 @@ class Connection(AbstractChannel):
return self._x_open(virtual_host) return self._x_open(virtual_host)
@property
def connected(self):
return self.transport and self.transport.connected
def _do_close(self): def _do_close(self):
try: try:
self.transport.close() self.transport.close()

View File

@@ -47,7 +47,9 @@ class AMQPError(Exception):
reply_text, method_sig, self.method_name) reply_text, method_sig, self.method_name)
def __str__(self): def __str__(self):
return '{0.method}: ({0.reply_code}) {0.reply_text}'.format(self) if self.method:
return '{0.method}: ({0.reply_code}) {0.reply_text}'.format(self)
return self.reply_text or '<AMQPError: unknown error>'
@property @property
def method(self): def method(self):

View File

@@ -46,7 +46,7 @@ _CONTENT_METHODS = [
class _PartialMessage(object): class _PartialMessage(object):
"""Helper class to build up a multi-frame method.""" """Helper class to build up a multi-frame method."""
def __init__(self, method_sig, args): def __init__(self, method_sig, args, channel):
self.method_sig = method_sig self.method_sig = method_sig
self.args = args self.args = args
self.msg = Message() self.msg = Message()
@@ -147,7 +147,9 @@ class MethodReader(object):
# #
# Save what we've got so far and wait for the content-header # Save what we've got so far and wait for the content-header
# #
self.partial_messages[channel] = _PartialMessage(method_sig, args) self.partial_messages[channel] = _PartialMessage(
method_sig, args, channel,
)
self.expected_types[channel] = 2 self.expected_types[channel] = 2
else: else:
self._quick_put((channel, method_sig, args, None)) self._quick_put((channel, method_sig, args, None))

View File

@@ -0,0 +1,13 @@
from __future__ import absolute_import
from collections import namedtuple
queue_declare_ok_t = namedtuple(
'queue_declare_ok_t', ('queue', 'message_count', 'consumer_count'),
)
basic_return_t = namedtuple(
'basic_return_t',
('reply_code', 'reply_text', 'exchange', 'routing_key', 'message'),
)

View File

@@ -49,6 +49,9 @@ except:
from struct import pack, unpack from struct import pack, unpack
from .exceptions import UnexpectedFrame from .exceptions import UnexpectedFrame
from .utils import get_errno, set_cloexec
_UNAVAIL = errno.EAGAIN, errno.EINTR
AMQP_PORT = 5672 AMQP_PORT = 5672
@@ -63,8 +66,10 @@ IPV6_LITERAL = re.compile(r'\[([\.0-9a-f:]+)\](?::(\d+))?')
class _AbstractTransport(object): class _AbstractTransport(object):
"""Common superclass for TCP and SSL transports""" """Common superclass for TCP and SSL transports"""
connected = False
def __init__(self, host, connect_timeout): def __init__(self, host, connect_timeout):
self.connected = True
msg = None msg = None
port = AMQP_PORT port = AMQP_PORT
@@ -85,6 +90,10 @@ class _AbstractTransport(object):
af, socktype, proto, canonname, sa = res af, socktype, proto, canonname, sa = res
try: try:
self.sock = socket.socket(af, socktype, proto) self.sock = socket.socket(af, socktype, proto)
try:
set_cloexec(self.sock, True)
except NotImplementedError:
pass
self.sock.settimeout(connect_timeout) self.sock.settimeout(connect_timeout)
self.sock.connect(sa) self.sock.connect(sa)
except socket.error as exc: except socket.error as exc:
@@ -99,13 +108,18 @@ class _AbstractTransport(object):
# Didn't connect, return the most recent error message # Didn't connect, return the most recent error message
raise socket.error(last_err) raise socket.error(last_err)
self.sock.settimeout(None) try:
self.sock.setsockopt(SOL_TCP, socket.TCP_NODELAY, 1) self.sock.settimeout(None)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) self.sock.setsockopt(SOL_TCP, socket.TCP_NODELAY, 1)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self._setup_transport() self._setup_transport()
self._write(AMQP_PROTOCOL_HEADER) self._write(AMQP_PROTOCOL_HEADER)
except (OSError, IOError, socket.error) as exc:
if get_errno(exc) not in _UNAVAIL:
self.connected = False
raise
def __del__(self): def __del__(self):
try: try:
@@ -141,12 +155,20 @@ class _AbstractTransport(object):
self.sock.shutdown(socket.SHUT_RDWR) self.sock.shutdown(socket.SHUT_RDWR)
self.sock.close() self.sock.close()
self.sock = None self.sock = None
self.connected = False
def read_frame(self, unpack=unpack): def read_frame(self, unpack=unpack):
read = self._read read = self._read
frame_type, channel, size = unpack('>BHI', read(7, True)) try:
payload = read(size) frame_type, channel, size = unpack('>BHI', read(7, True))
ch = ord(read(1)) payload = read(size)
ch = ord(read(1))
except socket.timeout:
raise
except (OSError, IOError, socket.error) as exc:
if get_errno(exc) not in _UNAVAIL:
self.connected = False
raise
if ch == 206: # '\xce' if ch == 206: # '\xce'
return frame_type, channel, payload return frame_type, channel, payload
else: else:
@@ -155,10 +177,17 @@ class _AbstractTransport(object):
def write_frame(self, frame_type, channel, payload): def write_frame(self, frame_type, channel, payload):
size = len(payload) size = len(payload)
self._write(pack( try:
'>BHI%dsB' % size, self._write(pack(
frame_type, channel, size, payload, 0xce, '>BHI%dsB' % size,
)) frame_type, channel, size, payload, 0xce,
))
except socket.timeout:
raise
except (OSError, IOError, socket.error) as exc:
if get_errno(exc) not in _UNAVAIL:
self.connected = False
raise
class SSLTransport(_AbstractTransport): class SSLTransport(_AbstractTransport):
@@ -200,19 +229,22 @@ class SSLTransport(_AbstractTransport):
# to get the exact number of bytes wanted. # to get the exact number of bytes wanted.
recv = self._quick_recv recv = self._quick_recv
rbuf = self._read_buffer rbuf = self._read_buffer
while len(rbuf) < n: try:
try: while len(rbuf) < n:
s = recv(131072) # see note above try:
except socket.error as exc: s = recv(131072) # see note above
# ssl.sock.read may cause ENOENT if the except socket.error as exc:
# operation couldn't be performed (Issue celery#1414). # ssl.sock.read may cause ENOENT if the
if not initial and exc.errno in _errnos: # operation couldn't be performed (Issue celery#1414).
continue if not initial and exc.errno in _errnos:
raise exc continue
if not s: raise
raise IOError('Socket closed') if not s:
rbuf += s raise IOError('Socket closed')
rbuf += s
except:
self._read_buffer = rbuf
raise
result, self._read_buffer = rbuf[:n], rbuf[n:] result, self._read_buffer = rbuf[:n], rbuf[n:]
return result return result
@@ -240,16 +272,20 @@ class TCPTransport(_AbstractTransport):
"""Read exactly n bytes from the socket""" """Read exactly n bytes from the socket"""
recv = self._quick_recv recv = self._quick_recv
rbuf = self._read_buffer rbuf = self._read_buffer
while len(rbuf) < n: try:
try: while len(rbuf) < n:
s = recv(131072) try:
except socket.error as exc: s = recv(131072)
if not initial and exc.errno in _errnos: except socket.error as exc:
continue if not initial and exc.errno in _errnos:
raise continue
if not s: raise
raise IOError('Socket closed') if not s:
rbuf += s raise IOError('Socket closed')
rbuf += s
except:
self._read_buffer = rbuf
raise
result, self._read_buffer = rbuf[:n], rbuf[n:] result, self._read_buffer = rbuf[:n], rbuf[n:]
return result return result

View File

@@ -2,6 +2,11 @@ from __future__ import absolute_import
import sys import sys
try:
import fcntl
except ImportError:
fcntl = None # noqa
class promise(object): class promise(object):
if not hasattr(sys, 'pypy_version_info'): if not hasattr(sys, 'pypy_version_info'):
@@ -59,3 +64,36 @@ class promise(object):
def noop(): def noop():
return promise(lambda *a, **k: None) return promise(lambda *a, **k: None)
try:
from os import set_cloexec # Python 3.4?
except ImportError:
def set_cloexec(fd, cloexec): # noqa
try:
FD_CLOEXEC = fcntl.FD_CLOEXEC
except AttributeError:
raise NotImplementedError(
'close-on-exec flag not supported on this platform',
)
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
if cloexec:
flags |= FD_CLOEXEC
else:
flags &= ~FD_CLOEXEC
return fcntl.fcntl(fd, fcntl.F_SETFD, flags)
def get_errno(exc):
""":exc:`socket.error` and :exc:`IOError` first got
the ``.errno`` attribute in Py2.7"""
try:
return exc.errno
except AttributeError:
try:
# e.args = (errno, reason)
if isinstance(exc.args, tuple) and len(exc.args) == 2:
return exc.args[0]
except AttributeError:
pass
return 0

View File

@@ -18,9 +18,8 @@
# #
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement
VERSION = (2, 7, 3, 32) VERSION = (3, 3, 0, 6)
__version__ = ".".join(map(str, VERSION[0:4])) + "".join(VERSION[4:]) __version__ = ".".join(map(str, VERSION[0:4])) + "".join(VERSION[4:])
__author__ = 'R Oudkerk / Python Software Foundation' __author__ = 'R Oudkerk / Python Software Foundation'
__author_email__ = 'python-dev@python.org' __author_email__ = 'python-dev@python.org'
@@ -90,15 +89,12 @@ def Manager():
return m return m
def Pipe(duplex=True): def Pipe(duplex=True, rnonblock=False, wnonblock=False):
''' '''
Returns two connection object connected by a pipe Returns two connection object connected by a pipe
''' '''
if sys.version_info[0] == 3: from billiard.connection import Pipe
from multiprocessing.connection import Pipe return Pipe(duplex, rnonblock, wnonblock)
else:
from billiard._connection import Pipe
return Pipe(duplex)
def cpu_count(): def cpu_count():
@@ -241,7 +237,11 @@ def Pool(processes=None, initializer=None, initargs=(), maxtasksperchild=None,
Returns a process pool object Returns a process pool object
''' '''
from .pool import Pool from .pool import Pool
return Pool(processes, initializer, initargs, maxtasksperchild) return Pool(processes, initializer, initargs, maxtasksperchild,
timeout, soft_timeout, lost_worker_timeout,
max_restarts, max_restart_freq, on_process_up,
on_process_down, on_timeout_set, on_timeout_cancel,
threads, semaphore, putlocks, allow_restart)
def RawValue(typecode_or_type, *args): def RawValue(typecode_or_type, *args):

View File

@@ -8,7 +8,6 @@
# #
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement
__all__ = ['Client', 'Listener', 'Pipe'] __all__ = ['Client', 'Listener', 'Pipe']
@@ -21,11 +20,13 @@ import tempfile
import itertools import itertools
from . import AuthenticationError from . import AuthenticationError
from . import reduction
from ._ext import _billiard, win32 from ._ext import _billiard, win32
from .compat import get_errno from .compat import get_errno, bytes, setblocking
from .util import get_temp_dir, Finalize, sub_debug, debug from .five import monotonic
from .forking import duplicate, close from .forking import duplicate, close
from .compat import bytes from .reduction import ForkingPickler
from .util import get_temp_dir, Finalize, sub_debug, debug
try: try:
WindowsError = WindowsError # noqa WindowsError = WindowsError # noqa
@@ -36,6 +37,9 @@ except NameError:
# global set later # global set later
xmlrpclib = None xmlrpclib = None
Connection = getattr(_billiard, 'Connection', None)
PipeConnection = getattr(_billiard, 'PipeConnection', None)
# #
# #
@@ -60,11 +64,11 @@ if sys.platform == 'win32':
def _init_timeout(timeout=CONNECTION_TIMEOUT): def _init_timeout(timeout=CONNECTION_TIMEOUT):
return time.time() + timeout return monotonic() + timeout
def _check_timeout(t): def _check_timeout(t):
return time.time() > t return monotonic() > t
# #
# #
@@ -81,7 +85,7 @@ def arbitrary_address(family):
return tempfile.mktemp(prefix='listener-', dir=get_temp_dir()) return tempfile.mktemp(prefix='listener-', dir=get_temp_dir())
elif family == 'AF_PIPE': elif family == 'AF_PIPE':
return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' % return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' %
(os.getpid(), _mmap_counter.next())) (os.getpid(), next(_mmap_counter)))
else: else:
raise ValueError('unrecognized family') raise ValueError('unrecognized family')
@@ -183,26 +187,32 @@ def Client(address, family=None, authkey=None):
if sys.platform != 'win32': if sys.platform != 'win32':
def Pipe(duplex=True): def Pipe(duplex=True, rnonblock=False, wnonblock=False):
''' '''
Returns pair of connection objects at either end of a pipe Returns pair of connection objects at either end of a pipe
''' '''
if duplex: if duplex:
s1, s2 = socket.socketpair() s1, s2 = socket.socketpair()
c1 = _billiard.Connection(os.dup(s1.fileno())) s1.setblocking(not rnonblock)
c2 = _billiard.Connection(os.dup(s2.fileno())) s2.setblocking(not wnonblock)
c1 = Connection(os.dup(s1.fileno()))
c2 = Connection(os.dup(s2.fileno()))
s1.close() s1.close()
s2.close() s2.close()
else: else:
fd1, fd2 = os.pipe() fd1, fd2 = os.pipe()
c1 = _billiard.Connection(fd1, writable=False) if rnonblock:
c2 = _billiard.Connection(fd2, readable=False) setblocking(fd1, 0)
if wnonblock:
setblocking(fd2, 0)
c1 = Connection(fd1, writable=False)
c2 = Connection(fd2, readable=False)
return c1, c2 return c1, c2
else: else:
def Pipe(duplex=True): # noqa def Pipe(duplex=True, rnonblock=False, wnonblock=False): # noqa
''' '''
Returns pair of connection objects at either end of a pipe Returns pair of connection objects at either end of a pipe
''' '''
@@ -231,12 +241,12 @@ else:
try: try:
win32.ConnectNamedPipe(h1, win32.NULL) win32.ConnectNamedPipe(h1, win32.NULL)
except WindowsError, e: except WindowsError as exc:
if e.args[0] != win32.ERROR_PIPE_CONNECTED: if exc.args[0] != win32.ERROR_PIPE_CONNECTED:
raise raise
c1 = _billiard.PipeConnection(h1, writable=duplex) c1 = PipeConnection(h1, writable=duplex)
c2 = _billiard.PipeConnection(h2, readable=duplex) c2 = PipeConnection(h2, readable=duplex)
return c1, c2 return c1, c2
@@ -275,7 +285,7 @@ class SocketListener(object):
def accept(self): def accept(self):
s, self._last_accepted = self._socket.accept() s, self._last_accepted = self._socket.accept()
fd = duplicate(s.fileno()) fd = duplicate(s.fileno())
conn = _billiard.Connection(fd) conn = Connection(fd)
s.close() s.close()
return conn return conn
@@ -296,7 +306,7 @@ def SocketClient(address):
while 1: while 1:
try: try:
s.connect(address) s.connect(address)
except socket.error, exc: except socket.error as exc:
if get_errno(exc) != errno.ECONNREFUSED or _check_timeout(t): if get_errno(exc) != errno.ECONNREFUSED or _check_timeout(t):
debug('failed to connect to address %s', address) debug('failed to connect to address %s', address)
raise raise
@@ -307,7 +317,7 @@ def SocketClient(address):
raise raise
fd = duplicate(s.fileno()) fd = duplicate(s.fileno())
conn = _billiard.Connection(fd) conn = Connection(fd)
s.close() s.close()
return conn return conn
@@ -352,10 +362,10 @@ if sys.platform == 'win32':
handle = self._handle_queue.pop(0) handle = self._handle_queue.pop(0)
try: try:
win32.ConnectNamedPipe(handle, win32.NULL) win32.ConnectNamedPipe(handle, win32.NULL)
except WindowsError, e: except WindowsError as exc:
if e.args[0] != win32.ERROR_PIPE_CONNECTED: if exc.args[0] != win32.ERROR_PIPE_CONNECTED:
raise raise
return _billiard.PipeConnection(handle) return PipeConnection(handle)
@staticmethod @staticmethod
def _finalize_pipe_listener(queue, address): def _finalize_pipe_listener(queue, address):
@@ -375,8 +385,8 @@ if sys.platform == 'win32':
address, win32.GENERIC_READ | win32.GENERIC_WRITE, address, win32.GENERIC_READ | win32.GENERIC_WRITE,
0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL, 0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL,
) )
except WindowsError, e: except WindowsError as exc:
if e.args[0] not in ( if exc.args[0] not in (
win32.ERROR_SEM_TIMEOUT, win32.ERROR_SEM_TIMEOUT,
win32.ERROR_PIPE_BUSY) or _check_timeout(t): win32.ERROR_PIPE_BUSY) or _check_timeout(t):
raise raise
@@ -388,7 +398,7 @@ if sys.platform == 'win32':
win32.SetNamedPipeHandleState( win32.SetNamedPipeHandleState(
h, win32.PIPE_READMODE_MESSAGE, None, None h, win32.PIPE_READMODE_MESSAGE, None, None
) )
return _billiard.PipeConnection(h) return PipeConnection(h)
# #
# Authentication stuff # Authentication stuff
@@ -471,3 +481,12 @@ def XmlClient(*args, **kwds):
global xmlrpclib global xmlrpclib
import xmlrpclib # noqa import xmlrpclib # noqa
return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads) return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
if sys.platform == 'win32':
ForkingPickler.register(socket.socket, reduction.reduce_socket)
ForkingPickler.register(Connection, reduction.reduce_connection)
ForkingPickler.register(PipeConnection, reduction.reduce_pipe_connection)
else:
ForkingPickler.register(socket.socket, reduction.reduce_socket)
ForkingPickler.register(Connection, reduction.reduce_connection)

View File

@@ -0,0 +1,955 @@
#
# A higher level module for using sockets (or Windows named pipes)
#
# multiprocessing/connection.py
#
# Copyright (c) 2006-2008, R Oudkerk
# Licensed to PSF under a Contributor Agreement.
#
from __future__ import absolute_import
__all__ = ['Client', 'Listener', 'Pipe', 'wait']
import io
import os
import sys
import select
import socket
import struct
import errno
import tempfile
import itertools
import _multiprocessing
from .compat import setblocking
from .exceptions import AuthenticationError, BufferTooShort
from .five import monotonic
from .util import get_temp_dir, Finalize, sub_debug
from .reduction import ForkingPickler
try:
import _winapi
from _winapi import (
WAIT_OBJECT_0,
WAIT_ABANDONED_0,
WAIT_TIMEOUT,
INFINITE,
)
except ImportError:
if sys.platform == 'win32':
raise
_winapi = None
#
#
#
BUFSIZE = 8192
# A very generous timeout when it comes to local connections...
CONNECTION_TIMEOUT = 20.
_mmap_counter = itertools.count()
default_family = 'AF_INET'
families = ['AF_INET']
if hasattr(socket, 'AF_UNIX'):
default_family = 'AF_UNIX'
families += ['AF_UNIX']
if sys.platform == 'win32':
default_family = 'AF_PIPE'
families += ['AF_PIPE']
def _init_timeout(timeout=CONNECTION_TIMEOUT):
return monotonic() + timeout
def _check_timeout(t):
return monotonic() > t
def arbitrary_address(family):
'''
Return an arbitrary free address for the given family
'''
if family == 'AF_INET':
return ('localhost', 0)
elif family == 'AF_UNIX':
return tempfile.mktemp(prefix='listener-', dir=get_temp_dir())
elif family == 'AF_PIPE':
return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' %
(os.getpid(), next(_mmap_counter)))
else:
raise ValueError('unrecognized family')
def _validate_family(family):
'''
Checks if the family is valid for the current environment.
'''
if sys.platform != 'win32' and family == 'AF_PIPE':
raise ValueError('Family %s is not recognized.' % family)
if sys.platform == 'win32' and family == 'AF_UNIX':
# double check
if not hasattr(socket, family):
raise ValueError('Family %s is not recognized.' % family)
def address_type(address):
'''
Return the types of the address
This can be 'AF_INET', 'AF_UNIX', or 'AF_PIPE'
'''
if type(address) == tuple:
return 'AF_INET'
elif type(address) is str and address.startswith('\\\\'):
return 'AF_PIPE'
elif type(address) is str:
return 'AF_UNIX'
else:
raise ValueError('address type of %r unrecognized' % address)
#
# Connection classes
#
class _ConnectionBase:
_handle = None
def __init__(self, handle, readable=True, writable=True):
handle = handle.__index__()
if handle < 0:
raise ValueError("invalid handle")
if not readable and not writable:
raise ValueError(
"at least one of `readable` and `writable` must be True")
self._handle = handle
self._readable = readable
self._writable = writable
# XXX should we use util.Finalize instead of a __del__?
def __del__(self):
if self._handle is not None:
self._close()
def _check_closed(self):
if self._handle is None:
raise OSError("handle is closed")
def _check_readable(self):
if not self._readable:
raise OSError("connection is write-only")
def _check_writable(self):
if not self._writable:
raise OSError("connection is read-only")
def _bad_message_length(self):
if self._writable:
self._readable = False
else:
self.close()
raise OSError("bad message length")
@property
def closed(self):
"""True if the connection is closed"""
return self._handle is None
@property
def readable(self):
"""True if the connection is readable"""
return self._readable
@property
def writable(self):
"""True if the connection is writable"""
return self._writable
def fileno(self):
"""File descriptor or handle of the connection"""
self._check_closed()
return self._handle
def close(self):
"""Close the connection"""
if self._handle is not None:
try:
self._close()
finally:
self._handle = None
def send_bytes(self, buf, offset=0, size=None):
"""Send the bytes data from a bytes-like object"""
self._check_closed()
self._check_writable()
m = memoryview(buf)
# HACK for byte-indexing of non-bytewise buffers (e.g. array.array)
if m.itemsize > 1:
m = memoryview(bytes(m))
n = len(m)
if offset < 0:
raise ValueError("offset is negative")
if n < offset:
raise ValueError("buffer length < offset")
if size is None:
size = n - offset
elif size < 0:
raise ValueError("size is negative")
elif offset + size > n:
raise ValueError("buffer length < offset + size")
self._send_bytes(m[offset:offset + size])
def send(self, obj):
"""Send a (picklable) object"""
self._check_closed()
self._check_writable()
self._send_bytes(ForkingPickler.dumps(obj))
def recv_bytes(self, maxlength=None):
"""
Receive bytes data as a bytes object.
"""
self._check_closed()
self._check_readable()
if maxlength is not None and maxlength < 0:
raise ValueError("negative maxlength")
buf = self._recv_bytes(maxlength)
if buf is None:
self._bad_message_length()
return buf.getvalue()
def recv_bytes_into(self, buf, offset=0):
"""
Receive bytes data into a writeable buffer-like object.
Return the number of bytes read.
"""
self._check_closed()
self._check_readable()
with memoryview(buf) as m:
# Get bytesize of arbitrary buffer
itemsize = m.itemsize
bytesize = itemsize * len(m)
if offset < 0:
raise ValueError("negative offset")
elif offset > bytesize:
raise ValueError("offset too large")
result = self._recv_bytes()
size = result.tell()
if bytesize < offset + size:
raise BufferTooShort(result.getvalue())
# Message can fit in dest
result.seek(0)
result.readinto(
m[offset // itemsize:(offset + size) // itemsize]
)
return size
def recv_payload(self):
return self._recv_bytes().getbuffer()
def recv(self):
"""Receive a (picklable) object"""
self._check_closed()
self._check_readable()
buf = self._recv_bytes()
return ForkingPickler.loads(buf.getbuffer())
def poll(self, timeout=0.0):
"""Whether there is any input available to be read"""
self._check_closed()
self._check_readable()
return self._poll(timeout)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
self.close()
if _winapi:
class PipeConnection(_ConnectionBase):
"""
Connection class based on a Windows named pipe.
Overlapped I/O is used, so the handles must have been created
with FILE_FLAG_OVERLAPPED.
"""
_got_empty_message = False
def _close(self, _CloseHandle=_winapi.CloseHandle):
_CloseHandle(self._handle)
def _send_bytes(self, buf):
ov, err = _winapi.WriteFile(self._handle, buf, overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
waitres = _winapi.WaitForMultipleObjects(
[ov.event], False, INFINITE)
assert waitres == WAIT_OBJECT_0
except:
ov.cancel()
raise
finally:
nwritten, err = ov.GetOverlappedResult(True)
assert err == 0
assert nwritten == len(buf)
def _recv_bytes(self, maxsize=None):
if self._got_empty_message:
self._got_empty_message = False
return io.BytesIO()
else:
bsize = 128 if maxsize is None else min(maxsize, 128)
try:
ov, err = _winapi.ReadFile(self._handle, bsize,
overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
waitres = _winapi.WaitForMultipleObjects(
[ov.event], False, INFINITE)
assert waitres == WAIT_OBJECT_0
except:
ov.cancel()
raise
finally:
nread, err = ov.GetOverlappedResult(True)
if err == 0:
f = io.BytesIO()
f.write(ov.getbuffer())
return f
elif err == _winapi.ERROR_MORE_DATA:
return self._get_more_data(ov, maxsize)
except OSError as e:
if e.winerror == _winapi.ERROR_BROKEN_PIPE:
raise EOFError
else:
raise
raise RuntimeError(
"shouldn't get here; expected KeyboardInterrupt"
)
def _poll(self, timeout):
if (self._got_empty_message or
_winapi.PeekNamedPipe(self._handle)[0] != 0):
return True
return bool(wait([self], timeout))
def _get_more_data(self, ov, maxsize):
buf = ov.getbuffer()
f = io.BytesIO()
f.write(buf)
left = _winapi.PeekNamedPipe(self._handle)[1]
assert left > 0
if maxsize is not None and len(buf) + left > maxsize:
self._bad_message_length()
ov, err = _winapi.ReadFile(self._handle, left, overlapped=True)
rbytes, err = ov.GetOverlappedResult(True)
assert err == 0
assert rbytes == left
f.write(ov.getbuffer())
return f
class Connection(_ConnectionBase):
"""
Connection class based on an arbitrary file descriptor (Unix only), or
a socket handle (Windows).
"""
if _winapi:
def _close(self, _close=_multiprocessing.closesocket):
_close(self._handle)
_write = _multiprocessing.send
_read = _multiprocessing.recv
else:
def _close(self, _close=os.close): # noqa
_close(self._handle)
_write = os.write
_read = os.read
def send_offset(self, buf, offset, write=_write):
return write(self._handle, buf[offset:])
def _send(self, buf, write=_write):
remaining = len(buf)
while True:
try:
n = write(self._handle, buf)
except OSError as exc:
if exc.errno == errno.EINTR:
continue
raise
remaining -= n
if remaining == 0:
break
buf = buf[n:]
def setblocking(self, blocking):
setblocking(self._handle, blocking)
def _recv(self, size, read=_read):
buf = io.BytesIO()
handle = self._handle
remaining = size
while remaining > 0:
try:
chunk = read(handle, remaining)
except OSError as exc:
if exc.errno == errno.EINTR:
continue
raise
n = len(chunk)
if n == 0:
if remaining == size:
raise EOFError
else:
raise OSError("got end of file during message")
buf.write(chunk)
remaining -= n
return buf
def _send_bytes(self, buf):
# For wire compatibility with 3.2 and lower
n = len(buf)
self._send(struct.pack("!i", n))
# The condition is necessary to avoid "broken pipe" errors
# when sending a 0-length buffer if the other end closed the pipe.
if n > 0:
self._send(buf)
def _recv_bytes(self, maxsize=None):
buf = self._recv(4)
size, = struct.unpack("!i", buf.getvalue())
if maxsize is not None and size > maxsize:
return None
return self._recv(size)
def _poll(self, timeout):
r = wait([self], timeout)
return bool(r)
#
# Public functions
#
class Listener(object):
'''
Returns a listener object.
This is a wrapper for a bound socket which is 'listening' for
connections, or for a Windows named pipe.
'''
def __init__(self, address=None, family=None, backlog=1, authkey=None):
family = (family or (address and address_type(address))
or default_family)
address = address or arbitrary_address(family)
_validate_family(family)
if family == 'AF_PIPE':
self._listener = PipeListener(address, backlog)
else:
self._listener = SocketListener(address, family, backlog)
if authkey is not None and not isinstance(authkey, bytes):
raise TypeError('authkey should be a byte string')
self._authkey = authkey
def accept(self):
'''
Accept a connection on the bound socket or named pipe of `self`.
Returns a `Connection` object.
'''
if self._listener is None:
raise OSError('listener is closed')
c = self._listener.accept()
if self._authkey:
deliver_challenge(c, self._authkey)
answer_challenge(c, self._authkey)
return c
def close(self):
'''
Close the bound socket or named pipe of `self`.
'''
if self._listener is not None:
self._listener.close()
self._listener = None
address = property(lambda self: self._listener._address)
last_accepted = property(lambda self: self._listener._last_accepted)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
self.close()
def Client(address, family=None, authkey=None):
'''
Returns a connection to the address of a `Listener`
'''
family = family or address_type(address)
_validate_family(family)
if family == 'AF_PIPE':
c = PipeClient(address)
else:
c = SocketClient(address)
if authkey is not None and not isinstance(authkey, bytes):
raise TypeError('authkey should be a byte string')
if authkey is not None:
answer_challenge(c, authkey)
deliver_challenge(c, authkey)
return c
if sys.platform != 'win32':
def Pipe(duplex=True, rnonblock=False, wnonblock=False):
'''
Returns pair of connection objects at either end of a pipe
'''
if duplex:
s1, s2 = socket.socketpair()
s1.setblocking(not rnonblock)
s2.setblocking(not wnonblock)
c1 = Connection(s1.detach())
c2 = Connection(s2.detach())
else:
fd1, fd2 = os.pipe()
if rnonblock:
setblocking(fd1, 0)
if wnonblock:
setblocking(fd2, 0)
c1 = Connection(fd1, writable=False)
c2 = Connection(fd2, readable=False)
return c1, c2
else:
def Pipe(duplex=True, rnonblock=False, wnonblock=False): # noqa
'''
Returns pair of connection objects at either end of a pipe
'''
address = arbitrary_address('AF_PIPE')
if duplex:
openmode = _winapi.PIPE_ACCESS_DUPLEX
access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE
obsize, ibsize = BUFSIZE, BUFSIZE
else:
openmode = _winapi.PIPE_ACCESS_INBOUND
access = _winapi.GENERIC_WRITE
obsize, ibsize = 0, BUFSIZE
h1 = _winapi.CreateNamedPipe(
address, openmode | _winapi.FILE_FLAG_OVERLAPPED |
_winapi.FILE_FLAG_FIRST_PIPE_INSTANCE,
_winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE |
_winapi.PIPE_WAIT,
1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL
)
h2 = _winapi.CreateFile(
address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING,
_winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL
)
_winapi.SetNamedPipeHandleState(
h2, _winapi.PIPE_READMODE_MESSAGE, None, None
)
overlapped = _winapi.ConnectNamedPipe(h1, overlapped=True)
_, err = overlapped.GetOverlappedResult(True)
assert err == 0
c1 = PipeConnection(h1, writable=duplex)
c2 = PipeConnection(h2, readable=duplex)
return c1, c2
#
# Definitions for connections based on sockets
#
class SocketListener(object):
'''
Representation of a socket which is bound to an address and listening
'''
def __init__(self, address, family, backlog=1):
self._socket = socket.socket(getattr(socket, family))
try:
# SO_REUSEADDR has different semantics on Windows (issue #2550).
if os.name == 'posix':
self._socket.setsockopt(socket.SOL_SOCKET,
socket.SO_REUSEADDR, 1)
self._socket.setblocking(True)
self._socket.bind(address)
self._socket.listen(backlog)
self._address = self._socket.getsockname()
except OSError:
self._socket.close()
raise
self._family = family
self._last_accepted = None
if family == 'AF_UNIX':
self._unlink = Finalize(
self, os.unlink, args=(address, ), exitpriority=0
)
else:
self._unlink = None
def accept(self):
while True:
try:
s, self._last_accepted = self._socket.accept()
except OSError as exc:
if exc.errno == errno.EINTR:
continue
raise
else:
break
s.setblocking(True)
return Connection(s.detach())
def close(self):
self._socket.close()
if self._unlink is not None:
self._unlink()
def SocketClient(address):
'''
Return a connection object connected to the socket given by `address`
'''
family = address_type(address)
with socket.socket(getattr(socket, family)) as s:
s.setblocking(True)
s.connect(address)
return Connection(s.detach())
#
# Definitions for connections based on named pipes
#
if sys.platform == 'win32':
class PipeListener(object):
'''
Representation of a named pipe
'''
def __init__(self, address, backlog=None):
self._address = address
self._handle_queue = [self._new_handle(first=True)]
self._last_accepted = None
sub_debug('listener created with address=%r', self._address)
self.close = Finalize(
self, PipeListener._finalize_pipe_listener,
args=(self._handle_queue, self._address), exitpriority=0
)
def _new_handle(self, first=False):
flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED
if first:
flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
return _winapi.CreateNamedPipe(
self._address, flags,
_winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE |
_winapi.PIPE_WAIT,
_winapi.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
_winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL
)
def accept(self):
self._handle_queue.append(self._new_handle())
handle = self._handle_queue.pop(0)
try:
ov = _winapi.ConnectNamedPipe(handle, overlapped=True)
except OSError as e:
if e.winerror != _winapi.ERROR_NO_DATA:
raise
# ERROR_NO_DATA can occur if a client has already connected,
# written data and then disconnected -- see Issue 14725.
else:
try:
_winapi.WaitForMultipleObjects([ov.event], False, INFINITE)
except:
ov.cancel()
_winapi.CloseHandle(handle)
raise
finally:
_, err = ov.GetOverlappedResult(True)
assert err == 0
return PipeConnection(handle)
@staticmethod
def _finalize_pipe_listener(queue, address):
sub_debug('closing listener with address=%r', address)
for handle in queue:
_winapi.CloseHandle(handle)
def PipeClient(address,
errors=(_winapi.ERROR_SEM_TIMEOUT,
_winapi.ERROR_PIPE_BUSY)):
'''
Return a connection object connected to the pipe given by `address`
'''
t = _init_timeout()
while 1:
try:
_winapi.WaitNamedPipe(address, 1000)
h = _winapi.CreateFile(
address, _winapi.GENERIC_READ | _winapi.GENERIC_WRITE,
0, _winapi.NULL, _winapi.OPEN_EXISTING,
_winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL
)
except OSError as e:
if e.winerror not in errors or _check_timeout(t):
raise
else:
break
else:
raise
_winapi.SetNamedPipeHandleState(
h, _winapi.PIPE_READMODE_MESSAGE, None, None
)
return PipeConnection(h)
#
# Authentication stuff
#
MESSAGE_LENGTH = 20
CHALLENGE = b'#CHALLENGE#'
WELCOME = b'#WELCOME#'
FAILURE = b'#FAILURE#'
def deliver_challenge(connection, authkey):
import hmac
assert isinstance(authkey, bytes)
message = os.urandom(MESSAGE_LENGTH)
connection.send_bytes(CHALLENGE + message)
digest = hmac.new(authkey, message).digest()
response = connection.recv_bytes(256) # reject large message
if response == digest:
connection.send_bytes(WELCOME)
else:
connection.send_bytes(FAILURE)
raise AuthenticationError('digest received was wrong')
def answer_challenge(connection, authkey):
import hmac
assert isinstance(authkey, bytes)
message = connection.recv_bytes(256) # reject large message
assert message[:len(CHALLENGE)] == CHALLENGE, 'message = %r' % message
message = message[len(CHALLENGE):]
digest = hmac.new(authkey, message).digest()
connection.send_bytes(digest)
response = connection.recv_bytes(256) # reject large message
if response != WELCOME:
raise AuthenticationError('digest sent was rejected')
#
# Support for using xmlrpclib for serialization
#
class ConnectionWrapper(object):
def __init__(self, conn, dumps, loads):
self._conn = conn
self._dumps = dumps
self._loads = loads
for attr in ('fileno', 'close', 'poll', 'recv_bytes', 'send_bytes'):
obj = getattr(conn, attr)
setattr(self, attr, obj)
def send(self, obj):
s = self._dumps(obj)
self._conn.send_bytes(s)
def recv(self):
s = self._conn.recv_bytes()
return self._loads(s)
def _xml_dumps(obj):
return xmlrpclib.dumps((obj,), None, None, None, 1).encode('utf-8') # noqa
def _xml_loads(s):
(obj,), method = xmlrpclib.loads(s.decode('utf-8')) # noqa
return obj
class XmlListener(Listener):
def accept(self):
global xmlrpclib
import xmlrpc.client as xmlrpclib # noqa
obj = Listener.accept(self)
return ConnectionWrapper(obj, _xml_dumps, _xml_loads)
def XmlClient(*args, **kwds):
global xmlrpclib
import xmlrpc.client as xmlrpclib # noqa
return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
#
# Wait
#
if sys.platform == 'win32':
def _exhaustive_wait(handles, timeout):
# Return ALL handles which are currently signalled. (Only
# returning the first signalled might create starvation issues.)
L = list(handles)
ready = []
while L:
res = _winapi.WaitForMultipleObjects(L, False, timeout)
if res == WAIT_TIMEOUT:
break
elif WAIT_OBJECT_0 <= res < WAIT_OBJECT_0 + len(L):
res -= WAIT_OBJECT_0
elif WAIT_ABANDONED_0 <= res < WAIT_ABANDONED_0 + len(L):
res -= WAIT_ABANDONED_0
else:
raise RuntimeError('Should not get here')
ready.append(L[res])
L = L[res+1:]
timeout = 0
return ready
_ready_errors = {_winapi.ERROR_BROKEN_PIPE, _winapi.ERROR_NETNAME_DELETED}
def wait(object_list, timeout=None):
'''
Wait till an object in object_list is ready/readable.
Returns list of those objects in object_list which are ready/readable.
'''
if timeout is None:
timeout = INFINITE
elif timeout < 0:
timeout = 0
else:
timeout = int(timeout * 1000 + 0.5)
object_list = list(object_list)
waithandle_to_obj = {}
ov_list = []
ready_objects = set()
ready_handles = set()
try:
for o in object_list:
try:
fileno = getattr(o, 'fileno')
except AttributeError:
waithandle_to_obj[o.__index__()] = o
else:
# start an overlapped read of length zero
try:
ov, err = _winapi.ReadFile(fileno(), 0, True)
except OSError as e:
err = e.winerror
if err not in _ready_errors:
raise
if err == _winapi.ERROR_IO_PENDING:
ov_list.append(ov)
waithandle_to_obj[ov.event] = o
else:
# If o.fileno() is an overlapped pipe handle and
# err == 0 then there is a zero length message
# in the pipe, but it HAS NOT been consumed.
ready_objects.add(o)
timeout = 0
ready_handles = _exhaustive_wait(waithandle_to_obj.keys(), timeout)
finally:
# request that overlapped reads stop
for ov in ov_list:
ov.cancel()
# wait for all overlapped reads to stop
for ov in ov_list:
try:
_, err = ov.GetOverlappedResult(True)
except OSError as e:
err = e.winerror
if err not in _ready_errors:
raise
if err != _winapi.ERROR_OPERATION_ABORTED:
o = waithandle_to_obj[ov.event]
ready_objects.add(o)
if err == 0:
# If o.fileno() is an overlapped pipe handle then
# a zero length message HAS been consumed.
if hasattr(o, '_got_empty_message'):
o._got_empty_message = True
ready_objects.update(waithandle_to_obj[h] for h in ready_handles)
return [o for o in object_list if o in ready_objects]
else:
if hasattr(select, 'poll'):
def _poll(fds, timeout):
if timeout is not None:
timeout = int(timeout * 1000) # timeout is in milliseconds
fd_map = {}
pollster = select.poll()
for fd in fds:
pollster.register(fd, select.POLLIN)
if hasattr(fd, 'fileno'):
fd_map[fd.fileno()] = fd
else:
fd_map[fd] = fd
ls = []
for fd, event in pollster.poll(timeout):
if event & select.POLLNVAL:
raise ValueError('invalid file descriptor %i' % fd)
ls.append(fd_map[fd])
return ls
else:
def _poll(fds, timeout): # noqa
return select.select(fds, [], [], timeout)[0]
def wait(object_list, timeout=None): # noqa
'''
Wait till an object in object_list is ready/readable.
Returns list of those objects in object_list which are ready/readable.
'''
if timeout is not None:
if timeout <= 0:
return _poll(object_list, 0)
else:
deadline = monotonic() + timeout
while True:
try:
return _poll(object_list, timeout)
except OSError as e:
if e.errno != errno.EINTR:
raise
if timeout is not None:
timeout = deadline - monotonic()

View File

@@ -4,10 +4,7 @@ import sys
supports_exec = True supports_exec = True
try: from .compat import _winapi as win32 # noqa
import _winapi as win32
except ImportError: # pragma: no cover
win32 = None
if sys.platform.startswith("java"): if sys.platform.startswith("java"):
_billiard = None _billiard = None
@@ -20,11 +17,9 @@ else:
try: try:
Connection = _billiard.Connection Connection = _billiard.Connection
except AttributeError: # Py3 except AttributeError: # Py3
from multiprocessing.connection import Connection # noqa from billiard.connection import Connection # noqa
PipeConnection = getattr(_billiard, "PipeConnection", None) PipeConnection = getattr(_billiard, "PipeConnection", None)
if win32 is None:
win32 = getattr(_billiard, "win32", None) # noqa
def ensure_multiprocessing(): def ensure_multiprocessing():

View File

@@ -0,0 +1,244 @@
#
# Module to allow connection and socket objects to be transferred
# between processes
#
# multiprocessing/reduction.py
#
# Copyright (c) 2006-2008, R Oudkerk
# Licensed to PSF under a Contributor Agreement.
#
from __future__ import absolute_import
__all__ = []
import os
import sys
import socket
import threading
from pickle import Pickler
from . import current_process
from ._ext import _billiard, win32
from .util import register_after_fork, debug, sub_debug
if not(sys.platform == 'win32' or hasattr(_billiard, 'recvfd')):
raise ImportError('pickling of connections not supported')
close = win32.CloseHandle if sys.platform == 'win32' else os.close
# globals set later
_listener = None
_lock = None
_cache = set()
#
# ForkingPickler
#
class ForkingPickler(Pickler): # noqa
dispatch = Pickler.dispatch.copy()
@classmethod
def register(cls, type, reduce):
def dispatcher(self, obj):
rv = reduce(obj)
self.save_reduce(obj=obj, *rv)
cls.dispatch[type] = dispatcher
def _reduce_method(m): # noqa
if m.__self__ is None:
return getattr, (m.__self__.__class__, m.__func__.__name__)
else:
return getattr, (m.__self__, m.__func__.__name__)
ForkingPickler.register(type(ForkingPickler.save), _reduce_method)
def _reduce_method_descriptor(m):
return getattr, (m.__objclass__, m.__name__)
ForkingPickler.register(type(list.append), _reduce_method_descriptor)
ForkingPickler.register(type(int.__add__), _reduce_method_descriptor)
try:
from functools import partial
except ImportError:
pass
else:
def _reduce_partial(p):
return _rebuild_partial, (p.func, p.args, p.keywords or {})
def _rebuild_partial(func, args, keywords):
return partial(func, *args, **keywords)
ForkingPickler.register(partial, _reduce_partial)
def dump(obj, file, protocol=None):
ForkingPickler(file, protocol).dump(obj)
#
# Platform specific definitions
#
if sys.platform == 'win32':
# XXX Should this subprocess import be here?
import _subprocess # noqa
def send_handle(conn, handle, destination_pid):
from .forking import duplicate
process_handle = win32.OpenProcess(
win32.PROCESS_ALL_ACCESS, False, destination_pid
)
try:
new_handle = duplicate(handle, process_handle)
conn.send(new_handle)
finally:
close(process_handle)
def recv_handle(conn):
return conn.recv()
else:
def send_handle(conn, handle, destination_pid): # noqa
_billiard.sendfd(conn.fileno(), handle)
def recv_handle(conn): # noqa
return _billiard.recvfd(conn.fileno())
#
# Support for a per-process server thread which caches pickled handles
#
def _reset(obj):
global _lock, _listener, _cache
for h in _cache:
close(h)
_cache.clear()
_lock = threading.Lock()
_listener = None
_reset(None)
register_after_fork(_reset, _reset)
def _get_listener():
global _listener
if _listener is None:
_lock.acquire()
try:
if _listener is None:
from .connection import Listener
debug('starting listener and thread for sending handles')
_listener = Listener(authkey=current_process().authkey)
t = threading.Thread(target=_serve)
t.daemon = True
t.start()
finally:
_lock.release()
return _listener
def _serve():
from .util import is_exiting, sub_warning
while 1:
try:
conn = _listener.accept()
handle_wanted, destination_pid = conn.recv()
_cache.remove(handle_wanted)
send_handle(conn, handle_wanted, destination_pid)
close(handle_wanted)
conn.close()
except:
if not is_exiting():
sub_warning('thread for sharing handles raised exception',
exc_info=True)
#
# Functions to be used for pickling/unpickling objects with handles
#
def reduce_handle(handle):
from .forking import Popen, duplicate
if Popen.thread_is_spawning():
return (None, Popen.duplicate_for_child(handle), True)
dup_handle = duplicate(handle)
_cache.add(dup_handle)
sub_debug('reducing handle %d', handle)
return (_get_listener().address, dup_handle, False)
def rebuild_handle(pickled_data):
from .connection import Client
address, handle, inherited = pickled_data
if inherited:
return handle
sub_debug('rebuilding handle %d', handle)
conn = Client(address, authkey=current_process().authkey)
conn.send((handle, os.getpid()))
new_handle = recv_handle(conn)
conn.close()
return new_handle
#
# Register `_billiard.Connection` with `ForkingPickler`
#
def reduce_connection(conn):
rh = reduce_handle(conn.fileno())
return rebuild_connection, (rh, conn.readable, conn.writable)
def rebuild_connection(reduced_handle, readable, writable):
handle = rebuild_handle(reduced_handle)
return _billiard.Connection(
handle, readable=readable, writable=writable
)
# Register `socket.socket` with `ForkingPickler`
#
def fromfd(fd, family, type_, proto=0):
s = socket.fromfd(fd, family, type_, proto)
if s.__class__ is not socket.socket:
s = socket.socket(_sock=s)
return s
def reduce_socket(s):
reduced_handle = reduce_handle(s.fileno())
return rebuild_socket, (reduced_handle, s.family, s.type, s.proto)
def rebuild_socket(reduced_handle, family, type_, proto):
fd = rebuild_handle(reduced_handle)
_sock = fromfd(fd, family, type_, proto)
close(fd)
return _sock
ForkingPickler.register(socket.socket, reduce_socket)
#
# Register `_billiard.PipeConnection` with `ForkingPickler`
#
if sys.platform == 'win32':
def reduce_pipe_connection(conn):
rh = reduce_handle(conn.fileno())
return rebuild_pipe_connection, (rh, conn.readable, conn.writable)
def rebuild_pipe_connection(reduced_handle, readable, writable):
handle = rebuild_handle(reduced_handle)
return _billiard.PipeConnection(
handle, readable=readable, writable=writable
)

View File

@@ -0,0 +1,249 @@
#
# Module which deals with pickling of objects.
#
# multiprocessing/reduction.py
#
# Copyright (c) 2006-2008, R Oudkerk
# Licensed to PSF under a Contributor Agreement.
#
from __future__ import absolute_import
import copyreg
import functools
import io
import os
import pickle
import socket
import sys
__all__ = ['send_handle', 'recv_handle', 'ForkingPickler', 'register', 'dump']
HAVE_SEND_HANDLE = (sys.platform == 'win32' or
(hasattr(socket, 'CMSG_LEN') and
hasattr(socket, 'SCM_RIGHTS') and
hasattr(socket.socket, 'sendmsg')))
#
# Pickler subclass
#
class ForkingPickler(pickle.Pickler):
'''Pickler subclass used by multiprocessing.'''
_extra_reducers = {}
_copyreg_dispatch_table = copyreg.dispatch_table
def __init__(self, *args):
super().__init__(*args)
self.dispatch_table = self._copyreg_dispatch_table.copy()
self.dispatch_table.update(self._extra_reducers)
@classmethod
def register(cls, type, reduce):
'''Register a reduce function for a type.'''
cls._extra_reducers[type] = reduce
@classmethod
def dumps(cls, obj, protocol=None):
buf = io.BytesIO()
cls(buf, protocol).dump(obj)
return buf.getbuffer()
loads = pickle.loads
register = ForkingPickler.register
def dump(obj, file, protocol=None):
'''Replacement for pickle.dump() using ForkingPickler.'''
ForkingPickler(file, protocol).dump(obj)
#
# Platform specific definitions
#
if sys.platform == 'win32':
# Windows
__all__ += ['DupHandle', 'duplicate', 'steal_handle']
import _winapi
def duplicate(handle, target_process=None, inheritable=False):
'''Duplicate a handle. (target_process is a handle not a pid!)'''
if target_process is None:
target_process = _winapi.GetCurrentProcess()
return _winapi.DuplicateHandle(
_winapi.GetCurrentProcess(), handle, target_process,
0, inheritable, _winapi.DUPLICATE_SAME_ACCESS)
def steal_handle(source_pid, handle):
'''Steal a handle from process identified by source_pid.'''
source_process_handle = _winapi.OpenProcess(
_winapi.PROCESS_DUP_HANDLE, False, source_pid)
try:
return _winapi.DuplicateHandle(
source_process_handle, handle,
_winapi.GetCurrentProcess(), 0, False,
_winapi.DUPLICATE_SAME_ACCESS | _winapi.DUPLICATE_CLOSE_SOURCE)
finally:
_winapi.CloseHandle(source_process_handle)
def send_handle(conn, handle, destination_pid):
'''Send a handle over a local connection.'''
dh = DupHandle(handle, _winapi.DUPLICATE_SAME_ACCESS, destination_pid)
conn.send(dh)
def recv_handle(conn):
'''Receive a handle over a local connection.'''
return conn.recv().detach()
class DupHandle(object):
'''Picklable wrapper for a handle.'''
def __init__(self, handle, access, pid=None):
if pid is None:
# We just duplicate the handle in the current process and
# let the receiving process steal the handle.
pid = os.getpid()
proc = _winapi.OpenProcess(_winapi.PROCESS_DUP_HANDLE, False, pid)
try:
self._handle = _winapi.DuplicateHandle(
_winapi.GetCurrentProcess(),
handle, proc, access, False, 0)
finally:
_winapi.CloseHandle(proc)
self._access = access
self._pid = pid
def detach(self):
'''Get the handle. This should only be called once.'''
# retrieve handle from process which currently owns it
if self._pid == os.getpid():
# The handle has already been duplicated for this process.
return self._handle
# We must steal the handle from the process whose pid is self._pid.
proc = _winapi.OpenProcess(_winapi.PROCESS_DUP_HANDLE, False,
self._pid)
try:
return _winapi.DuplicateHandle(
proc, self._handle, _winapi.GetCurrentProcess(),
self._access, False, _winapi.DUPLICATE_CLOSE_SOURCE)
finally:
_winapi.CloseHandle(proc)
else:
# Unix
__all__ += ['DupFd', 'sendfds', 'recvfds']
import array
# On MacOSX we should acknowledge receipt of fds -- see Issue14669
ACKNOWLEDGE = sys.platform == 'darwin'
def sendfds(sock, fds):
'''Send an array of fds over an AF_UNIX socket.'''
fds = array.array('i', fds)
msg = bytes([len(fds) % 256])
sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)])
if ACKNOWLEDGE and sock.recv(1) != b'A':
raise RuntimeError('did not receive acknowledgement of fd')
def recvfds(sock, size):
'''Receive an array of fds over an AF_UNIX socket.'''
a = array.array('i')
bytes_size = a.itemsize * size
msg, ancdata, flags, addr = sock.recvmsg(
1, socket.CMSG_LEN(bytes_size),
)
if not msg and not ancdata:
raise EOFError
try:
if ACKNOWLEDGE:
sock.send(b'A')
if len(ancdata) != 1:
raise RuntimeError(
'received %d items of ancdata' % len(ancdata),
)
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
if (cmsg_level == socket.SOL_SOCKET and
cmsg_type == socket.SCM_RIGHTS):
if len(cmsg_data) % a.itemsize != 0:
raise ValueError
a.frombytes(cmsg_data)
assert len(a) % 256 == msg[0]
return list(a)
except (ValueError, IndexError):
pass
raise RuntimeError('Invalid data received')
def send_handle(conn, handle, destination_pid): # noqa
'''Send a handle over a local connection.'''
fd = conn.fileno()
with socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) as s:
sendfds(s, [handle])
def recv_handle(conn): # noqa
'''Receive a handle over a local connection.'''
fd = conn.fileno()
with socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) as s:
return recvfds(s, 1)[0]
def DupFd(fd):
'''Return a wrapper for an fd.'''
from .forking import Popen
return Popen.duplicate_for_child(fd)
#
# Try making some callable types picklable
#
def _reduce_method(m):
if m.__self__ is None:
return getattr, (m.__class__, m.__func__.__name__)
else:
return getattr, (m.__self__, m.__func__.__name__)
class _C:
def f(self):
pass
register(type(_C().f), _reduce_method)
def _reduce_method_descriptor(m):
return getattr, (m.__objclass__, m.__name__)
register(type(list.append), _reduce_method_descriptor)
register(type(int.__add__), _reduce_method_descriptor)
def _reduce_partial(p):
return _rebuild_partial, (p.func, p.args, p.keywords or {})
def _rebuild_partial(func, args, keywords):
return functools.partial(func, *args, **keywords)
register(functools.partial, _reduce_partial)
#
# Make sockets picklable
#
if sys.platform == 'win32':
def _reduce_socket(s):
from .resource_sharer import DupSocket
return _rebuild_socket, (DupSocket(s),)
def _rebuild_socket(ds):
return ds.detach()
register(socket.socket, _reduce_socket)
else:
def _reduce_socket(s): # noqa
df = DupFd(s.fileno())
return _rebuild_socket, (df, s.family, s.type, s.proto)
def _rebuild_socket(df, family, type, proto): # noqa
fd = df.detach()
return socket.socket(family, type, proto, fileno=fd)
register(socket.socket, _reduce_socket)

View File

@@ -88,7 +88,7 @@ def get_all_processes_pids():
def get_processtree_pids(pid, include_parent=True): def get_processtree_pids(pid, include_parent=True):
"""Return a list with all the pids of a process tree""" """Return a list with all the pids of a process tree"""
parents = get_all_processes_pids() parents = get_all_processes_pids()
all_pids = parents.keys() all_pids = list(parents.keys())
pids = set([pid]) pids = set([pid])
while 1: while 1:
pids_new = pids.copy() pids_new = pids.copy()

View File

@@ -4,10 +4,10 @@ This module contains utilities added by billiard, to keep
"non-core" functionality out of ``.util``.""" "non-core" functionality out of ``.util``."""
from __future__ import absolute_import from __future__ import absolute_import
import os
import signal import signal
import sys import sys
from time import time
import pickle as pypickle import pickle as pypickle
try: try:
import cPickle as cpickle import cPickle as cpickle
@@ -15,6 +15,7 @@ except ImportError: # pragma: no cover
cpickle = None # noqa cpickle = None # noqa
from .exceptions import RestartFreqExceeded from .exceptions import RestartFreqExceeded
from .five import monotonic
if sys.version_info < (2, 6): # pragma: no cover if sys.version_info < (2, 6): # pragma: no cover
# cPickle does not use absolute_imports # cPickle does not use absolute_imports
@@ -36,16 +37,15 @@ else:
except ImportError: except ImportError:
from StringIO import StringIO as BytesIO # noqa from StringIO import StringIO as BytesIO # noqa
EX_SOFTWARE = 70
TERMSIGS = ( TERMSIGS = (
'SIGHUP', 'SIGHUP',
'SIGQUIT', 'SIGQUIT',
'SIGILL',
'SIGTRAP', 'SIGTRAP',
'SIGABRT', 'SIGABRT',
'SIGEMT', 'SIGEMT',
'SIGFPE',
'SIGBUS', 'SIGBUS',
'SIGSEGV',
'SIGSYS', 'SIGSYS',
'SIGPIPE', 'SIGPIPE',
'SIGALRM', 'SIGALRM',
@@ -58,13 +58,33 @@ TERMSIGS = (
'SIGUSR2', 'SIGUSR2',
) )
#: set by signal handlers just before calling exit.
#: if this is true after the sighandler returns it means that something
#: went wrong while terminating the process, and :func:`os._exit`
#: must be called ASAP.
_should_have_exited = [False]
def pickle_loads(s, load=pickle_load): def pickle_loads(s, load=pickle_load):
# used to support buffer objects # used to support buffer objects
return load(BytesIO(s)) return load(BytesIO(s))
def maybe_setsignal(signum, handler):
try:
signal.signal(signum, handler)
except (OSError, AttributeError, ValueError, RuntimeError):
pass
def _shutdown_cleanup(signum, frame): def _shutdown_cleanup(signum, frame):
# we will exit here so if the signal is received a second time
# we can be sure that something is very wrong and we may be in
# a crashing loop.
if _should_have_exited[0]:
os._exit(EX_SOFTWARE)
maybe_setsignal(signum, signal.SIG_DFL)
_should_have_exited[0] = True
sys.exit(-(256 - signum)) sys.exit(-(256 - signum))
@@ -72,11 +92,12 @@ def reset_signals(handler=_shutdown_cleanup):
for sig in TERMSIGS: for sig in TERMSIGS:
try: try:
signum = getattr(signal, sig) signum = getattr(signal, sig)
except AttributeError:
pass
else:
current = signal.getsignal(signum) current = signal.getsignal(signum)
if current is not None and current != signal.SIG_IGN: if current is not None and current != signal.SIG_IGN:
signal.signal(signum, handler) maybe_setsignal(signum, handler)
except (OSError, AttributeError, ValueError, RuntimeError):
pass
class restart_state(object): class restart_state(object):
@@ -87,7 +108,7 @@ class restart_state(object):
self.R, self.T = 0, None self.R, self.T = 0, None
def step(self, now=None): def step(self, now=None):
now = time() if now is None else now now = monotonic() if now is None else now
R = self.R R = self.R
if self.T and now - self.T >= self.maxT: if self.T and now - self.T >= self.maxT:
# maxT passed, reset counter and time passed. # maxT passed, reset counter and time passed.
@@ -98,9 +119,8 @@ class restart_state(object):
# the startup probably went fine (startup restart burst # the startup probably went fine (startup restart burst
# protection) # protection)
if self.R: # pragma: no cover if self.R: # pragma: no cover
pass self.R = 0 # reset in case someone catches the error
self.R = 0 # reset in case someone catches the error raise self.RestartFreqExceeded("%r in %rs" % (R, self.maxT))
raise self.RestartFreqExceeded("%r in %rs" % (R, self.maxT))
# first run sets T # first run sets T
if self.T is None: if self.T is None:
self.T = now self.T = now

View File

@@ -3,13 +3,50 @@ from __future__ import absolute_import
import errno import errno
import os import os
import sys import sys
import __builtin__
from .five import builtins, range
if sys.platform == 'win32':
try:
import _winapi # noqa
except ImportError: # pragma: no cover
try:
from _billiard import win32 as _winapi # noqa
except (ImportError, AttributeError):
from _multiprocessing import win32 as _winapi # noqa
else:
_winapi = None # noqa
try:
buf_t, is_new_buffer = memoryview, True # noqa
except NameError: # Py2.6
buf_t, is_new_buffer = buffer, False # noqa
if hasattr(os, 'write'):
__write__ = os.write
if is_new_buffer:
def send_offset(fd, buf, offset):
return __write__(fd, buf[offset:])
else: # Py2.6
def send_offset(fd, buf, offset): # noqa
return __write__(fd, buf_t(buf, offset))
else: # non-posix platform
def send_offset(fd, buf, offset): # noqa
raise NotImplementedError('send_offset')
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
bytes = bytes bytes = bytes
else: else:
try: try:
_bytes = __builtin__.bytes _bytes = builtins.bytes
except AttributeError: except AttributeError:
_bytes = str _bytes = str
@@ -25,10 +62,10 @@ try:
except AttributeError: except AttributeError:
def closerange(fd_low, fd_high): # noqa def closerange(fd_low, fd_high): # noqa
for fd in reversed(xrange(fd_low, fd_high)): for fd in reversed(range(fd_low, fd_high)):
try: try:
os.close(fd) os.close(fd)
except OSError, exc: except OSError as exc:
if exc.errno != errno.EBADF: if exc.errno != errno.EBADF:
raise raise
@@ -46,3 +83,26 @@ def get_errno(exc):
except AttributeError: except AttributeError:
pass pass
return 0 return 0
if sys.platform == 'win32':
def setblocking(handle, blocking):
raise NotImplementedError('setblocking not implemented on win32')
def isblocking(handle):
raise NotImplementedError('isblocking not implemented on win32')
else:
from os import O_NONBLOCK
from fcntl import fcntl, F_GETFL, F_SETFL
def isblocking(handle): # noqa
return not (fcntl(handle, F_GETFL) & O_NONBLOCK)
def setblocking(handle, blocking): # noqa
flags = fcntl(handle, F_GETFL, 0)
fcntl(
handle, F_SETFL,
flags & (~O_NONBLOCK) if blocking else flags | O_NONBLOCK,
)

View File

@@ -1,11 +1,27 @@
from __future__ import absolute_import from __future__ import absolute_import
import sys import sys
is_pypy = hasattr(sys, 'pypy_version_info')
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
from multiprocessing import connection from . import _connection3 as connection
else: else:
from billiard import _connection as connection # noqa from . import _connection as connection # noqa
if is_pypy:
import _multiprocessing
from .compat import setblocking, send_offset
class Connection(_multiprocessing.Connection):
def send_offset(self, buf, offset):
return send_offset(self.fileno(), buf, offset)
def setblocking(self, blocking):
setblocking(self.fileno(), blocking)
_multiprocessing.Connection = Connection
sys.modules[__name__] = connection sys.modules[__name__] = connection

View File

@@ -50,12 +50,10 @@ import array
from threading import Lock, RLock, Semaphore, BoundedSemaphore from threading import Lock, RLock, Semaphore, BoundedSemaphore
from threading import Event from threading import Event
from Queue import Queue
if sys.version_info[0] == 3: from billiard.five import Queue
from multiprocessing.connection import Pipe
else: from billiard.connection import Pipe
from billiard._connection import Pipe
class DummyProcess(threading.Thread): class DummyProcess(threading.Thread):
@@ -91,7 +89,7 @@ class Condition(_Condition):
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
notify_all = _Condition.notifyAll notify_all = _Condition.notifyAll
else: else:
notify_all = _Condition.notifyAll.im_func notify_all = _Condition.notifyAll.__func__
Process = DummyProcess Process = DummyProcess
@@ -117,7 +115,7 @@ class Namespace(object):
self.__dict__.update(kwds) self.__dict__.update(kwds)
def __repr__(self): def __repr__(self):
items = self.__dict__.items() items = list(self.__dict__.items())
temp = [] temp = []
for name, value in items: for name, value in items:
if not name.startswith('_'): if not name.startswith('_'):

View File

@@ -35,7 +35,7 @@ from __future__ import absolute_import
__all__ = ['Client', 'Listener', 'Pipe'] __all__ = ['Client', 'Listener', 'Pipe']
from Queue import Queue from billiard.five import Queue
families = [None] families = [None]

View File

@@ -32,7 +32,7 @@ class _Frame(object):
class _Object(object): class _Object(object):
def __init__(self, **kw): def __init__(self, **kw):
[setattr(self, k, v) for k, v in kw.iteritems()] [setattr(self, k, v) for k, v in kw.items()]
class _Truncated(object): class _Truncated(object):

View File

@@ -0,0 +1,189 @@
# -*- coding: utf-8 -*-
"""
celery.five
~~~~~~~~~~~
Compatibility implementations of features
only available in newer Python versions.
"""
from __future__ import absolute_import
############## py3k #########################################################
import sys
PY3 = sys.version_info[0] == 3
try:
reload = reload # noqa
except NameError: # pragma: no cover
from imp import reload # noqa
try:
from UserList import UserList # noqa
except ImportError: # pragma: no cover
from collections import UserList # noqa
try:
from UserDict import UserDict # noqa
except ImportError: # pragma: no cover
from collections import UserDict # noqa
############## time.monotonic ################################################
if sys.version_info < (3, 3):
import platform
SYSTEM = platform.system()
if SYSTEM == 'Darwin':
import ctypes
libSystem = ctypes.CDLL('libSystem.dylib')
CoreServices = ctypes.CDLL(
'/System/Library/Frameworks/CoreServices.framework/CoreServices',
use_errno=True,
)
mach_absolute_time = libSystem.mach_absolute_time
mach_absolute_time.restype = ctypes.c_uint64
absolute_to_nanoseconds = CoreServices.AbsoluteToNanoseconds
absolute_to_nanoseconds.restype = ctypes.c_uint64
absolute_to_nanoseconds.argtypes = [ctypes.c_uint64]
def _monotonic():
return absolute_to_nanoseconds(mach_absolute_time()) * 1e-9
elif SYSTEM == 'Linux':
# from stackoverflow:
# questions/1205722/how-do-i-get-monotonic-time-durations-in-python
import ctypes
import os
CLOCK_MONOTONIC = 1 # see <linux/time.h>
class timespec(ctypes.Structure):
_fields_ = [
('tv_sec', ctypes.c_long),
('tv_nsec', ctypes.c_long),
]
librt = ctypes.CDLL('librt.so.1', use_errno=True)
clock_gettime = librt.clock_gettime
clock_gettime.argtypes = [
ctypes.c_int, ctypes.POINTER(timespec),
]
def _monotonic(): # noqa
t = timespec()
if clock_gettime(CLOCK_MONOTONIC, ctypes.pointer(t)) != 0:
errno_ = ctypes.get_errno()
raise OSError(errno_, os.strerror(errno_))
return t.tv_sec + t.tv_nsec * 1e-9
else:
from time import time as _monotonic
try:
from time import monotonic
except ImportError:
monotonic = _monotonic # noqa
if PY3:
import builtins
from queue import Queue, Empty, Full
from itertools import zip_longest
from io import StringIO, BytesIO
map = map
string = str
string_t = str
long_t = int
text_t = str
range = range
int_types = (int, )
open_fqdn = 'builtins.open'
def items(d):
return d.items()
def keys(d):
return d.keys()
def values(d):
return d.values()
def nextfun(it):
return it.__next__
exec_ = getattr(builtins, 'exec')
def reraise(tp, value, tb=None):
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
class WhateverIO(StringIO):
def write(self, data):
if isinstance(data, bytes):
data = data.encode()
StringIO.write(self, data)
else:
import __builtin__ as builtins # noqa
from Queue import Queue, Empty, Full # noqa
from itertools import imap as map, izip_longest as zip_longest # noqa
from StringIO import StringIO # noqa
string = unicode # noqa
string_t = basestring # noqa
text_t = unicode
long_t = long # noqa
range = xrange
int_types = (int, long)
open_fqdn = '__builtin__.open'
def items(d): # noqa
return d.iteritems()
def keys(d): # noqa
return d.iterkeys()
def values(d): # noqa
return d.itervalues()
def nextfun(it): # noqa
return it.next
def exec_(code, globs=None, locs=None):
"""Execute code in a namespace."""
if globs is None:
frame = sys._getframe(1)
globs = frame.f_globals
if locs is None:
locs = frame.f_locals
del frame
elif locs is None:
locs = globs
exec("""exec code in globs, locs""")
exec_("""def reraise(tp, value, tb=None): raise tp, value, tb""")
BytesIO = WhateverIO = StringIO # noqa
def with_metaclass(Type, skip_attrs=set(['__dict__', '__weakref__'])):
"""Class decorator to set metaclass.
Works with both Python 3 and Python 3 and it does not add
an extra class in the lookup order like ``six.with_metaclass`` does
(that is -- it copies the original class instead of using inheritance).
"""
def _clone_with_metaclass(Class):
attrs = dict((key, value) for key, value in items(vars(Class))
if key not in skip_attrs)
return Type(Class.__name__, Class.__bases__, attrs)
return _clone_with_metaclass

View File

@@ -14,12 +14,15 @@ import sys
import signal import signal
import warnings import warnings
from ._ext import Connection, PipeConnection, win32
from pickle import load, HIGHEST_PROTOCOL from pickle import load, HIGHEST_PROTOCOL
from billiard import util, process from billiard import util
from billiard import process
from billiard.five import int_types
from .reduction import dump
from .compat import _winapi as win32
__all__ = ['Popen', 'assert_spawning', 'exit', __all__ = ['Popen', 'assert_spawning', 'exit',
'duplicate', 'close', 'ForkingPickler'] 'duplicate', 'close']
try: try:
WindowsError = WindowsError # noqa WindowsError = WindowsError # noqa
@@ -53,105 +56,16 @@ def assert_spawning(self):
' through inheritance' % type(self).__name__ ' through inheritance' % type(self).__name__
) )
#
# Try making some callable types picklable
#
from pickle import Pickler
if sys.version_info[0] == 3:
from copyreg import dispatch_table
class ForkingPickler(Pickler):
_extra_reducers = {}
def __init__(self, *args, **kwargs):
Pickler.__init__(self, *args, **kwargs)
self.dispatch_table = dispatch_table.copy()
self.dispatch_table.update(self._extra_reducers)
@classmethod
def register(cls, type, reduce):
cls._extra_reducers[type] = reduce
def _reduce_method(m):
if m.__self__ is None:
return getattr, (m.__class__, m.__func__.__name__)
else:
return getattr, (m.__self__, m.__func__.__name__)
class _C:
def f(self):
pass
ForkingPickler.register(type(_C().f), _reduce_method)
else:
class ForkingPickler(Pickler): # noqa
dispatch = Pickler.dispatch.copy()
@classmethod
def register(cls, type, reduce):
def dispatcher(self, obj):
rv = reduce(obj)
self.save_reduce(obj=obj, *rv)
cls.dispatch[type] = dispatcher
def _reduce_method(m): # noqa
if m.im_self is None:
return getattr, (m.im_class, m.im_func.func_name)
else:
return getattr, (m.im_self, m.im_func.func_name)
ForkingPickler.register(type(ForkingPickler.save), _reduce_method)
def _reduce_method_descriptor(m):
return getattr, (m.__objclass__, m.__name__)
ForkingPickler.register(type(list.append), _reduce_method_descriptor)
ForkingPickler.register(type(int.__add__), _reduce_method_descriptor)
try:
from functools import partial
except ImportError:
pass
else:
def _reduce_partial(p):
return _rebuild_partial, (p.func, p.args, p.keywords or {})
def _rebuild_partial(func, args, keywords):
return partial(func, *args, **keywords)
ForkingPickler.register(partial, _reduce_partial)
def dump(obj, file, protocol=None):
ForkingPickler(file, protocol).dump(obj)
#
# Make (Pipe)Connection picklable
#
def reduce_connection(conn):
# XXX check not necessary since only registered with ForkingPickler
if not Popen.thread_is_spawning():
raise RuntimeError(
'By default %s objects can only be shared between processes\n'
'using inheritance' % type(conn).__name__
)
return type(conn), (Popen.duplicate_for_child(conn.fileno()),
conn.readable, conn.writable)
ForkingPickler.register(Connection, reduce_connection)
if PipeConnection:
ForkingPickler.register(PipeConnection, reduce_connection)
# #
# Unix # Unix
# #
if sys.platform != 'win32': if sys.platform != 'win32':
import thread try:
import thread
except ImportError:
import _thread as thread # noqa
import select import select
WINEXE = False WINEXE = False
@@ -172,6 +86,8 @@ if sys.platform != 'win32':
_tls = thread._local() _tls = thread._local()
def __init__(self, process_obj): def __init__(self, process_obj):
# register reducers
from billiard import connection # noqa
_Django_old_layout_hack__save() _Django_old_layout_hack__save()
sys.stdout.flush() sys.stdout.flush()
sys.stderr.flush() sys.stderr.flush()
@@ -265,9 +181,15 @@ if sys.platform != 'win32':
# #
else: else:
import thread try:
import thread
except ImportError:
import _thread as thread # noqa
import msvcrt import msvcrt
import _subprocess try:
import _subprocess
except ImportError:
import _winapi as _subprocess # noqa
# #
# #
@@ -287,10 +209,14 @@ else:
def duplicate(handle, target_process=None, inheritable=False): def duplicate(handle, target_process=None, inheritable=False):
if target_process is None: if target_process is None:
target_process = _subprocess.GetCurrentProcess() target_process = _subprocess.GetCurrentProcess()
return _subprocess.DuplicateHandle( h = _subprocess.DuplicateHandle(
_subprocess.GetCurrentProcess(), handle, target_process, _subprocess.GetCurrentProcess(), handle, target_process,
0, inheritable, _subprocess.DUPLICATE_SAME_ACCESS 0, inheritable, _subprocess.DUPLICATE_SAME_ACCESS
).Detach() )
if sys.version_info[0] < 3 or (
sys.version_info[0] == 3 and sys.version_info[1] < 3):
h = h.Detach()
return h
# #
# We define a Popen class similar to the one from subprocess, but # We define a Popen class similar to the one from subprocess, but
@@ -318,8 +244,9 @@ else:
hp, ht, pid, tid = _subprocess.CreateProcess( hp, ht, pid, tid = _subprocess.CreateProcess(
_python_exe, cmd, None, None, 1, 0, None, None, None _python_exe, cmd, None, None, 1, 0, None, None, None
) )
ht.Close() close(ht) if isinstance(ht, int_types) else ht.Close()
close(rhandle) (close(rhandle) if isinstance(rhandle, int_types)
else rhandle.Close())
# set attributes of self # set attributes of self
self.pid = pid self.pid = pid
@@ -566,22 +493,6 @@ def get_preparation_data(name):
return d return d
#
# Make (Pipe)Connection picklable
#
def reduce_connection(conn):
if not Popen.thread_is_spawning():
raise RuntimeError(
'By default %s objects can only be shared between processes\n'
'using inheritance' % type(conn).__name__
)
return type(conn), (Popen.duplicate_for_child(conn.fileno()),
conn.readable, conn.writable)
ForkingPickler.register(Connection, reduce_connection)
ForkingPickler.register(PipeConnection, reduce_connection)
# #
# Prepare current process # Prepare current process
# #
@@ -659,7 +570,7 @@ def prepare(data):
# Try to make the potentially picklable objects in # Try to make the potentially picklable objects in
# sys.modules['__main__'] realize they are in the main # sys.modules['__main__'] realize they are in the main
# module -- somewhat ugly. # module -- somewhat ugly.
for obj in main_module.__dict__.values(): for obj in list(main_module.__dict__.values()):
try: try:
if obj.__module__ == '__parents_main__': if obj.__module__ == '__parents_main__':
obj.__module__ = '__main__' obj.__module__ = '__main__'

View File

@@ -17,7 +17,8 @@ import itertools
from ._ext import _billiard, win32 from ._ext import _billiard, win32
from .util import Finalize, info, get_temp_dir from .util import Finalize, info, get_temp_dir
from .forking import assert_spawning, ForkingPickler from .forking import assert_spawning
from .reduction import ForkingPickler
__all__ = ['BufferWrapper'] __all__ = ['BufferWrapper']
@@ -38,7 +39,7 @@ if sys.platform == 'win32':
def __init__(self, size): def __init__(self, size):
self.size = size self.size = size
self.name = 'pym-%d-%d' % (os.getpid(), Arena._counter.next()) self.name = 'pym-%d-%d' % (os.getpid(), next(Arena._counter))
self.buffer = mmap.mmap(-1, self.size, tagname=self.name) self.buffer = mmap.mmap(-1, self.size, tagname=self.name)
assert win32.GetLastError() == 0, 'tagname already in use' assert win32.GetLastError() == 0, 'tagname already in use'
self._state = (self.size, self.name) self._state = (self.size, self.name)
@@ -65,9 +66,9 @@ else:
if fileno == -1 and not _forking_is_enabled: if fileno == -1 and not _forking_is_enabled:
name = os.path.join( name = os.path.join(
get_temp_dir(), get_temp_dir(),
'pym-%d-%d' % (os.getpid(), self._counter.next())) 'pym-%d-%d' % (os.getpid(), next(self._counter)))
self.fileno = os.open( self.fileno = os.open(
name, os.O_RDWR | os.O_CREAT | os.O_EXCL, 0600) name, os.O_RDWR | os.O_CREAT | os.O_EXCL, 0o600)
os.unlink(name) os.unlink(name)
os.ftruncate(self.fileno, size) os.ftruncate(self.fileno, size)
self.buffer = mmap.mmap(self.fileno, self.size) self.buffer = mmap.mmap(self.fileno, self.size)

View File

@@ -8,7 +8,6 @@
# Licensed to PSF under a Contributor Agreement. # Licensed to PSF under a Contributor Agreement.
# #
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement
__all__ = ['BaseManager', 'SyncManager', 'BaseProxy', 'Token'] __all__ = ['BaseManager', 'SyncManager', 'BaseProxy', 'Token']
@@ -19,14 +18,15 @@ __all__ = ['BaseManager', 'SyncManager', 'BaseProxy', 'Token']
import sys import sys
import threading import threading
import array import array
import Queue
from collections import Callable
from traceback import format_exc from traceback import format_exc
from time import time as _time
from . import Process, current_process, active_children, Pool, util, connection from . import Process, current_process, active_children, Pool, util, connection
from .five import Queue, items, monotonic
from .process import AuthenticationString from .process import AuthenticationString
from .forking import exit, Popen, ForkingPickler from .forking import exit, Popen
from .reduction import ForkingPickler
from .util import Finalize, error, info from .util import Finalize, error, info
# #
@@ -123,7 +123,7 @@ def all_methods(obj):
temp = [] temp = []
for name in dir(obj): for name in dir(obj):
func = getattr(obj, name) func = getattr(obj, name)
if callable(func): if isinstance(func, Callable):
temp.append(name) temp.append(name)
return temp return temp
@@ -205,14 +205,14 @@ class Server(object):
msg = ('#RETURN', result) msg = ('#RETURN', result)
try: try:
c.send(msg) c.send(msg)
except Exception, e: except Exception as exc:
try: try:
c.send(('#TRACEBACK', format_exc())) c.send(('#TRACEBACK', format_exc()))
except Exception: except Exception:
pass pass
info('Failure to send message: %r', msg) info('Failure to send message: %r', msg)
info(' ... request was %r', request) info(' ... request was %r', request)
info(' ... exception was %r', e) info(' ... exception was %r', exc)
c.close() c.close()
@@ -245,8 +245,8 @@ class Server(object):
try: try:
res = function(*args, **kwds) res = function(*args, **kwds)
except Exception, e: except Exception as exc:
msg = ('#ERROR', e) msg = ('#ERROR', exc)
else: else:
typeid = gettypeid and gettypeid.get(methodname, None) typeid = gettypeid and gettypeid.get(methodname, None)
if typeid: if typeid:
@@ -280,13 +280,13 @@ class Server(object):
try: try:
try: try:
send(msg) send(msg)
except Exception, e: except Exception:
send(('#UNSERIALIZABLE', repr(msg))) send(('#UNSERIALIZABLE', repr(msg)))
except Exception, e: except Exception as exc:
info('exception in thread serving %r', info('exception in thread serving %r',
threading.currentThread().name) threading.currentThread().name)
info(' ... message was %r', msg) info(' ... message was %r', msg)
info(' ... exception was %r', e) info(' ... exception was %r', exc)
conn.close() conn.close()
sys.exit(1) sys.exit(1)
@@ -314,7 +314,7 @@ class Server(object):
''' '''
with self.mutex: with self.mutex:
result = [] result = []
keys = self.id_to_obj.keys() keys = list(self.id_to_obj.keys())
keys.sort() keys.sort()
for ident in keys: for ident in keys:
if ident != '0': if ident != '0':
@@ -492,7 +492,8 @@ class BaseManager(object):
''' '''
assert self._state.value == State.INITIAL assert self._state.value == State.INITIAL
if initializer is not None and not callable(initializer): if initializer is not None and \
not isinstance(initializer, Callable):
raise TypeError('initializer must be a callable') raise TypeError('initializer must be a callable')
# pipe over which we will retrieve address of server # pipe over which we will retrieve address of server
@@ -641,7 +642,7 @@ class BaseManager(object):
) )
if method_to_typeid: if method_to_typeid:
for key, value in method_to_typeid.items(): for key, value in items(method_to_typeid):
assert type(key) is str, '%r is not a string' % key assert type(key) is str, '%r is not a string' % key
assert type(value) is str, '%r is not a string' % value assert type(value) is str, '%r is not a string' % value
@@ -797,8 +798,8 @@ class BaseProxy(object):
util.debug('DECREF %r', token.id) util.debug('DECREF %r', token.id)
conn = _Client(token.address, authkey=authkey) conn = _Client(token.address, authkey=authkey)
dispatch(conn, None, 'decref', (token.id,)) dispatch(conn, None, 'decref', (token.id,))
except Exception, e: except Exception as exc:
util.debug('... decref failed %s', e) util.debug('... decref failed %s', exc)
else: else:
util.debug('DECREF %r -- manager already shutdown', token.id) util.debug('DECREF %r -- manager already shutdown', token.id)
@@ -815,9 +816,9 @@ class BaseProxy(object):
self._manager = None self._manager = None
try: try:
self._incref() self._incref()
except Exception, e: except Exception as exc:
# the proxy may just be for a manager which has shutdown # the proxy may just be for a manager which has shutdown
info('incref failed: %s', e) info('incref failed: %s', exc)
def __reduce__(self): def __reduce__(self):
kwds = {} kwds = {}
@@ -933,7 +934,7 @@ class Namespace(object):
self.__dict__.update(kwds) self.__dict__.update(kwds)
def __repr__(self): def __repr__(self):
items = self.__dict__.items() items = list(self.__dict__.items())
temp = [] temp = []
for name, value in items: for name, value in items:
if not name.startswith('_'): if not name.startswith('_'):
@@ -1026,13 +1027,13 @@ class ConditionProxy(AcquirerProxy):
if result: if result:
return result return result
if timeout is not None: if timeout is not None:
endtime = _time() + timeout endtime = monotonic() + timeout
else: else:
endtime = None endtime = None
waittime = None waittime = None
while not result: while not result:
if endtime is not None: if endtime is not None:
waittime = endtime - _time() waittime = endtime - monotonic()
if waittime <= 0: if waittime <= 0:
break break
self.wait(waittime) self.wait(waittime)
@@ -1149,8 +1150,8 @@ class SyncManager(BaseManager):
this class. this class.
''' '''
SyncManager.register('Queue', Queue.Queue) SyncManager.register('Queue', Queue)
SyncManager.register('JoinableQueue', Queue.Queue) SyncManager.register('JoinableQueue', Queue)
SyncManager.register('Event', threading.Event, EventProxy) SyncManager.register('Event', threading.Event, EventProxy)
SyncManager.register('Lock', threading.Lock, AcquirerProxy) SyncManager.register('Lock', threading.Lock, AcquirerProxy)
SyncManager.register('RLock', threading.RLock, AcquirerProxy) SyncManager.register('RLock', threading.RLock, AcquirerProxy)

File diff suppressed because it is too large Load Diff

View File

@@ -27,6 +27,7 @@ try:
from _weakrefset import WeakSet from _weakrefset import WeakSet
except ImportError: except ImportError:
WeakSet = None # noqa WeakSet = None # noqa
from .five import items, string_t
try: try:
ORIGINAL_DIR = os.path.abspath(os.getcwd()) ORIGINAL_DIR = os.path.abspath(os.getcwd())
@@ -85,7 +86,7 @@ class Process(object):
def __init__(self, group=None, target=None, name=None, def __init__(self, group=None, target=None, name=None,
args=(), kwargs={}, daemon=None, **_kw): args=(), kwargs={}, daemon=None, **_kw):
assert group is None, 'group argument must be None for now' assert group is None, 'group argument must be None for now'
count = _current_process._counter.next() count = next(_current_process._counter)
self._identity = _current_process._identity + (count,) self._identity = _current_process._identity + (count,)
self._authkey = _current_process._authkey self._authkey = _current_process._authkey
if daemon is not None: if daemon is not None:
@@ -164,7 +165,7 @@ class Process(object):
return self._name return self._name
def _set_name(self, value): def _set_name(self, value):
assert isinstance(name, basestring), 'name must be a string' assert isinstance(name, string_t), 'name must be a string'
self._name = value self._name = value
name = property(_get_name, _set_name) name = property(_get_name, _set_name)
@@ -256,14 +257,17 @@ class Process(object):
_current_process = self _current_process = self
# Re-init logging system. # Re-init logging system.
# Workaround for http://bugs.python.org/issue6721#msg140215 # Workaround for http://bugs.python.org/issue6721/#msg140215
# Python logging module uses RLock() objects which are broken after # Python logging module uses RLock() objects which are broken
# fork. This can result in a deadlock (Celery Issue #496). # after fork. This can result in a deadlock (Celery Issue #496).
logger_names = logging.Logger.manager.loggerDict.keys() loggerDict = logging.Logger.manager.loggerDict
logger_names = list(loggerDict.keys())
logger_names.append(None) # for root logger logger_names.append(None) # for root logger
for name in logger_names: for name in logger_names:
for handler in logging.getLogger(name).handlers: if not name or not isinstance(loggerDict[name],
handler.createLock() logging.PlaceHolder):
for handler in logging.getLogger(name).handlers:
handler.createLock()
logging._lock = threading.RLock() logging._lock = threading.RLock()
try: try:
@@ -279,15 +283,15 @@ class Process(object):
exitcode = 0 exitcode = 0
finally: finally:
util._exit_function() util._exit_function()
except SystemExit, e: except SystemExit as exc:
if not e.args: if not exc.args:
exitcode = 1 exitcode = 1
elif isinstance(e.args[0], int): elif isinstance(exc.args[0], int):
exitcode = e.args[0] exitcode = exc.args[0]
else: else:
sys.stderr.write(str(e.args[0]) + '\n') sys.stderr.write(str(exc.args[0]) + '\n')
_maybe_flush(sys.stderr) _maybe_flush(sys.stderr)
exitcode = 0 if isinstance(e.args[0], str) else 1 exitcode = 0 if isinstance(exc.args[0], str) else 1
except: except:
exitcode = 1 exitcode = 1
if not util.error('Process %s', self.name, exc_info=True): if not util.error('Process %s', self.name, exc_info=True):
@@ -347,7 +351,7 @@ del _MainProcess
_exitcode_to_name = {} _exitcode_to_name = {}
for name, signum in signal.__dict__.items(): for name, signum in items(signal.__dict__):
if name[:3] == 'SIG' and '_' not in name: if name[:3] == 'SIG' and '_' not in name:
_exitcode_to_name[-signum] = name _exitcode_to_name[-signum] = name

View File

@@ -7,7 +7,6 @@
# Licensed to PSF under a Contributor Agreement. # Licensed to PSF under a Contributor Agreement.
# #
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement
__all__ = ['Queue', 'SimpleQueue', 'JoinableQueue'] __all__ = ['Queue', 'SimpleQueue', 'JoinableQueue']
@@ -15,17 +14,16 @@ import sys
import os import os
import threading import threading
import collections import collections
import time
import weakref import weakref
import errno import errno
from Queue import Empty, Full
from . import Pipe from . import Pipe
from ._ext import _billiard from ._ext import _billiard
from .compat import get_errno from .compat import get_errno
from .five import monotonic
from .synchronize import Lock, BoundedSemaphore, Semaphore, Condition from .synchronize import Lock, BoundedSemaphore, Semaphore, Condition
from .util import debug, error, info, Finalize, register_after_fork from .util import debug, error, info, Finalize, register_after_fork
from .five import Empty, Full
from .forking import assert_spawning from .forking import assert_spawning
@@ -96,12 +94,12 @@ class Queue(object):
else: else:
if block: if block:
deadline = time.time() + timeout deadline = monotonic() + timeout
if not self._rlock.acquire(block, timeout): if not self._rlock.acquire(block, timeout):
raise Empty raise Empty
try: try:
if block: if block:
timeout = deadline - time.time() timeout = deadline - monotonic()
if timeout < 0 or not self._poll(timeout): if timeout < 0 or not self._poll(timeout):
raise Empty raise Empty
elif not self._poll(): elif not self._poll():
@@ -238,7 +236,7 @@ class Queue(object):
send(obj) send(obj)
except IndexError: except IndexError:
pass pass
except Exception, exc: except Exception as exc:
if ignore_epipe and get_errno(exc) == errno.EPIPE: if ignore_epipe and get_errno(exc) == errno.EPIPE:
return return
# Since this runs in a daemon thread the resources it uses # Since this runs in a daemon thread the resources it uses
@@ -306,19 +304,17 @@ class JoinableQueue(Queue):
self._cond.wait() self._cond.wait()
class SimpleQueue(object): class _SimpleQueue(object):
''' '''
Simplified Queue type -- really just a locked pipe Simplified Queue type -- really just a locked pipe
''' '''
def __init__(self): def __init__(self, rnonblock=False, wnonblock=False):
self._reader, self._writer = Pipe(duplex=False) self._reader, self._writer = Pipe(
self._rlock = Lock() duplex=False, rnonblock=rnonblock, wnonblock=wnonblock,
)
self._poll = self._reader.poll self._poll = self._reader.poll
if sys.platform == 'win32': self._rlock = self._wlock = None
self._wlock = None
else:
self._wlock = Lock()
self._make_methods() self._make_methods()
def empty(self): def empty(self):
@@ -337,19 +333,22 @@ class SimpleQueue(object):
try: try:
recv_payload = self._reader.recv_payload recv_payload = self._reader.recv_payload
except AttributeError: except AttributeError:
recv_payload = None # C extension not installed recv_payload = self._reader.recv_bytes
rlock = self._rlock rlock = self._rlock
def get(): if rlock is not None:
with rlock: def get():
return recv() with rlock:
self.get = get return recv()
self.get = get
if recv_payload is not None:
def get_payload(): def get_payload():
with rlock: with rlock:
return recv_payload() return recv_payload()
self.get_payload = get_payload self.get_payload = get_payload
else:
self.get = recv
self.get_payload = recv_payload
if self._wlock is None: if self._wlock is None:
# writes to a message oriented win32 pipe are atomic # writes to a message oriented win32 pipe are atomic
@@ -362,3 +361,12 @@ class SimpleQueue(object):
with wlock: with wlock:
return send(obj) return send(obj)
self.put = put self.put = put
class SimpleQueue(_SimpleQueue):
def __init__(self):
self._reader, self._writer = Pipe(duplex=False)
self._rlock = Lock()
self._wlock = Lock() if sys.platform != 'win32' else None
self._make_methods()

View File

@@ -1,200 +1,10 @@
#
# Module to allow connection and socket objects to be transferred
# between processes
#
# multiprocessing/reduction.py
#
# Copyright (c) 2006-2008, R Oudkerk
# Licensed to PSF under a Contributor Agreement.
#
from __future__ import absolute_import from __future__ import absolute_import
__all__ = []
import os
import sys import sys
import socket
import threading
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
from multiprocessing.connection import Client, Listener from . import _reduction3 as reduction
else: else:
from billiard._connection import Client, Listener # noqa from . import _reduction as reduction # noqa
from . import current_process sys.modules[__name__] = reduction
from ._ext import _billiard, win32
from .forking import Popen, duplicate, close, ForkingPickler
from .util import register_after_fork, debug, sub_debug
if not(sys.platform == 'win32' or hasattr(_billiard, 'recvfd')):
raise ImportError('pickling of connections not supported')
# globals set later
_listener = None
_lock = None
_cache = set()
#
# Platform specific definitions
#
if sys.platform == 'win32':
# XXX Should this subprocess import be here?
import _subprocess # noqa
def send_handle(conn, handle, destination_pid):
process_handle = win32.OpenProcess(
win32.PROCESS_ALL_ACCESS, False, destination_pid
)
try:
new_handle = duplicate(handle, process_handle)
conn.send(new_handle)
finally:
close(process_handle)
def recv_handle(conn):
return conn.recv()
else:
def send_handle(conn, handle, destination_pid): # noqa
_billiard.sendfd(conn.fileno(), handle)
def recv_handle(conn): # noqa
return _billiard.recvfd(conn.fileno())
#
# Support for a per-process server thread which caches pickled handles
#
def _reset(obj):
global _lock, _listener, _cache
for h in _cache:
close(h)
_cache.clear()
_lock = threading.Lock()
_listener = None
_reset(None)
register_after_fork(_reset, _reset)
def _get_listener():
global _listener
if _listener is None:
_lock.acquire()
try:
if _listener is None:
debug('starting listener and thread for sending handles')
_listener = Listener(authkey=current_process().authkey)
t = threading.Thread(target=_serve)
t.daemon = True
t.start()
finally:
_lock.release()
return _listener
def _serve():
from .util import is_exiting, sub_warning
while 1:
try:
conn = _listener.accept()
handle_wanted, destination_pid = conn.recv()
_cache.remove(handle_wanted)
send_handle(conn, handle_wanted, destination_pid)
close(handle_wanted)
conn.close()
except:
if not is_exiting():
sub_warning('thread for sharing handles raised exception',
exc_info=True)
#
# Functions to be used for pickling/unpickling objects with handles
#
def reduce_handle(handle):
if Popen.thread_is_spawning():
return (None, Popen.duplicate_for_child(handle), True)
dup_handle = duplicate(handle)
_cache.add(dup_handle)
sub_debug('reducing handle %d', handle)
return (_get_listener().address, dup_handle, False)
def rebuild_handle(pickled_data):
address, handle, inherited = pickled_data
if inherited:
return handle
sub_debug('rebuilding handle %d', handle)
conn = Client(address, authkey=current_process().authkey)
conn.send((handle, os.getpid()))
new_handle = recv_handle(conn)
conn.close()
return new_handle
#
# Register `_billiard.Connection` with `ForkingPickler`
#
def reduce_connection(conn):
rh = reduce_handle(conn.fileno())
return rebuild_connection, (rh, conn.readable, conn.writable)
def rebuild_connection(reduced_handle, readable, writable):
handle = rebuild_handle(reduced_handle)
return _billiard.Connection(
handle, readable=readable, writable=writable
)
ForkingPickler.register(_billiard.Connection, reduce_connection)
#
# Register `socket.socket` with `ForkingPickler`
#
def fromfd(fd, family, type_, proto=0):
s = socket.fromfd(fd, family, type_, proto)
if s.__class__ is not socket.socket:
s = socket.socket(_sock=s)
return s
def reduce_socket(s):
reduced_handle = reduce_handle(s.fileno())
return rebuild_socket, (reduced_handle, s.family, s.type, s.proto)
def rebuild_socket(reduced_handle, family, type_, proto):
fd = rebuild_handle(reduced_handle)
_sock = fromfd(fd, family, type_, proto)
close(fd)
return _sock
ForkingPickler.register(socket.socket, reduce_socket)
#
# Register `_billiard.PipeConnection` with `ForkingPickler`
#
if sys.platform == 'win32':
def reduce_pipe_connection(conn):
rh = reduce_handle(conn.fileno())
return rebuild_pipe_connection, (rh, conn.readable, conn.writable)
def rebuild_pipe_connection(reduced_handle, readable, writable):
handle = rebuild_handle(reduced_handle)
return _billiard.PipeConnection(
handle, readable=readable, writable=writable
)
ForkingPickler.register(_billiard.PipeConnection, reduce_pipe_connection)

View File

@@ -12,7 +12,9 @@ import ctypes
import weakref import weakref
from . import heap, RLock from . import heap, RLock
from .forking import assert_spawning, ForkingPickler from .five import int_types
from .forking import assert_spawning
from .reduction import ForkingPickler
__all__ = ['RawValue', 'RawArray', 'Value', 'Array', 'copy', 'synchronized'] __all__ = ['RawValue', 'RawArray', 'Value', 'Array', 'copy', 'synchronized']
@@ -48,7 +50,7 @@ def RawArray(typecode_or_type, size_or_initializer):
Returns a ctypes array allocated from shared memory Returns a ctypes array allocated from shared memory
''' '''
type_ = typecode_to_type.get(typecode_or_type, typecode_or_type) type_ = typecode_to_type.get(typecode_or_type, typecode_or_type)
if isinstance(size_or_initializer, (int, long)): if isinstance(size_or_initializer, int_types):
type_ = type_ * size_or_initializer type_ = type_ * size_or_initializer
obj = _new_value(type_) obj = _new_value(type_)
ctypes.memset(ctypes.addressof(obj), 0, ctypes.sizeof(obj)) ctypes.memset(ctypes.addressof(obj), 0, ctypes.sizeof(obj))
@@ -66,7 +68,8 @@ def Value(typecode_or_type, *args, **kwds):
''' '''
lock = kwds.pop('lock', None) lock = kwds.pop('lock', None)
if kwds: if kwds:
raise ValueError('unrecognized keyword argument(s): %s' % kwds.keys()) raise ValueError(
'unrecognized keyword argument(s): %s' % list(kwds.keys()))
obj = RawValue(typecode_or_type, *args) obj = RawValue(typecode_or_type, *args)
if lock is False: if lock is False:
return obj return obj
@@ -83,7 +86,8 @@ def Array(typecode_or_type, size_or_initializer, **kwds):
''' '''
lock = kwds.pop('lock', None) lock = kwds.pop('lock', None)
if kwds: if kwds:
raise ValueError('unrecognized keyword argument(s): %s' % kwds.keys()) raise ValueError(
'unrecognized keyword argument(s): %s' % list(kwds.keys()))
obj = RawArray(typecode_or_type, size_or_initializer) obj = RawArray(typecode_or_type, size_or_initializer)
if lock is False: if lock is False:
return obj return obj

View File

@@ -19,9 +19,8 @@ import sys
import threading import threading
from time import time as _time
from ._ext import _billiard, ensure_SemLock from ._ext import _billiard, ensure_SemLock
from .five import range, monotonic
from .process import current_process from .process import current_process
from .util import Finalize, register_after_fork, debug from .util import Finalize, register_after_fork, debug
from .forking import assert_spawning, Popen from .forking import assert_spawning, Popen
@@ -36,7 +35,7 @@ ensure_SemLock()
# Constants # Constants
# #
RECURSIVE_MUTEX, SEMAPHORE = range(2) RECURSIVE_MUTEX, SEMAPHORE = list(range(2))
SEM_VALUE_MAX = _billiard.SemLock.SEM_VALUE_MAX SEM_VALUE_MAX = _billiard.SemLock.SEM_VALUE_MAX
try: try:
@@ -115,7 +114,7 @@ class SemLock(object):
@staticmethod @staticmethod
def _make_name(): def _make_name():
return '/%s-%s-%s' % (current_process()._semprefix, return '/%s-%s-%s' % (current_process()._semprefix,
os.getpid(), SemLock._counter.next()) os.getpid(), next(SemLock._counter))
class Semaphore(SemLock): class Semaphore(SemLock):
@@ -248,7 +247,7 @@ class Condition(object):
# release lock # release lock
count = self._lock._semlock._count() count = self._lock._semlock._count()
for i in xrange(count): for i in range(count):
self._lock.release() self._lock.release()
try: try:
@@ -259,7 +258,7 @@ class Condition(object):
self._woken_count.release() self._woken_count.release()
# reacquire lock # reacquire lock
for i in xrange(count): for i in range(count):
self._lock.acquire() self._lock.acquire()
return ret return ret
@@ -296,7 +295,7 @@ class Condition(object):
sleepers += 1 sleepers += 1
if sleepers: if sleepers:
for i in xrange(sleepers): for i in range(sleepers):
self._woken_count.acquire() # wait for a sleeper to wake self._woken_count.acquire() # wait for a sleeper to wake
# rezero wait_semaphore in case some timeouts just happened # rezero wait_semaphore in case some timeouts just happened
@@ -308,13 +307,13 @@ class Condition(object):
if result: if result:
return result return result
if timeout is not None: if timeout is not None:
endtime = _time() + timeout endtime = monotonic() + timeout
else: else:
endtime = None endtime = None
waittime = None waittime = None
while not result: while not result:
if endtime is not None: if endtime is not None:
waittime = endtime - _time() waittime = endtime - monotonic()
if waittime <= 0: if waittime <= 0:
break break
self.wait(waittime) self.wait(waittime)

View File

@@ -13,6 +13,9 @@ def teardown():
except (AttributeError, ImportError): except (AttributeError, ImportError):
pass pass
atexit._exithandlers[:] = [ try:
e for e in atexit._exithandlers if e[0] not in cancelled atexit._exithandlers[:] = [
] e for e in atexit._exithandlers if e[0] not in cancelled
]
except AttributeError:
pass

View File

@@ -1,5 +1,4 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement
import os import os
import signal import signal

View File

@@ -1,5 +1,4 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement
import re import re
import sys import sys
@@ -13,6 +12,8 @@ except AttributeError:
import unittest2 as unittest # noqa import unittest2 as unittest # noqa
from unittest2.util import safe_repr, unorderable_list_difference # noqa from unittest2.util import safe_repr, unorderable_list_difference # noqa
from billiard.five import string_t, items, values
from .compat import catch_warnings from .compat import catch_warnings
# -- adds assertWarns from recent unittest2, not in Python 2.7. # -- adds assertWarns from recent unittest2, not in Python 2.7.
@@ -25,7 +26,7 @@ class _AssertRaisesBaseContext(object):
self.expected = expected self.expected = expected
self.failureException = test_case.failureException self.failureException = test_case.failureException
self.obj_name = None self.obj_name = None
if isinstance(expected_regex, basestring): if isinstance(expected_regex, string_t):
expected_regex = re.compile(expected_regex) expected_regex = re.compile(expected_regex)
self.expected_regex = expected_regex self.expected_regex = expected_regex
@@ -37,7 +38,7 @@ class _AssertWarnsContext(_AssertRaisesBaseContext):
# The __warningregistry__'s need to be in a pristine state for tests # The __warningregistry__'s need to be in a pristine state for tests
# to work properly. # to work properly.
warnings.resetwarnings() warnings.resetwarnings()
for v in sys.modules.values(): for v in values(sys.modules):
if getattr(v, '__warningregistry__', None): if getattr(v, '__warningregistry__', None):
v.__warningregistry__ = {} v.__warningregistry__ = {}
self.warnings_manager = catch_warnings(record=True) self.warnings_manager = catch_warnings(record=True)
@@ -93,7 +94,7 @@ class Case(unittest.TestCase):
def assertDictContainsSubset(self, expected, actual, msg=None): def assertDictContainsSubset(self, expected, actual, msg=None):
missing, mismatched = [], [] missing, mismatched = [], []
for key, value in expected.iteritems(): for key, value in items(expected):
if key not in actual: if key not in actual:
missing.append(key) missing.append(key)
elif value != actual[key]: elif value != actual[key]:

View File

@@ -10,16 +10,25 @@ from __future__ import absolute_import
import errno import errno
import functools import functools
import itertools
import weakref
import atexit import atexit
import shutil
import tempfile from multiprocessing.util import ( # noqa
import threading # we want threading to install its _afterfork_registry,
# cleanup function before multiprocessing does _afterfork_counter,
_exit_function,
_finalizer_registry,
_finalizer_counter,
Finalize,
ForkAwareLocal,
ForkAwareThreadLock,
get_temp_dir,
is_exiting,
register_after_fork,
_run_after_forkers,
_run_finalizers,
)
from .compat import get_errno from .compat import get_errno
from .process import current_process, active_children
__all__ = [ __all__ = [
'sub_debug', 'debug', 'info', 'sub_warning', 'get_logger', 'sub_debug', 'debug', 'info', 'sub_warning', 'get_logger',
@@ -45,17 +54,6 @@ DEFAULT_LOGGING_FORMAT = '[%(levelname)s/%(processName)s] %(message)s'
_logger = None _logger = None
_log_to_stderr = False _log_to_stderr = False
#: Support for reinitialization of objects when bootstrapping a child process
_afterfork_registry = weakref.WeakValueDictionary()
_afterfork_counter = itertools.count()
#: Finalization using weakrefs
_finalizer_registry = {}
_finalizer_counter = itertools.count()
#: set to true if the process is shutting down.
_exiting = False
def sub_debug(msg, *args, **kwargs): def sub_debug(msg, *args, **kwargs):
if _logger: if _logger:
@@ -138,195 +136,6 @@ def log_to_stderr(level=None):
return _logger return _logger
def get_temp_dir():
'''
Function returning a temp directory which will be removed on exit
'''
# get name of a temp directory which will be automatically cleaned up
if current_process()._tempdir is None:
tempdir = tempfile.mkdtemp(prefix='pymp-')
info('created temp directory %s', tempdir)
Finalize(None, shutil.rmtree, args=[tempdir], exitpriority=-100)
current_process()._tempdir = tempdir
return current_process()._tempdir
def _run_after_forkers():
items = list(_afterfork_registry.items())
items.sort()
for (index, ident, func), obj in items:
try:
func(obj)
except Exception, e:
info('after forker raised exception %s', e)
def register_after_fork(obj, func):
_afterfork_registry[(_afterfork_counter.next(), id(obj), func)] = obj
class Finalize(object):
'''
Class which supports object finalization using weakrefs
'''
def __init__(self, obj, callback, args=(), kwargs=None, exitpriority=None):
assert exitpriority is None or type(exitpriority) is int
if obj is not None:
self._weakref = weakref.ref(obj, self)
else:
assert exitpriority is not None
self._callback = callback
self._args = args
self._kwargs = kwargs or {}
self._key = (exitpriority, _finalizer_counter.next())
_finalizer_registry[self._key] = self
def __call__(self, wr=None,
# Need to bind these locally because the globals
# could've been cleared at shutdown
_finalizer_registry=_finalizer_registry,
sub_debug=sub_debug):
'''
Run the callback unless it has already been called or cancelled
'''
try:
del _finalizer_registry[self._key]
except KeyError:
sub_debug('finalizer no longer registered')
else:
sub_debug(
'finalizer calling %s with args %s and kwargs %s',
self._callback, self._args, self._kwargs,
)
res = self._callback(*self._args, **self._kwargs)
self._weakref = self._callback = self._args = \
self._kwargs = self._key = None
return res
def cancel(self):
'''
Cancel finalization of the object
'''
try:
del _finalizer_registry[self._key]
except KeyError:
pass
else:
self._weakref = self._callback = self._args = \
self._kwargs = self._key = None
def still_active(self):
'''
Return whether this finalizer is still waiting to invoke callback
'''
return self._key in _finalizer_registry
def __repr__(self):
try:
obj = self._weakref()
except (AttributeError, TypeError):
obj = None
if obj is None:
return '<Finalize object, dead>'
x = '<Finalize object, callback=%s' % \
getattr(self._callback, '__name__', self._callback)
if self._args:
x += ', args=' + str(self._args)
if self._kwargs:
x += ', kwargs=' + str(self._kwargs)
if self._key[0] is not None:
x += ', exitprority=' + str(self._key[0])
return x + '>'
def _run_finalizers(minpriority=None,
_finalizer_registry=_finalizer_registry,
sub_debug=sub_debug, error=error):
'''
Run all finalizers whose exit priority is not None and at least minpriority
Finalizers with highest priority are called first; finalizers with
the same priority will be called in reverse order of creation.
'''
if minpriority is None:
f = lambda p: p[0][0] is not None
else:
f = lambda p: p[0][0] is not None and p[0][0] >= minpriority
items = [x for x in _finalizer_registry.items() if f(x)]
items.sort(reverse=True)
for key, finalizer in items:
sub_debug('calling %s', finalizer)
try:
finalizer()
except Exception:
if not error("Error calling finalizer %r", finalizer,
exc_info=True):
import traceback
traceback.print_exc()
if minpriority is None:
_finalizer_registry.clear()
def is_exiting():
'''
Returns true if the process is shutting down
'''
return _exiting or _exiting is None
def _exit_function(info=info, debug=debug,
active_children=active_children,
_run_finalizers=_run_finalizers):
'''
Clean up on exit
'''
global _exiting
info('process shutting down')
debug('running all "atexit" finalizers with priority >= 0')
_run_finalizers(0)
for p in active_children():
if p._daemonic:
info('calling terminate() for daemon %s', p.name)
p._popen.terminate()
for p in active_children():
info('calling join() for process %s', p.name)
p.join()
debug('running the remaining "atexit" finalizers')
_run_finalizers()
atexit.register(_exit_function)
class ForkAwareThreadLock(object):
def __init__(self):
self._lock = threading.Lock()
self.acquire = self._lock.acquire
self.release = self._lock.release
register_after_fork(self, ForkAwareThreadLock.__init__)
class ForkAwareLocal(threading.local):
def __init__(self):
register_after_fork(self, lambda obj: obj.__dict__.clear())
def __reduce__(self):
return type(self), ()
def _eintr_retry(func): def _eintr_retry(func):
''' '''
Automatic retry after EINTR. Automatic retry after EINTR.
@@ -337,7 +146,7 @@ def _eintr_retry(func):
while 1: while 1:
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except OSError, exc: except OSError as exc:
if get_errno(exc) != errno.EINTR: if get_errno(exc) != errno.EINTR:
raise raise
return wrapped return wrapped

View File

@@ -36,7 +36,7 @@ import logging.config
import urlparse import urlparse
from boto.exception import InvalidUriError from boto.exception import InvalidUriError
__version__ = '2.13.3' __version__ = '2.17.0'
Version = __version__ # for backware compatibility Version = __version__ # for backware compatibility
UserAgent = 'Boto/%s Python/%s %s/%s' % ( UserAgent = 'Boto/%s Python/%s %s/%s' % (
@@ -721,6 +721,29 @@ def connect_support(aws_access_key_id=None,
) )
def connect_cloudtrail(aws_access_key_id=None,
aws_secret_access_key=None,
**kwargs):
"""
Connect to AWS CloudTrail
:type aws_access_key_id: string
:param aws_access_key_id: Your AWS Access Key ID
:type aws_secret_access_key: string
:param aws_secret_access_key: Your AWS Secret Access Key
:rtype: :class:`boto.cloudtrail.layer1.CloudtrailConnection`
:return: A connection to the AWS Cloudtrail service
"""
from boto.cloudtrail.layer1 import CloudTrailConnection
return CloudTrailConnection(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
**kwargs
)
def storage_uri(uri_str, default_scheme='file', debug=0, validate=True, def storage_uri(uri_str, default_scheme='file', debug=0, validate=True,
bucket_storage_uri_class=BucketStorageUri, bucket_storage_uri_class=BucketStorageUri,
suppress_consec_slashes=True, is_latest=False): suppress_consec_slashes=True, is_latest=False):

View File

@@ -431,13 +431,17 @@ class HmacAuthV4Handler(AuthHandler, HmacKeys):
parts = http_request.host.split('.') parts = http_request.host.split('.')
if self.region_name is not None: if self.region_name is not None:
region_name = self.region_name region_name = self.region_name
elif parts[1] == 'us-gov': elif len(parts) > 1:
region_name = 'us-gov-west-1' if parts[1] == 'us-gov':
else: region_name = 'us-gov-west-1'
if len(parts) == 3:
region_name = 'us-east-1'
else: else:
region_name = parts[1] if len(parts) == 3:
region_name = 'us-east-1'
else:
region_name = parts[1]
else:
region_name = parts[0]
if self.service_name is not None: if self.service_name is not None:
service_name = self.service_name service_name = self.service_name
else: else:

View File

@@ -191,12 +191,9 @@ class DocumentServiceConnection(object):
session = requests.Session() session = requests.Session()
adapter = requests.adapters.HTTPAdapter( adapter = requests.adapters.HTTPAdapter(
pool_connections=20, pool_connections=20,
pool_maxsize=50 pool_maxsize=50,
max_retries=5
) )
# Now kludge in the right number of retries.
# Once we're requiring ``requests>=1.2.1``, this can become an
# initialization parameter above.
adapter.max_retries = 5
session.mount('http://', adapter) session.mount('http://', adapter)
session.mount('https://', adapter) session.mount('https://', adapter)
r = session.post(url, data=sdf, headers={'Content-Type': 'application/json'}) r = session.post(url, data=sdf, headers={'Content-Type': 'application/json'})

View File

@@ -79,7 +79,7 @@ class SearchResults(object):
class Query(object): class Query(object):
RESULTS_PER_PAGE = 500 RESULTS_PER_PAGE = 500
def __init__(self, q=None, bq=None, rank=None, def __init__(self, q=None, bq=None, rank=None,
@@ -147,7 +147,7 @@ class Query(object):
class SearchConnection(object): class SearchConnection(object):
def __init__(self, domain=None, endpoint=None): def __init__(self, domain=None, endpoint=None):
self.domain = domain self.domain = domain
self.endpoint = endpoint self.endpoint = endpoint
@@ -209,7 +209,7 @@ class SearchConnection(object):
:param facet_sort: Rules used to specify the order in which facet :param facet_sort: Rules used to specify the order in which facet
values should be returned. Allowed values are *alpha*, *count*, values should be returned. Allowed values are *alpha*, *count*,
*max*, *sum*. Use *alpha* to sort alphabetical, and *count* to sort *max*, *sum*. Use *alpha* to sort alphabetical, and *count* to sort
the facet by number of available result. the facet by number of available result.
``{'color': 'alpha', 'size': 'count'}`` ``{'color': 'alpha', 'size': 'count'}``
:type facet_top_n: dict :type facet_top_n: dict
@@ -243,10 +243,10 @@ class SearchConnection(object):
the search string. the search string.
>>> search(bq="'Tim*'") # Return documents with words like Tim or Timothy) >>> search(bq="'Tim*'") # Return documents with words like Tim or Timothy)
Search terms can also be combined. Allowed operators are "and", "or", Search terms can also be combined. Allowed operators are "and", "or",
"not", "field", "optional", "token", "phrase", or "filter" "not", "field", "optional", "token", "phrase", or "filter"
>>> search(bq="(and 'Tim' (field author 'John Smith'))") >>> search(bq="(and 'Tim' (field author 'John Smith'))")
Facets allow you to show classification information about the search Facets allow you to show classification information about the search
@@ -258,12 +258,12 @@ class SearchConnection(object):
With facet_constraints, facet_top_n and facet_sort more complicated With facet_constraints, facet_top_n and facet_sort more complicated
constraints can be specified such as returning the top author out of constraints can be specified such as returning the top author out of
John Smith and Mark Smith who have a document with the word Tim in it. John Smith and Mark Smith who have a document with the word Tim in it.
>>> search(q='Tim', >>> search(q='Tim',
... facet=['Author'], ... facet=['Author'],
... facet_constraints={'author': "'John Smith','Mark Smith'"}, ... facet_constraints={'author': "'John Smith','Mark Smith'"},
... facet=['author'], ... facet=['author'],
... facet_top_n={'author': 1}, ... facet_top_n={'author': 1},
... facet_sort={'author': 'count'}) ... facet_sort={'author': 'count'})
""" """
@@ -300,9 +300,7 @@ class SearchConnection(object):
except AttributeError: except AttributeError:
pass pass
raise SearchServiceException('Authentication error from Amazon%s' % msg) raise SearchServiceException('Authentication error from Amazon%s' % msg)
raise SearchServiceException("Got non-json response from Amazon") raise SearchServiceException("Got non-json response from Amazon. %s" % r.content, query)
data['query'] = query
data['search_service'] = self
if 'messages' in data and 'error' in data: if 'messages' in data and 'error' in data:
for m in data['messages']: for m in data['messages']:
@@ -311,7 +309,10 @@ class SearchConnection(object):
"=> %s" % (params, m['message']), query) "=> %s" % (params, m['message']), query)
elif 'error' in data: elif 'error' in data:
raise SearchServiceException("Unknown error processing search %s" raise SearchServiceException("Unknown error processing search %s"
% (params), query) % json.dumps(data), query)
data['query'] = query
data['search_service'] = self
return SearchResults(**data) return SearchResults(**data)

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2013 Amazon.com, Inc. or its affiliates.
# All Rights Reserved
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish, dis-
# tribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the fol-
# lowing conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
from boto.regioninfo import RegionInfo
def regions():
"""
Get all available regions for the AWS Cloudtrail service.
:rtype: list
:return: A list of :class:`boto.regioninfo.RegionInfo`
"""
from boto.cloudtrail.layer1 import CloudTrailConnection
return [RegionInfo(name='us-east-1',
endpoint='cloudtrail.us-east-1.amazonaws.com',
connection_cls=CloudTrailConnection),
RegionInfo(name='us-west-2',
endpoint='cloudtrail.us-west-2.amazonaws.com',
connection_cls=CloudTrailConnection),
]
def connect_to_region(region_name, **kw_params):
for region in regions():
if region.name == region_name:
return region.connect(**kw_params)
return None

View File

@@ -0,0 +1,86 @@
"""
Exceptions that are specific to the cloudtrail module.
"""
from boto.exception import BotoServerError
class InvalidSnsTopicNameException(BotoServerError):
"""
Raised when an invalid SNS topic name is passed to Cloudtrail.
"""
pass
class InvalidS3BucketNameException(BotoServerError):
"""
Raised when an invalid S3 bucket name is passed to Cloudtrail.
"""
pass
class TrailAlreadyExistsException(BotoServerError):
"""
Raised when the given trail name already exists.
"""
pass
class InsufficientSnsTopicPolicyException(BotoServerError):
"""
Raised when the SNS topic does not allow Cloudtrail to post
messages.
"""
pass
class InvalidTrailNameException(BotoServerError):
"""
Raised when the trail name is invalid.
"""
pass
class InternalErrorException(BotoServerError):
"""
Raised when there was an internal Cloudtrail error.
"""
pass
class TrailNotFoundException(BotoServerError):
"""
Raised when the given trail name is not found.
"""
pass
class S3BucketDoesNotExistException(BotoServerError):
"""
Raised when the given S3 bucket does not exist.
"""
pass
class TrailNotProvidedException(BotoServerError):
"""
Raised when no trail name was provided.
"""
pass
class InvalidS3PrefixException(BotoServerError):
"""
Raised when an invalid key prefix is given.
"""
pass
class MaximumNumberOfTrailsExceededException(BotoServerError):
"""
Raised when no more trails can be created.
"""
pass
class InsufficientS3BucketPolicyException(BotoServerError):
"""
Raised when the S3 bucket does not allow Cloudtrail to
write files into the prefix.
"""
pass

View File

@@ -0,0 +1,309 @@
# Copyright (c) 2013 Amazon.com, Inc. or its affiliates. All Rights Reserved
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish, dis-
# tribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the fol-
# lowing conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
try:
import json
except ImportError:
import simplejson as json
import boto
from boto.connection import AWSQueryConnection
from boto.regioninfo import RegionInfo
from boto.exception import JSONResponseError
from boto.cloudtrail import exceptions
class CloudTrailConnection(AWSQueryConnection):
"""
AWS Cloud Trail
This is the CloudTrail API Reference. It provides descriptions of
actions, data types, common parameters, and common errors for
CloudTrail.
CloudTrail is a web service that records AWS API calls for your
AWS account and delivers log files to an Amazon S3 bucket. The
recorded information includes the identity of the user, the start
time of the event, the source IP address, the request parameters,
and the response elements returned by the service.
As an alternative to using the API, you can use one of the AWS
SDKs, which consist of libraries and sample code for various
programming languages and platforms (Java, Ruby, .NET, iOS,
Android, etc.). The SDKs provide a convenient way to create
programmatic access to AWSCloudTrail. For example, the SDKs take
care of cryptographically signing requests, managing errors, and
retrying requests automatically. For information about the AWS
SDKs, including how to download and install them, see the Tools
for Amazon Web Services page.
See the CloudTrail User Guide for information about the data that
is included with each event listed in the log files.
"""
APIVersion = "2013-11-01"
DefaultRegionName = "us-east-1"
DefaultRegionEndpoint = "cloudtrail.us-east-1.amazonaws.com"
ServiceName = "CloudTrail"
TargetPrefix = "com.amazonaws.cloudtrail.v20131101.CloudTrail_20131101"
ResponseError = JSONResponseError
_faults = {
"InvalidSnsTopicNameException": exceptions.InvalidSnsTopicNameException,
"InvalidS3BucketNameException": exceptions.InvalidS3BucketNameException,
"TrailAlreadyExistsException": exceptions.TrailAlreadyExistsException,
"InsufficientSnsTopicPolicyException": exceptions.InsufficientSnsTopicPolicyException,
"InvalidTrailNameException": exceptions.InvalidTrailNameException,
"InternalErrorException": exceptions.InternalErrorException,
"TrailNotFoundException": exceptions.TrailNotFoundException,
"S3BucketDoesNotExistException": exceptions.S3BucketDoesNotExistException,
"TrailNotProvidedException": exceptions.TrailNotProvidedException,
"InvalidS3PrefixException": exceptions.InvalidS3PrefixException,
"MaximumNumberOfTrailsExceededException": exceptions.MaximumNumberOfTrailsExceededException,
"InsufficientS3BucketPolicyException": exceptions.InsufficientS3BucketPolicyException,
}
def __init__(self, **kwargs):
region = kwargs.pop('region', None)
if not region:
region = RegionInfo(self, self.DefaultRegionName,
self.DefaultRegionEndpoint)
if 'host' not in kwargs:
kwargs['host'] = region.endpoint
AWSQueryConnection.__init__(self, **kwargs)
self.region = region
def _required_auth_capability(self):
return ['hmac-v4']
def create_trail(self, trail=None):
"""
From the command line, use create-subscription.
Creates a trail that specifies the settings for delivery of
log data to an Amazon S3 bucket. The request includes a Trail
structure that specifies the following:
+ Trail name.
+ The name of the Amazon S3 bucket to which CloudTrail
delivers your log files.
+ The name of the Amazon S3 key prefix that precedes each log
file.
+ The name of the Amazon SNS topic that notifies you that a
new file is available in your bucket.
+ Whether the log file should include events from global
services. Currently, the only events included in CloudTrail
log files are from IAM and AWS STS.
Returns the appropriate HTTP status code if successful. If
not, it returns either one of the CommonErrors or a
FrontEndException with one of the following error codes:
**MaximumNumberOfTrailsExceeded**
An attempt was made to create more trails than allowed. You
can only create one trail for each account in each region.
**TrailAlreadyExists**
At attempt was made to create a trail with a name that already
exists.
**S3BucketDoesNotExist**
Specified Amazon S3 bucket does not exist.
**InsufficientS3BucketPolicy**
Policy on Amazon S3 bucket does not permit CloudTrail to write
to your bucket. See the AWS AWS CloudTrail User Guide for the
required bucket policy.
**InsufficientSnsTopicPolicy**
The policy on Amazon SNS topic does not permit CloudTrail to
write to it. Can also occur when an Amazon SNS topic does not
exist.
:type trail: dict
:param trail: Contains the Trail structure that specifies the settings
for each trail.
"""
params = {}
if trail is not None:
params['trail'] = trail
return self.make_request(action='CreateTrail',
body=json.dumps(params))
def delete_trail(self, name=None):
"""
Deletes a trail.
:type name: string
:param name: The name of a trail to be deleted.
"""
params = {}
if name is not None:
params['Name'] = name
return self.make_request(action='DeleteTrail',
body=json.dumps(params))
def describe_trails(self, trail_name_list=None):
"""
Retrieves the settings for some or all trails associated with
an account. Returns a list of Trail structures in JSON format.
:type trail_name_list: list
:param trail_name_list: The list of Trail object names.
"""
params = {}
if trail_name_list is not None:
params['trailNameList'] = trail_name_list
return self.make_request(action='DescribeTrails',
body=json.dumps(params))
def get_trail_status(self, name=None):
"""
Returns GetTrailStatusResult, which contains a JSON-formatted
list of information about the trail specified in the request.
JSON fields include information such as delivery errors,
Amazon SNS and Amazon S3 errors, and times that logging
started and stopped for each trail.
:type name: string
:param name: The name of the trail for which you are requesting the
current status.
"""
params = {}
if name is not None:
params['Name'] = name
return self.make_request(action='GetTrailStatus',
body=json.dumps(params))
def start_logging(self, name=None):
"""
Starts the processing of recording user activity events and
log file delivery for a trail.
:type name: string
:param name: The name of the Trail for which CloudTrail logs events.
"""
params = {}
if name is not None:
params['Name'] = name
return self.make_request(action='StartLogging',
body=json.dumps(params))
def stop_logging(self, name=None):
"""
Suspends the recording of user activity events and log file
delivery for the specified trail. Under most circumstances,
there is no need to use this action. You can update a trail
without stopping it first. This action is the only way to stop
logging activity.
:type name: string
:param name: Communicates to CloudTrail the name of the Trail for which
to stop logging events.
"""
params = {}
if name is not None:
params['Name'] = name
return self.make_request(action='StopLogging',
body=json.dumps(params))
def update_trail(self, trail=None):
"""
From the command line, use update-subscription.
Updates the settings that specify delivery of log files.
Changes to a trail do not require stopping the CloudTrail
service. You can use this action to designate an existing
bucket for log delivery, or to create a new bucket and prefix.
If the existing bucket has previously been a target for
CloudTrail log files, an IAM policy exists for the bucket. If
you create a new bucket using UpdateTrail, you need to apply
the policy to the bucket using one of the means provided by
the Amazon S3 service.
The request includes a Trail structure that specifies the
following:
+ Trail name.
+ The name of the Amazon S3 bucket to which CloudTrail
delivers your log files.
+ The name of the Amazon S3 key prefix that precedes each log
file.
+ The name of the Amazon SNS topic that notifies you that a
new file is available in your bucket.
+ Whether the log file should include events from global
services, such as IAM or AWS STS.
**CreateTrail** returns the appropriate HTTP status code if
successful. If not, it returns either one of the common errors
or one of the exceptions listed at the end of this page.
:type trail: dict
:param trail: Represents the Trail structure that contains the
CloudTrail setting for an account.
"""
params = {}
if trail is not None:
params['trail'] = trail
return self.make_request(action='UpdateTrail',
body=json.dumps(params))
def make_request(self, action, body):
headers = {
'X-Amz-Target': '%s.%s' % (self.TargetPrefix, action),
'Host': self.region.endpoint,
'Content-Type': 'application/x-amz-json-1.1',
'Content-Length': str(len(body)),
}
http_request = self.build_base_http_request(
method='POST', path='/', auth_path='/', params={},
headers=headers, data=body)
response = self._mexe(http_request, sender=None,
override_num_retries=10)
response_body = response.read()
boto.log.debug(response_body)
if response.status == 200:
if response_body:
return json.loads(response_body)
else:
json_body = json.loads(response_body)
fault_name = json_body.get('__type', None)
exception_class = self._faults.get(fault_name, self.ResponseError)
raise exception_class(response.status, response.reason,
body=json_body)

View File

@@ -101,7 +101,7 @@ DEFAULT_CA_CERTS_FILE = os.path.join(os.path.dirname(os.path.abspath(boto.cacert
class HostConnectionPool(object): class HostConnectionPool(object):
""" """
A pool of connections for one remote (host,is_secure). A pool of connections for one remote (host,port,is_secure).
When connections are added to the pool, they are put into a When connections are added to the pool, they are put into a
pending queue. The _mexe method returns connections to the pool pending queue. The _mexe method returns connections to the pool
@@ -145,7 +145,7 @@ class HostConnectionPool(object):
def get(self): def get(self):
""" """
Returns the next connection in this pool that is ready to be Returns the next connection in this pool that is ready to be
reused. Returns None of there aren't any. reused. Returns None if there aren't any.
""" """
# Discard ready connections that are too old. # Discard ready connections that are too old.
self.clean() self.clean()
@@ -234,7 +234,7 @@ class ConnectionPool(object):
STALE_DURATION = 60.0 STALE_DURATION = 60.0
def __init__(self): def __init__(self):
# Mapping from (host,is_secure) to HostConnectionPool. # Mapping from (host,port,is_secure) to HostConnectionPool.
# If a pool becomes empty, it is removed. # If a pool becomes empty, it is removed.
self.host_to_pool = {} self.host_to_pool = {}
# The last time the pool was cleaned. # The last time the pool was cleaned.
@@ -259,7 +259,7 @@ class ConnectionPool(object):
""" """
return sum(pool.size() for pool in self.host_to_pool.values()) return sum(pool.size() for pool in self.host_to_pool.values())
def get_http_connection(self, host, is_secure): def get_http_connection(self, host, port, is_secure):
""" """
Gets a connection from the pool for the named host. Returns Gets a connection from the pool for the named host. Returns
None if there is no connection that can be reused. It's the caller's None if there is no connection that can be reused. It's the caller's
@@ -268,18 +268,18 @@ class ConnectionPool(object):
""" """
self.clean() self.clean()
with self.mutex: with self.mutex:
key = (host, is_secure) key = (host, port, is_secure)
if key not in self.host_to_pool: if key not in self.host_to_pool:
return None return None
return self.host_to_pool[key].get() return self.host_to_pool[key].get()
def put_http_connection(self, host, is_secure, conn): def put_http_connection(self, host, port, is_secure, conn):
""" """
Adds a connection to the pool of connections that can be Adds a connection to the pool of connections that can be
reused for the named host. reused for the named host.
""" """
with self.mutex: with self.mutex:
key = (host, is_secure) key = (host, port, is_secure)
if key not in self.host_to_pool: if key not in self.host_to_pool:
self.host_to_pool[key] = HostConnectionPool() self.host_to_pool[key] = HostConnectionPool()
self.host_to_pool[key].put(conn) self.host_to_pool[key].put(conn)
@@ -486,6 +486,11 @@ class AWSAuthConnection(object):
"2.6 or later.") "2.6 or later.")
self.ca_certificates_file = config.get_value( self.ca_certificates_file = config.get_value(
'Boto', 'ca_certificates_file', DEFAULT_CA_CERTS_FILE) 'Boto', 'ca_certificates_file', DEFAULT_CA_CERTS_FILE)
if port:
self.port = port
else:
self.port = PORTS_BY_SECURITY[is_secure]
self.handle_proxy(proxy, proxy_port, proxy_user, proxy_pass) self.handle_proxy(proxy, proxy_port, proxy_user, proxy_pass)
# define exceptions from httplib that we want to catch and retry # define exceptions from httplib that we want to catch and retry
self.http_exceptions = (httplib.HTTPException, socket.error, self.http_exceptions = (httplib.HTTPException, socket.error,
@@ -513,10 +518,6 @@ class AWSAuthConnection(object):
if not isinstance(debug, (int, long)): if not isinstance(debug, (int, long)):
debug = 0 debug = 0
self.debug = config.getint('Boto', 'debug', debug) self.debug = config.getint('Boto', 'debug', debug)
if port:
self.port = port
else:
self.port = PORTS_BY_SECURITY[is_secure]
self.host_header = None self.host_header = None
# Timeout used to tell httplib how long to wait for socket timeouts. # Timeout used to tell httplib how long to wait for socket timeouts.
@@ -551,7 +552,7 @@ class AWSAuthConnection(object):
self.host_header = self.provider.host_header self.host_header = self.provider.host_header
self._pool = ConnectionPool() self._pool = ConnectionPool()
self._connection = (self.server_name(), self.is_secure) self._connection = (self.host, self.port, self.is_secure)
self._last_rs = None self._last_rs = None
self._auth_handler = auth.get_auth_handler( self._auth_handler = auth.get_auth_handler(
host, config, self.provider, self._required_auth_capability()) host, config, self.provider, self._required_auth_capability())
@@ -652,7 +653,7 @@ class AWSAuthConnection(object):
if 'http_proxy' in os.environ and not self.proxy: if 'http_proxy' in os.environ and not self.proxy:
pattern = re.compile( pattern = re.compile(
'(?:http://)?' \ '(?:http://)?' \
'(?:(?P<user>\w+):(?P<pass>.*)@)?' \ '(?:(?P<user>[\w\-\.]+):(?P<pass>.*)@)?' \
'(?P<host>[\w\-\.]+)' \ '(?P<host>[\w\-\.]+)' \
'(?::(?P<port>\d+))?' '(?::(?P<port>\d+))?'
) )
@@ -680,12 +681,12 @@ class AWSAuthConnection(object):
self.no_proxy = os.environ.get('no_proxy', '') or os.environ.get('NO_PROXY', '') self.no_proxy = os.environ.get('no_proxy', '') or os.environ.get('NO_PROXY', '')
self.use_proxy = (self.proxy != None) self.use_proxy = (self.proxy != None)
def get_http_connection(self, host, is_secure): def get_http_connection(self, host, port, is_secure):
conn = self._pool.get_http_connection(host, is_secure) conn = self._pool.get_http_connection(host, port, is_secure)
if conn is not None: if conn is not None:
return conn return conn
else: else:
return self.new_http_connection(host, is_secure) return self.new_http_connection(host, port, is_secure)
def skip_proxy(self, host): def skip_proxy(self, host):
if not self.no_proxy: if not self.no_proxy:
@@ -703,16 +704,29 @@ class AWSAuthConnection(object):
return False return False
def new_http_connection(self, host, is_secure): def new_http_connection(self, host, port, is_secure):
if self.use_proxy and not is_secure and \
not self.skip_proxy(host):
host = '%s:%d' % (self.proxy, int(self.proxy_port))
if host is None: if host is None:
host = self.server_name() host = self.server_name()
# Make sure the host is really just the host, not including
# the port number
host = host.split(':', 1)[0]
http_connection_kwargs = self.http_connection_kwargs.copy()
# Connection factories below expect a port keyword argument
http_connection_kwargs['port'] = port
# Override host with proxy settings if needed
if self.use_proxy and not is_secure and \
not self.skip_proxy(host):
host = self.proxy
http_connection_kwargs['port'] = int(self.proxy_port)
if is_secure: if is_secure:
boto.log.debug( boto.log.debug(
'establishing HTTPS connection: host=%s, kwargs=%s', 'establishing HTTPS connection: host=%s, kwargs=%s',
host, self.http_connection_kwargs) host, http_connection_kwargs)
if self.use_proxy and not self.skip_proxy(host): if self.use_proxy and not self.skip_proxy(host):
connection = self.proxy_ssl(host, is_secure and 443 or 80) connection = self.proxy_ssl(host, is_secure and 443 or 80)
elif self.https_connection_factory: elif self.https_connection_factory:
@@ -720,35 +734,35 @@ class AWSAuthConnection(object):
elif self.https_validate_certificates and HAVE_HTTPS_CONNECTION: elif self.https_validate_certificates and HAVE_HTTPS_CONNECTION:
connection = https_connection.CertValidatingHTTPSConnection( connection = https_connection.CertValidatingHTTPSConnection(
host, ca_certs=self.ca_certificates_file, host, ca_certs=self.ca_certificates_file,
**self.http_connection_kwargs) **http_connection_kwargs)
else: else:
connection = httplib.HTTPSConnection(host, connection = httplib.HTTPSConnection(host,
**self.http_connection_kwargs) **http_connection_kwargs)
else: else:
boto.log.debug('establishing HTTP connection: kwargs=%s' % boto.log.debug('establishing HTTP connection: kwargs=%s' %
self.http_connection_kwargs) http_connection_kwargs)
if self.https_connection_factory: if self.https_connection_factory:
# even though the factory says https, this is too handy # even though the factory says https, this is too handy
# to not be able to allow overriding for http also. # to not be able to allow overriding for http also.
connection = self.https_connection_factory(host, connection = self.https_connection_factory(host,
**self.http_connection_kwargs) **http_connection_kwargs)
else: else:
connection = httplib.HTTPConnection(host, connection = httplib.HTTPConnection(host,
**self.http_connection_kwargs) **http_connection_kwargs)
if self.debug > 1: if self.debug > 1:
connection.set_debuglevel(self.debug) connection.set_debuglevel(self.debug)
# self.connection must be maintained for backwards-compatibility # self.connection must be maintained for backwards-compatibility
# however, it must be dynamically pulled from the connection pool # however, it must be dynamically pulled from the connection pool
# set a private variable which will enable that # set a private variable which will enable that
if host.split(':')[0] == self.host and is_secure == self.is_secure: if host.split(':')[0] == self.host and is_secure == self.is_secure:
self._connection = (host, is_secure) self._connection = (host, port, is_secure)
# Set the response class of the http connection to use our custom # Set the response class of the http connection to use our custom
# class. # class.
connection.response_class = HTTPResponse connection.response_class = HTTPResponse
return connection return connection
def put_http_connection(self, host, is_secure, connection): def put_http_connection(self, host, port, is_secure, connection):
self._pool.put_http_connection(host, is_secure, connection) self._pool.put_http_connection(host, port, is_secure, connection)
def proxy_ssl(self, host=None, port=None): def proxy_ssl(self, host=None, port=None):
if host and port: if host and port:
@@ -841,6 +855,7 @@ class AWSAuthConnection(object):
boto.log.debug('Data: %s' % request.body) boto.log.debug('Data: %s' % request.body)
boto.log.debug('Headers: %s' % request.headers) boto.log.debug('Headers: %s' % request.headers)
boto.log.debug('Host: %s' % request.host) boto.log.debug('Host: %s' % request.host)
boto.log.debug('Port: %s' % request.port)
boto.log.debug('Params: %s' % request.params) boto.log.debug('Params: %s' % request.params)
response = None response = None
body = None body = None
@@ -850,7 +865,8 @@ class AWSAuthConnection(object):
else: else:
num_retries = override_num_retries num_retries = override_num_retries
i = 0 i = 0
connection = self.get_http_connection(request.host, self.is_secure) connection = self.get_http_connection(request.host, request.port,
self.is_secure)
while i <= num_retries: while i <= num_retries:
# Use binary exponential backoff to desynchronize client requests. # Use binary exponential backoff to desynchronize client requests.
next_sleep = random.random() * (2 ** i) next_sleep = random.random() * (2 ** i)
@@ -858,6 +874,12 @@ class AWSAuthConnection(object):
# we now re-sign each request before it is retried # we now re-sign each request before it is retried
boto.log.debug('Token: %s' % self.provider.security_token) boto.log.debug('Token: %s' % self.provider.security_token)
request.authorize(connection=self) request.authorize(connection=self)
# Only force header for non-s3 connections, because s3 uses
# an older signing method + bucket resource URLs that include
# the port info. All others should be now be up to date and
# not include the port.
if 's3' not in self._required_auth_capability():
request.headers['Host'] = self.host.split(':', 1)[0]
if callable(sender): if callable(sender):
response = sender(connection, request.method, request.path, response = sender(connection, request.method, request.path,
request.body, request.headers) request.body, request.headers)
@@ -880,31 +902,45 @@ class AWSAuthConnection(object):
boto.log.debug(msg) boto.log.debug(msg)
time.sleep(next_sleep) time.sleep(next_sleep)
continue continue
if response.status == 500 or response.status == 503: if response.status in [500, 502, 503, 504]:
msg = 'Received %d response. ' % response.status msg = 'Received %d response. ' % response.status
msg += 'Retrying in %3.1f seconds' % next_sleep msg += 'Retrying in %3.1f seconds' % next_sleep
boto.log.debug(msg) boto.log.debug(msg)
body = response.read() body = response.read()
elif response.status < 300 or response.status >= 400 or \ elif response.status < 300 or response.status >= 400 or \
not location: not location:
self.put_http_connection(request.host, self.is_secure, # don't return connection to the pool if response contains
connection) # Connection:close header, because the connection has been
# closed and default reconnect behavior may do something
# different than new_http_connection. Also, it's probably
# less efficient to try to reuse a closed connection.
conn_header_value = response.getheader('connection')
if conn_header_value == 'close':
connection.close()
else:
self.put_http_connection(request.host, request.port,
self.is_secure, connection)
return response return response
else: else:
scheme, request.host, request.path, \ scheme, request.host, request.path, \
params, query, fragment = urlparse.urlparse(location) params, query, fragment = urlparse.urlparse(location)
if query: if query:
request.path += '?' + query request.path += '?' + query
# urlparse can return both host and port in netloc, so if
# that's the case we need to split them up properly
if ':' in request.host:
request.host, request.port = request.host.split(':', 1)
msg = 'Redirecting: %s' % scheme + '://' msg = 'Redirecting: %s' % scheme + '://'
msg += request.host + request.path msg += request.host + request.path
boto.log.debug(msg) boto.log.debug(msg)
connection = self.get_http_connection(request.host, connection = self.get_http_connection(request.host,
request.port,
scheme == 'https') scheme == 'https')
response = None response = None
continue continue
except PleaseRetryException, e: except PleaseRetryException, e:
boto.log.debug('encountered a retry exception: %s' % e) boto.log.debug('encountered a retry exception: %s' % e)
connection = self.new_http_connection(request.host, connection = self.new_http_connection(request.host, request.port,
self.is_secure) self.is_secure)
response = e.response response = e.response
except self.http_exceptions, e: except self.http_exceptions, e:
@@ -913,10 +949,10 @@ class AWSAuthConnection(object):
boto.log.debug( boto.log.debug(
'encountered unretryable %s exception, re-raising' % 'encountered unretryable %s exception, re-raising' %
e.__class__.__name__) e.__class__.__name__)
raise e raise
boto.log.debug('encountered %s exception, reconnecting' % \ boto.log.debug('encountered %s exception, reconnecting' % \
e.__class__.__name__) e.__class__.__name__)
connection = self.new_http_connection(request.host, connection = self.new_http_connection(request.host, request.port,
self.is_secure) self.is_secure)
time.sleep(next_sleep) time.sleep(next_sleep)
i += 1 i += 1
@@ -927,7 +963,7 @@ class AWSAuthConnection(object):
if response: if response:
raise BotoServerError(response.status, response.reason, body) raise BotoServerError(response.status, response.reason, body)
elif e: elif e:
raise e raise
else: else:
msg = 'Please report this exception as a Boto Issue!' msg = 'Please report this exception as a Boto Issue!'
raise BotoClientError(msg) raise BotoClientError(msg)
@@ -1006,7 +1042,7 @@ class AWSQueryConnection(AWSAuthConnection):
def make_request(self, action, params=None, path='/', verb='GET'): def make_request(self, action, params=None, path='/', verb='GET'):
http_request = self.build_base_http_request(verb, path, None, http_request = self.build_base_http_request(verb, path, None,
params, {}, '', params, {}, '',
self.server_name()) self.host)
if action: if action:
http_request.params['Action'] = action http_request.params['Action'] = action
if self.APIVersion: if self.APIVersion:

View File

@@ -50,11 +50,11 @@ class Item(dict):
if range_key == None: if range_key == None:
range_key = attrs.get(self._range_key_name, None) range_key = attrs.get(self._range_key_name, None)
self[self._range_key_name] = range_key self[self._range_key_name] = range_key
self._updates = {}
for key, value in attrs.items(): for key, value in attrs.items():
if key != self._hash_key_name and key != self._range_key_name: if key != self._hash_key_name and key != self._range_key_name:
self[key] = value self[key] = value
self.consumed_units = 0 self.consumed_units = 0
self._updates = {}
@property @property
def hash_key(self): def hash_key(self):

View File

@@ -277,6 +277,10 @@ class Dynamizer(object):
if len(attr) > 1 or not attr: if len(attr) > 1 or not attr:
return attr return attr
dynamodb_type = attr.keys()[0] dynamodb_type = attr.keys()[0]
if dynamodb_type.lower() == dynamodb_type:
# It's not an actual type, just a single character attr that
# overlaps with the DDB types. Return it.
return attr
try: try:
decoder = getattr(self, '_decode_%s' % dynamodb_type.lower()) decoder = getattr(self, '_decode_%s' % dynamodb_type.lower())
except AttributeError: except AttributeError:

View File

@@ -21,7 +21,11 @@
# #
from binascii import crc32 from binascii import crc32
import json try:
import json
except ImportError:
import simplejson as json
import boto import boto
from boto.connection import AWSQueryConnection from boto.connection import AWSQueryConnection
from boto.regioninfo import RegionInfo from boto.regioninfo import RegionInfo
@@ -67,7 +71,11 @@ class DynamoDBConnection(AWSQueryConnection):
if reg.name == region_name: if reg.name == region_name:
region = reg region = reg
break break
kwargs['host'] = region.endpoint
# Only set host if it isn't manually overwritten
if 'host' not in kwargs:
kwargs['host'] = region.endpoint
AWSQueryConnection.__init__(self, **kwargs) AWSQueryConnection.__init__(self, **kwargs)
self.region = region self.region = region
self._validate_checksums = boto.config.getbool( self._validate_checksums = boto.config.getbool(
@@ -1467,13 +1475,13 @@ class DynamoDBConnection(AWSQueryConnection):
def make_request(self, action, body): def make_request(self, action, body):
headers = { headers = {
'X-Amz-Target': '%s.%s' % (self.TargetPrefix, action), 'X-Amz-Target': '%s.%s' % (self.TargetPrefix, action),
'Host': self.region.endpoint, 'Host': self.host,
'Content-Type': 'application/x-amz-json-1.0', 'Content-Type': 'application/x-amz-json-1.0',
'Content-Length': str(len(body)), 'Content-Length': str(len(body)),
} }
http_request = self.build_base_http_request( http_request = self.build_base_http_request(
method='POST', path='/', auth_path='/', params={}, method='POST', path='/', auth_path='/', params={},
headers=headers, data=body) headers=headers, data=body, host=self.host)
response = self._mexe(http_request, sender=None, response = self._mexe(http_request, sender=None,
override_num_retries=self.NumberRetries, override_num_retries=self.NumberRetries,
retry_handler=self._retry_handler) retry_handler=self._retry_handler)

View File

@@ -418,6 +418,45 @@ class Table(object):
item.load(item_data) item.load(item_data)
return item return item
def lookup(self, *args, **kwargs):
"""
Look up an entry in DynamoDB. This is mostly backwards compatible
with boto.dynamodb. Unlike get_item, it takes hash_key and range_key first,
although you may still specify keyword arguments instead.
Also unlike the get_item command, if the returned item has no keys
(i.e., it does not exist in DynamoDB), a None result is returned, instead
of an empty key object.
Example::
>>> user = users.lookup(username)
>>> user = users.lookup(username, consistent=True)
>>> app = apps.lookup('my_customer_id', 'my_app_id')
"""
if not self.schema:
self.describe()
for x, arg in enumerate(args):
kwargs[self.schema[x].name] = arg
ret = self.get_item(**kwargs)
if not ret.keys():
return None
return ret
def new_item(self, *args):
"""
Returns a new, blank item
This is mostly for consistency with boto.dynamodb
"""
if not self.schema:
self.describe()
data = {}
for x, arg in enumerate(args):
data[self.schema[x].name] = arg
return Item(self, data=data)
def put_item(self, data, overwrite=False): def put_item(self, data, overwrite=False):
""" """
Saves an entire item to DynamoDB. Saves an entire item to DynamoDB.
@@ -1164,4 +1203,4 @@ class BatchTable(object):
self.handle_unprocessed(resp) self.handle_unprocessed(resp)
boto.log.info( boto.log.info(
"%s unprocessed items left" % len(self._unprocessed) "%s unprocessed items left" % len(self._unprocessed)
) )

View File

@@ -241,6 +241,10 @@ class AutoScaleConnection(AWSQueryConnection):
params['EbsOptimized'] = 'true' params['EbsOptimized'] = 'true'
else: else:
params['EbsOptimized'] = 'false' params['EbsOptimized'] = 'false'
if launch_config.associate_public_ip_address is True:
params['AssociatePublicIpAddress'] = 'true'
elif launch_config.associate_public_ip_address is False:
params['AssociatePublicIpAddress'] = 'false'
return self.get_object('CreateLaunchConfiguration', params, return self.get_object('CreateLaunchConfiguration', params,
Request, verb='POST') Request, verb='POST')
@@ -492,15 +496,19 @@ class AutoScaleConnection(AWSQueryConnection):
If no group name or list of policy names are provided, all If no group name or list of policy names are provided, all
available policies are returned. available policies are returned.
:type as_name: str :type as_group: str
:param as_name: The name of the :param as_group: The name of the
:class:`boto.ec2.autoscale.group.AutoScalingGroup` to filter for. :class:`boto.ec2.autoscale.group.AutoScalingGroup` to filter for.
:type names: list :type policy_names: list
:param names: List of policy names which should be searched for. :param policy_names: List of policy names which should be searched for.
:type max_records: int :type max_records: int
:param max_records: Maximum amount of groups to return. :param max_records: Maximum amount of groups to return.
:type next_token: str
:param next_token: If you have more results than can be returned
at once, pass in this parameter to page through all results.
""" """
params = {} params = {}
if as_group: if as_group:
@@ -681,9 +689,9 @@ class AutoScaleConnection(AWSQueryConnection):
Configures an Auto Scaling group to send notifications when Configures an Auto Scaling group to send notifications when
specified events take place. specified events take place.
:type as_group: str or :type autoscale_group: str or
:class:`boto.ec2.autoscale.group.AutoScalingGroup` object :class:`boto.ec2.autoscale.group.AutoScalingGroup` object
:param as_group: The Auto Scaling group to put notification :param autoscale_group: The Auto Scaling group to put notification
configuration on. configuration on.
:type topic: str :type topic: str
@@ -692,7 +700,12 @@ class AutoScaleConnection(AWSQueryConnection):
:type notification_types: list :type notification_types: list
:param notification_types: The type of events that will trigger :param notification_types: The type of events that will trigger
the notification. the notification. Valid types are:
'autoscaling:EC2_INSTANCE_LAUNCH',
'autoscaling:EC2_INSTANCE_LAUNCH_ERROR',
'autoscaling:EC2_INSTANCE_TERMINATE',
'autoscaling:EC2_INSTANCE_TERMINATE_ERROR',
'autoscaling:TEST_NOTIFICATION'
""" """
name = autoscale_group name = autoscale_group
@@ -704,6 +717,29 @@ class AutoScaleConnection(AWSQueryConnection):
self.build_list_params(params, notification_types, 'NotificationTypes') self.build_list_params(params, notification_types, 'NotificationTypes')
return self.get_status('PutNotificationConfiguration', params) return self.get_status('PutNotificationConfiguration', params)
def delete_notification_configuration(self, autoscale_group, topic):
"""
Deletes notifications created by put_notification_configuration.
:type autoscale_group: str or
:class:`boto.ec2.autoscale.group.AutoScalingGroup` object
:param autoscale_group: The Auto Scaling group to put notification
configuration on.
:type topic: str
:param topic: The Amazon Resource Name (ARN) of the Amazon Simple
Notification Service (SNS) topic.
"""
name = autoscale_group
if isinstance(autoscale_group, AutoScalingGroup):
name = autoscale_group.name
params = {'AutoScalingGroupName': name,
'TopicARN': topic}
return self.get_status('DeleteNotificationConfiguration', params)
def set_instance_health(self, instance_id, health_status, def set_instance_health(self, instance_id, health_status,
should_respect_grace_period=True): should_respect_grace_period=True):
""" """

View File

@@ -148,6 +148,9 @@ class AutoScalingGroup(object):
:type vpc_zone_identifier: str :type vpc_zone_identifier: str
:param vpc_zone_identifier: The subnet identifier of the Virtual :param vpc_zone_identifier: The subnet identifier of the Virtual
Private Cloud. Private Cloud.
:type tags: list
:param tags: List of :class:`boto.ec2.autoscale.tag.Tag`s
:type termination_policies: list :type termination_policies: list
:param termination_policies: A list of termination policies. Valid values :param termination_policies: A list of termination policies. Valid values
@@ -296,12 +299,23 @@ class AutoScalingGroup(object):
def put_notification_configuration(self, topic, notification_types): def put_notification_configuration(self, topic, notification_types):
""" """
Configures an Auto Scaling group to send notifications when Configures an Auto Scaling group to send notifications when
specified events take place. specified events take place. Valid notification types are:
'autoscaling:EC2_INSTANCE_LAUNCH',
'autoscaling:EC2_INSTANCE_LAUNCH_ERROR',
'autoscaling:EC2_INSTANCE_TERMINATE',
'autoscaling:EC2_INSTANCE_TERMINATE_ERROR',
'autoscaling:TEST_NOTIFICATION'
""" """
return self.connection.put_notification_configuration(self, return self.connection.put_notification_configuration(self,
topic, topic,
notification_types) notification_types)
def delete_notification_configuration(self, topic):
"""
Deletes notifications created by put_notification_configuration.
"""
return self.connection.delete_notification_configuration(self, topic)
def suspend_processes(self, scaling_processes=None): def suspend_processes(self, scaling_processes=None):
""" """
Suspends Auto Scaling processes for an Auto Scaling group. Suspends Auto Scaling processes for an Auto Scaling group.

View File

@@ -94,7 +94,8 @@ class LaunchConfiguration(object):
instance_type='m1.small', kernel_id=None, instance_type='m1.small', kernel_id=None,
ramdisk_id=None, block_device_mappings=None, ramdisk_id=None, block_device_mappings=None,
instance_monitoring=False, spot_price=None, instance_monitoring=False, spot_price=None,
instance_profile_name=None, ebs_optimized=False): instance_profile_name=None, ebs_optimized=False,
associate_public_ip_address=None):
""" """
A launch configuration. A launch configuration.
@@ -109,8 +110,9 @@ class LaunchConfiguration(object):
:param key_name: The name of the EC2 key pair. :param key_name: The name of the EC2 key pair.
:type security_groups: list :type security_groups: list
:param security_groups: Names of the security groups with which to :param security_groups: Names or security group id's of the security
associate the EC2 instances. groups with which to associate the EC2 instances or VPC instances,
respectively.
:type user_data: str :type user_data: str
:param user_data: The user data available to launched EC2 instances. :param user_data: The user data available to launched EC2 instances.
@@ -144,6 +146,10 @@ class LaunchConfiguration(object):
:type ebs_optimized: bool :type ebs_optimized: bool
:param ebs_optimized: Specifies whether the instance is optimized :param ebs_optimized: Specifies whether the instance is optimized
for EBS I/O (true) or not (false). for EBS I/O (true) or not (false).
:type associate_public_ip_address: bool
:param associate_public_ip_address: Used for Auto Scaling groups that launch instances into an Amazon Virtual Private Cloud.
Specifies whether to assign a public IP address to each instance launched in a Amazon VPC.
""" """
self.connection = connection self.connection = connection
self.name = name self.name = name
@@ -163,6 +169,7 @@ class LaunchConfiguration(object):
self.instance_profile_name = instance_profile_name self.instance_profile_name = instance_profile_name
self.launch_configuration_arn = None self.launch_configuration_arn = None
self.ebs_optimized = ebs_optimized self.ebs_optimized = ebs_optimized
self.associate_public_ip_address = associate_public_ip_address
def __repr__(self): def __repr__(self):
return 'LaunchConfiguration:%s' % self.name return 'LaunchConfiguration:%s' % self.name

View File

@@ -55,11 +55,11 @@ class Tag(object):
self.key = value self.key = value
elif name == 'Value': elif name == 'Value':
self.value = value self.value = value
elif name == 'PropogateAtLaunch': elif name == 'PropagateAtLaunch':
if value.lower() == 'true': if value.lower() == 'true':
self.propogate_at_launch = True self.propagate_at_launch = True
else: else:
self.propogate_at_launch = False self.propagate_at_launch = False
elif name == 'ResourceId': elif name == 'ResourceId':
self.resource_id = value self.resource_id = value
elif name == 'ResourceType': elif name == 'ResourceType':

View File

@@ -95,7 +95,7 @@ class MetricAlarm(object):
statistic is applied. statistic is applied.
:type evaluation_periods: int :type evaluation_periods: int
:param evaluation_period: The number of periods over which data is :param evaluation_periods: The number of periods over which data is
compared to the specified threshold. compared to the specified threshold.
:type unit: str :type unit: str
@@ -112,9 +112,16 @@ class MetricAlarm(object):
:type description: str :type description: str
:param description: Description of MetricAlarm :param description: Description of MetricAlarm
:type dimensions: list of dicts :type dimensions: dict
:param description: Dimensions of alarm, such as: :param dimensions: A dictionary of dimension key/values where
[{'InstanceId':['i-0123456,i-0123457']}] the key is the dimension name and the value
is either a scalar value or an iterator
of values to be associated with that
dimension.
Example: {
'InstanceId': ['i-0123456', 'i-0123457'],
'LoadBalancerName': 'test-lb'
}
:type alarm_actions: list of strs :type alarm_actions: list of strs
:param alarm_actions: A list of the ARNs of the actions to take in :param alarm_actions: A list of the ARNs of the actions to take in

View File

@@ -69,7 +69,7 @@ from boto.exception import EC2ResponseError
class EC2Connection(AWSQueryConnection): class EC2Connection(AWSQueryConnection):
APIVersion = boto.config.get('Boto', 'ec2_version', '2013-07-15') APIVersion = boto.config.get('Boto', 'ec2_version', '2013-10-01')
DefaultRegionName = boto.config.get('Boto', 'ec2_region_name', 'us-east-1') DefaultRegionName = boto.config.get('Boto', 'ec2_region_name', 'us-east-1')
DefaultRegionEndpoint = boto.config.get('Boto', 'ec2_region_endpoint', DefaultRegionEndpoint = boto.config.get('Boto', 'ec2_region_endpoint',
'ec2.us-east-1.amazonaws.com') 'ec2.us-east-1.amazonaws.com')
@@ -260,7 +260,7 @@ class EC2Connection(AWSQueryConnection):
def register_image(self, name=None, description=None, image_location=None, def register_image(self, name=None, description=None, image_location=None,
architecture=None, kernel_id=None, ramdisk_id=None, architecture=None, kernel_id=None, ramdisk_id=None,
root_device_name=None, block_device_map=None, root_device_name=None, block_device_map=None,
dry_run=False): dry_run=False, virtualization_type=None):
""" """
Register an image. Register an image.
@@ -293,6 +293,12 @@ class EC2Connection(AWSQueryConnection):
:type dry_run: bool :type dry_run: bool
:param dry_run: Set to True if the operation should not actually run. :param dry_run: Set to True if the operation should not actually run.
:type virtualization_type: string
:param virtualization_type: The virutalization_type of the image.
Valid choices are:
* paravirtual
* hvm
:rtype: string :rtype: string
:return: The new image id :return: The new image id
""" """
@@ -315,6 +321,9 @@ class EC2Connection(AWSQueryConnection):
block_device_map.ec2_build_list_params(params) block_device_map.ec2_build_list_params(params)
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
if virtualization_type:
params['VirtualizationType'] = virtualization_type
rs = self.get_object('RegisterImage', params, ResultSet, verb='POST') rs = self.get_object('RegisterImage', params, ResultSet, verb='POST')
image_id = getattr(rs, 'imageId', None) image_id = getattr(rs, 'imageId', None)
return image_id return image_id
@@ -355,7 +364,8 @@ class EC2Connection(AWSQueryConnection):
return result return result
def create_image(self, instance_id, name, def create_image(self, instance_id, name,
description=None, no_reboot=False, dry_run=False): description=None, no_reboot=False,
block_device_mapping=None, dry_run=False):
""" """
Will create an AMI from the instance in the running or stopped Will create an AMI from the instance in the running or stopped
state. state.
@@ -377,6 +387,10 @@ class EC2Connection(AWSQueryConnection):
responsibility of maintaining file system integrity is responsibility of maintaining file system integrity is
left to the owner of the instance. left to the owner of the instance.
:type block_device_mapping: :class:`boto.ec2.blockdevicemapping.BlockDeviceMapping`
:param block_device_mapping: A BlockDeviceMapping data structure
describing the EBS volumes associated with the Image.
:type dry_run: bool :type dry_run: bool
:param dry_run: Set to True if the operation should not actually run. :param dry_run: Set to True if the operation should not actually run.
@@ -389,6 +403,8 @@ class EC2Connection(AWSQueryConnection):
params['Description'] = description params['Description'] = description
if no_reboot: if no_reboot:
params['NoReboot'] = 'true' params['NoReboot'] = 'true'
if block_device_mapping:
block_device_mapping.ec2_build_list_params(params)
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
img = self.get_object('CreateImage', params, Image, verb='POST') img = self.get_object('CreateImage', params, Image, verb='POST')
@@ -1500,7 +1516,7 @@ class EC2Connection(AWSQueryConnection):
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_list('CancelSpotInstanceRequests', params, return self.get_list('CancelSpotInstanceRequests', params,
[('item', Instance)], verb='POST') [('item', SpotInstanceRequest)], verb='POST')
def get_spot_datafeed_subscription(self, dry_run=False): def get_spot_datafeed_subscription(self, dry_run=False):
""" """
@@ -2189,17 +2205,17 @@ class EC2Connection(AWSQueryConnection):
present, only the Snapshots associated with present, only the Snapshots associated with
these snapshot ids will be returned. these snapshot ids will be returned.
:type owner: str :type owner: str or list
:param owner: If present, only the snapshots owned by the specified user :param owner: If present, only the snapshots owned by the specified user(s)
will be returned. Valid values are: will be returned. Valid values are:
* self * self
* amazon * amazon
* AWS Account ID * AWS Account ID
:type restorable_by: str :type restorable_by: str or list
:param restorable_by: If present, only the snapshots that are restorable :param restorable_by: If present, only the snapshots that are restorable
by the specified account id will be returned. by the specified account id(s) will be returned.
:type filters: dict :type filters: dict
:param filters: Optional filters that can be used to limit :param filters: Optional filters that can be used to limit
@@ -2220,10 +2236,11 @@ class EC2Connection(AWSQueryConnection):
params = {} params = {}
if snapshot_ids: if snapshot_ids:
self.build_list_params(params, snapshot_ids, 'SnapshotId') self.build_list_params(params, snapshot_ids, 'SnapshotId')
if owner: if owner:
params['Owner'] = owner self.build_list_params(params, owner, 'Owner')
if restorable_by: if restorable_by:
params['RestorableBy'] = restorable_by self.build_list_params(params, restorable_by, 'RestorableBy')
if filters: if filters:
self.build_filter_params(params, filters) self.build_filter_params(params, filters)
if dry_run: if dry_run:

View File

@@ -188,13 +188,13 @@ class ELBConnection(AWSQueryConnection):
(LoadBalancerPortNumber, InstancePortNumber, Protocol, InstanceProtocol, (LoadBalancerPortNumber, InstancePortNumber, Protocol, InstanceProtocol,
SSLCertificateId). SSLCertificateId).
Where; Where:
- LoadBalancerPortNumber and InstancePortNumber are integer - LoadBalancerPortNumber and InstancePortNumber are integer
values between 1 and 65535. values between 1 and 65535
- Protocol and InstanceProtocol is a string containing either 'TCP', - Protocol and InstanceProtocol is a string containing either 'TCP',
'SSL', 'HTTP', or 'HTTPS' 'SSL', 'HTTP', or 'HTTPS'
- SSLCertificateId is the ARN of an SSL certificate loaded into - SSLCertificateId is the ARN of an SSL certificate loaded into
AWS IAM AWS IAM
:rtype: :class:`boto.ec2.elb.loadbalancer.LoadBalancer` :rtype: :class:`boto.ec2.elb.loadbalancer.LoadBalancer`
:return: The newly created :return: The newly created
@@ -272,13 +272,13 @@ class ELBConnection(AWSQueryConnection):
(LoadBalancerPortNumber, InstancePortNumber, Protocol, InstanceProtocol, (LoadBalancerPortNumber, InstancePortNumber, Protocol, InstanceProtocol,
SSLCertificateId). SSLCertificateId).
Where; Where:
- LoadBalancerPortNumber and InstancePortNumber are integer - LoadBalancerPortNumber and InstancePortNumber are integer
values between 1 and 65535. values between 1 and 65535
- Protocol and InstanceProtocol is a string containing either 'TCP', - Protocol and InstanceProtocol is a string containing either 'TCP',
'SSL', 'HTTP', or 'HTTPS' 'SSL', 'HTTP', or 'HTTPS'
- SSLCertificateId is the ARN of an SSL certificate loaded into - SSLCertificateId is the ARN of an SSL certificate loaded into
AWS IAM AWS IAM
:return: The status of the request :return: The status of the request
""" """

View File

@@ -342,7 +342,7 @@ class LoadBalancer(object):
""" """
if isinstance(subnets, str) or isinstance(subnets, unicode): if isinstance(subnets, str) or isinstance(subnets, unicode):
subnets = [subnets] subnets = [subnets]
new_subnets = self.connection.detach_lb_to_subnets(self.name, subnets) new_subnets = self.connection.detach_lb_from_subnets(self.name, subnets)
self.subnets = new_subnets self.subnets = new_subnets
def apply_security_groups(self, security_groups): def apply_security_groups(self, security_groups):

View File

@@ -340,14 +340,6 @@ class Instance(TaggedEC2Object):
self.ami_launch_index = value self.ami_launch_index = value
elif name == 'previousState': elif name == 'previousState':
self.previous_state = value self.previous_state = value
elif name == 'name':
self.state = value
elif name == 'code':
try:
self.state_code = int(value)
except ValueError:
boto.log.warning('Error converting code (%s) to int' % value)
self.state_code = value
elif name == 'instanceType': elif name == 'instanceType':
self.instance_type = value self.instance_type = value
elif name == 'rootDeviceName': elif name == 'rootDeviceName':

View File

@@ -234,11 +234,12 @@ class PriceSchedule(object):
class ReservedInstancesConfiguration(object): class ReservedInstancesConfiguration(object):
def __init__(self, connection=None, availability_zone=None, platform=None, def __init__(self, connection=None, availability_zone=None, platform=None,
instance_count=None): instance_count=None, instance_type=None):
self.connection = connection self.connection = connection
self.availability_zone = availability_zone self.availability_zone = availability_zone
self.platform = platform self.platform = platform
self.instance_count = instance_count self.instance_count = instance_count
self.instance_type = instance_type
def startElement(self, name, attrs, connection): def startElement(self, name, attrs, connection):
return None return None
@@ -250,6 +251,8 @@ class ReservedInstancesConfiguration(object):
self.platform = value self.platform = value
elif name == 'instanceCount': elif name == 'instanceCount':
self.instance_count = int(value) self.instance_count = int(value)
elif name == 'instanceType':
self.instance_type = value
else: else:
setattr(self, name, value) setattr(self, name, value)
@@ -271,12 +274,14 @@ class ModifyReservedInstancesResult(object):
class ModificationResult(object): class ModificationResult(object):
def __init__(self, connection=None, modification_id=None, def __init__(self, connection=None, modification_id=None,
availability_zone=None, platform=None, instance_count=None): availability_zone=None, platform=None, instance_count=None,
instance_type=None):
self.connection = connection self.connection = connection
self.modification_id = modification_id self.modification_id = modification_id
self.availability_zone = availability_zone self.availability_zone = availability_zone
self.platform = platform self.platform = platform
self.instance_count = instance_count self.instance_count = instance_count
self.instance_type = instance_type
def startElement(self, name, attrs, connection): def startElement(self, name, attrs, connection):
return None return None
@@ -290,6 +295,8 @@ class ModificationResult(object):
self.platform = value self.platform = value
elif name == 'instanceCount': elif name == 'instanceCount':
self.instance_count = int(value) self.instance_count = int(value)
elif name == 'instanceType':
self.instance_type = value
else: else:
setattr(self, name, value) setattr(self, name, value)

View File

@@ -123,6 +123,9 @@ class SecurityGroup(TaggedEC2Object):
only changes the local version of the object. No information only changes the local version of the object. No information
is sent to EC2. is sent to EC2.
""" """
if not self.rules:
raise ValueError("The security group has no rules")
target_rule = None target_rule = None
for rule in self.rules: for rule in self.rules:
if rule.ip_protocol == ip_protocol: if rule.ip_protocol == ip_protocol:
@@ -136,9 +139,9 @@ class SecurityGroup(TaggedEC2Object):
if grant.cidr_ip == cidr_ip: if grant.cidr_ip == cidr_ip:
target_grant = grant target_grant = grant
if target_grant: if target_grant:
rule.grants.remove(target_grant, dry_run=dry_run) rule.grants.remove(target_grant)
if len(rule.grants) == 0: if len(rule.grants) == 0:
self.rules.remove(target_rule, dry_run=dry_run) self.rules.remove(target_rule)
def authorize(self, ip_protocol=None, from_port=None, to_port=None, def authorize(self, ip_protocol=None, from_port=None, to_port=None,
cidr_ip=None, src_group=None, dry_run=False): cidr_ip=None, src_group=None, dry_run=False):

View File

@@ -387,8 +387,8 @@ class ElasticTranscoderConnection(AWSAuthConnection):
:param description: A description of the preset. :param description: A description of the preset.
:type container: string :type container: string
:param container: The container type for the output file. This value :param container: The container type for the output file. Valid values
must be `mp4`. include `mp3`, `mp4`, `ogg`, `ts`, and `webm`.
:type video: dict :type video: dict
:param video: A section of the request body that specifies the video :param video: A section of the request body that specifies the video

View File

@@ -43,25 +43,25 @@ def regions():
endpoint='elasticmapreduce.us-east-1.amazonaws.com', endpoint='elasticmapreduce.us-east-1.amazonaws.com',
connection_cls=EmrConnection), connection_cls=EmrConnection),
RegionInfo(name='us-west-1', RegionInfo(name='us-west-1',
endpoint='elasticmapreduce.us-west-1.amazonaws.com', endpoint='us-west-1.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection), connection_cls=EmrConnection),
RegionInfo(name='us-west-2', RegionInfo(name='us-west-2',
endpoint='elasticmapreduce.us-west-2.amazonaws.com', endpoint='us-west-2.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection), connection_cls=EmrConnection),
RegionInfo(name='ap-northeast-1', RegionInfo(name='ap-northeast-1',
endpoint='elasticmapreduce.ap-northeast-1.amazonaws.com', endpoint='ap-northeast-1.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection), connection_cls=EmrConnection),
RegionInfo(name='ap-southeast-1', RegionInfo(name='ap-southeast-1',
endpoint='elasticmapreduce.ap-southeast-1.amazonaws.com', endpoint='ap-southeast-1.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection), connection_cls=EmrConnection),
RegionInfo(name='ap-southeast-2', RegionInfo(name='ap-southeast-2',
endpoint='elasticmapreduce.ap-southeast-2.amazonaws.com', endpoint='ap-southeast-2.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection), connection_cls=EmrConnection),
RegionInfo(name='eu-west-1', RegionInfo(name='eu-west-1',
endpoint='elasticmapreduce.eu-west-1.amazonaws.com', endpoint='eu-west-1.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection), connection_cls=EmrConnection),
RegionInfo(name='sa-east-1', RegionInfo(name='sa-east-1',
endpoint='elasticmapreduce.sa-east-1.amazonaws.com', endpoint='sa-east-1.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection), connection_cls=EmrConnection),
] ]

View File

@@ -28,9 +28,12 @@ import types
import boto import boto
import boto.utils import boto.utils
from boto.ec2.regioninfo import RegionInfo from boto.ec2.regioninfo import RegionInfo
from boto.emr.emrobject import JobFlow, RunJobFlowResponse from boto.emr.emrobject import AddInstanceGroupsResponse, BootstrapActionList, \
from boto.emr.emrobject import AddInstanceGroupsResponse Cluster, ClusterSummaryList, HadoopStep, \
from boto.emr.emrobject import ModifyInstanceGroupsResponse InstanceGroupList, InstanceList, JobFlow, \
JobFlowStepList, \
ModifyInstanceGroupsResponse, \
RunJobFlowResponse, StepSummaryList
from boto.emr.step import JarStep from boto.emr.step import JarStep
from boto.connection import AWSQueryConnection from boto.connection import AWSQueryConnection
from boto.exception import EmrResponseError from boto.exception import EmrResponseError
@@ -65,10 +68,30 @@ class EmrConnection(AWSQueryConnection):
https_connection_factory, path, https_connection_factory, path,
security_token, security_token,
validate_certs=validate_certs) validate_certs=validate_certs)
# Many of the EMR hostnames are of the form:
# <region>.<service_name>.amazonaws.com
# rather than the more common:
# <service_name>.<region>.amazonaws.com
# so we need to explicitly set the region_name and service_name
# for the SigV4 signing.
self.auth_region_name = self.region.name
self.auth_service_name = 'elasticmapreduce'
def _required_auth_capability(self): def _required_auth_capability(self):
return ['hmac-v4'] return ['hmac-v4']
def describe_cluster(self, cluster_id):
"""
Describes an Elastic MapReduce cluster
:type cluster_id: str
:param cluster_id: The cluster id of interest
"""
params = {
'ClusterId': cluster_id
}
return self.get_object('DescribeCluster', params, Cluster)
def describe_jobflow(self, jobflow_id): def describe_jobflow(self, jobflow_id):
""" """
Describes a single Elastic MapReduce job flow Describes a single Elastic MapReduce job flow
@@ -111,6 +134,139 @@ class EmrConnection(AWSQueryConnection):
return self.get_list('DescribeJobFlows', params, [('member', JobFlow)]) return self.get_list('DescribeJobFlows', params, [('member', JobFlow)])
def describe_step(self, cluster_id, step_id):
"""
Describe an Elastic MapReduce step
:type cluster_id: str
:param cluster_id: The cluster id of interest
:type step_id: str
:param step_id: The step id of interest
"""
params = {
'ClusterId': cluster_id,
'StepId': step_id
}
return self.get_object('DescribeStep', params, HadoopStep)
def list_bootstrap_actions(self, cluster_id, marker=None):
"""
Get a list of bootstrap actions for an Elastic MapReduce cluster
:type cluster_id: str
:param cluster_id: The cluster id of interest
:type marker: str
:param marker: Pagination marker
"""
params = {
'ClusterId': cluster_id
}
if marker:
params['Marker'] = marker
return self.get_object('ListBootstrapActions', params, BootstrapActionList)
def list_clusters(self, created_after=None, created_before=None,
cluster_states=None, marker=None):
"""
List Elastic MapReduce clusters with optional filtering
:type created_after: datetime
:param created_after: Bound on cluster creation time
:type created_before: datetime
:param created_before: Bound on cluster creation time
:type cluster_states: list
:param cluster_states: Bound on cluster states
:type marker: str
:param marker: Pagination marker
"""
params = {}
if created_after:
params['CreatedAfter'] = created_after.strftime(
boto.utils.ISO8601)
if created_before:
params['CreatedBefore'] = created_before.strftime(
boto.utils.ISO8601)
if marker:
params['Marker'] = marker
if cluster_states:
self.build_list_params(params, cluster_states, 'ClusterStates.member')
return self.get_object('ListClusters', params, ClusterSummaryList)
def list_instance_groups(self, cluster_id, marker=None):
"""
List EC2 instance groups in a cluster
:type cluster_id: str
:param cluster_id: The cluster id of interest
:type marker: str
:param marker: Pagination marker
"""
params = {
'ClusterId': cluster_id
}
if marker:
params['Marker'] = marker
return self.get_object('ListInstanceGroups', params, InstanceGroupList)
def list_instances(self, cluster_id, instance_group_id=None,
instance_group_types=None, marker=None):
"""
List EC2 instances in a cluster
:type cluster_id: str
:param cluster_id: The cluster id of interest
:type instance_group_id: str
:param instance_group_id: The EC2 instance group id of interest
:type instance_group_types: list
:param instance_group_types: Filter by EC2 instance group type
:type marker: str
:param marker: Pagination marker
"""
params = {
'ClusterId': cluster_id
}
if instance_group_id:
params['InstanceGroupId'] = instance_group_id
if marker:
params['Marker'] = marker
if instance_group_types:
self.build_list_params(params, instance_group_types,
'InstanceGroupTypeList.member')
return self.get_object('ListInstances', params, InstanceList)
def list_steps(self, cluster_id, step_states=None, marker=None):
"""
List cluster steps
:type cluster_id: str
:param cluster_id: The cluster id of interest
:type step_states: list
:param step_states: Filter by step states
:type marker: str
:param marker: Pagination marker
"""
params = {
'ClusterId': cluster_id
}
if marker:
params['Marker'] = marker
if step_states:
self.build_list_params(params, step_states, 'StepStateList.member')
self.get_object('ListSteps', params, StepSummaryList)
def terminate_jobflow(self, jobflow_id): def terminate_jobflow(self, jobflow_id):
""" """
Terminate an Elastic MapReduce job flow Terminate an Elastic MapReduce job flow
@@ -150,7 +306,7 @@ class EmrConnection(AWSQueryConnection):
params.update(self._build_step_list(step_args)) params.update(self._build_step_list(step_args))
return self.get_object( return self.get_object(
'AddJobFlowSteps', params, RunJobFlowResponse, verb='POST') 'AddJobFlowSteps', params, JobFlowStepList, verb='POST')
def add_instance_groups(self, jobflow_id, instance_groups): def add_instance_groups(self, jobflow_id, instance_groups):
""" """

View File

@@ -60,11 +60,29 @@ class Arg(EmrObject):
self.value = value self.value = value
class StepId(Arg):
pass
class JobFlowStepList(EmrObject):
def __ini__(self, connection=None):
self.connection = connection
self.stepids = None
def startElement(self, name, attrs, connection):
if name == 'StepIds':
self.stepids = ResultSet([('member', StepId)])
return self.stepids
else:
return None
class BootstrapAction(EmrObject): class BootstrapAction(EmrObject):
Fields = set([ Fields = set([
'Args', 'Args',
'Name', 'Name',
'Path', 'Path',
'ScriptPath',
]) ])
def startElement(self, name, attrs, connection): def startElement(self, name, attrs, connection):
@@ -174,3 +192,281 @@ class JobFlow(EmrObject):
return self.bootstrapactions return self.bootstrapactions
else: else:
return None return None
class ClusterTimeline(EmrObject):
Fields = set([
'CreationDateTime',
'ReadyDateTime',
'EndDateTime'
])
class ClusterStatus(EmrObject):
Fields = set([
'State',
'StateChangeReason',
'Timeline'
])
def __init__(self, connection=None):
self.connection = connection
self.timeline = None
def startElement(self, name, attrs, connection):
if name == 'Timeline':
self.timeline = ClusterTimeline()
return self.timeline
else:
return None
class Ec2InstanceAttributes(EmrObject):
Fields = set([
'Ec2KeyName',
'Ec2SubnetId',
'Ec2AvailabilityZone',
'IamInstanceProfile'
])
class Application(EmrObject):
Fields = set([
'Name',
'Version',
'Args',
'AdditionalInfo'
])
class Cluster(EmrObject):
Fields = set([
'Id',
'Name',
'LogUri',
'RequestedAmiVersion',
'RunningAmiVersion',
'AutoTerminate',
'TerminationProtected',
'VisibleToAllUsers'
])
def __init__(self, connection=None):
self.connection = connection
self.status = None
self.ec2instanceattributes = None
self.applications = None
def startElement(self, name, attrs, connection):
if name == 'Status':
self.status = ClusterStatus()
return self.status
elif name == 'EC2InstanceAttributes':
self.ec2instanceattributes = Ec2InstanceAttributes()
return self.ec2instanceattributes
elif name == 'Applications':
self.applications = ResultSet([('member', Application)])
else:
return None
class ClusterSummary(Cluster):
Fields = set([
'Id',
'Name'
])
class ClusterSummaryList(EmrObject):
Fields = set([
'Marker'
])
def __init__(self, connection):
self.connection = connection
self.clusters = None
def startElement(self, name, attrs, connection):
if name == 'Clusters':
self.clusters = ResultSet([('member', ClusterSummary)])
return self.clusters
else:
return None
class StepConfig(EmrObject):
Fields = set([
'Jar'
'MainClass'
])
def __init__(self, connection=None):
self.connection = connection
self.properties = None
self.args = None
def startElement(self, name, attrs, connection):
if name == 'Properties':
self.properties = ResultSet([('member', KeyValue)])
return self.properties
elif name == 'Args':
self.args = ResultSet([('member', Arg)])
return self.args
else:
return None
class HadoopStep(EmrObject):
Fields = set([
'Id',
'Name',
'ActionOnFailure'
])
def __init__(self, connection=None):
self.connection = connection
self.config = None
self.status = None
def startElement(self, name, attrs, connection):
if name == 'Config':
self.config = StepConfig()
return self.config
elif name == 'Status':
self.status = ClusterStatus()
return self.status
else:
return None
class InstanceGroupInfo(EmrObject):
Fields = set([
'Id',
'Name',
'Market',
'InstanceGroupType',
'BidPrice',
'InstanceType',
'RequestedInstanceCount',
'RunningInstanceCount'
])
def __init__(self, connection=None):
self.connection = connection
self.status = None
def startElement(self, name, attrs, connection):
if name == 'Status':
self.status = ClusterStatus()
return self.status
else:
return None
class InstanceGroupList(EmrObject):
Fields = set([
'Marker'
])
def __init__(self, connection=None):
self.connection = connection
self.instancegroups = None
def startElement(self, name, attrs, connection):
if name == 'InstanceGroups':
self.instancegroups = ResultSet([('member', InstanceGroupInfo)])
return self.instancegroups
else:
return None
class InstanceInfo(EmrObject):
Fields = set([
'Id',
'Ec2InstanceId',
'PublicDnsName',
'PublicIpAddress',
'PrivateDnsName',
'PrivateIpAddress'
])
def __init__(self, connection=None):
self.connection = connection
self.status = None
def startElement(self, name, attrs, connection):
if name == 'Status':
self.status = ClusterStatus()
return self.status
else:
return None
class InstanceList(EmrObject):
Fields = set([
'Marker'
])
def __init__(self, connection=None):
self.connection = connection
self.instances = None
def startElement(self, name, attrs, connection):
if name == 'Instances':
self.instances = ResultSet([('member', InstanceInfo)])
return self.instances
else:
return None
class StepSummary(EmrObject):
Fields = set([
'Id',
'Name'
])
def __init__(self, connection=None):
self.connection = connection
self.status = None
def startElement(self, name, attrs, connection):
if name == 'Status':
self.status = ClusterStatus()
return self.status
else:
return None
class StepSummaryList(EmrObject):
Fields = set([
'Marker'
])
def __init__(self, connection=None):
self.connection = connection
self.steps = None
def startElement(self, name, attrs, connection):
if name == 'Steps':
self.steps = ResultSet([('member', StepSummary)])
return self.steps
else:
return None
class BootstrapActionList(EmrObject):
Fields = set([
'Marker'
])
def __init__(self, connection=None):
self.connection = connection
self.actions = None
def startElement(self, name, attrs, connection):
if name == 'BootstrapActions':
self.actions = ResultSet([('member', BootstrapAction)])
return self.actions
else:
return None

View File

@@ -47,6 +47,9 @@ def regions():
RegionInfo(name='eu-west-1', RegionInfo(name='eu-west-1',
endpoint='glacier.eu-west-1.amazonaws.com', endpoint='glacier.eu-west-1.amazonaws.com',
connection_cls=Layer2), connection_cls=Layer2),
RegionInfo(name='ap-southeast-2',
endpoint='glacier.ap-southeast-2.amazonaws.com',
connection_cls=Layer2),
] ]

View File

@@ -19,12 +19,14 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE. # IN THE SOFTWARE.
import re
import urllib import urllib
import xml.sax import xml.sax
import boto import boto
from boto import handler from boto import handler
from boto.resultset import ResultSet from boto.resultset import ResultSet
from boto.exception import GSResponseError
from boto.exception import InvalidAclError from boto.exception import InvalidAclError
from boto.gs.acl import ACL, CannedACLStrings from boto.gs.acl import ACL, CannedACLStrings
from boto.gs.acl import SupportedPermissions as GSPermissions from boto.gs.acl import SupportedPermissions as GSPermissions
@@ -41,6 +43,7 @@ DEF_OBJ_ACL = 'defaultObjectAcl'
STANDARD_ACL = 'acl' STANDARD_ACL = 'acl'
CORS_ARG = 'cors' CORS_ARG = 'cors'
LIFECYCLE_ARG = 'lifecycle' LIFECYCLE_ARG = 'lifecycle'
ERROR_DETAILS_REGEX = re.compile(r'<Details>(?P<details>.*)</Details>')
class Bucket(S3Bucket): class Bucket(S3Bucket):
"""Represents a Google Cloud Storage bucket.""" """Represents a Google Cloud Storage bucket."""
@@ -99,9 +102,16 @@ class Bucket(S3Bucket):
if response_headers: if response_headers:
for rk, rv in response_headers.iteritems(): for rk, rv in response_headers.iteritems():
query_args_l.append('%s=%s' % (rk, urllib.quote(rv))) query_args_l.append('%s=%s' % (rk, urllib.quote(rv)))
try:
key, resp = self._get_key_internal(key_name, headers, key, resp = self._get_key_internal(key_name, headers,
query_args_l=query_args_l) query_args_l=query_args_l)
except GSResponseError, e:
if e.status == 403 and 'Forbidden' in e.reason:
# If we failed getting an object, let the user know which object
# failed rather than just returning a generic 403.
e.reason = ("Access denied to 'gs://%s/%s'." %
(self.name, key_name))
raise
return key return key
def copy_key(self, new_key_name, src_bucket_name, src_key_name, def copy_key(self, new_key_name, src_bucket_name, src_key_name,
@@ -312,6 +322,14 @@ class Bucket(S3Bucket):
headers=headers) headers=headers)
body = response.read() body = response.read()
if response.status != 200: if response.status != 200:
if response.status == 403:
match = ERROR_DETAILS_REGEX.search(body)
details = match.group('details') if match else None
if details:
details = (('<Details>%s. Note that Full Control access'
' is required to access ACLs.</Details>') %
details)
body = re.sub(ERROR_DETAILS_REGEX, details, body)
raise self.connection.provider.storage_response_error( raise self.connection.provider.storage_response_error(
response.status, response.reason, body) response.status, response.reason, body)
return body return body

View File

@@ -482,7 +482,7 @@ class ResumableUploadHandler(object):
# pool connections) because httplib requires a new HTTP connection per # pool connections) because httplib requires a new HTTP connection per
# transaction. (Without this, calling http_conn.getresponse() would get # transaction. (Without this, calling http_conn.getresponse() would get
# "ResponseNotReady".) # "ResponseNotReady".)
http_conn = conn.new_http_connection(self.tracker_uri_host, http_conn = conn.new_http_connection(self.tracker_uri_host, conn.port,
conn.is_secure) conn.is_secure)
http_conn.set_debuglevel(conn.debug) http_conn.set_debuglevel(conn.debug)

View File

@@ -38,6 +38,8 @@ class XmlHandler(xml.sax.ContentHandler):
def endElement(self, name): def endElement(self, name):
self.nodes[-1][1].endElement(name, self.current_text, self.connection) self.nodes[-1][1].endElement(name, self.current_text, self.connection)
if self.nodes[-1][0] == name: if self.nodes[-1][0] == name:
if hasattr(self.nodes[-1][1], 'endNode'):
self.nodes[-1][1].endNode(self.connection)
self.nodes.pop() self.nodes.pop()
self.current_text = '' self.current_text = ''

View File

@@ -836,7 +836,7 @@ class IAMConnection(AWSQueryConnection):
:param user_name: The username of the user :param user_name: The username of the user
:type serial_number: string :type serial_number: string
:param seriasl_number: The serial number which uniquely identifies :param serial_number: The serial number which uniquely identifies
the MFA device. the MFA device.
:type auth_code_1: string :type auth_code_1: string
@@ -862,7 +862,7 @@ class IAMConnection(AWSQueryConnection):
:param user_name: The username of the user :param user_name: The username of the user
:type serial_number: string :type serial_number: string
:param seriasl_number: The serial number which uniquely identifies :param serial_number: The serial number which uniquely identifies
the MFA device. the MFA device.
""" """
@@ -879,7 +879,7 @@ class IAMConnection(AWSQueryConnection):
:param user_name: The username of the user :param user_name: The username of the user
:type serial_number: string :type serial_number: string
:param seriasl_number: The serial number which uniquely identifies :param serial_number: The serial number which uniquely identifies
the MFA device. the MFA device.
:type auth_code_1: string :type auth_code_1: string

View File

@@ -34,10 +34,11 @@ class SSHClient(object):
def __init__(self, server, def __init__(self, server,
host_key_file='~/.ssh/known_hosts', host_key_file='~/.ssh/known_hosts',
uname='root', ssh_pwd=None): uname='root', timeout=None, ssh_pwd=None):
self.server = server self.server = server
self.host_key_file = host_key_file self.host_key_file = host_key_file
self.uname = uname self.uname = uname
self._timeout = timeout
self._pkey = paramiko.RSAKey.from_private_key_file(server.ssh_key_file, self._pkey = paramiko.RSAKey.from_private_key_file(server.ssh_key_file,
password=ssh_pwd) password=ssh_pwd)
self._ssh_client = paramiko.SSHClient() self._ssh_client = paramiko.SSHClient()
@@ -52,7 +53,8 @@ class SSHClient(object):
try: try:
self._ssh_client.connect(self.server.hostname, self._ssh_client.connect(self.server.hostname,
username=self.uname, username=self.uname,
pkey=self._pkey) pkey=self._pkey,
timeout=self._timeout)
return return
except socket.error, (value, message): except socket.error, (value, message):
if value in (51, 61, 111): if value in (51, 61, 111):

View File

@@ -37,15 +37,16 @@ api_version_path = {
'Products': ('2011-10-01', 'SellerId', '/Products/2011-10-01'), 'Products': ('2011-10-01', 'SellerId', '/Products/2011-10-01'),
'Sellers': ('2011-07-01', 'SellerId', '/Sellers/2011-07-01'), 'Sellers': ('2011-07-01', 'SellerId', '/Sellers/2011-07-01'),
'Inbound': ('2010-10-01', 'SellerId', 'Inbound': ('2010-10-01', 'SellerId',
'/FulfillmentInboundShipment/2010-10-01'), '/FulfillmentInboundShipment/2010-10-01'),
'Outbound': ('2010-10-01', 'SellerId', 'Outbound': ('2010-10-01', 'SellerId',
'/FulfillmentOutboundShipment/2010-10-01'), '/FulfillmentOutboundShipment/2010-10-01'),
'Inventory': ('2010-10-01', 'SellerId', 'Inventory': ('2010-10-01', 'SellerId',
'/FulfillmentInventory/2010-10-01'), '/FulfillmentInventory/2010-10-01'),
} }
content_md5 = lambda c: base64.encodestring(hashlib.md5(c).digest()).strip() content_md5 = lambda c: base64.encodestring(hashlib.md5(c).digest()).strip()
decorated_attrs = ('action', 'response', 'section', decorated_attrs = ('action', 'response', 'section',
'quota', 'restore', 'version') 'quota', 'restore', 'version')
api_call_map = {}
def add_attrs_from(func, to): def add_attrs_from(func, to):
@@ -67,7 +68,7 @@ def structured_lists(*fields):
kw.pop(key) kw.pop(key)
return func(self, *args, **kw) return func(self, *args, **kw)
wrapper.__doc__ = "{0}\nLists: {1}".format(func.__doc__, wrapper.__doc__ = "{0}\nLists: {1}".format(func.__doc__,
', '.join(fields)) ', '.join(fields))
return add_attrs_from(func, to=wrapper) return add_attrs_from(func, to=wrapper)
return decorator return decorator
@@ -101,7 +102,7 @@ def destructure_object(value, into={}, prefix=''):
destructure_object(attr, into=into, prefix=prefix + '.' + name) destructure_object(attr, into=into, prefix=prefix + '.' + name)
elif filter(lambda x: isinstance(value, x), (list, set, tuple)): elif filter(lambda x: isinstance(value, x), (list, set, tuple)):
for index, element in [(prefix + '.' + str(i + 1), value[i]) for index, element in [(prefix + '.' + str(i + 1), value[i])
for i in range(len(value))]: for i in range(len(value))]:
destructure_object(element, into=into, prefix=index) destructure_object(element, into=into, prefix=index)
elif isinstance(value, bool): elif isinstance(value, bool):
into[prefix] = str(value).lower() into[prefix] = str(value).lower()
@@ -118,7 +119,7 @@ def structured_objects(*fields):
destructure_object(kw.pop(field), into=kw, prefix=field) destructure_object(kw.pop(field), into=kw, prefix=field)
return func(*args, **kw) return func(*args, **kw)
wrapper.__doc__ = "{0}\nObjects: {1}".format(func.__doc__, wrapper.__doc__ = "{0}\nObjects: {1}".format(func.__doc__,
', '.join(fields)) ', '.join(fields))
return add_attrs_from(func, to=wrapper) return add_attrs_from(func, to=wrapper)
return decorator return decorator
@@ -137,7 +138,7 @@ def requires(*groups):
return func(*args, **kw) return func(*args, **kw)
message = ' OR '.join(['+'.join(g) for g in groups]) message = ' OR '.join(['+'.join(g) for g in groups])
wrapper.__doc__ = "{0}\nRequired: {1}".format(func.__doc__, wrapper.__doc__ = "{0}\nRequired: {1}".format(func.__doc__,
message) message)
return add_attrs_from(func, to=wrapper) return add_attrs_from(func, to=wrapper)
return decorator return decorator
@@ -156,7 +157,7 @@ def exclusive(*groups):
return func(*args, **kw) return func(*args, **kw)
message = ' OR '.join(['+'.join(g) for g in groups]) message = ' OR '.join(['+'.join(g) for g in groups])
wrapper.__doc__ = "{0}\nEither: {1}".format(func.__doc__, wrapper.__doc__ = "{0}\nEither: {1}".format(func.__doc__,
message) message)
return add_attrs_from(func, to=wrapper) return add_attrs_from(func, to=wrapper)
return decorator return decorator
@@ -175,8 +176,8 @@ def dependent(field, *groups):
return func(*args, **kw) return func(*args, **kw)
message = ' OR '.join(['+'.join(g) for g in groups]) message = ' OR '.join(['+'.join(g) for g in groups])
wrapper.__doc__ = "{0}\n{1} requires: {2}".format(func.__doc__, wrapper.__doc__ = "{0}\n{1} requires: {2}".format(func.__doc__,
field, field,
message) message)
return add_attrs_from(func, to=wrapper) return add_attrs_from(func, to=wrapper)
return decorator return decorator
@@ -192,7 +193,7 @@ def requires_some_of(*fields):
raise KeyError(message) raise KeyError(message)
return func(*args, **kw) return func(*args, **kw)
wrapper.__doc__ = "{0}\nSome Required: {1}".format(func.__doc__, wrapper.__doc__ = "{0}\nSome Required: {1}".format(func.__doc__,
', '.join(fields)) ', '.join(fields))
return add_attrs_from(func, to=wrapper) return add_attrs_from(func, to=wrapper)
return decorator return decorator
@@ -206,7 +207,7 @@ def boolean_arguments(*fields):
kw[field] = str(kw[field]).lower() kw[field] = str(kw[field]).lower()
return func(*args, **kw) return func(*args, **kw)
wrapper.__doc__ = "{0}\nBooleans: {1}".format(func.__doc__, wrapper.__doc__ = "{0}\nBooleans: {1}".format(func.__doc__,
', '.join(fields)) ', '.join(fields))
return add_attrs_from(func, to=wrapper) return add_attrs_from(func, to=wrapper)
return decorator return decorator
@@ -237,6 +238,7 @@ def api_action(section, quota, restore, *api):
wrapper.__doc__ = "MWS {0}/{1} API call; quota={2} restore={3:.2f}\n" \ wrapper.__doc__ = "MWS {0}/{1} API call; quota={2} restore={3:.2f}\n" \
"{4}".format(action, version, quota, restore, "{4}".format(action, version, quota, restore,
func.__doc__) func.__doc__)
api_call_map[action] = func.func_name
return wrapper return wrapper
return decorator return decorator
@@ -260,7 +262,8 @@ class MWSConnection(AWSQueryConnection):
Modelled off of the inherited get_object/make_request flow. Modelled off of the inherited get_object/make_request flow.
""" """
request = self.build_base_http_request('POST', path, None, data=body, request = self.build_base_http_request('POST', path, None, data=body,
params=params, headers=headers, host=self.server_name()) params=params, headers=headers,
host=self.host)
response = self._mexe(request, override_num_retries=None) response = self._mexe(request, override_num_retries=None)
body = response.read() body = response.read()
boto.log.debug(body) boto.log.debug(body)
@@ -275,6 +278,9 @@ class MWSConnection(AWSQueryConnection):
digest = response.getheader('Content-MD5') digest = response.getheader('Content-MD5')
assert content_md5(body) == digest assert content_md5(body) == digest
return body return body
return self._parse_response(cls, body)
def _parse_response(self, cls, body):
obj = cls(self) obj = cls(self)
h = XmlHandler(obj, self) h = XmlHandler(obj, self)
xml.sax.parseString(body, h) xml.sax.parseString(body, h)
@@ -285,13 +291,10 @@ class MWSConnection(AWSQueryConnection):
The named method can be in CamelCase or underlined_lower_case. The named method can be in CamelCase or underlined_lower_case.
This is the complement to MWSConnection.any_call.action This is the complement to MWSConnection.any_call.action
""" """
# this looks ridiculous but it should be better than regex
action = '_' in name and string.capwords(name, '_') or name action = '_' in name and string.capwords(name, '_') or name
attribs = [getattr(self, m) for m in dir(self)] if action in api_call_map:
ismethod = lambda m: type(m) is type(self.method_for) return getattr(self, api_call_map[action])
ismatch = lambda m: getattr(m, 'action', None) == action return None
method = filter(ismatch, filter(ismethod, attribs))
return method and method[0] or None
def iter_call(self, call, *args, **kw): def iter_call(self, call, *args, **kw):
"""Pass a call name as the first argument and a generator """Pass a call name as the first argument and a generator
@@ -322,7 +325,7 @@ class MWSConnection(AWSQueryConnection):
"""Uploads a feed for processing by Amazon MWS. """Uploads a feed for processing by Amazon MWS.
""" """
return self.post_request(path, kw, response, body=body, return self.post_request(path, kw, response, body=body,
headers=headers) headers=headers)
@structured_lists('FeedSubmissionIdList.Id', 'FeedTypeList.Type', @structured_lists('FeedSubmissionIdList.Id', 'FeedTypeList.Type',
'FeedProcessingStatusList.Status') 'FeedProcessingStatusList.Status')
@@ -365,10 +368,10 @@ class MWSConnection(AWSQueryConnection):
def get_service_status(self, **kw): def get_service_status(self, **kw):
"""Instruct the user on how to get service status. """Instruct the user on how to get service status.
""" """
sections = ', '.join(map(str.lower, api_version_path.keys()))
message = "Use {0}.get_(section)_service_status(), " \ message = "Use {0}.get_(section)_service_status(), " \
"where (section) is one of the following: " \ "where (section) is one of the following: " \
"{1}".format(self.__class__.__name__, "{1}".format(self.__class__.__name__, sections)
', '.join(map(str.lower, api_version_path.keys())))
raise AttributeError(message) raise AttributeError(message)
@structured_lists('MarketplaceIdList.Id') @structured_lists('MarketplaceIdList.Id')
@@ -583,6 +586,14 @@ class MWSConnection(AWSQueryConnection):
""" """
return self.post_request(path, kw, response) return self.post_request(path, kw, response)
@requires(['PackageNumber'])
@api_action('Outbound', 30, 0.5)
def get_package_tracking_details(self, path, response, **kw):
"""Returns delivery tracking information for a package in
an outbound shipment for a Multi-Channel Fulfillment order.
"""
return self.post_request(path, kw, response)
@structured_objects('Address', 'Items') @structured_objects('Address', 'Items')
@requires(['Address', 'Items']) @requires(['Address', 'Items'])
@api_action('Outbound', 30, 0.5) @api_action('Outbound', 30, 0.5)
@@ -659,8 +670,8 @@ class MWSConnection(AWSQueryConnection):
frame that you specify. frame that you specify.
""" """
toggle = set(('FulfillmentChannel.Channel.1', toggle = set(('FulfillmentChannel.Channel.1',
'OrderStatus.Status.1', 'PaymentMethod.1', 'OrderStatus.Status.1', 'PaymentMethod.1',
'LastUpdatedAfter', 'LastUpdatedBefore')) 'LastUpdatedAfter', 'LastUpdatedBefore'))
for do, dont in { for do, dont in {
'BuyerEmail': toggle.union(['SellerOrderId']), 'BuyerEmail': toggle.union(['SellerOrderId']),
'SellerOrderId': toggle.union(['BuyerEmail']), 'SellerOrderId': toggle.union(['BuyerEmail']),
@@ -804,7 +815,7 @@ class MWSConnection(AWSQueryConnection):
@requires(['NextToken']) @requires(['NextToken'])
@api_action('Sellers', 15, 60) @api_action('Sellers', 15, 60)
def list_marketplace_participations_by_next_token(self, path, response, def list_marketplace_participations_by_next_token(self, path, response,
**kw): **kw):
"""Returns the next page of marketplaces and participations """Returns the next page of marketplaces and participations
using the NextToken value that was returned by your using the NextToken value that was returned by your
previous request to either ListMarketplaceParticipations previous request to either ListMarketplaceParticipations

View File

@@ -33,20 +33,30 @@ class ComplexType(dict):
class DeclarativeType(object): class DeclarativeType(object):
def __init__(self, _hint=None, **kw): def __init__(self, _hint=None, **kw):
self._value = None
if _hint is not None: if _hint is not None:
self._hint = _hint self._hint = _hint
else: return
class JITResponse(ResponseElement):
pass class JITResponse(ResponseElement):
self._hint = JITResponse pass
for name, value in kw.items(): self._hint = JITResponse
setattr(self._hint, name, value) self._hint.__name__ = 'JIT_{0}/{1}'.format(self.__class__.__name__,
self._value = None hex(id(self._hint))[2:])
for name, value in kw.items():
setattr(self._hint, name, value)
def __repr__(self):
parent = getattr(self, '_parent', None)
return '<{0}_{1}/{2}_{3}>'.format(self.__class__.__name__,
parent and parent._name or '?',
getattr(self, '_name', '?'),
hex(id(self.__class__)))
def setup(self, parent, name, *args, **kw): def setup(self, parent, name, *args, **kw):
self._parent = parent self._parent = parent
self._name = name self._name = name
self._clone = self.__class__(self._hint) self._clone = self.__class__(_hint=self._hint)
self._clone._parent = parent self._clone._parent = parent
self._clone._name = name self._clone._name = name
setattr(self._parent, self._name, self._clone) setattr(self._parent, self._name, self._clone)
@@ -58,10 +68,7 @@ class DeclarativeType(object):
raise NotImplemented raise NotImplemented
def teardown(self, *args, **kw): def teardown(self, *args, **kw):
if self._value is None: setattr(self._parent, self._name, self._value)
delattr(self._parent, self._name)
else:
setattr(self._parent, self._name, self._value)
class Element(DeclarativeType): class Element(DeclarativeType):
@@ -78,11 +85,6 @@ class SimpleList(DeclarativeType):
DeclarativeType.__init__(self, *args, **kw) DeclarativeType.__init__(self, *args, **kw)
self._value = [] self._value = []
def teardown(self, *args, **kw):
if self._value == []:
self._value = None
DeclarativeType.teardown(self, *args, **kw)
def start(self, *args, **kw): def start(self, *args, **kw):
return None return None
@@ -93,35 +95,46 @@ class SimpleList(DeclarativeType):
class ElementList(SimpleList): class ElementList(SimpleList):
def start(self, *args, **kw): def start(self, *args, **kw):
value = self._hint(parent=self._parent, **kw) value = self._hint(parent=self._parent, **kw)
self._value += [value] self._value.append(value)
return self._value[-1] return value
def end(self, *args, **kw): def end(self, *args, **kw):
pass pass
class MemberList(ElementList): class MemberList(Element):
def __init__(self, *args, **kw): def __init__(self, _member=None, _hint=None, *args, **kw):
self._this = kw.get('this') message = 'Invalid `member` specification in {0}'.format(self.__class__.__name__)
ElementList.__init__(self, *args, **kw) assert 'member' not in kw, message
if _member is None:
def start(self, attrs={}, **kw): if _hint is None:
Class = self._this or self._parent._type_for(self._name, attrs) Element.__init__(self, *args, member=ElementList(**kw))
if issubclass(self._hint, ResponseElement): else:
ListClass = ElementList Element.__init__(self, _hint=_hint)
else: else:
ListClass = SimpleList if _hint is None:
setattr(Class, Class._member, ListClass(self._hint)) if issubclass(_member, DeclarativeType):
self._value = Class(attrs=attrs, parent=self._parent, **kw) member = _member(**kw)
return self._value else:
member = ElementList(_member, **kw)
Element.__init__(self, *args, member=member)
else:
message = 'Nonsensical {0} hint {1!r}'.format(self.__class__.__name__,
_hint)
raise AssertionError(message)
def end(self, *args, **kw): def teardown(self, *args, **kw):
self._value = getattr(self._value, self._value._member) if self._value is None:
ElementList.end(self, *args, **kw) self._value = []
else:
if isinstance(self._value.member, DeclarativeType):
self._value.member = []
self._value = self._value.member
Element.teardown(self, *args, **kw)
def ResponseFactory(action): def ResponseFactory(action, force=None):
result = globals().get(action + 'Result', ResponseElement) result = force or globals().get(action + 'Result', ResponseElement)
class MWSResponse(Response): class MWSResponse(Response):
_name = action + 'Response' _name = action + 'Response'
@@ -141,18 +154,17 @@ def strip_namespace(func):
class ResponseElement(dict): class ResponseElement(dict):
_override = {} _override = {}
_member = 'member'
_name = None _name = None
_namespace = None _namespace = None
def __init__(self, connection=None, name=None, parent=None, attrs={}): def __init__(self, connection=None, name=None, parent=None, attrs=None):
if parent is not None and self._namespace is None: if parent is not None and self._namespace is None:
self._namespace = parent._namespace self._namespace = parent._namespace
if connection is not None: if connection is not None:
self._connection = connection self._connection = connection
self._name = name or self._name or self.__class__.__name__ self._name = name or self._name or self.__class__.__name__
self._declared('setup', attrs=attrs) self._declared('setup', attrs=attrs)
dict.__init__(self, attrs.copy()) dict.__init__(self, attrs and attrs.copy() or {})
def _declared(self, op, **kw): def _declared(self, op, **kw):
def inherit(obj): def inherit(obj):
@@ -177,7 +189,7 @@ class ResponseElement(dict):
do_show = lambda pair: not pair[0].startswith('_') do_show = lambda pair: not pair[0].startswith('_')
attrs = filter(do_show, self.__dict__.items()) attrs = filter(do_show, self.__dict__.items())
name = self.__class__.__name__ name = self.__class__.__name__
if name == 'JITResponse': if name.startswith('JIT_'):
name = '^{0}^'.format(self._name or '') name = '^{0}^'.format(self._name or '')
elif name == 'MWSResponse': elif name == 'MWSResponse':
name = '^{0}^'.format(self._name or name) name = '^{0}^'.format(self._name or name)
@@ -192,7 +204,7 @@ class ResponseElement(dict):
attribute = getattr(self, name, None) attribute = getattr(self, name, None)
if isinstance(attribute, DeclarativeType): if isinstance(attribute, DeclarativeType):
return attribute.start(name=name, attrs=attrs, return attribute.start(name=name, attrs=attrs,
connection=connection) connection=connection)
elif attrs.getLength(): elif attrs.getLength():
setattr(self, name, ComplexType(attrs.copy())) setattr(self, name, ComplexType(attrs.copy()))
else: else:
@@ -316,7 +328,7 @@ class CreateInboundShipmentPlanResult(ResponseElement):
class ListInboundShipmentsResult(ResponseElement): class ListInboundShipmentsResult(ResponseElement):
ShipmentData = MemberList(Element(ShipFromAddress=Element())) ShipmentData = MemberList(ShipFromAddress=Element())
class ListInboundShipmentsByNextTokenResult(ListInboundShipmentsResult): class ListInboundShipmentsByNextTokenResult(ListInboundShipmentsResult):
@@ -334,8 +346,8 @@ class ListInboundShipmentItemsByNextTokenResult(ListInboundShipmentItemsResult):
class ListInventorySupplyResult(ResponseElement): class ListInventorySupplyResult(ResponseElement):
InventorySupplyList = MemberList( InventorySupplyList = MemberList(
EarliestAvailability=Element(), EarliestAvailability=Element(),
SupplyDetail=MemberList(\ SupplyDetail=MemberList(
EarliestAvailabileToPick=Element(), EarliestAvailableToPick=Element(),
LatestAvailableToPick=Element(), LatestAvailableToPick=Element(),
) )
) )
@@ -431,13 +443,9 @@ class FulfillmentPreviewItem(ResponseElement):
class FulfillmentPreview(ResponseElement): class FulfillmentPreview(ResponseElement):
EstimatedShippingWeight = Element(ComplexWeight) EstimatedShippingWeight = Element(ComplexWeight)
EstimatedFees = MemberList(\ EstimatedFees = MemberList(Amount=Element(ComplexAmount))
Element(\
Amount=Element(ComplexAmount),
),
)
UnfulfillablePreviewItems = MemberList(FulfillmentPreviewItem) UnfulfillablePreviewItems = MemberList(FulfillmentPreviewItem)
FulfillmentPreviewShipments = MemberList(\ FulfillmentPreviewShipments = MemberList(
FulfillmentPreviewItems=MemberList(FulfillmentPreviewItem), FulfillmentPreviewItems=MemberList(FulfillmentPreviewItem),
) )
@@ -448,15 +456,14 @@ class GetFulfillmentPreviewResult(ResponseElement):
class FulfillmentOrder(ResponseElement): class FulfillmentOrder(ResponseElement):
DestinationAddress = Element() DestinationAddress = Element()
NotificationEmailList = MemberList(str) NotificationEmailList = MemberList(SimpleList)
class GetFulfillmentOrderResult(ResponseElement): class GetFulfillmentOrderResult(ResponseElement):
FulfillmentOrder = Element(FulfillmentOrder) FulfillmentOrder = Element(FulfillmentOrder)
FulfillmentShipment = MemberList(Element(\ FulfillmentShipment = MemberList(
FulfillmentShipmentItem=MemberList(), FulfillmentShipmentItem=MemberList(),
FulfillmentShipmentPackage=MemberList(), FulfillmentShipmentPackage=MemberList(),
)
) )
FulfillmentOrderItem = MemberList() FulfillmentOrderItem = MemberList()
@@ -469,6 +476,11 @@ class ListAllFulfillmentOrdersByNextTokenResult(ListAllFulfillmentOrdersResult):
pass pass
class GetPackageTrackingDetailsResult(ResponseElement):
ShipToAddress = Element()
TrackingEvents = MemberList(EventAddress=Element())
class Image(ResponseElement): class Image(ResponseElement):
pass pass
@@ -533,17 +545,17 @@ class Product(ResponseElement):
_namespace = 'ns2' _namespace = 'ns2'
Identifiers = Element(MarketplaceASIN=Element(), Identifiers = Element(MarketplaceASIN=Element(),
SKUIdentifier=Element()) SKUIdentifier=Element())
AttributeSets = Element(\ AttributeSets = Element(
ItemAttributes=ElementList(ItemAttributes), ItemAttributes=ElementList(ItemAttributes),
) )
Relationships = Element(\ Relationships = Element(
VariationParent=ElementList(VariationRelationship), VariationParent=ElementList(VariationRelationship),
) )
CompetitivePricing = ElementList(CompetitivePricing) CompetitivePricing = ElementList(CompetitivePricing)
SalesRankings = Element(\ SalesRankings = Element(
SalesRank=ElementList(SalesRank), SalesRank=ElementList(SalesRank),
) )
LowestOfferListings = Element(\ LowestOfferListings = Element(
LowestOfferListing=ElementList(LowestOfferListing), LowestOfferListing=ElementList(LowestOfferListing),
) )
@@ -569,6 +581,10 @@ class GetMatchingProductForIdResult(ListMatchingProductsResult):
pass pass
class GetMatchingProductForIdResponse(ResponseResultList):
_ResultClass = GetMatchingProductForIdResult
class GetCompetitivePricingForSKUResponse(ProductsBulkOperationResponse): class GetCompetitivePricingForSKUResponse(ProductsBulkOperationResponse):
pass pass
@@ -607,9 +623,9 @@ class GetProductCategoriesForASINResult(GetProductCategoriesResult):
class Order(ResponseElement): class Order(ResponseElement):
OrderTotal = Element(ComplexMoney) OrderTotal = Element(ComplexMoney)
ShippingAddress = Element() ShippingAddress = Element()
PaymentExecutionDetail = Element(\ PaymentExecutionDetail = Element(
PaymentExecutionDetailItem=ElementList(\ PaymentExecutionDetailItem=ElementList(
PaymentExecutionDetailItem=Element(\ PaymentExecutionDetailItem=Element(
Payment=Element(ComplexMoney) Payment=Element(ComplexMoney)
) )
) )

View File

@@ -80,11 +80,51 @@ class OpsWorksConnection(AWSQueryConnection):
def _required_auth_capability(self): def _required_auth_capability(self):
return ['hmac-v4'] return ['hmac-v4']
def assign_volume(self, volume_id, instance_id=None):
"""
Assigns one of the stack's registered Amazon EBS volumes to a
specified instance. The volume must first be registered with
the stack by calling RegisterVolume. For more information, see
``_.
:type volume_id: string
:param volume_id: The volume ID.
:type instance_id: string
:param instance_id: The instance ID.
"""
params = {'VolumeId': volume_id, }
if instance_id is not None:
params['InstanceId'] = instance_id
return self.make_request(action='AssignVolume',
body=json.dumps(params))
def associate_elastic_ip(self, elastic_ip, instance_id=None):
"""
Associates one of the stack's registered Elastic IP addresses
with a specified instance. The address must first be
registered with the stack by calling RegisterElasticIp. For
more information, see ``_.
:type elastic_ip: string
:param elastic_ip: The Elastic IP address.
:type instance_id: string
:param instance_id: The instance ID.
"""
params = {'ElasticIp': elastic_ip, }
if instance_id is not None:
params['InstanceId'] = instance_id
return self.make_request(action='AssociateElasticIp',
body=json.dumps(params))
def attach_elastic_load_balancer(self, elastic_load_balancer_name, def attach_elastic_load_balancer(self, elastic_load_balancer_name,
layer_id): layer_id):
""" """
Attaches an Elastic Load Balancing instance to a specified Attaches an Elastic Load Balancing load balancer to a
layer. specified layer.
You must create the Elastic Load Balancing instance You must create the Elastic Load Balancing instance
separately, by using the Elastic Load Balancing console, API, separately, by using the Elastic Load Balancing console, API,
@@ -136,8 +176,8 @@ class OpsWorksConnection(AWSQueryConnection):
will be launched into this VPC, and you cannot change the ID later. will be launched into this VPC, and you cannot change the ID later.
+ If your account supports EC2 Classic, the default value is no VPC. + If your account supports EC2 Classic, the default value is no VPC.
+ If you account does not support EC2 Classic, the default value is the + If your account does not support EC2 Classic, the default value is
default VPC for the specified region. the default VPC for the specified region.
If the VPC ID corresponds to a default VPC and you have specified If the VPC ID corresponds to a default VPC and you have specified
@@ -559,7 +599,8 @@ class OpsWorksConnection(AWSQueryConnection):
custom_instance_profile_arn=None, custom_instance_profile_arn=None,
custom_security_group_ids=None, packages=None, custom_security_group_ids=None, packages=None,
volume_configurations=None, enable_auto_healing=None, volume_configurations=None, enable_auto_healing=None,
auto_assign_elastic_ips=None, custom_recipes=None, auto_assign_elastic_ips=None,
auto_assign_public_ips=None, custom_recipes=None,
install_updates_on_boot=None): install_updates_on_boot=None):
""" """
Creates a layer. For more information, see `How to Create a Creates a layer. For more information, see `How to Create a
@@ -629,7 +670,13 @@ class OpsWorksConnection(AWSQueryConnection):
:type auto_assign_elastic_ips: boolean :type auto_assign_elastic_ips: boolean
:param auto_assign_elastic_ips: Whether to automatically assign an :param auto_assign_elastic_ips: Whether to automatically assign an
`Elastic IP address`_ to the layer. `Elastic IP address`_ to the layer's instances. For more
information, see `How to Edit a Layer`_.
:type auto_assign_public_ips: boolean
:param auto_assign_public_ips: For stacks that are running in a VPC,
whether to automatically assign a public IP address to the layer's
instances. For more information, see `How to Edit a Layer`_.
:type custom_recipes: dict :type custom_recipes: dict
:param custom_recipes: A `LayerCustomRecipes` object that specifies the :param custom_recipes: A `LayerCustomRecipes` object that specifies the
@@ -668,6 +715,8 @@ class OpsWorksConnection(AWSQueryConnection):
params['EnableAutoHealing'] = enable_auto_healing params['EnableAutoHealing'] = enable_auto_healing
if auto_assign_elastic_ips is not None: if auto_assign_elastic_ips is not None:
params['AutoAssignElasticIps'] = auto_assign_elastic_ips params['AutoAssignElasticIps'] = auto_assign_elastic_ips
if auto_assign_public_ips is not None:
params['AutoAssignPublicIps'] = auto_assign_public_ips
if custom_recipes is not None: if custom_recipes is not None:
params['CustomRecipes'] = custom_recipes params['CustomRecipes'] = custom_recipes
if install_updates_on_boot is not None: if install_updates_on_boot is not None:
@@ -700,8 +749,8 @@ class OpsWorksConnection(AWSQueryConnection):
into this VPC, and you cannot change the ID later. into this VPC, and you cannot change the ID later.
+ If your account supports EC2 Classic, the default value is no VPC. + If your account supports EC2 Classic, the default value is no VPC.
+ If you account does not support EC2 Classic, the default value is the + If your account does not support EC2 Classic, the default value is
default VPC for the specified region. the default VPC for the specified region.
If the VPC ID corresponds to a default VPC and you have specified If the VPC ID corresponds to a default VPC and you have specified
@@ -954,6 +1003,33 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='DeleteUserProfile', return self.make_request(action='DeleteUserProfile',
body=json.dumps(params)) body=json.dumps(params))
def deregister_elastic_ip(self, elastic_ip):
"""
Deregisters a specified Elastic IP address. The address can
then be registered by another stack. For more information, see
``_.
:type elastic_ip: string
:param elastic_ip: The Elastic IP address.
"""
params = {'ElasticIp': elastic_ip, }
return self.make_request(action='DeregisterElasticIp',
body=json.dumps(params))
def deregister_volume(self, volume_id):
"""
Deregisters an Amazon EBS volume. The volume can then be
registered by another stack. For more information, see ``_.
:type volume_id: string
:param volume_id: The volume ID.
"""
params = {'VolumeId': volume_id, }
return self.make_request(action='DeregisterVolume',
body=json.dumps(params))
def describe_apps(self, stack_id=None, app_ids=None): def describe_apps(self, stack_id=None, app_ids=None):
""" """
Requests a description of a specified set of apps. Requests a description of a specified set of apps.
@@ -1047,7 +1123,7 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='DescribeDeployments', return self.make_request(action='DescribeDeployments',
body=json.dumps(params)) body=json.dumps(params))
def describe_elastic_ips(self, instance_id=None, ips=None): def describe_elastic_ips(self, instance_id=None, stack_id=None, ips=None):
""" """
Describes `Elastic IP addresses`_. Describes `Elastic IP addresses`_.
@@ -1058,6 +1134,11 @@ class OpsWorksConnection(AWSQueryConnection):
`DescribeElasticIps` returns a description of the Elastic IP `DescribeElasticIps` returns a description of the Elastic IP
addresses associated with the specified instance. addresses associated with the specified instance.
:type stack_id: string
:param stack_id: A stack ID. If you include this parameter,
`DescribeElasticIps` returns a description of the Elastic IP
addresses that are registered with the specified stack.
:type ips: list :type ips: list
:param ips: An array of Elastic IP addresses to be described. If you :param ips: An array of Elastic IP addresses to be described. If you
include this parameter, `DescribeElasticIps` returns a description include this parameter, `DescribeElasticIps` returns a description
@@ -1068,6 +1149,8 @@ class OpsWorksConnection(AWSQueryConnection):
params = {} params = {}
if instance_id is not None: if instance_id is not None:
params['InstanceId'] = instance_id params['InstanceId'] = instance_id
if stack_id is not None:
params['StackId'] = stack_id
if ips is not None: if ips is not None:
params['Ips'] = ips params['Ips'] = ips
return self.make_request(action='DescribeElasticIps', return self.make_request(action='DescribeElasticIps',
@@ -1080,8 +1163,8 @@ class OpsWorksConnection(AWSQueryConnection):
You must specify at least one of the parameters. You must specify at least one of the parameters.
:type stack_id: string :type stack_id: string
:param stack_id: A stack ID. The action describes the Elastic Load :param stack_id: A stack ID. The action describes the stack's Elastic
Balancing instances for the stack. Load Balancing instances.
:type layer_ids: list :type layer_ids: list
:param layer_ids: A list of layer IDs. The action describes the Elastic :param layer_ids: A list of layer IDs. The action describes the Elastic
@@ -1130,7 +1213,7 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='DescribeInstances', return self.make_request(action='DescribeInstances',
body=json.dumps(params)) body=json.dumps(params))
def describe_layers(self, stack_id, layer_ids=None): def describe_layers(self, stack_id=None, layer_ids=None):
""" """
Requests a description of one or more layers in a specified Requests a description of one or more layers in a specified
stack. stack.
@@ -1146,7 +1229,9 @@ class OpsWorksConnection(AWSQueryConnection):
description of every layer in the specified stack. description of every layer in the specified stack.
""" """
params = {'StackId': stack_id, } params = {}
if stack_id is not None:
params['StackId'] = stack_id
if layer_ids is not None: if layer_ids is not None:
params['LayerIds'] = layer_ids params['LayerIds'] = layer_ids
return self.make_request(action='DescribeLayers', return self.make_request(action='DescribeLayers',
@@ -1285,8 +1370,8 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='DescribeUserProfiles', return self.make_request(action='DescribeUserProfiles',
body=json.dumps(params)) body=json.dumps(params))
def describe_volumes(self, instance_id=None, raid_array_id=None, def describe_volumes(self, instance_id=None, stack_id=None,
volume_ids=None): raid_array_id=None, volume_ids=None):
""" """
Describes an instance's Amazon EBS volumes. Describes an instance's Amazon EBS volumes.
@@ -1297,6 +1382,10 @@ class OpsWorksConnection(AWSQueryConnection):
`DescribeVolumes` returns descriptions of the volumes associated `DescribeVolumes` returns descriptions of the volumes associated
with the specified instance. with the specified instance.
:type stack_id: string
:param stack_id: A stack ID. The action describes the stack's
registered Amazon EBS volumes.
:type raid_array_id: string :type raid_array_id: string
:param raid_array_id: The RAID array ID. If you use this parameter, :param raid_array_id: The RAID array ID. If you use this parameter,
`DescribeVolumes` returns descriptions of the volumes associated `DescribeVolumes` returns descriptions of the volumes associated
@@ -1311,6 +1400,8 @@ class OpsWorksConnection(AWSQueryConnection):
params = {} params = {}
if instance_id is not None: if instance_id is not None:
params['InstanceId'] = instance_id params['InstanceId'] = instance_id
if stack_id is not None:
params['StackId'] = stack_id
if raid_array_id is not None: if raid_array_id is not None:
params['RaidArrayId'] = raid_array_id params['RaidArrayId'] = raid_array_id
if volume_ids is not None: if volume_ids is not None:
@@ -1321,7 +1412,7 @@ class OpsWorksConnection(AWSQueryConnection):
def detach_elastic_load_balancer(self, elastic_load_balancer_name, def detach_elastic_load_balancer(self, elastic_load_balancer_name,
layer_id): layer_id):
""" """
Detaches a specified Elastic Load Balancing instance from it's Detaches a specified Elastic Load Balancing instance from its
layer. layer.
:type elastic_load_balancer_name: string :type elastic_load_balancer_name: string
@@ -1340,6 +1431,20 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='DetachElasticLoadBalancer', return self.make_request(action='DetachElasticLoadBalancer',
body=json.dumps(params)) body=json.dumps(params))
def disassociate_elastic_ip(self, elastic_ip):
"""
Disassociates an Elastic IP address from its instance. The
address remains registered with the stack. For more
information, see ``_.
:type elastic_ip: string
:param elastic_ip: The Elastic IP address.
"""
params = {'ElasticIp': elastic_ip, }
return self.make_request(action='DisassociateElasticIp',
body=json.dumps(params))
def get_hostname_suggestion(self, layer_id): def get_hostname_suggestion(self, layer_id):
""" """
Gets a generated host name for the specified layer, based on Gets a generated host name for the specified layer, based on
@@ -1366,6 +1471,45 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='RebootInstance', return self.make_request(action='RebootInstance',
body=json.dumps(params)) body=json.dumps(params))
def register_elastic_ip(self, elastic_ip, stack_id):
"""
Registers an Elastic IP address with a specified stack. An
address can be registered with only one stack at a time. If
the address is already registered, you must first deregister
it by calling DeregisterElasticIp. For more information, see
``_.
:type elastic_ip: string
:param elastic_ip: The Elastic IP address.
:type stack_id: string
:param stack_id: The stack ID.
"""
params = {'ElasticIp': elastic_ip, 'StackId': stack_id, }
return self.make_request(action='RegisterElasticIp',
body=json.dumps(params))
def register_volume(self, stack_id, ec_2_volume_id=None):
"""
Registers an Amazon EBS volume with a specified stack. A
volume can be registered with only one stack at a time. If the
volume is already registered, you must first deregister it by
calling DeregisterVolume. For more information, see ``_.
:type ec_2_volume_id: string
:param ec_2_volume_id: The Amazon EBS volume ID.
:type stack_id: string
:param stack_id: The stack ID.
"""
params = {'StackId': stack_id, }
if ec_2_volume_id is not None:
params['Ec2VolumeId'] = ec_2_volume_id
return self.make_request(action='RegisterVolume',
body=json.dumps(params))
def set_load_based_auto_scaling(self, layer_id, enable=None, def set_load_based_auto_scaling(self, layer_id, enable=None,
up_scaling=None, down_scaling=None): up_scaling=None, down_scaling=None):
""" """
@@ -1511,6 +1655,19 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='StopStack', return self.make_request(action='StopStack',
body=json.dumps(params)) body=json.dumps(params))
def unassign_volume(self, volume_id):
"""
Unassigns an assigned Amazon EBS volume. The volume remains
registered with the stack. For more information, see ``_.
:type volume_id: string
:param volume_id: The volume ID.
"""
params = {'VolumeId': volume_id, }
return self.make_request(action='UnassignVolume',
body=json.dumps(params))
def update_app(self, app_id, name=None, description=None, type=None, def update_app(self, app_id, name=None, description=None, type=None,
app_source=None, domains=None, enable_ssl=None, app_source=None, domains=None, enable_ssl=None,
ssl_configuration=None, attributes=None): ssl_configuration=None, attributes=None):
@@ -1568,6 +1725,24 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='UpdateApp', return self.make_request(action='UpdateApp',
body=json.dumps(params)) body=json.dumps(params))
def update_elastic_ip(self, elastic_ip, name=None):
"""
Updates a registered Elastic IP address's name. For more
information, see ``_.
:type elastic_ip: string
:param elastic_ip: The address.
:type name: string
:param name: The new name.
"""
params = {'ElasticIp': elastic_ip, }
if name is not None:
params['Name'] = name
return self.make_request(action='UpdateElasticIp',
body=json.dumps(params))
def update_instance(self, instance_id, layer_ids=None, def update_instance(self, instance_id, layer_ids=None,
instance_type=None, auto_scaling_type=None, instance_type=None, auto_scaling_type=None,
hostname=None, os=None, ami_id=None, hostname=None, os=None, ami_id=None,
@@ -1673,7 +1848,8 @@ class OpsWorksConnection(AWSQueryConnection):
attributes=None, custom_instance_profile_arn=None, attributes=None, custom_instance_profile_arn=None,
custom_security_group_ids=None, packages=None, custom_security_group_ids=None, packages=None,
volume_configurations=None, enable_auto_healing=None, volume_configurations=None, enable_auto_healing=None,
auto_assign_elastic_ips=None, custom_recipes=None, auto_assign_elastic_ips=None,
auto_assign_public_ips=None, custom_recipes=None,
install_updates_on_boot=None): install_updates_on_boot=None):
""" """
Updates a specified layer. Updates a specified layer.
@@ -1718,7 +1894,13 @@ class OpsWorksConnection(AWSQueryConnection):
:type auto_assign_elastic_ips: boolean :type auto_assign_elastic_ips: boolean
:param auto_assign_elastic_ips: Whether to automatically assign an :param auto_assign_elastic_ips: Whether to automatically assign an
`Elastic IP address`_ to the layer. `Elastic IP address`_ to the layer's instances. For more
information, see `How to Edit a Layer`_.
:type auto_assign_public_ips: boolean
:param auto_assign_public_ips: For stacks that are running in a VPC,
whether to automatically assign a public IP address to the layer's
instances. For more information, see `How to Edit a Layer`_.
:type custom_recipes: dict :type custom_recipes: dict
:param custom_recipes: A `LayerCustomRecipes` object that specifies the :param custom_recipes: A `LayerCustomRecipes` object that specifies the
@@ -1756,6 +1938,8 @@ class OpsWorksConnection(AWSQueryConnection):
params['EnableAutoHealing'] = enable_auto_healing params['EnableAutoHealing'] = enable_auto_healing
if auto_assign_elastic_ips is not None: if auto_assign_elastic_ips is not None:
params['AutoAssignElasticIps'] = auto_assign_elastic_ips params['AutoAssignElasticIps'] = auto_assign_elastic_ips
if auto_assign_public_ips is not None:
params['AutoAssignPublicIps'] = auto_assign_public_ips
if custom_recipes is not None: if custom_recipes is not None:
params['CustomRecipes'] = custom_recipes params['CustomRecipes'] = custom_recipes
if install_updates_on_boot is not None: if install_updates_on_boot is not None:
@@ -1934,6 +2118,29 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='UpdateUserProfile', return self.make_request(action='UpdateUserProfile',
body=json.dumps(params)) body=json.dumps(params))
def update_volume(self, volume_id, name=None, mount_point=None):
"""
Updates an Amazon EBS volume's name or mount point. For more
information, see ``_.
:type volume_id: string
:param volume_id: The volume ID.
:type name: string
:param name: The new name.
:type mount_point: string
:param mount_point: The new mount point.
"""
params = {'VolumeId': volume_id, }
if name is not None:
params['Name'] = name
if mount_point is not None:
params['MountPoint'] = mount_point
return self.make_request(action='UpdateVolume',
body=json.dumps(params))
def make_request(self, action, body): def make_request(self, action, body):
headers = { headers = {
'X-Amz-Target': '%s.%s' % (self.TargetPrefix, action), 'X-Amz-Target': '%s.%s' % (self.TargetPrefix, action),

View File

@@ -45,6 +45,12 @@ def regions():
RegionInfo(name='ap-northeast-1', RegionInfo(name='ap-northeast-1',
endpoint='redshift.ap-northeast-1.amazonaws.com', endpoint='redshift.ap-northeast-1.amazonaws.com',
connection_cls=cls), connection_cls=cls),
RegionInfo(name='ap-southeast-1',
endpoint='redshift.ap-southeast-1.amazonaws.com',
connection_cls=cls),
RegionInfo(name='ap-southeast-2',
endpoint='redshift.ap-southeast-2.amazonaws.com',
connection_cls=cls),
] ]

View File

@@ -188,3 +188,272 @@ class AccessToSnapshotDeniedFault(JSONResponseError):
class UnauthorizedOperationFault(JSONResponseError): class UnauthorizedOperationFault(JSONResponseError):
pass pass
class SnapshotCopyAlreadyDisabled(JSONResponseError):
pass
class ClusterNotFound(JSONResponseError):
pass
class UnknownSnapshotCopyRegion(JSONResponseError):
pass
class InvalidClusterSubnetState(JSONResponseError):
pass
class ReservedNodeQuotaExceeded(JSONResponseError):
pass
class InvalidClusterState(JSONResponseError):
pass
class HsmClientCertificateQuotaExceeded(JSONResponseError):
pass
class SubscriptionCategoryNotFound(JSONResponseError):
pass
class HsmClientCertificateNotFound(JSONResponseError):
pass
class SubscriptionEventIdNotFound(JSONResponseError):
pass
class ClusterSecurityGroupAlreadyExists(JSONResponseError):
pass
class HsmConfigurationAlreadyExists(JSONResponseError):
pass
class NumberOfNodesQuotaExceeded(JSONResponseError):
pass
class ReservedNodeOfferingNotFound(JSONResponseError):
pass
class BucketNotFound(JSONResponseError):
pass
class InsufficientClusterCapacity(JSONResponseError):
pass
class InvalidRestore(JSONResponseError):
pass
class UnauthorizedOperation(JSONResponseError):
pass
class ClusterQuotaExceeded(JSONResponseError):
pass
class InvalidVPCNetworkState(JSONResponseError):
pass
class ClusterSnapshotNotFound(JSONResponseError):
pass
class AuthorizationQuotaExceeded(JSONResponseError):
pass
class InvalidHsmClientCertificateState(JSONResponseError):
pass
class SNSTopicArnNotFound(JSONResponseError):
pass
class ResizeNotFound(JSONResponseError):
pass
class ClusterSubnetGroupNotFound(JSONResponseError):
pass
class SNSNoAuthorization(JSONResponseError):
pass
class ClusterSnapshotQuotaExceeded(JSONResponseError):
pass
class AccessToSnapshotDenied(JSONResponseError):
pass
class InvalidClusterSecurityGroupState(JSONResponseError):
pass
class NumberOfNodesPerClusterLimitExceeded(JSONResponseError):
pass
class ClusterSubnetQuotaExceeded(JSONResponseError):
pass
class SNSInvalidTopic(JSONResponseError):
pass
class ClusterSecurityGroupNotFound(JSONResponseError):
pass
class InvalidElasticIp(JSONResponseError):
pass
class InvalidClusterParameterGroupState(JSONResponseError):
pass
class InvalidHsmConfigurationState(JSONResponseError):
pass
class ClusterAlreadyExists(JSONResponseError):
pass
class HsmConfigurationQuotaExceeded(JSONResponseError):
pass
class ClusterSnapshotAlreadyExists(JSONResponseError):
pass
class SubscriptionSeverityNotFound(JSONResponseError):
pass
class SourceNotFound(JSONResponseError):
pass
class ReservedNodeAlreadyExists(JSONResponseError):
pass
class ClusterSubnetGroupQuotaExceeded(JSONResponseError):
pass
class ClusterParameterGroupNotFound(JSONResponseError):
pass
class InvalidS3BucketName(JSONResponseError):
pass
class InvalidS3KeyPrefix(JSONResponseError):
pass
class SubscriptionAlreadyExist(JSONResponseError):
pass
class HsmConfigurationNotFound(JSONResponseError):
pass
class AuthorizationNotFound(JSONResponseError):
pass
class ClusterSecurityGroupQuotaExceeded(JSONResponseError):
pass
class EventSubscriptionQuotaExceeded(JSONResponseError):
pass
class AuthorizationAlreadyExists(JSONResponseError):
pass
class InvalidClusterSnapshotState(JSONResponseError):
pass
class ClusterParameterGroupQuotaExceeded(JSONResponseError):
pass
class SnapshotCopyDisabled(JSONResponseError):
pass
class ClusterSubnetGroupAlreadyExists(JSONResponseError):
pass
class ReservedNodeNotFound(JSONResponseError):
pass
class HsmClientCertificateAlreadyExists(JSONResponseError):
pass
class InvalidClusterSubnetGroupState(JSONResponseError):
pass
class SubscriptionNotFound(JSONResponseError):
pass
class InsufficientS3BucketPolicy(JSONResponseError):
pass
class ClusterParameterGroupAlreadyExists(JSONResponseError):
pass
class UnsupportedOption(JSONResponseError):
pass
class CopyToRegionDisabled(JSONResponseError):
pass
class SnapshotCopyAlreadyEnabled(JSONResponseError):
pass
class IncompatibleOrderableOptions(JSONResponseError):
pass

File diff suppressed because it is too large Load Diff

View File

@@ -18,22 +18,25 @@
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT # ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, # SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE. # IN THE SOFTWARE.
# #
import xml.sax import exception
import uuid import random
import urllib import urllib
import uuid
import xml.sax
import boto import boto
from boto.connection import AWSAuthConnection from boto.connection import AWSAuthConnection
from boto import handler from boto import handler
import boto.jsonresponse
from boto.route53.record import ResourceRecordSets from boto.route53.record import ResourceRecordSets
from boto.route53.zone import Zone from boto.route53.zone import Zone
import boto.jsonresponse
import exception
HZXML = """<?xml version="1.0" encoding="UTF-8"?> HZXML = """<?xml version="1.0" encoding="UTF-8"?>
<CreateHostedZoneRequest xmlns="%(xmlns)s"> <CreateHostedZoneRequest xmlns="%(xmlns)s">
@@ -43,7 +46,7 @@ HZXML = """<?xml version="1.0" encoding="UTF-8"?>
<Comment>%(comment)s</Comment> <Comment>%(comment)s</Comment>
</HostedZoneConfig> </HostedZoneConfig>
</CreateHostedZoneRequest>""" </CreateHostedZoneRequest>"""
#boto.set_stream_logger('dns') #boto.set_stream_logger('dns')
@@ -60,12 +63,13 @@ class Route53Connection(AWSAuthConnection):
def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, def __init__(self, aws_access_key_id=None, aws_secret_access_key=None,
port=None, proxy=None, proxy_port=None, port=None, proxy=None, proxy_port=None,
host=DefaultHost, debug=0, security_token=None, host=DefaultHost, debug=0, security_token=None,
validate_certs=True): validate_certs=True, https_connection_factory=None):
AWSAuthConnection.__init__(self, host, AWSAuthConnection.__init__(self, host,
aws_access_key_id, aws_secret_access_key, aws_access_key_id, aws_secret_access_key,
True, port, proxy, proxy_port, debug=debug, True, port, proxy, proxy_port, debug=debug,
security_token=security_token, security_token=security_token,
validate_certs=validate_certs) validate_certs=validate_certs,
https_connection_factory=https_connection_factory)
def _required_auth_capability(self): def _required_auth_capability(self):
return ['route53'] return ['route53']
@@ -79,7 +83,8 @@ class Route53Connection(AWSAuthConnection):
pairs.append(key + '=' + urllib.quote(str(val))) pairs.append(key + '=' + urllib.quote(str(val)))
path += '?' + '&'.join(pairs) path += '?' + '&'.join(pairs)
return AWSAuthConnection.make_request(self, action, path, return AWSAuthConnection.make_request(self, action, path,
headers, data) headers, data,
retry_handler=self._retry_handler)
# Hosted Zones # Hosted Zones
@@ -118,7 +123,7 @@ class Route53Connection(AWSAuthConnection):
def get_hosted_zone(self, hosted_zone_id): def get_hosted_zone(self, hosted_zone_id):
""" """
Get detailed information about a particular Hosted Zone. Get detailed information about a particular Hosted Zone.
:type hosted_zone_id: str :type hosted_zone_id: str
:param hosted_zone_id: The unique identifier for the Hosted Zone :param hosted_zone_id: The unique identifier for the Hosted Zone
@@ -158,7 +163,7 @@ class Route53Connection(AWSAuthConnection):
""" """
Create a new Hosted Zone. Returns a Python data structure with Create a new Hosted Zone. Returns a Python data structure with
information about the newly created Hosted Zone. information about the newly created Hosted Zone.
:type domain_name: str :type domain_name: str
:param domain_name: The name of the domain. This should be a :param domain_name: The name of the domain. This should be a
fully-specified domain, and should end with a final period fully-specified domain, and should end with a final period
@@ -178,7 +183,7 @@ class Route53Connection(AWSAuthConnection):
use that. use that.
:type comment: str :type comment: str
:param comment: Any comments you want to include about the hosted :param comment: Any comments you want to include about the hosted
zone. zone.
""" """
@@ -204,7 +209,7 @@ class Route53Connection(AWSAuthConnection):
raise exception.DNSServerError(response.status, raise exception.DNSServerError(response.status,
response.reason, response.reason,
body) body)
def delete_hosted_zone(self, hosted_zone_id): def delete_hosted_zone(self, hosted_zone_id):
uri = '/%s/hostedzone/%s' % (self.Version, hosted_zone_id) uri = '/%s/hostedzone/%s' % (self.Version, hosted_zone_id)
response = self.make_request('DELETE', uri) response = self.make_request('DELETE', uri)
@@ -226,7 +231,7 @@ class Route53Connection(AWSAuthConnection):
""" """
Retrieve the Resource Record Sets defined for this Hosted Zone. Retrieve the Resource Record Sets defined for this Hosted Zone.
Returns the raw XML data returned by the Route53 call. Returns the raw XML data returned by the Route53 call.
:type hosted_zone_id: str :type hosted_zone_id: str
:param hosted_zone_id: The unique identifier for the Hosted Zone :param hosted_zone_id: The unique identifier for the Hosted Zone
@@ -401,3 +406,24 @@ class Route53Connection(AWSAuthConnection):
if value and not value[-1] == '.': if value and not value[-1] == '.':
value = "%s." % value value = "%s." % value
return value return value
def _retry_handler(self, response, i, next_sleep):
status = None
boto.log.debug("Saw HTTP status: %s" % response.status)
if response.status == 400:
code = response.getheader('Code')
if code and 'PriorRequestNotComplete' in code:
# This is a case where we need to ignore a 400 error, as
# Route53 returns this. See
# http://docs.aws.amazon.com/Route53/latest/DeveloperGuide/DNSLimitations.html
msg = "%s, retry attempt %s" % (
'PriorRequestNotComplete',
i
)
next_sleep = random.random() * (2 ** i)
i += 1
status = (msg, i, next_sleep)
return status

View File

@@ -63,6 +63,7 @@ class S3WebsiteEndpointTranslate:
trans_region['sa-east-1'] = 's3-website-sa-east-1' trans_region['sa-east-1'] = 's3-website-sa-east-1'
trans_region['ap-northeast-1'] = 's3-website-ap-northeast-1' trans_region['ap-northeast-1'] = 's3-website-ap-northeast-1'
trans_region['ap-southeast-1'] = 's3-website-ap-southeast-1' trans_region['ap-southeast-1'] = 's3-website-ap-southeast-1'
trans_region['ap-southeast-2'] = 's3-website-ap-southeast-2'
@classmethod @classmethod
def translate_region(self, reg): def translate_region(self, reg):
@@ -341,6 +342,11 @@ class Bucket(object):
raise self.connection.provider.storage_response_error( raise self.connection.provider.storage_response_error(
response.status, response.reason, body) response.status, response.reason, body)
def _validate_kwarg_names(self, kwargs, names):
for kwarg in kwargs:
if kwarg not in names:
raise TypeError('Invalid argument %s!' % kwarg)
def get_all_keys(self, headers=None, **params): def get_all_keys(self, headers=None, **params):
""" """
A lower-level method for listing contents of a bucket. This A lower-level method for listing contents of a bucket. This
@@ -370,6 +376,8 @@ class Bucket(object):
:return: The result from S3 listing the keys requested :return: The result from S3 listing the keys requested
""" """
self._validate_kwarg_names(params, ['maxkeys', 'max_keys', 'prefix',
'marker', 'delimiter'])
return self._get_all([('Contents', self.key_class), return self._get_all([('Contents', self.key_class),
('CommonPrefixes', Prefix)], ('CommonPrefixes', Prefix)],
'', headers, **params) '', headers, **params)
@@ -407,6 +415,9 @@ class Bucket(object):
:rtype: ResultSet :rtype: ResultSet
:return: The result from S3 listing the keys requested :return: The result from S3 listing the keys requested
""" """
self._validate_kwarg_names(params, ['maxkeys', 'max_keys', 'prefix',
'key_marker', 'version_id_marker',
'delimiter'])
return self._get_all([('Version', self.key_class), return self._get_all([('Version', self.key_class),
('CommonPrefixes', Prefix), ('CommonPrefixes', Prefix),
('DeleteMarker', DeleteMarker)], ('DeleteMarker', DeleteMarker)],
@@ -450,6 +461,8 @@ class Bucket(object):
:return: The result from S3 listing the uploads requested :return: The result from S3 listing the uploads requested
""" """
self._validate_kwarg_names(params, ['max_uploads', 'key_marker',
'upload_id_marker'])
return self._get_all([('Upload', MultiPartUpload), return self._get_all([('Upload', MultiPartUpload),
('CommonPrefixes', Prefix)], ('CommonPrefixes', Prefix)],
'uploads', headers, **params) 'uploads', headers, **params)
@@ -693,7 +706,8 @@ class Bucket(object):
if self.name == src_bucket_name: if self.name == src_bucket_name:
src_bucket = self src_bucket = self
else: else:
src_bucket = self.connection.get_bucket(src_bucket_name) src_bucket = self.connection.get_bucket(
src_bucket_name, validate=False)
acl = src_bucket.get_xml_acl(src_key_name) acl = src_bucket.get_xml_acl(src_key_name)
if encrypt_key: if encrypt_key:
headers[provider.server_side_encryption_header] = 'AES256' headers[provider.server_side_encryption_header] = 'AES256'
@@ -1300,6 +1314,7 @@ class Bucket(object):
* ErrorDocument * ErrorDocument
* Key : name of object to serve when an error occurs * Key : name of object to serve when an error occurs
""" """
return self.get_website_configuration_with_xml(headers)[0] return self.get_website_configuration_with_xml(headers)[0]
@@ -1320,15 +1335,24 @@ class Bucket(object):
:rtype: 2-Tuple :rtype: 2-Tuple
:returns: 2-tuple containing: :returns: 2-tuple containing:
1) A dictionary containing a Python representation
of the XML response. The overall structure is: 1) A dictionary containing a Python representation \
* WebsiteConfiguration of the XML response. The overall structure is:
* IndexDocument
* Suffix : suffix that is appended to request that * WebsiteConfiguration
is for a "directory" on the website endpoint
* ErrorDocument * IndexDocument
* Key : name of object to serve when an error occurs
2) unparsed XML describing the bucket's website configuration. * Suffix : suffix that is appended to request that \
is for a "directory" on the website endpoint
* ErrorDocument
* Key : name of object to serve when an error occurs
2) unparsed XML describing the bucket's website configuration
""" """
body = self.get_website_configuration_xml(headers=headers) body = self.get_website_configuration_xml(headers=headers)

View File

@@ -264,7 +264,7 @@ class SNSConnection(AWSQueryConnection):
:type protocol: string :type protocol: string
:param protocol: The protocol used to communicate with :param protocol: The protocol used to communicate with
the subscriber. Current choices are: the subscriber. Current choices are:
email|email-json|http|https|sqs email|email-json|http|https|sqs|sms
:type endpoint: string :type endpoint: string
:param endpoint: The location of the endpoint for :param endpoint: The location of the endpoint for
@@ -274,6 +274,7 @@ class SNSConnection(AWSQueryConnection):
* For http, this would be a URL beginning with http * For http, this would be a URL beginning with http
* For https, this would be a URL beginning with https * For https, this would be a URL beginning with https
* For sqs, this would be the ARN of an SQS Queue * For sqs, this would be the ARN of an SQS Queue
* For sms, this would be a phone number of an SMS-enabled device
""" """
params = {'TopicArn': topic, params = {'TopicArn': topic,
'Protocol': protocol, 'Protocol': protocol,

View File

@@ -286,8 +286,8 @@ class SQSConnection(AWSQueryConnection):
:param queue: The Queue from which messages are read. :param queue: The Queue from which messages are read.
:type receipt_handle: str :type receipt_handle: str
:param queue: The receipt handle associated with the message whose :param receipt_handle: The receipt handle associated with the message
visibility timeout will be changed. whose visibility timeout will be changed.
:type visibility_timeout: int :type visibility_timeout: int
:param visibility_timeout: The new value of the message's visibility :param visibility_timeout: The new value of the message's visibility
@@ -337,16 +337,19 @@ class SQSConnection(AWSQueryConnection):
params['QueueNamePrefix'] = prefix params['QueueNamePrefix'] = prefix
return self.get_list('ListQueues', params, [('QueueUrl', Queue)]) return self.get_list('ListQueues', params, [('QueueUrl', Queue)])
def get_queue(self, queue_name): def get_queue(self, queue_name, owner_acct_id=None):
""" """
Retrieves the queue with the given name, or ``None`` if no match Retrieves the queue with the given name, or ``None`` if no match
was found. was found.
:param str queue_name: The name of the queue to retrieve. :param str queue_name: The name of the queue to retrieve.
:param str owner_acct_id: Optionally, the AWS account ID of the account that created the queue.
:rtype: :py:class:`boto.sqs.queue.Queue` or ``None`` :rtype: :py:class:`boto.sqs.queue.Queue` or ``None``
:returns: The requested queue, or ``None`` if no match was found. :returns: The requested queue, or ``None`` if no match was found.
""" """
params = {'QueueName': queue_name} params = {'QueueName': queue_name}
if owner_acct_id:
params['QueueOwnerAWSAccountId']=owner_acct_id
try: try:
return self.get_object('GetQueueUrl', params, Queue) return self.get_object('GetQueueUrl', params, Queue)
except SQSError: except SQSError:

View File

@@ -95,7 +95,7 @@ class RawMessage:
def endElement(self, name, value, connection): def endElement(self, name, value, connection):
if name == 'Body': if name == 'Body':
self.set_body(self.decode(value)) self.set_body(value)
elif name == 'MessageId': elif name == 'MessageId':
self.id = value self.id = value
elif name == 'ReceiptHandle': elif name == 'ReceiptHandle':
@@ -105,6 +105,9 @@ class RawMessage:
else: else:
setattr(self, name, value) setattr(self, name, value)
def endNode(self, connection):
self.set_body(self.decode(self.get_body()))
def encode(self, value): def encode(self, value):
"""Transform body object into serialized byte array format.""" """Transform body object into serialized byte array format."""
return value return value

View File

@@ -188,7 +188,11 @@ class ActivityWorker(Actor):
@wraps(Layer1.poll_for_activity_task) @wraps(Layer1.poll_for_activity_task)
def poll(self, **kwargs): def poll(self, **kwargs):
"""PollForActivityTask.""" """PollForActivityTask."""
task = self._swf.poll_for_activity_task(self.domain, self.task_list, task_list = self.task_list
if 'task_list' in kwargs:
task_list = kwargs.get('task_list')
del kwargs['task_list']
task = self._swf.poll_for_activity_task(self.domain, task_list,
**kwargs) **kwargs)
self.last_tasktoken = task.get('taskToken') self.last_tasktoken = task.get('taskToken')
return task return task
@@ -211,12 +215,14 @@ class Decider(Actor):
@wraps(Layer1.poll_for_decision_task) @wraps(Layer1.poll_for_decision_task)
def poll(self, **kwargs): def poll(self, **kwargs):
"""PollForDecisionTask.""" """PollForDecisionTask."""
result = self._swf.poll_for_decision_task(self.domain, self.task_list, task_list = self.task_list
if 'task_list' in kwargs:
task_list = kwargs.get('task_list')
del kwargs['task_list']
decision_task = self._swf.poll_for_decision_task(self.domain, task_list,
**kwargs) **kwargs)
# Record task token. self.last_tasktoken = decision_task.get('taskToken')
self.last_tasktoken = result.get('taskToken') return decision_task
# Record the last event.
return result
class WorkflowType(SWFBase): class WorkflowType(SWFBase):

View File

@@ -27,6 +27,7 @@ from boto.ec2.connection import EC2Connection
from boto.resultset import ResultSet from boto.resultset import ResultSet
from boto.vpc.vpc import VPC from boto.vpc.vpc import VPC
from boto.vpc.customergateway import CustomerGateway from boto.vpc.customergateway import CustomerGateway
from boto.vpc.networkacl import NetworkAcl
from boto.vpc.routetable import RouteTable from boto.vpc.routetable import RouteTable
from boto.vpc.internetgateway import InternetGateway from boto.vpc.internetgateway import InternetGateway
from boto.vpc.vpngateway import VpnGateway, Attachment from boto.vpc.vpngateway import VpnGateway, Attachment
@@ -36,6 +37,7 @@ from boto.vpc.vpnconnection import VpnConnection
from boto.ec2 import RegionData from boto.ec2 import RegionData
from boto.regioninfo import RegionInfo from boto.regioninfo import RegionInfo
def regions(**kw_params): def regions(**kw_params):
""" """
Get all available regions for the EC2 service. Get all available regions for the EC2 service.
@@ -53,9 +55,8 @@ def regions(**kw_params):
connection_cls=VPCConnection) connection_cls=VPCConnection)
regions.append(region) regions.append(region)
regions.append(RegionInfo(name='us-gov-west-1', regions.append(RegionInfo(name='us-gov-west-1',
endpoint=RegionData[region_name], endpoint=RegionData[region_name],
connection_cls=VPCConnection) connection_cls=VPCConnection))
)
return regions return regions
@@ -117,20 +118,26 @@ class VPCConnection(EC2Connection):
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_list('DescribeVpcs', params, [('item', VPC)]) return self.get_list('DescribeVpcs', params, [('item', VPC)])
def create_vpc(self, cidr_block, dry_run=False): def create_vpc(self, cidr_block, instance_tenancy=None, dry_run=False):
""" """
Create a new Virtual Private Cloud. Create a new Virtual Private Cloud.
:type cidr_block: str :type cidr_block: str
:param cidr_block: A valid CIDR block :param cidr_block: A valid CIDR block
:type instance_tenancy: str
:param instance_tenancy: The supported tenancy options for instances
launched into the VPC. Valid values are 'default' and 'dedicated'.
:type dry_run: bool :type dry_run: bool
:param dry_run: Set to True if the operation should not actually run. :param dry_run: Set to True if the operation should not actually run.
:rtype: The newly created VPC :rtype: The newly created VPC
:return: A :class:`boto.vpc.vpc.VPC` object :return: A :class:`boto.vpc.vpc.VPC` object
""" """
params = {'CidrBlock' : cidr_block} params = {'CidrBlock': cidr_block}
if instance_tenancy:
params['InstanceTenancy'] = instance_tenancy
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_object('CreateVpc', params, VPC) return self.get_object('CreateVpc', params, VPC)
@@ -266,7 +273,7 @@ class VPCConnection(EC2Connection):
:rtype: bool :rtype: bool
:return: True if successful :return: True if successful
""" """
params = { 'AssociationId': association_id } params = {'AssociationId': association_id}
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_status('DisassociateRouteTable', params) return self.get_status('DisassociateRouteTable', params)
@@ -284,7 +291,7 @@ class VPCConnection(EC2Connection):
:rtype: The newly created route table :rtype: The newly created route table
:return: A :class:`boto.vpc.routetable.RouteTable` object :return: A :class:`boto.vpc.routetable.RouteTable` object
""" """
params = { 'VpcId': vpc_id } params = {'VpcId': vpc_id}
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_object('CreateRouteTable', params, RouteTable) return self.get_object('CreateRouteTable', params, RouteTable)
@@ -302,13 +309,96 @@ class VPCConnection(EC2Connection):
:rtype: bool :rtype: bool
:return: True if successful :return: True if successful
""" """
params = { 'RouteTableId': route_table_id } params = {'RouteTableId': route_table_id}
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_status('DeleteRouteTable', params) return self.get_status('DeleteRouteTable', params)
def _replace_route_table_association(self, association_id,
route_table_id, dry_run=False):
"""
Helper function for replace_route_table_association and
replace_route_table_association_with_assoc. Should not be used directly.
:type association_id: str
:param association_id: The ID of the existing association to replace.
:type route_table_id: str
:param route_table_id: The route table to ID to be used in the
association.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
:rtype: ResultSet
:return: ResultSet of Amazon resposne
"""
params = {
'AssociationId': association_id,
'RouteTableId': route_table_id
}
if dry_run:
params['DryRun'] = 'true'
return self.get_object('ReplaceRouteTableAssociation', params,
ResultSet)
def replace_route_table_assocation(self, association_id,
route_table_id, dry_run=False):
"""
Replaces a route association with a new route table. This can be
used to replace the 'main' route table by using the main route
table association instead of the more common subnet type
association.
NOTE: It may be better to use replace_route_table_association_with_assoc
instead of this function; this function does not return the new
association ID. This function is retained for backwards compatibility.
:type association_id: str
:param association_id: The ID of the existing association to replace.
:type route_table_id: str
:param route_table_id: The route table to ID to be used in the
association.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
:rtype: bool
:return: True if successful
"""
return self._replace_route_table_association(
association_id, route_table_id, dry_run=dry_run).status
def replace_route_table_association_with_assoc(self, association_id,
route_table_id,
dry_run=False):
"""
Replaces a route association with a new route table. This can be
used to replace the 'main' route table by using the main route
table association instead of the more common subnet type
association. Returns the new association ID.
:type association_id: str
:param association_id: The ID of the existing association to replace.
:type route_table_id: str
:param route_table_id: The route table to ID to be used in the
association.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
:rtype: str
:return: New association ID
"""
return self._replace_route_table_association(
association_id, route_table_id, dry_run=dry_run).newAssociationId
def create_route(self, route_table_id, destination_cidr_block, def create_route(self, route_table_id, destination_cidr_block,
gateway_id=None, instance_id=None, dry_run=False): gateway_id=None, instance_id=None, interface_id=None,
dry_run=False):
""" """
Creates a new route in the route table within a VPC. The route's target Creates a new route in the route table within a VPC. The route's target
can be either a gateway attached to the VPC or a NAT instance in the can be either a gateway attached to the VPC or a NAT instance in the
@@ -327,6 +417,9 @@ class VPCConnection(EC2Connection):
:type instance_id: str :type instance_id: str
:param instance_id: The ID of a NAT instance in your VPC. :param instance_id: The ID of a NAT instance in your VPC.
:type interface_id: str
:param interface_id: Allows routing to network interface attachments.
:type dry_run: bool :type dry_run: bool
:param dry_run: Set to True if the operation should not actually run. :param dry_run: Set to True if the operation should not actually run.
@@ -342,14 +435,16 @@ class VPCConnection(EC2Connection):
params['GatewayId'] = gateway_id params['GatewayId'] = gateway_id
elif instance_id is not None: elif instance_id is not None:
params['InstanceId'] = instance_id params['InstanceId'] = instance_id
elif interface_id is not None:
params['NetworkInterfaceId'] = interface_id
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_status('CreateRoute', params) return self.get_status('CreateRoute', params)
def replace_route(self, route_table_id, destination_cidr_block, def replace_route(self, route_table_id, destination_cidr_block,
gateway_id=None, instance_id=None, interface_id=None, gateway_id=None, instance_id=None, interface_id=None,
dry_run=False): dry_run=False):
""" """
Replaces an existing route within a route table in a VPC. Replaces an existing route within a route table in a VPC.
@@ -417,6 +512,271 @@ class VPCConnection(EC2Connection):
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_status('DeleteRoute', params) return self.get_status('DeleteRoute', params)
#Network ACLs
def get_all_network_acls(self, network_acl_ids=None, filters=None):
"""
Retrieve information about your network acls. You can filter results
to return information only about those network acls that match your
search parameters. Otherwise, all network acls associated with your
account are returned.
:type network_acl_ids: list
:param network_acl_ids: A list of strings with the desired network ACL
IDs.
:type filters: list of tuples
:param filters: A list of tuples containing filters. Each tuple
consists of a filter key and a filter value.
:rtype: list
:return: A list of :class:`boto.vpc.networkacl.NetworkAcl`
"""
params = {}
if network_acl_ids:
self.build_list_params(params, network_acl_ids, "NetworkAclId")
if filters:
self.build_filter_params(params, dict(filters))
return self.get_list('DescribeNetworkAcls', params,
[('item', NetworkAcl)])
def associate_network_acl(self, network_acl_id, subnet_id):
"""
Associates a network acl with a specific subnet.
:type network_acl_id: str
:param network_acl_id: The ID of the network ACL to associate.
:type subnet_id: str
:param subnet_id: The ID of the subnet to associate with.
:rtype: str
:return: The ID of the association created
"""
acl = self.get_all_network_acls(filters=[('association.subnet-id', subnet_id)])[0]
association = [ association for association in acl.associations if association.subnet_id == subnet_id ][0]
params = {
'AssociationId': association.id,
'NetworkAclId': network_acl_id
}
result = self.get_object('ReplaceNetworkAclAssociation', params, ResultSet)
return result.newAssociationId
def disassociate_network_acl(self, subnet_id, vpc_id=None):
"""
Figures out what the default ACL is for the VPC, and associates
current network ACL with the default.
:type subnet_id: str
:param association_id: The ID of the subnet to which the ACL belongs.
:type vpc_id: str
:param vpc_id: The ID of the VPC to which the ACL/subnet belongs. Queries EC2 if omitted.
:rtype: str
:return: The ID of the association created
"""
if not vpc_id:
vpc_id = self.get_all_subnets([subnet_id])[0].vpc_id
acls = self.get_all_network_acls(filters=[('vpc-id', vpc_id), ('default', 'true')])
default_acl_id = acls[0].id
return self.associate_network_acl(default_acl_id, subnet_id)
def create_network_acl(self, vpc_id):
"""
Creates a new network ACL.
:type vpc_id: str
:param vpc_id: The VPC ID to associate this network ACL with.
:rtype: The newly created network ACL
:return: A :class:`boto.vpc.networkacl.NetworkAcl` object
"""
params = {'VpcId': vpc_id}
return self.get_object('CreateNetworkAcl', params, NetworkAcl)
def delete_network_acl(self, network_acl_id):
"""
Delete a network ACL
:type network_acl_id: str
:param network_acl_id: The ID of the network_acl to delete.
:rtype: bool
:return: True if successful
"""
params = {'NetworkAclId': network_acl_id}
return self.get_status('DeleteNetworkAcl', params)
def create_network_acl_entry(self, network_acl_id, rule_number, protocol, rule_action,
cidr_block, egress=None, icmp_code=None, icmp_type=None,
port_range_from=None, port_range_to=None):
"""
Creates a new network ACL entry in a network ACL within a VPC.
:type network_acl_id: str
:param network_acl_id: The ID of the network ACL for this network ACL entry.
:type rule_number: int
:param rule_number: The rule number to assign to the entry (for example, 100).
:type protocol: int
:param protocol: Valid values: -1 or a protocol number
(http://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml)
:type rule_action: str
:param rule_action: Indicates whether to allow or deny traffic that matches the rule.
:type cidr_block: str
:param cidr_block: The CIDR range to allow or deny, in CIDR notation (for example,
172.16.0.0/24).
:type egress: bool
:param egress: Indicates whether this rule applies to egress traffic from the subnet (true)
or ingress traffic to the subnet (false).
:type icmp_type: int
:param icmp_type: For the ICMP protocol, the ICMP type. You can use -1 to specify
all ICMP types.
:type icmp_code: int
:param icmp_code: For the ICMP protocol, the ICMP code. You can use -1 to specify
all ICMP codes for the given ICMP type.
:type port_range_from: int
:param port_range_from: The first port in the range.
:type port_range_to: int
:param port_range_to: The last port in the range.
:rtype: bool
:return: True if successful
"""
params = {
'NetworkAclId': network_acl_id,
'RuleNumber': rule_number,
'Protocol': protocol,
'RuleAction': rule_action,
'CidrBlock': cidr_block
}
if egress is not None:
if isinstance(egress, bool):
egress = str(egress).lower()
params['Egress'] = egress
if icmp_code is not None:
params['Icmp.Code'] = icmp_code
if icmp_type is not None:
params['Icmp.Type'] = icmp_type
if port_range_from is not None:
params['PortRange.From'] = port_range_from
if port_range_to is not None:
params['PortRange.To'] = port_range_to
return self.get_status('CreateNetworkAclEntry', params)
def replace_network_acl_entry(self, network_acl_id, rule_number, protocol, rule_action,
cidr_block, egress=None, icmp_code=None, icmp_type=None,
port_range_from=None, port_range_to=None):
"""
Creates a new network ACL entry in a network ACL within a VPC.
:type network_acl_id: str
:param network_acl_id: The ID of the network ACL for the id you want to replace
:type rule_number: int
:param rule_number: The rule number that you want to replace(for example, 100).
:type protocol: int
:param protocol: Valid values: -1 or a protocol number
(http://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml)
:type rule_action: str
:param rule_action: Indicates whether to allow or deny traffic that matches the rule.
:type cidr_block: str
:param cidr_block: The CIDR range to allow or deny, in CIDR notation (for example,
172.16.0.0/24).
:type egress: bool
:param egress: Indicates whether this rule applies to egress traffic from the subnet (true)
or ingress traffic to the subnet (false).
:type icmp_type: int
:param icmp_type: For the ICMP protocol, the ICMP type. You can use -1 to specify
all ICMP types.
:type icmp_code: int
:param icmp_code: For the ICMP protocol, the ICMP code. You can use -1 to specify
all ICMP codes for the given ICMP type.
:type port_range_from: int
:param port_range_from: The first port in the range.
:type port_range_to: int
:param port_range_to: The last port in the range.
:rtype: bool
:return: True if successful
"""
params = {
'NetworkAclId': network_acl_id,
'RuleNumber': rule_number,
'Protocol': protocol,
'RuleAction': rule_action,
'CidrBlock': cidr_block
}
if egress is not None:
if isinstance(egress, bool):
egress = str(egress).lower()
params['Egress'] = egress
if icmp_code is not None:
params['Icmp.Code'] = icmp_code
if icmp_type is not None:
params['Icmp.Type'] = icmp_type
if port_range_from is not None:
params['PortRange.From'] = port_range_from
if port_range_to is not None:
params['PortRange.To'] = port_range_to
return self.get_status('ReplaceNetworkAclEntry', params)
def delete_network_acl_entry(self, network_acl_id, rule_number, egress=None):
"""
Deletes a network ACL entry from a network ACL within a VPC.
:type network_acl_id: str
:param network_acl_id: The ID of the network ACL with the network ACL entry.
:type rule_number: int
:param rule_number: The rule number for the entry to delete.
:type egress: bool
:param egress: Specifies whether the rule to delete is an egress rule (true)
or ingress rule (false).
:rtype: bool
:return: True if successful
"""
params = {
'NetworkAclId': network_acl_id,
'RuleNumber': rule_number
}
if egress is not None:
if isinstance(egress, bool):
egress = str(egress).lower()
params['Egress'] = egress
return self.get_status('DeleteNetworkAclEntry', params)
# Internet Gateways # Internet Gateways
def get_all_internet_gateways(self, internet_gateway_ids=None, def get_all_internet_gateways(self, internet_gateway_ids=None,
@@ -476,7 +836,7 @@ class VPCConnection(EC2Connection):
:rtype: Bool :rtype: Bool
:return: True if successful :return: True if successful
""" """
params = { 'InternetGatewayId': internet_gateway_id } params = {'InternetGatewayId': internet_gateway_id}
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_status('DeleteInternetGateway', params) return self.get_status('DeleteInternetGateway', params)
@@ -586,7 +946,7 @@ class VPCConnection(EC2Connection):
:param ip_address: Internet-routable IP address for customer's gateway. :param ip_address: Internet-routable IP address for customer's gateway.
Must be a static address. Must be a static address.
:type bgp_asn: str :type bgp_asn: int
:param bgp_asn: Customer gateway's Border Gateway Protocol (BGP) :param bgp_asn: Customer gateway's Border Gateway Protocol (BGP)
Autonomous System Number (ASN) Autonomous System Number (ASN)
@@ -596,9 +956,9 @@ class VPCConnection(EC2Connection):
:rtype: The newly created CustomerGateway :rtype: The newly created CustomerGateway
:return: A :class:`boto.vpc.customergateway.CustomerGateway` object :return: A :class:`boto.vpc.customergateway.CustomerGateway` object
""" """
params = {'Type' : type, params = {'Type': type,
'IpAddress' : ip_address, 'IpAddress': ip_address,
'BgpAsn' : bgp_asn} 'BgpAsn': bgp_asn}
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_object('CreateCustomerGateway', params, CustomerGateway) return self.get_object('CreateCustomerGateway', params, CustomerGateway)
@@ -677,7 +1037,7 @@ class VPCConnection(EC2Connection):
:rtype: The newly created VpnGateway :rtype: The newly created VpnGateway
:return: A :class:`boto.vpc.vpngateway.VpnGateway` object :return: A :class:`boto.vpc.vpngateway.VpnGateway` object
""" """
params = {'Type' : type} params = {'Type': type}
if availability_zone: if availability_zone:
params['AvailabilityZone'] = availability_zone params['AvailabilityZone'] = availability_zone
if dry_run: if dry_run:
@@ -719,11 +1079,33 @@ class VPCConnection(EC2Connection):
:return: a :class:`boto.vpc.vpngateway.Attachment` :return: a :class:`boto.vpc.vpngateway.Attachment`
""" """
params = {'VpnGatewayId': vpn_gateway_id, params = {'VpnGatewayId': vpn_gateway_id,
'VpcId' : vpc_id} 'VpcId': vpc_id}
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_object('AttachVpnGateway', params, Attachment) return self.get_object('AttachVpnGateway', params, Attachment)
def detach_vpn_gateway(self, vpn_gateway_id, vpc_id, dry_run=False):
"""
Detaches a VPN gateway from a VPC.
:type vpn_gateway_id: str
:param vpn_gateway_id: The ID of the vpn_gateway to detach
:type vpc_id: str
:param vpc_id: The ID of the VPC you want to detach the gateway from.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
:rtype: bool
:return: True if successful
"""
params = {'VpnGatewayId': vpn_gateway_id,
'VpcId': vpc_id}
if dry_run:
params['DryRun'] = 'true'
return self.get_status('DetachVpnGateway', params)
# Subnets # Subnets
def get_all_subnets(self, subnet_ids=None, filters=None, dry_run=False): def get_all_subnets(self, subnet_ids=None, filters=None, dry_run=False):
@@ -784,8 +1166,8 @@ class VPCConnection(EC2Connection):
:rtype: The newly created Subnet :rtype: The newly created Subnet
:return: A :class:`boto.vpc.customergateway.Subnet` object :return: A :class:`boto.vpc.customergateway.Subnet` object
""" """
params = {'VpcId' : vpc_id, params = {'VpcId': vpc_id,
'CidrBlock' : cidr_block} 'CidrBlock': cidr_block}
if availability_zone: if availability_zone:
params['AvailabilityZone'] = availability_zone params['AvailabilityZone'] = availability_zone
if dry_run: if dry_run:
@@ -810,16 +1192,19 @@ class VPCConnection(EC2Connection):
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_status('DeleteSubnet', params) return self.get_status('DeleteSubnet', params)
# DHCP Options # DHCP Options
def get_all_dhcp_options(self, dhcp_options_ids=None, dry_run=False): def get_all_dhcp_options(self, dhcp_options_ids=None, filters=None, dry_run=False):
""" """
Retrieve information about your DhcpOptions. Retrieve information about your DhcpOptions.
:type dhcp_options_ids: list :type dhcp_options_ids: list
:param dhcp_options_ids: A list of strings with the desired DhcpOption ID's :param dhcp_options_ids: A list of strings with the desired DhcpOption ID's
:type filters: list of tuples
:param filters: A list of tuples containing filters. Each tuple
consists of a filter key and a filter value.
:type dry_run: bool :type dry_run: bool
:param dry_run: Set to True if the operation should not actually run. :param dry_run: Set to True if the operation should not actually run.
@@ -829,6 +1214,8 @@ class VPCConnection(EC2Connection):
params = {} params = {}
if dhcp_options_ids: if dhcp_options_ids:
self.build_list_params(params, dhcp_options_ids, 'DhcpOptionsId') self.build_list_params(params, dhcp_options_ids, 'DhcpOptionsId')
if filters:
self.build_filter_params(params, dict(filters))
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_list('DescribeDhcpOptions', params, return self.get_list('DescribeDhcpOptions', params,
@@ -890,19 +1277,19 @@ class VPCConnection(EC2Connection):
if domain_name: if domain_name:
key_counter = insert_option(params, key_counter = insert_option(params,
'domain-name', domain_name) 'domain-name', domain_name)
if domain_name_servers: if domain_name_servers:
key_counter = insert_option(params, key_counter = insert_option(params,
'domain-name-servers', domain_name_servers) 'domain-name-servers', domain_name_servers)
if ntp_servers: if ntp_servers:
key_counter = insert_option(params, key_counter = insert_option(params,
'ntp-servers', ntp_servers) 'ntp-servers', ntp_servers)
if netbios_name_servers: if netbios_name_servers:
key_counter = insert_option(params, key_counter = insert_option(params,
'netbios-name-servers', netbios_name_servers) 'netbios-name-servers', netbios_name_servers)
if netbios_node_type: if netbios_node_type:
key_counter = insert_option(params, key_counter = insert_option(params,
'netbios-node-type', netbios_node_type) 'netbios-node-type', netbios_node_type)
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
@@ -943,7 +1330,7 @@ class VPCConnection(EC2Connection):
:return: True if successful :return: True if successful
""" """
params = {'DhcpOptionsId': dhcp_options_id, params = {'DhcpOptionsId': dhcp_options_id,
'VpcId' : vpc_id} 'VpcId': vpc_id}
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_status('AssociateDhcpOptions', params) return self.get_status('AssociateDhcpOptions', params)
@@ -983,7 +1370,7 @@ class VPCConnection(EC2Connection):
params = {} params = {}
if vpn_connection_ids: if vpn_connection_ids:
self.build_list_params(params, vpn_connection_ids, self.build_list_params(params, vpn_connection_ids,
'Vpn_ConnectionId') 'VpnConnectionId')
if filters: if filters:
self.build_filter_params(params, dict(filters)) self.build_filter_params(params, dict(filters))
if dry_run: if dry_run:
@@ -992,7 +1379,7 @@ class VPCConnection(EC2Connection):
[('item', VpnConnection)]) [('item', VpnConnection)])
def create_vpn_connection(self, type, customer_gateway_id, vpn_gateway_id, def create_vpn_connection(self, type, customer_gateway_id, vpn_gateway_id,
dry_run=False): static_routes_only=None, dry_run=False):
""" """
Create a new VPN Connection. Create a new VPN Connection.
@@ -1006,15 +1393,24 @@ class VPCConnection(EC2Connection):
:type vpn_gateway_id: str :type vpn_gateway_id: str
:param vpn_gateway_id: The ID of the VPN gateway. :param vpn_gateway_id: The ID of the VPN gateway.
:type static_routes_only: bool
:param static_routes_only: Indicates whether the VPN connection
requires static routes. If you are creating a VPN connection
for a device that does not support BGP, you must specify true.
:type dry_run: bool :type dry_run: bool
:param dry_run: Set to True if the operation should not actually run. :param dry_run: Set to True if the operation should not actually run.
:rtype: The newly created VpnConnection :rtype: The newly created VpnConnection
:return: A :class:`boto.vpc.vpnconnection.VpnConnection` object :return: A :class:`boto.vpc.vpnconnection.VpnConnection` object
""" """
params = {'Type' : type, params = {'Type': type,
'CustomerGatewayId' : customer_gateway_id, 'CustomerGatewayId': customer_gateway_id,
'VpnGatewayId' : vpn_gateway_id} 'VpnGatewayId': vpn_gateway_id}
if static_routes_only is not None:
if isinstance(static_routes_only, bool):
static_routes_only = str(static_routes_only).lower()
params['Options.StaticRoutesOnly'] = static_routes_only
if dry_run: if dry_run:
params['DryRun'] = 'true' params['DryRun'] = 'true'
return self.get_object('CreateVpnConnection', params, VpnConnection) return self.get_object('CreateVpnConnection', params, VpnConnection)

View File

@@ -14,7 +14,7 @@
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT # ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, # SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE. # IN THE SOFTWARE.
@@ -25,6 +25,7 @@ Represents a Customer Gateway
from boto.ec2.ec2object import TaggedEC2Object from boto.ec2.ec2object import TaggedEC2Object
class CustomerGateway(TaggedEC2Object): class CustomerGateway(TaggedEC2Object):
def __init__(self, connection=None): def __init__(self, connection=None):
@@ -37,7 +38,7 @@ class CustomerGateway(TaggedEC2Object):
def __repr__(self): def __repr__(self):
return 'CustomerGateway:%s' % self.id return 'CustomerGateway:%s' % self.id
def endElement(self, name, value, connection): def endElement(self, name, value, connection):
if name == 'customerGatewayId': if name == 'customerGatewayId':
self.id = value self.id = value
@@ -48,7 +49,6 @@ class CustomerGateway(TaggedEC2Object):
elif name == 'state': elif name == 'state':
self.state = value self.state = value
elif name == 'bgpAsn': elif name == 'bgpAsn':
self.bgp_asn = value self.bgp_asn = int(value)
else: else:
setattr(self, name, value) setattr(self, name, value)

View File

@@ -0,0 +1,164 @@
# Copyright (c) 2009-2010 Mitch Garnaat http://garnaat.org/
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish, dis-
# tribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the fol-
# lowing conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
"""
Represents a Network ACL
"""
from boto.ec2.ec2object import TaggedEC2Object
from boto.resultset import ResultSet
class Icmp(object):
"""
Defines the ICMP code and type.
"""
def __init__(self, connection=None):
self.code = None
self.type = None
def __repr__(self):
return 'Icmp::code:%s, type:%s)' % ( self.code, self.type)
def startElement(self, name, attrs, connection):
pass
def endElement(self, name, value, connection):
if name == 'code':
self.code = value
elif name == 'type':
self.type = value
class NetworkAcl(TaggedEC2Object):
def __init__(self, connection=None):
TaggedEC2Object.__init__(self, connection)
self.id = None
self.vpc_id = None
self.network_acl_entries = []
self.associations = []
def __repr__(self):
return 'NetworkAcl:%s' % self.id
def startElement(self, name, attrs, connection):
result = super(NetworkAcl, self).startElement(name, attrs, connection)
if result is not None:
# Parent found an interested element, just return it
return result
if name == 'entrySet':
self.network_acl_entries = ResultSet([('item', NetworkAclEntry)])
return self.network_acl_entries
elif name == 'associationSet':
self.associations = ResultSet([('item', NetworkAclAssociation)])
return self.associations
else:
return None
def endElement(self, name, value, connection):
if name == 'networkAclId':
self.id = value
elif name == 'vpcId':
self.vpc_id = value
else:
setattr(self, name, value)
class NetworkAclEntry(object):
def __init__(self, connection=None):
self.rule_number = None
self.protocol = None
self.rule_action = None
self.egress = None
self.cidr_block = None
self.port_range = PortRange()
self.icmp = Icmp()
def __repr__(self):
return 'Acl:%s' % self.rule_number
def startElement(self, name, attrs, connection):
if name == 'portRange':
return self.port_range
elif name == 'icmpTypeCode':
return self.icmp
else:
return None
def endElement(self, name, value, connection):
if name == 'cidrBlock':
self.cidr_block = value
elif name == 'egress':
self.egress = value
elif name == 'protocol':
self.protocol = value
elif name == 'ruleAction':
self.rule_action = value
elif name == 'ruleNumber':
self.rule_number = value
class NetworkAclAssociation(object):
def __init__(self, connection=None):
self.id = None
self.subnet_id = None
self.network_acl_id = None
def __repr__(self):
return 'NetworkAclAssociation:%s' % self.id
def startElement(self, name, attrs, connection):
return None
def endElement(self, name, value, connection):
if name == 'networkAclAssociationId':
self.id = value
elif name == 'networkAclId':
self.route_table_id = value
elif name == 'subnetId':
self.subnet_id = value
class PortRange(object):
"""
Define the port range for the ACL entry if it is tcp / udp
"""
def __init__(self, connection=None):
self.from_port = None
self.to_port = None
def __repr__(self):
return 'PortRange:(%s-%s)' % ( self.from_port, self.to_port)
def startElement(self, name, attrs, connection):
pass
def endElement(self, name, value, connection):
if name == 'from':
self.from_port = value
elif name == 'to':
self.to_port = value

View File

@@ -1,208 +0,0 @@
# -*- coding: utf-8 -*-
"""
celery.__compat__
~~~~~~~~~~~~~~~~~
This module contains utilities to dynamically
recreate modules, either for lazy loading or
to create old modules at runtime instead of
having them litter the source tree.
"""
from __future__ import absolute_import
import operator
import sys
# import fails in python 2.5. fallback to reduce in stdlib
try:
from functools import reduce
except ImportError:
pass
from importlib import import_module
from types import ModuleType
from .local import Proxy
MODULE_DEPRECATED = """
The module %s is deprecated and will be removed in a future version.
"""
DEFAULT_ATTRS = set(['__file__', '__path__', '__doc__', '__all__'])
# im_func is no longer available in Py3.
# instead the unbound method itself can be used.
if sys.version_info[0] == 3: # pragma: no cover
def fun_of_method(method):
return method
else:
def fun_of_method(method): # noqa
return method.im_func
def getappattr(path):
"""Gets attribute from the current_app recursively,
e.g. getappattr('amqp.get_task_consumer')``."""
from celery import current_app
return current_app._rgetattr(path)
def _compat_task_decorator(*args, **kwargs):
from celery import current_app
kwargs.setdefault('accept_magic_kwargs', True)
return current_app.task(*args, **kwargs)
def _compat_periodic_task_decorator(*args, **kwargs):
from celery.task import periodic_task
kwargs.setdefault('accept_magic_kwargs', True)
return periodic_task(*args, **kwargs)
COMPAT_MODULES = {
'celery': {
'execute': {
'send_task': 'send_task',
},
'decorators': {
'task': _compat_task_decorator,
'periodic_task': _compat_periodic_task_decorator,
},
'log': {
'get_default_logger': 'log.get_default_logger',
'setup_logger': 'log.setup_logger',
'setup_loggig_subsystem': 'log.setup_logging_subsystem',
'redirect_stdouts_to_logger': 'log.redirect_stdouts_to_logger',
},
'messaging': {
'TaskPublisher': 'amqp.TaskPublisher',
'TaskConsumer': 'amqp.TaskConsumer',
'establish_connection': 'connection',
'with_connection': 'with_default_connection',
'get_consumer_set': 'amqp.TaskConsumer',
},
'registry': {
'tasks': 'tasks',
},
},
'celery.task': {
'control': {
'broadcast': 'control.broadcast',
'rate_limit': 'control.rate_limit',
'time_limit': 'control.time_limit',
'ping': 'control.ping',
'revoke': 'control.revoke',
'discard_all': 'control.purge',
'inspect': 'control.inspect',
},
'schedules': 'celery.schedules',
'chords': 'celery.canvas',
}
}
class class_property(object):
def __init__(self, fget=None, fset=None):
assert fget and isinstance(fget, classmethod)
assert isinstance(fset, classmethod) if fset else True
self.__get = fget
self.__set = fset
info = fget.__get__(object) # just need the info attrs.
self.__doc__ = info.__doc__
self.__name__ = info.__name__
self.__module__ = info.__module__
def __get__(self, obj, type=None):
if obj and type is None:
type = obj.__class__
return self.__get.__get__(obj, type)()
def __set__(self, obj, value):
if obj is None:
return self
return self.__set.__get__(obj)(value)
def reclassmethod(method):
return classmethod(fun_of_method(method))
class MagicModule(ModuleType):
_compat_modules = ()
_all_by_module = {}
_direct = {}
_object_origins = {}
def __getattr__(self, name):
if name in self._object_origins:
module = __import__(self._object_origins[name], None, None, [name])
for item in self._all_by_module[module.__name__]:
setattr(self, item, getattr(module, item))
return getattr(module, name)
elif name in self._direct:
module = __import__(self._direct[name], None, None, [name])
setattr(self, name, module)
return module
return ModuleType.__getattribute__(self, name)
def __dir__(self):
return list(set(self.__all__) | DEFAULT_ATTRS)
def create_module(name, attrs, cls_attrs=None, pkg=None,
base=MagicModule, prepare_attr=None):
fqdn = '.'.join([pkg.__name__, name]) if pkg else name
cls_attrs = {} if cls_attrs is None else cls_attrs
attrs = dict((attr_name, prepare_attr(attr) if prepare_attr else attr)
for attr_name, attr in attrs.iteritems())
module = sys.modules[fqdn] = type(name, (base, ), cls_attrs)(fqdn)
module.__dict__.update(attrs)
return module
def recreate_module(name, compat_modules=(), by_module={}, direct={},
base=MagicModule, **attrs):
old_module = sys.modules[name]
origins = get_origins(by_module)
compat_modules = COMPAT_MODULES.get(name, ())
cattrs = dict(
_compat_modules=compat_modules,
_all_by_module=by_module, _direct=direct,
_object_origins=origins,
__all__=tuple(set(reduce(
operator.add,
[tuple(v) for v in [compat_modules, origins, direct, attrs]],
))),
)
new_module = create_module(name, attrs, cls_attrs=cattrs, base=base)
new_module.__dict__.update(dict((mod, get_compat_module(new_module, mod))
for mod in compat_modules))
return old_module, new_module
def get_compat_module(pkg, name):
def prepare(attr):
if isinstance(attr, basestring):
return Proxy(getappattr, (attr, ))
return attr
attrs = COMPAT_MODULES[pkg.__name__][name]
if isinstance(attrs, basestring):
fqdn = '.'.join([pkg.__name__, name])
module = sys.modules[fqdn] = import_module(attrs)
return module
attrs['__all__'] = list(attrs)
return create_module(name, dict(attrs), pkg=pkg, prepare_attr=prepare)
def get_origins(defs):
origins = {}
for module, items in defs.iteritems():
origins.update(dict((item, module) for item in items))
return origins

View File

@@ -2,45 +2,126 @@
"""Distributed Task Queue""" """Distributed Task Queue"""
# :copyright: (c) 2009 - 2012 Ask Solem and individual contributors, # :copyright: (c) 2009 - 2012 Ask Solem and individual contributors,
# All rights reserved. # All rights reserved.
# :copyright: (c) 2012 VMware, Inc., All rights reserved. # :copyright: (c) 2012-2013 GoPivotal, Inc., All rights reserved.
# :license: BSD (3 Clause), see LICENSE for more details. # :license: BSD (3 Clause), see LICENSE for more details.
from __future__ import absolute_import from __future__ import absolute_import
SERIES = 'Chiastic Slide' SERIES = 'Cipater'
VERSION = (3, 0, 23) VERSION = (3, 1, 3)
__version__ = '.'.join(str(p) for p in VERSION[0:3]) + ''.join(VERSION[3:]) __version__ = '.'.join(str(p) for p in VERSION[0:3]) + ''.join(VERSION[3:])
__author__ = 'Ask Solem' __author__ = 'Ask Solem'
__contact__ = 'ask@celeryproject.org' __contact__ = 'ask@celeryproject.org'
__homepage__ = 'http://celeryproject.org' __homepage__ = 'http://celeryproject.org'
__docformat__ = 'restructuredtext' __docformat__ = 'restructuredtext'
__all__ = [ __all__ = [
'Celery', 'bugreport', 'shared_task', 'Task', 'Celery', 'bugreport', 'shared_task', 'task',
'current_app', 'current_task', 'current_app', 'current_task', 'maybe_signature',
'chain', 'chord', 'chunks', 'group', 'subtask', 'chain', 'chord', 'chunks', 'group', 'signature',
'xmap', 'xstarmap', 'uuid', 'VERSION', '__version__', 'xmap', 'xstarmap', 'uuid', 'version', '__version__',
] ]
VERSION_BANNER = '%s (%s)' % (__version__, SERIES) VERSION_BANNER = '{0} ({1})'.format(__version__, SERIES)
# -eof meta- # -eof meta-
import os
import sys
if os.environ.get('C_IMPDEBUG'): # pragma: no cover
from .five import builtins
real_import = builtins.__import__
def debug_import(name, locals=None, globals=None,
fromlist=None, level=-1):
glob = globals or getattr(sys, 'emarfteg_'[::-1])(1).f_globals
importer_name = glob and glob.get('__name__') or 'unknown'
print('-- {0} imports {1}'.format(importer_name, name))
return real_import(name, locals, globals, fromlist, level)
builtins.__import__ = debug_import
# This is never executed, but tricks static analyzers (PyDev, PyCharm,
# pylint, etc.) into knowing the types of these symbols, and what
# they contain.
STATICA_HACK = True STATICA_HACK = True
globals()['kcah_acitats'[::-1].upper()] = False globals()['kcah_acitats'[::-1].upper()] = False
if STATICA_HACK: if STATICA_HACK: # pragma: no cover
# This is never executed, but tricks static analyzers (PyDev, PyCharm, from celery.app import shared_task # noqa
# pylint, etc.) into knowing the types of these symbols, and what from celery.app.base import Celery # noqa
# they contain. from celery.app.utils import bugreport # noqa
from celery.app.base import Celery from celery.app.task import Task # noqa
from celery.app.utils import bugreport from celery._state import current_app, current_task # noqa
from celery.app.task import Task from celery.canvas import ( # noqa
from celery._state import current_app, current_task chain, chord, chunks, group,
from celery.canvas import ( signature, maybe_signature, xmap, xstarmap, subtask,
chain, chord, chunks, group, subtask, xmap, xstarmap,
) )
from celery.utils import uuid from celery.utils import uuid # noqa
# Eventlet/gevent patching must happen before importing
# anything else, so these tools must be at top-level.
def _find_option_with_arg(argv, short_opts=None, long_opts=None):
"""Search argv for option specifying its short and longopt
alternatives.
Return the value of the option if found.
"""
for i, arg in enumerate(argv):
if arg.startswith('-'):
if long_opts and arg.startswith('--'):
name, _, val = arg.partition('=')
if name in long_opts:
return val
if short_opts and arg in short_opts:
return argv[i + 1]
raise KeyError('|'.join(short_opts or [] + long_opts or []))
def _patch_eventlet():
import eventlet
import eventlet.debug
eventlet.monkey_patch()
EVENTLET_DBLOCK = int(os.environ.get('EVENTLET_NOBLOCK', 0))
if EVENTLET_DBLOCK:
eventlet.debug.hub_blocking_detection(EVENTLET_DBLOCK)
def _patch_gevent():
from gevent import monkey, version_info
monkey.patch_all()
if version_info[0] == 0: # pragma: no cover
# Signals aren't working in gevent versions <1.0,
# and are not monkey patched by patch_all()
from gevent import signal as _gevent_signal
_signal = __import__('signal')
_signal.signal = _gevent_signal
def maybe_patch_concurrency(argv=sys.argv,
short_opts=['-P'], long_opts=['--pool'],
patches={'eventlet': _patch_eventlet,
'gevent': _patch_gevent}):
"""With short and long opt alternatives that specify the command line
option to set the pool, this makes sure that anything that needs
to be patched is completed as early as possible.
(e.g. eventlet/gevent monkey patches)."""
try:
pool = _find_option_with_arg(argv, short_opts, long_opts)
except KeyError:
pass
else:
try:
patcher = patches[pool]
except KeyError:
pass
else:
patcher()
# set up eventlet/gevent environments ASAP.
from celery import concurrency
concurrency.get_implementation(pool)
# Lazy loading # Lazy loading
from .__compat__ import recreate_module from .five import recreate_module
old_module, new_module = recreate_module( # pragma: no cover old_module, new_module = recreate_module( # pragma: no cover
__name__, __name__,
@@ -49,7 +130,8 @@ old_module, new_module = recreate_module( # pragma: no cover
'celery.app.task': ['Task'], 'celery.app.task': ['Task'],
'celery._state': ['current_app', 'current_task'], 'celery._state': ['current_app', 'current_task'],
'celery.canvas': ['chain', 'chord', 'chunks', 'group', 'celery.canvas': ['chain', 'chord', 'chunks', 'group',
'subtask', 'xmap', 'xstarmap'], 'signature', 'maybe_signature', 'subtask',
'xmap', 'xstarmap'],
'celery.utils': ['uuid'], 'celery.utils': ['uuid'],
}, },
direct={'task': 'celery.task'}, direct={'task': 'celery.task'},
@@ -58,4 +140,6 @@ old_module, new_module = recreate_module( # pragma: no cover
__author__=__author__, __contact__=__contact__, __author__=__author__, __contact__=__contact__,
__homepage__=__homepage__, __docformat__=__docformat__, __homepage__=__homepage__, __docformat__=__docformat__,
VERSION=VERSION, SERIES=SERIES, VERSION_BANNER=VERSION_BANNER, VERSION=VERSION, SERIES=SERIES, VERSION_BANNER=VERSION_BANNER,
maybe_patch_concurrency=maybe_patch_concurrency,
_find_option_with_arg=_find_option_with_arg,
) )

View File

@@ -2,10 +2,25 @@ from __future__ import absolute_import
import sys import sys
from os.path import basename
def maybe_patch_concurrency(): from . import maybe_patch_concurrency
from celery.platforms import maybe_patch_concurrency
maybe_patch_concurrency(sys.argv, ['-P'], ['--pool']) __all__ = ['main']
DEPRECATED_FMT = """
The {old!r} command is deprecated, please use {new!r} instead:
$ {new_argv}
"""
def _warn_deprecated(new):
print(DEPRECATED_FMT.format(
old=basename(sys.argv[0]), new=new,
new_argv=' '.join([new] + sys.argv[1:])),
)
def main(): def main():
@@ -16,21 +31,24 @@ def main():
def _compat_worker(): def _compat_worker():
maybe_patch_concurrency() maybe_patch_concurrency()
from celery.bin.celeryd import main _warn_deprecated('celery worker')
from celery.bin.worker import main
main() main()
def _compat_multi(): def _compat_multi():
maybe_patch_concurrency() maybe_patch_concurrency()
from celery.bin.celeryd_multi import main _warn_deprecated('celery multi')
from celery.bin.multi import main
main() main()
def _compat_beat(): def _compat_beat():
maybe_patch_concurrency() maybe_patch_concurrency()
from celery.bin.celerybeat import main _warn_deprecated('celery beat')
from celery.bin.beat import main
main() main()
if __name__ == '__main__': if __name__ == '__main__': # pragma: no cover
main() main()

View File

@@ -9,7 +9,7 @@
This module shouldn't be used directly. This module shouldn't be used directly.
""" """
from __future__ import absolute_import from __future__ import absolute_import, print_function
import os import os
import sys import sys
@@ -19,12 +19,26 @@ import weakref
from celery.local import Proxy from celery.local import Proxy
from celery.utils.threads import LocalStack from celery.utils.threads import LocalStack
__all__ = ['set_default_app', 'get_current_app', 'get_current_task',
'get_current_worker_task', 'current_app', 'current_task']
#: Global default app used when no current app. #: Global default app used when no current app.
default_app = None default_app = None
#: List of all app instances (weakrefs), must not be used directly. #: List of all app instances (weakrefs), must not be used directly.
_apps = set() _apps = set()
_task_join_will_block = False
def _set_task_join_will_block(blocks):
global _task_join_will_block
_task_join_will_block = True
def task_join_will_block():
return _task_join_will_block
class _TLS(threading.local): class _TLS(threading.local):
#: Apps with the :attr:`~celery.app.base.BaseApp.set_as_current` attribute #: Apps with the :attr:`~celery.app.base.BaseApp.set_as_current` attribute
@@ -53,10 +67,11 @@ def _get_current_app():
return _tls.current_app or default_app return _tls.current_app or default_app
C_STRICT_APP = os.environ.get('C_STRICT_APP') C_STRICT_APP = os.environ.get('C_STRICT_APP')
if os.environ.get('C_STRICT_APP'): if os.environ.get('C_STRICT_APP'): # pragma: no cover
def get_current_app(): def get_current_app():
raise Exception('USES CURRENT APP')
import traceback import traceback
sys.stderr.write('USES CURRENT_APP\n') print('-- USES CURRENT_APP', file=sys.stderr) # noqa+
traceback.print_stack(file=sys.stderr) traceback.print_stack(file=sys.stderr)
return _get_current_app() return _get_current_app()
else: else:

View File

@@ -7,22 +7,29 @@
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement
import os import os
from collections import Callable
from celery.local import Proxy from celery.local import Proxy
from celery import _state from celery import _state
from celery._state import ( # noqa from celery._state import (
set_default_app, set_default_app,
get_current_app as current_app, get_current_app as current_app,
get_current_task as current_task, get_current_task as current_task,
_get_active_apps, _get_active_apps,
_task_stack,
) )
from celery.utils import gen_task_name from celery.utils import gen_task_name
from .builtins import shared_task as _shared_task from .builtins import shared_task as _shared_task
from .base import Celery, AppPickler # noqa from .base import Celery, AppPickler
__all__ = ['Celery', 'AppPickler', 'default_app', 'app_or_default',
'bugreport', 'enable_trace', 'disable_trace', 'shared_task',
'set_default_app', 'current_app', 'current_task',
'push_current_task', 'pop_current_task']
#: Proxy always returning the app set as default. #: Proxy always returning the app set as default.
default_app = Proxy(lambda: _state.default_app) default_app = Proxy(lambda: _state.default_app)
@@ -40,8 +47,18 @@ app_or_default = None
default_loader = os.environ.get('CELERY_LOADER') or 'default' # XXX default_loader = os.environ.get('CELERY_LOADER') or 'default' # XXX
def bugreport(): #: Function used to push a task to the thread local stack
return current_app().bugreport() #: keeping track of the currently executing task.
#: You must remember to pop the task after.
push_current_task = _task_stack.push
#: Function used to pop a task from the thread local stack
#: keeping track of the currently executing task.
pop_current_task = _task_stack.pop
def bugreport(app=None):
return (app or current_app()).bugreport()
def _app_or_default(app=None): def _app_or_default(app=None):
@@ -84,8 +101,8 @@ App = Celery # XXX Compat
def shared_task(*args, **kwargs): def shared_task(*args, **kwargs):
"""Task decorator that creates shared tasks, """Create shared tasks (decorator).
and returns a proxy that always returns the task from the current apps Will return a proxy that always takes the task from the current apps
task registry. task registry.
This can be used by library authors to create tasks that will work This can be used by library authors to create tasks that will work
@@ -121,7 +138,7 @@ def shared_task(*args, **kwargs):
with app._finalize_mutex: with app._finalize_mutex:
app._task_from_fun(fun, **options) app._task_from_fun(fun, **options)
# Returns a proxy that always gets the task from the current # Return a proxy that always gets the task from the current
# apps task registry. # apps task registry.
def task_by_cons(): def task_by_cons():
app = current_app() app = current_app()
@@ -131,6 +148,6 @@ def shared_task(*args, **kwargs):
return Proxy(task_by_cons) return Proxy(task_by_cons)
return __inner return __inner
if len(args) == 1 and callable(args[0]): if len(args) == 1 and isinstance(args[0], Callable):
return create_shared_task(**kwargs)(args[0]) return create_shared_task(**kwargs)(args[0])
return create_shared_task(*args, **kwargs) return create_shared_task(*args, **kwargs)

View File

@@ -1,63 +0,0 @@
# -*- coding: utf-8 -*-
"""
celery.app.abstract
~~~~~~~~~~~~~~~~~~~
Abstract class that takes default attribute values
from the configuration.
"""
from __future__ import absolute_import
class from_config(object):
def __init__(self, key=None):
self.key = key
def get_key(self, attr):
return attr if self.key is None else self.key
class _configurated(type):
def __new__(cls, name, bases, attrs):
attrs['__confopts__'] = dict((attr, spec.get_key(attr))
for attr, spec in attrs.iteritems()
if isinstance(spec, from_config))
inherit_from = attrs.get('inherit_confopts', ())
for subcls in bases:
try:
attrs['__confopts__'].update(subcls.__confopts__)
except AttributeError:
pass
for subcls in inherit_from:
attrs['__confopts__'].update(subcls.__confopts__)
attrs = dict((k, v if not isinstance(v, from_config) else None)
for k, v in attrs.iteritems())
return super(_configurated, cls).__new__(cls, name, bases, attrs)
class configurated(object):
__metaclass__ = _configurated
def setup_defaults(self, kwargs, namespace='celery'):
confopts = self.__confopts__
app, find = self.app, self.app.conf.find_value_for_key
for attr, keyname in confopts.iteritems():
try:
value = kwargs[attr]
except KeyError:
value = find(keyname, namespace)
else:
if value is None:
value = find(keyname, namespace)
setattr(self, attr, value)
for attr_name, attr_value in kwargs.iteritems():
if attr_name not in confopts and attr_value is not None:
setattr(self, attr_name, attr_value)
def confopts_as_dict(self):
return dict((key, getattr(self, key)) for key in self.__confopts__)

View File

@@ -12,20 +12,25 @@ from datetime import timedelta
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from kombu import Connection, Consumer, Exchange, Producer, Queue from kombu import Connection, Consumer, Exchange, Producer, Queue
from kombu.common import entry_to_queue from kombu.common import Broadcast
from kombu.pools import ProducerPool from kombu.pools import ProducerPool
from kombu.utils import cached_property, uuid from kombu.utils import cached_property, uuid
from kombu.utils.encoding import safe_repr from kombu.utils.encoding import safe_repr
from kombu.utils.functional import maybe_list
from celery import signals from celery import signals
from celery.five import items, string_t
from celery.utils.text import indent as textindent from celery.utils.text import indent as textindent
from . import app_or_default from . import app_or_default
from . import routes as _routes from . import routes as _routes
__all__ = ['AMQP', 'Queues', 'TaskProducer', 'TaskConsumer']
#: Human readable queue declaration. #: Human readable queue declaration.
QUEUE_FORMAT = """ QUEUE_FORMAT = """
.> %(name)s exchange:%(exchange)s(%(exchange_type)s) binding:%(routing_key)s .> {0.name:<16} exchange={0.exchange.name}({0.exchange.type}) \
key={0.routing_key}
""" """
@@ -46,15 +51,16 @@ class Queues(dict):
_consume_from = None _consume_from = None
def __init__(self, queues=None, default_exchange=None, def __init__(self, queues=None, default_exchange=None,
create_missing=True, ha_policy=None): create_missing=True, ha_policy=None, autoexchange=None):
dict.__init__(self) dict.__init__(self)
self.aliases = WeakValueDictionary() self.aliases = WeakValueDictionary()
self.default_exchange = default_exchange self.default_exchange = default_exchange
self.create_missing = create_missing self.create_missing = create_missing
self.ha_policy = ha_policy self.ha_policy = ha_policy
self.autoexchange = Exchange if autoexchange is None else autoexchange
if isinstance(queues, (tuple, list)): if isinstance(queues, (tuple, list)):
queues = dict((q.name, q) for q in queues) queues = dict((q.name, q) for q in queues)
for name, q in (queues or {}).iteritems(): for name, q in items(queues or {}):
self.add(q) if isinstance(q, Queue) else self.add_compat(name, **q) self.add(q) if isinstance(q, Queue) else self.add_compat(name, **q)
def __getitem__(self, name): def __getitem__(self, name):
@@ -79,11 +85,16 @@ class Queues(dict):
def add(self, queue, **kwargs): def add(self, queue, **kwargs):
"""Add new queue. """Add new queue.
:param queue: Name of the queue. The first argument can either be a :class:`kombu.Queue` instance,
:keyword exchange: Name of the exchange. or the name of a queue. If the former the rest of the keyword
:keyword routing_key: Binding key. arguments are ignored, and options are simply taken from the queue
:keyword exchange_type: Type of exchange. instance.
:keyword \*\*options: Additional declaration options.
:param queue: :class:`kombu.Queue` instance or name of the queue.
:keyword exchange: (if named) specifies exchange name.
:keyword routing_key: (if named) specifies binding key.
:keyword exchange_type: (if named) specifies type of exchange.
:keyword \*\*options: (if named) Additional declaration options.
""" """
if not isinstance(queue, Queue): if not isinstance(queue, Queue):
@@ -102,7 +113,7 @@ class Queues(dict):
options['routing_key'] = name options['routing_key'] = name
if self.ha_policy is not None: if self.ha_policy is not None:
self._set_ha_policy(options.setdefault('queue_arguments', {})) self._set_ha_policy(options.setdefault('queue_arguments', {}))
q = self[name] = entry_to_queue(name, **options) q = self[name] = Queue.from_dict(name, **options)
return q return q
def _set_ha_policy(self, args): def _set_ha_policy(self, args):
@@ -117,13 +128,8 @@ class Queues(dict):
active = self.consume_from active = self.consume_from
if not active: if not active:
return '' return ''
info = [ info = [QUEUE_FORMAT.strip().format(q)
QUEUE_FORMAT.strip() % { for _, q in sorted(items(active))]
'name': (name + ':').ljust(12),
'exchange': q.exchange.name,
'exchange_type': q.exchange.type,
'routing_key': q.routing_key}
for name, q in sorted(active.iteritems())]
if indent_first: if indent_first:
return textindent('\n'.join(info), indent) return textindent('\n'.join(info), indent)
return info[0] + '\n' + textindent('\n'.join(info[1:]), indent) return info[0] + '\n' + textindent('\n'.join(info[1:]), indent)
@@ -136,23 +142,37 @@ class Queues(dict):
self._consume_from[q.name] = q self._consume_from[q.name] = q
return q return q
def select_subset(self, wanted): def select(self, include):
"""Sets :attr:`consume_from` by selecting a subset of the """Sets :attr:`consume_from` by selecting a subset of the
currently defined queues. currently defined queues.
:param wanted: List of wanted queue names. :param include: Names of queues to consume from.
Can be iterable or string.
""" """
if wanted: if include:
self._consume_from = dict((name, self[name]) for name in wanted) self._consume_from = dict((name, self[name])
for name in maybe_list(include))
select_subset = select # XXX compat
def select_remove(self, queue): def deselect(self, exclude):
if self._consume_from is None: """Deselect queues so that they will not be consumed from.
self.select_subset(k for k in self if k != queue)
else: :param exclude: Names of queues to avoid consuming from.
self._consume_from.pop(queue, None) Can be iterable or string.
"""
if exclude:
exclude = maybe_list(exclude)
if self._consume_from is None:
# using selection
return self.select(k for k in self if k not in exclude)
# using all queues
for queue in exclude:
self._consume_from.pop(queue, None)
select_remove = deselect # XXX compat
def new_missing(self, name): def new_missing(self, name):
return Queue(name, Exchange(name), name) return Queue(name, self.autoexchange(name), name)
@property @property
def consume_from(self): def consume_from(self):
@@ -189,20 +209,30 @@ class TaskProducer(Producer):
queue=None, now=None, retries=0, chord=None, queue=None, now=None, retries=0, chord=None,
callbacks=None, errbacks=None, routing_key=None, callbacks=None, errbacks=None, routing_key=None,
serializer=None, delivery_mode=None, compression=None, serializer=None, delivery_mode=None, compression=None,
declare=None, **kwargs): reply_to=None, time_limit=None, soft_time_limit=None,
declare=None, headers=None,
send_before_publish=signals.before_task_publish.send,
before_receivers=signals.before_task_publish.receivers,
send_after_publish=signals.after_task_publish.send,
after_receivers=signals.after_task_publish.receivers,
send_task_sent=signals.task_sent.send, # XXX deprecated
sent_receivers=signals.task_sent.receivers,
**kwargs):
"""Send task message.""" """Send task message."""
retry = self.retry if retry is None else retry
qname = queue qname = queue
if queue is None and exchange is None: if queue is None and exchange is None:
queue = self.default_queue queue = self.default_queue
if queue is not None: if queue is not None:
if isinstance(queue, basestring): if isinstance(queue, string_t):
qname, queue = queue, self.queues[queue] qname, queue = queue, self.queues[queue]
else: else:
qname = queue.name qname = queue.name
exchange = exchange or queue.exchange.name exchange = exchange or queue.exchange.name
routing_key = routing_key or queue.routing_key routing_key = routing_key or queue.routing_key
declare = declare or ([queue] if queue else []) if declare is None and queue and not isinstance(queue, Broadcast):
declare = [queue]
# merge default and custom policy # merge default and custom policy
retry = self.retry if retry is None else retry retry = self.retry if retry is None else retry
@@ -218,9 +248,13 @@ class TaskProducer(Producer):
if countdown: # Convert countdown to ETA. if countdown: # Convert countdown to ETA.
now = now or self.app.now() now = now or self.app.now()
eta = now + timedelta(seconds=countdown) eta = now + timedelta(seconds=countdown)
if self.utc:
eta = eta.replace(tzinfo=self.app.timezone)
if isinstance(expires, (int, float)): if isinstance(expires, (int, float)):
now = now or self.app.now() now = now or self.app.now()
expires = now + timedelta(seconds=expires) expires = now + timedelta(seconds=expires)
if self.utc:
expires = expires.replace(tzinfo=self.app.timezone)
eta = eta and eta.isoformat() eta = eta and eta.isoformat()
expires = expires and expires.isoformat() expires = expires and expires.isoformat()
@@ -235,21 +269,44 @@ class TaskProducer(Producer):
'utc': self.utc, 'utc': self.utc,
'callbacks': callbacks, 'callbacks': callbacks,
'errbacks': errbacks, 'errbacks': errbacks,
'timelimit': (time_limit, soft_time_limit),
'taskset': group_id or taskset_id, 'taskset': group_id or taskset_id,
'chord': chord, 'chord': chord,
} }
if before_receivers:
send_before_publish(
sender=task_name, body=body,
exchange=exchange,
routing_key=routing_key,
declare=declare,
headers=headers,
properties=kwargs,
retry_policy=retry_policy,
)
self.publish( self.publish(
body, body,
exchange=exchange, routing_key=routing_key, exchange=exchange, routing_key=routing_key,
serializer=serializer or self.serializer, serializer=serializer or self.serializer,
compression=compression or self.compression, compression=compression or self.compression,
headers=headers,
retry=retry, retry_policy=_rp, retry=retry, retry_policy=_rp,
reply_to=reply_to,
correlation_id=task_id,
delivery_mode=delivery_mode, declare=declare, delivery_mode=delivery_mode, declare=declare,
**kwargs **kwargs
) )
signals.task_sent.send(sender=task_name, **body) if after_receivers:
send_after_publish(sender=task_name, body=body,
exchange=exchange, routing_key=routing_key)
if sent_receivers: # XXX deprecated
send_task_sent(sender=task_name, task_id=task_id,
task=task_name, args=task_args,
kwargs=task_kwargs, eta=eta,
taskset=group_id or taskset_id)
if self.send_sent_event: if self.send_sent_event:
evd = event_dispatcher or self.event_dispatcher evd = event_dispatcher or self.event_dispatcher
exname = exchange or self.exchange exname = exchange or self.exchange
@@ -306,7 +363,7 @@ class TaskConsumer(Consumer):
accept = self.app.conf.CELERY_ACCEPT_CONTENT accept = self.app.conf.CELERY_ACCEPT_CONTENT
super(TaskConsumer, self).__init__( super(TaskConsumer, self).__init__(
channel, channel,
queues or self.app.amqp.queues.consume_from.values(), queues or list(self.app.amqp.queues.consume_from.values()),
accept=accept, accept=accept,
**kw **kw
) )
@@ -329,13 +386,20 @@ class AMQP(object):
#: set by the :attr:`producer_pool`. #: set by the :attr:`producer_pool`.
_producer_pool = None _producer_pool = None
# Exchange class/function used when defining automatic queues.
# E.g. you can use ``autoexchange = lambda n: None`` to use the
# amqp default exchange, which is a shortcut to bypass routing
# and instead send directly to the queue named in the routing key.
autoexchange = None
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
def flush_routes(self): def flush_routes(self):
self._rtable = _routes.prepare(self.app.conf.CELERY_ROUTES) self._rtable = _routes.prepare(self.app.conf.CELERY_ROUTES)
def Queues(self, queues, create_missing=None, ha_policy=None): def Queues(self, queues, create_missing=None, ha_policy=None,
autoexchange=None):
"""Create new :class:`Queues` instance, using queue defaults """Create new :class:`Queues` instance, using queue defaults
from the current configuration.""" from the current configuration."""
conf = self.app.conf conf = self.app.conf
@@ -347,10 +411,15 @@ class AMQP(object):
queues = (Queue(conf.CELERY_DEFAULT_QUEUE, queues = (Queue(conf.CELERY_DEFAULT_QUEUE,
exchange=self.default_exchange, exchange=self.default_exchange,
routing_key=conf.CELERY_DEFAULT_ROUTING_KEY), ) routing_key=conf.CELERY_DEFAULT_ROUTING_KEY), )
return Queues(queues, self.default_exchange, create_missing, ha_policy) autoexchange = (self.autoexchange if autoexchange is None
else autoexchange)
return Queues(
queues, self.default_exchange, create_missing,
ha_policy, autoexchange,
)
def Router(self, queues=None, create_missing=None): def Router(self, queues=None, create_missing=None):
"""Returns the current task router.""" """Return the current task router."""
return _routes.Router(self.routes, queues or self.queues, return _routes.Router(self.routes, queues or self.queues,
self.app.either('CELERY_CREATE_MISSING_QUEUES', self.app.either('CELERY_CREATE_MISSING_QUEUES',
create_missing), app=self.app) create_missing), app=self.app)
@@ -365,7 +434,7 @@ class AMQP(object):
@cached_property @cached_property
def TaskProducer(self): def TaskProducer(self):
"""Returns publisher used to send tasks. """Return publisher used to send tasks.
You should use `app.send_task` instead. You should use `app.send_task` instead.

View File

@@ -12,15 +12,14 @@
""" """
from __future__ import absolute_import from __future__ import absolute_import
from celery.utils.functional import firstmethod, mpromise from celery.five import string_t
from celery.utils.functional import firstmethod, mlazy
from celery.utils.imports import instantiate from celery.utils.imports import instantiate
_first_match = firstmethod('annotate') _first_match = firstmethod('annotate')
_first_match_any = firstmethod('annotate_any') _first_match_any = firstmethod('annotate_any')
__all__ = ['MapAnnotation', 'prepare', 'resolve_all']
def resolve_all(anno, task):
return (r for r in (_first_match(anno, task), _first_match_any(anno)) if r)
class MapAnnotation(dict): class MapAnnotation(dict):
@@ -44,8 +43,8 @@ def prepare(annotations):
def expand_annotation(annotation): def expand_annotation(annotation):
if isinstance(annotation, dict): if isinstance(annotation, dict):
return MapAnnotation(annotation) return MapAnnotation(annotation)
elif isinstance(annotation, basestring): elif isinstance(annotation, string_t):
return mpromise(instantiate, annotation) return mlazy(instantiate, annotation)
return annotation return annotation
if annotations is None: if annotations is None:
@@ -53,3 +52,7 @@ def prepare(annotations):
elif not isinstance(annotations, (list, tuple)): elif not isinstance(annotations, (list, tuple)):
annotations = (annotations, ) annotations = (annotations, )
return [expand_annotation(anno) for anno in annotations] return [expand_annotation(anno) for anno in annotations]
def resolve_all(anno, task):
return (x for x in (_first_match(anno, task), _first_match_any(anno)) if x)

View File

@@ -7,36 +7,59 @@
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement
import os import os
import threading import threading
import warnings import warnings
from collections import deque from collections import Callable, defaultdict, deque
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from functools import wraps from operator import attrgetter
from billiard.util import register_after_fork from billiard.util import register_after_fork
from kombu.clocks import LamportClock from kombu.clocks import LamportClock
from kombu.utils import cached_property from kombu.common import oid_from
from kombu.utils import cached_property, uuid
from celery import platforms from celery import platforms
from celery._state import (
_task_stack, _tls, get_current_app, _register_app, get_current_worker_task,
)
from celery.exceptions import AlwaysEagerIgnored, ImproperlyConfigured from celery.exceptions import AlwaysEagerIgnored, ImproperlyConfigured
from celery.five import items, values
from celery.loaders import get_loader_cls from celery.loaders import get_loader_cls
from celery.local import PromiseProxy, maybe_evaluate from celery.local import PromiseProxy, maybe_evaluate
from celery._state import _task_stack, _tls, get_current_app, _register_app
from celery.utils.functional import first, maybe_list from celery.utils.functional import first, maybe_list
from celery.utils.imports import instantiate, symbol_by_name from celery.utils.imports import instantiate, symbol_by_name
from celery.utils.log import ensure_process_aware_logger
from celery.utils.objects import mro_lookup
from .annotations import prepare as prepare_annotations from .annotations import prepare as prepare_annotations
from .builtins import shared_task, load_shared_tasks from .builtins import shared_task, load_shared_tasks
from .defaults import DEFAULTS, find_deprecated_settings from .defaults import DEFAULTS, find_deprecated_settings
from .registry import TaskRegistry from .registry import TaskRegistry
from .utils import AppPickler, Settings, bugreport, _unpickle_app from .utils import (
AppPickler, Settings, bugreport, _unpickle_app, _unpickle_app_v2, appstr,
)
__all__ = ['Celery']
_EXECV = os.environ.get('FORKED_BY_MULTIPROCESSING') _EXECV = os.environ.get('FORKED_BY_MULTIPROCESSING')
BUILTIN_FIXUPS = frozenset([
'celery.fixups.django:fixup',
])
ERR_ENVVAR_NOT_SET = """\
The environment variable {0!r} is not set,
and as such the configuration could not be loaded.
Please set this variable and make it point to
a configuration module."""
def app_has_custom(app, attr):
return mro_lookup(app.__class__, attr, stop=(Celery, object),
monkey_patched=[__name__])
def _unpickle_appattr(reverse_name, args): def _unpickle_appattr(reverse_name, args):
@@ -46,6 +69,7 @@ def _unpickle_appattr(reverse_name, args):
class Celery(object): class Celery(object):
#: This is deprecated, use :meth:`reduce_keys` instead
Pickler = AppPickler Pickler = AppPickler
SYSTEM = platforms.SYSTEM SYSTEM = platforms.SYSTEM
@@ -57,6 +81,7 @@ class Celery(object):
loader_cls = 'celery.loaders.app:AppLoader' loader_cls = 'celery.loaders.app:AppLoader'
log_cls = 'celery.app.log:Logging' log_cls = 'celery.app.log:Logging'
control_cls = 'celery.app.control:Control' control_cls = 'celery.app.control:Control'
task_cls = 'celery.app.task:Task'
registry_cls = TaskRegistry registry_cls = TaskRegistry
_pool = None _pool = None
@@ -64,8 +89,7 @@ class Celery(object):
amqp=None, events=None, log=None, control=None, amqp=None, events=None, log=None, control=None,
set_as_current=True, accept_magic_kwargs=False, set_as_current=True, accept_magic_kwargs=False,
tasks=None, broker=None, include=None, changes=None, tasks=None, broker=None, include=None, changes=None,
config_source=None, config_source=None, fixups=None, task_cls=None, **kwargs):
**kwargs):
self.clock = LamportClock() self.clock = LamportClock()
self.main = main self.main = main
self.amqp_cls = amqp or self.amqp_cls self.amqp_cls = amqp or self.amqp_cls
@@ -74,10 +98,13 @@ class Celery(object):
self.loader_cls = loader or self.loader_cls self.loader_cls = loader or self.loader_cls
self.log_cls = log or self.log_cls self.log_cls = log or self.log_cls
self.control_cls = control or self.control_cls self.control_cls = control or self.control_cls
self.task_cls = task_cls or self.task_cls
self.set_as_current = set_as_current self.set_as_current = set_as_current
self.registry_cls = symbol_by_name(self.registry_cls) self.registry_cls = symbol_by_name(self.registry_cls)
self.accept_magic_kwargs = accept_magic_kwargs self.accept_magic_kwargs = accept_magic_kwargs
self.user_options = defaultdict(set)
self._config_source = config_source self._config_source = config_source
self.steps = defaultdict(set)
self.configured = False self.configured = False
self._pending_defaults = deque() self._pending_defaults = deque()
@@ -89,6 +116,11 @@ class Celery(object):
if not isinstance(self._tasks, TaskRegistry): if not isinstance(self._tasks, TaskRegistry):
self._tasks = TaskRegistry(self._tasks or {}) self._tasks = TaskRegistry(self._tasks or {})
# If the class defins a custom __reduce_args__ we need to use
# the old way of pickling apps, which is pickling a list of
# args instead of the new way that pickles a dict of keywords.
self._using_v1_reduce = app_has_custom(self, '__reduce_args__')
# these options are moved to the config to # these options are moved to the config to
# simplify pickling of the app object. # simplify pickling of the app object.
self._preconf = changes or {} self._preconf = changes or {}
@@ -97,6 +129,11 @@ class Celery(object):
if include: if include:
self._preconf['CELERY_IMPORTS'] = include self._preconf['CELERY_IMPORTS'] = include
# Apply fixups.
self.fixups = set(fixups or ())
for fixup in self.fixups | BUILTIN_FIXUPS:
symbol_by_name(fixup)(self)
if self.set_as_current: if self.set_as_current:
self.set_current() self.set_current()
@@ -133,7 +170,7 @@ class Celery(object):
def worker_main(self, argv=None): def worker_main(self, argv=None):
return instantiate( return instantiate(
'celery.bin.celeryd:WorkerCommand', 'celery.bin.worker:worker',
app=self).execute_from_commandline(argv) app=self).execute_from_commandline(argv)
def task(self, *args, **opts): def task(self, *args, **opts):
@@ -145,8 +182,7 @@ class Celery(object):
# the task instance from the current app. # the task instance from the current app.
# Really need a better solution for this :( # Really need a better solution for this :(
from . import shared_task as proxies_to_curapp from . import shared_task as proxies_to_curapp
opts['_force_evaluate'] = True # XXX Py2.5 return proxies_to_curapp(*args, _force_evaluate=True, **opts)
return proxies_to_curapp(*args, **opts)
def inner_create_task_cls(shared=True, filter=None, **opts): def inner_create_task_cls(shared=True, filter=None, **opts):
_filt = filter # stupid 2to3 _filt = filter # stupid 2to3
@@ -162,16 +198,20 @@ class Celery(object):
task = filter(task) task = filter(task)
return task return task
# return a proxy object that is only evaluated when first used if self.finalized or opts.get('_force_evaluate'):
promise = PromiseProxy(self._task_from_fun, (fun, ), opts) ret = self._task_from_fun(fun, **opts)
self._pending.append(promise) else:
# return a proxy object that evaluates on first use
ret = PromiseProxy(self._task_from_fun, (fun, ), opts,
__doc__=fun.__doc__)
self._pending.append(ret)
if _filt: if _filt:
return _filt(promise) return _filt(ret)
return promise return ret
return _create_task_cls return _create_task_cls
if len(args) == 1 and callable(args[0]): if len(args) == 1 and isinstance(args[0], Callable):
return inner_create_task_cls(**opts)(*args) return inner_create_task_cls(**opts)(*args)
if args: if args:
raise TypeError( raise TypeError(
@@ -180,15 +220,16 @@ class Celery(object):
def _task_from_fun(self, fun, **options): def _task_from_fun(self, fun, **options):
base = options.pop('base', None) or self.Task base = options.pop('base', None) or self.Task
bind = options.pop('bind', False)
T = type(fun.__name__, (base, ), dict({ T = type(fun.__name__, (base, ), dict({
'app': self, 'app': self,
'accept_magic_kwargs': False, 'accept_magic_kwargs': False,
'run': staticmethod(fun), 'run': fun if bind else staticmethod(fun),
'_decorated': True,
'__doc__': fun.__doc__, '__doc__': fun.__doc__,
'__module__': fun.__module__}, **options))() '__module__': fun.__module__}, **options))()
task = self._tasks[T.name] # return global instance. task = self._tasks[T.name] # return global instance.
task.bind(self)
return task return task
def finalize(self): def finalize(self):
@@ -201,11 +242,11 @@ class Celery(object):
while pending: while pending:
maybe_evaluate(pending.popleft()) maybe_evaluate(pending.popleft())
for task in self._tasks.itervalues(): for task in values(self._tasks):
task.bind(self) task.bind(self)
def add_defaults(self, fun): def add_defaults(self, fun):
if not callable(fun): if not isinstance(fun, Callable):
d, fun = fun, lambda: d d, fun = fun, lambda: d
if self.configured: if self.configured:
return self.conf.add_defaults(fun()) return self.conf.add_defaults(fun())
@@ -221,56 +262,83 @@ class Celery(object):
if not module_name: if not module_name:
if silent: if silent:
return False return False
raise ImproperlyConfigured(self.error_envvar_not_set % module_name) raise ImproperlyConfigured(ERR_ENVVAR_NOT_SET.format(module_name))
return self.config_from_object(module_name, silent=silent) return self.config_from_object(module_name, silent=silent)
def config_from_cmdline(self, argv, namespace='celery'): def config_from_cmdline(self, argv, namespace='celery'):
self.conf.update(self.loader.cmdline_config_parser(argv, namespace)) self.conf.update(self.loader.cmdline_config_parser(argv, namespace))
def setup_security(self, allowed_serializers=None, key=None, cert=None,
store=None, digest='sha1', serializer='json'):
from celery.security import setup_security
return setup_security(allowed_serializers, key, cert,
store, digest, serializer, app=self)
def autodiscover_tasks(self, packages, related_name='tasks'):
if self.conf.CELERY_FORCE_BILLIARD_LOGGING:
# we'll use billiard's processName instead of
# multiprocessing's one in all the loggers
# created after this call
ensure_process_aware_logger()
self.loader.autodiscover_tasks(packages, related_name)
def send_task(self, name, args=None, kwargs=None, countdown=None, def send_task(self, name, args=None, kwargs=None, countdown=None,
eta=None, task_id=None, producer=None, connection=None, eta=None, task_id=None, producer=None, connection=None,
result_cls=None, expires=None, queues=None, publisher=None, router=None, result_cls=None, expires=None,
link=None, link_error=None, publisher=None, link=None, link_error=None,
**options): add_to_parent=True, reply_to=None, **options):
task_id = task_id or uuid()
producer = producer or publisher # XXX compat producer = producer or publisher # XXX compat
if self.conf.CELERY_ALWAYS_EAGER: # pragma: no cover router = router or self.amqp.router
conf = self.conf
if conf.CELERY_ALWAYS_EAGER: # pragma: no cover
warnings.warn(AlwaysEagerIgnored( warnings.warn(AlwaysEagerIgnored(
'CELERY_ALWAYS_EAGER has no effect on send_task')) 'CELERY_ALWAYS_EAGER has no effect on send_task'))
result_cls = result_cls or self.AsyncResult
router = self.amqp.Router(queues)
options.setdefault('compression',
self.conf.CELERY_MESSAGE_COMPRESSION)
options = router.route(options, name, args, kwargs) options = router.route(options, name, args, kwargs)
with self.producer_or_acquire(producer) as producer: if connection:
return result_cls(producer.publish_task( producer = self.amqp.TaskProducer(connection)
name, args, kwargs, with self.producer_or_acquire(producer) as P:
task_id=task_id, self.backend.on_task_call(P, task_id)
countdown=countdown, eta=eta, task_id = P.publish_task(
callbacks=maybe_list(link), name, args, kwargs, countdown=countdown, eta=eta,
errbacks=maybe_list(link_error), task_id=task_id, expires=expires,
expires=expires, **options callbacks=maybe_list(link), errbacks=maybe_list(link_error),
)) reply_to=reply_to or self.oid, **options
)
result = (result_cls or self.AsyncResult)(task_id)
if add_to_parent:
parent = get_current_worker_task()
if parent:
parent.add_trail(result)
return result
def connection(self, hostname=None, userid=None, def connection(self, hostname=None, userid=None, password=None,
password=None, virtual_host=None, port=None, ssl=None, virtual_host=None, port=None, ssl=None,
insist=None, connect_timeout=None, transport=None, connect_timeout=None, transport=None,
transport_options=None, heartbeat=None, **kwargs): transport_options=None, heartbeat=None,
login_method=None, failover_strategy=None, **kwargs):
conf = self.conf conf = self.conf
return self.amqp.Connection( return self.amqp.Connection(
hostname or conf.BROKER_HOST, hostname or conf.BROKER_URL,
userid or conf.BROKER_USER, userid or conf.BROKER_USER,
password or conf.BROKER_PASSWORD, password or conf.BROKER_PASSWORD,
virtual_host or conf.BROKER_VHOST, virtual_host or conf.BROKER_VHOST,
port or conf.BROKER_PORT, port or conf.BROKER_PORT,
transport=transport or conf.BROKER_TRANSPORT, transport=transport or conf.BROKER_TRANSPORT,
insist=self.either('BROKER_INSIST', insist),
ssl=self.either('BROKER_USE_SSL', ssl), ssl=self.either('BROKER_USE_SSL', ssl),
connect_timeout=self.either(
'BROKER_CONNECTION_TIMEOUT', connect_timeout),
heartbeat=heartbeat, heartbeat=heartbeat,
transport_options=dict(conf.BROKER_TRANSPORT_OPTIONS, login_method=login_method or conf.BROKER_LOGIN_METHOD,
**transport_options or {})) failover_strategy=(
failover_strategy or conf.BROKER_FAILOVER_STRATEGY
),
transport_options=dict(
conf.BROKER_TRANSPORT_OPTIONS, **transport_options or {}
),
connect_timeout=self.either(
'BROKER_CONNECTION_TIMEOUT', connect_timeout
),
)
broker_connection = connection broker_connection = connection
@contextmanager @contextmanager
@@ -296,26 +364,6 @@ class Celery(object):
yield producer yield producer
default_producer = producer_or_acquire # XXX compat default_producer = producer_or_acquire # XXX compat
def with_default_connection(self, fun):
"""With any function accepting a `connection`
keyword argument, establishes a default connection if one is
not already passed to it.
Any automatically established connection will be closed after
the function returns.
**Deprecated**
Use ``with app.connection_or_acquire(connection)`` instead.
"""
@wraps(fun)
def _inner(*args, **kwargs):
connection = kwargs.pop('connection', None)
with self.connection_or_acquire(connection) as c:
return fun(*args, **dict(kwargs, connection=c))
return _inner
def prepare_config(self, c): def prepare_config(self, c):
"""Prepare configuration before it is merged with the defaults.""" """Prepare configuration before it is merged with the defaults."""
return find_deprecated_settings(c) return find_deprecated_settings(c)
@@ -339,7 +387,7 @@ class Celery(object):
) )
def select_queues(self, queues=None): def select_queues(self, queues=None):
return self.amqp.queues.select_subset(queues) return self.amqp.queues.select(queues)
def either(self, default_key, *values): def either(self, default_key, *values):
"""Fallback to the value of a configuration key if none of the """Fallback to the value of a configuration key if none of the
@@ -356,7 +404,12 @@ class Celery(object):
self.loader) self.loader)
return backend(app=self, url=url) return backend(app=self, url=url)
def on_configure(self):
"""Callback calld when the app loads configuration"""
pass
def _get_config(self): def _get_config(self):
self.on_configure()
self.configured = True self.configured = True
s = Settings({}, [self.prepare_config(self.loader.conf), s = Settings({}, [self.prepare_config(self.loader.conf),
deepcopy(DEFAULTS)]) deepcopy(DEFAULTS)])
@@ -364,9 +417,9 @@ class Celery(object):
# load lazy config dict initializers. # load lazy config dict initializers.
pending = self._pending_defaults pending = self._pending_defaults
while pending: while pending:
s.add_defaults(pending.popleft()()) s.add_defaults(maybe_evaluate(pending.popleft()()))
if self._preconf: if self._preconf:
for key, value in self._preconf.iteritems(): for key, value in items(self._preconf):
setattr(s, key, value) setattr(s, key, value)
return s return s
@@ -382,14 +435,20 @@ class Celery(object):
amqp._producer_pool.force_close_all() amqp._producer_pool.force_close_all()
amqp._producer_pool = None amqp._producer_pool = None
def signature(self, *args, **kwargs):
kwargs['app'] = self
return self.canvas.signature(*args, **kwargs)
def create_task_cls(self): def create_task_cls(self):
"""Creates a base task class using default configuration """Creates a base task class using default configuration
taken from this app.""" taken from this app."""
return self.subclass_with_self('celery.app.task:Task', name='Task', return self.subclass_with_self(
attribute='_app', abstract=True) self.task_cls, name='Task', attribute='_app',
keep_reduce=True, abstract=True,
)
def subclass_with_self(self, Class, name=None, attribute='app', def subclass_with_self(self, Class, name=None, attribute='app',
reverse=None, **kw): reverse=None, keep_reduce=False, **kw):
"""Subclass an app-compatible class by setting its app attribute """Subclass an app-compatible class by setting its app attribute
to be this app instance. to be this app instance.
@@ -410,18 +469,24 @@ class Celery(object):
return _unpickle_appattr, (reverse, self.__reduce_args__()) return _unpickle_appattr, (reverse, self.__reduce_args__())
attrs = dict({attribute: self}, __module__=Class.__module__, attrs = dict({attribute: self}, __module__=Class.__module__,
__doc__=Class.__doc__, __reduce__=__reduce__, **kw) __doc__=Class.__doc__, **kw)
if not keep_reduce:
attrs['__reduce__'] = __reduce__
return type(name or Class.__name__, (Class, ), attrs) return type(name or Class.__name__, (Class, ), attrs)
def _rgetattr(self, path): def _rgetattr(self, path):
return reduce(getattr, [self] + path.split('.')) return attrgetter(path)(self)
def __repr__(self): def __repr__(self):
return '<%s %s:0x%x>' % (self.__class__.__name__, return '<{0} {1}>'.format(type(self).__name__, appstr(self))
self.main or '__main__', id(self), )
def __reduce__(self): def __reduce__(self):
if self._using_v1_reduce:
return self.__reduce_v1__()
return (_unpickle_app_v2, (self.__class__, self.__reduce_keys__()))
def __reduce_v1__(self):
# Reduce only pickles the configuration changes, # Reduce only pickles the configuration changes,
# so the default configuration doesn't have to be passed # so the default configuration doesn't have to be passed
# between processes. # between processes.
@@ -430,11 +495,30 @@ class Celery(object):
(self.__class__, self.Pickler) + self.__reduce_args__(), (self.__class__, self.Pickler) + self.__reduce_args__(),
) )
def __reduce_keys__(self):
"""Return keyword arguments used to reconstruct the object
when unpickling."""
return {
'main': self.main,
'changes': self.conf.changes,
'loader': self.loader_cls,
'backend': self.backend_cls,
'amqp': self.amqp_cls,
'events': self.events_cls,
'log': self.log_cls,
'control': self.control_cls,
'accept_magic_kwargs': self.accept_magic_kwargs,
'fixups': self.fixups,
'config_source': self._config_source,
'task_cls': self.task_cls,
}
def __reduce_args__(self): def __reduce_args__(self):
return (self.main, self.conf.changes, self.loader_cls, """Deprecated method, please use :meth:`__reduce_keys__` instead."""
self.backend_cls, self.amqp_cls, self.events_cls, return (self.main, self.conf.changes,
self.log_cls, self.control_cls, self.accept_magic_kwargs, self.loader_cls, self.backend_cls, self.amqp_cls,
self._config_source) self.events_cls, self.log_cls, self.control_cls,
self.accept_magic_kwargs, self._config_source)
@cached_property @cached_property
def Worker(self): def Worker(self):
@@ -448,10 +532,6 @@ class Celery(object):
def Beat(self, **kwargs): def Beat(self, **kwargs):
return self.subclass_with_self('celery.apps.beat:Beat') return self.subclass_with_self('celery.apps.beat:Beat')
@cached_property
def TaskSet(self):
return self.subclass_with_self('celery.task.sets:TaskSet')
@cached_property @cached_property
def Task(self): def Task(self):
return self.create_task_cls() return self.create_task_cls()
@@ -464,12 +544,22 @@ class Celery(object):
def AsyncResult(self): def AsyncResult(self):
return self.subclass_with_self('celery.result:AsyncResult') return self.subclass_with_self('celery.result:AsyncResult')
@cached_property
def ResultSet(self):
return self.subclass_with_self('celery.result:ResultSet')
@cached_property @cached_property
def GroupResult(self): def GroupResult(self):
return self.subclass_with_self('celery.result:GroupResult') return self.subclass_with_self('celery.result:GroupResult')
@cached_property
def TaskSet(self): # XXX compat
"""Deprecated! Please use :class:`celery.group` instead."""
return self.subclass_with_self('celery.task.sets:TaskSet')
@cached_property @cached_property
def TaskSetResult(self): # XXX compat def TaskSetResult(self): # XXX compat
"""Deprecated! Please use :attr:`GroupResult` instead."""
return self.subclass_with_self('celery.result:TaskSetResult') return self.subclass_with_self('celery.result:TaskSetResult')
@property @property
@@ -484,6 +574,10 @@ class Celery(object):
def current_task(self): def current_task(self):
return _task_stack.top return _task_stack.top
@cached_property
def oid(self):
return oid_from(self)
@cached_property @cached_property
def amqp(self): def amqp(self):
return instantiate(self.amqp_cls, app=self) return instantiate(self.amqp_cls, app=self)
@@ -512,8 +606,23 @@ class Celery(object):
def log(self): def log(self):
return instantiate(self.log_cls, app=self) return instantiate(self.log_cls, app=self)
@cached_property
def canvas(self):
from celery import canvas
return canvas
@cached_property @cached_property
def tasks(self): def tasks(self):
self.finalize() self.finalize()
return self._tasks return self._tasks
@cached_property
def timezone(self):
from celery.utils.timeutils import timezone
conf = self.conf
tz = conf.CELERY_TIMEZONE
if not tz:
return (timezone.get_timezone('UTC') if conf.CELERY_ENABLE_UTC
else timezone.local)
return timezone.get_timezone(self.conf.CELERY_TIMEZONE)
App = Celery # compat App = Celery # compat

View File

@@ -8,32 +8,35 @@
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement
from collections import deque from collections import deque
from celery._state import get_current_worker_task from celery._state import get_current_worker_task
from celery.utils import uuid from celery.utils import uuid
__all__ = ['shared_task', 'load_shared_tasks']
#: global list of functions defining tasks that should be #: global list of functions defining tasks that should be
#: added to all apps. #: added to all apps.
_shared_tasks = [] _shared_tasks = set()
def shared_task(constructor): def shared_task(constructor):
"""Decorator that specifies that the decorated function is a function """Decorator that specifies a function that generates a built-in task.
that generates a built-in task.
The function will then be called for every new app instance created The function will then be called for every new app instance created
(lazily, so more exactly when the task registry for that app is needed). (lazily, so more exactly when the task registry for that app is needed).
The function must take a single ``app`` argument.
""" """
_shared_tasks.append(constructor) _shared_tasks.add(constructor)
return constructor return constructor
def load_shared_tasks(app): def load_shared_tasks(app):
"""Loads the built-in tasks for an app instance.""" """Create built-in tasks for an app instance."""
for constructor in _shared_tasks: constructors = set(_shared_tasks)
for constructor in constructors:
constructor(app) constructor(app)
@@ -42,17 +45,13 @@ def add_backend_cleanup_task(app):
"""The backend cleanup task can be used to clean up the default result """The backend cleanup task can be used to clean up the default result
backend. backend.
This task is also added do the periodic task schedule so that it is If the configured backend requires periodic cleanup this task is also
run every day at midnight, but :program:`celerybeat` must be running automatically configured to run every day at midnight (requires
for this to be effective. :program:`celery beat` to be running).
Note that not all backends do anything for this, what needs to be
done at cleanup is up to each backend, and some backends
may even clean up in realtime so that a periodic cleanup is not necessary.
""" """
@app.task(name='celery.backend_cleanup',
@app.task(name='celery.backend_cleanup', _force_evaluate=True) shared=False, _force_evaluate=True)
def backend_cleanup(): def backend_cleanup():
app.backend.cleanup() app.backend.cleanup()
return backend_cleanup return backend_cleanup
@@ -60,58 +59,62 @@ def add_backend_cleanup_task(app):
@shared_task @shared_task
def add_unlock_chord_task(app): def add_unlock_chord_task(app):
"""The unlock chord task is used by result backends that doesn't """This task is used by result backends without native chord support.
have native chord support.
It creates a task chain polling the header for completion. It joins chords by creating a task chain polling the header for completion.
""" """
from celery.canvas import subtask from celery.canvas import signature
from celery.exceptions import ChordError from celery.exceptions import ChordError
from celery.result import from_serializable from celery.result import result_from_tuple
default_propagate = app.conf.CELERY_CHORD_PROPAGATES default_propagate = app.conf.CELERY_CHORD_PROPAGATES
@app.task(name='celery.chord_unlock', max_retries=None, @app.task(name='celery.chord_unlock', max_retries=None, shared=False,
default_retry_delay=1, ignore_result=True, _force_evaluate=True) default_retry_delay=1, ignore_result=True, _force_evaluate=True)
def unlock_chord(group_id, callback, interval=None, propagate=None, def unlock_chord(group_id, callback, interval=None, propagate=None,
max_retries=None, result=None, max_retries=None, result=None,
Result=app.AsyncResult, GroupResult=app.GroupResult, Result=app.AsyncResult, GroupResult=app.GroupResult,
from_serializable=from_serializable): result_from_tuple=result_from_tuple):
# if propagate is disabled exceptions raised by chord tasks # if propagate is disabled exceptions raised by chord tasks
# will be sent as part of the result list to the chord callback. # will be sent as part of the result list to the chord callback.
# Since 3.1 propagate will be enabled by default, and instead # Since 3.1 propagate will be enabled by default, and instead
# the chord callback changes state to FAILURE with the # the chord callback changes state to FAILURE with the
# exception set to ChordError. # exception set to ChordError.
propagate = default_propagate if propagate is None else propagate propagate = default_propagate if propagate is None else propagate
if interval is None:
interval = unlock_chord.default_retry_delay
# check if the task group is ready, and if so apply the callback. # check if the task group is ready, and if so apply the callback.
deps = GroupResult( deps = GroupResult(
group_id, group_id,
[from_serializable(r, app=app) for r in result], [result_from_tuple(r, app=app) for r in result],
) )
j = deps.join_native if deps.supports_native_join else deps.join j = deps.join_native if deps.supports_native_join else deps.join
if deps.ready(): if deps.ready():
callback = subtask(callback) callback = signature(callback, app=app)
try: try:
ret = j(propagate=propagate) ret = j(propagate=propagate)
except Exception, exc: except Exception as exc:
try: try:
culprit = deps._failed_join_report().next() culprit = next(deps._failed_join_report())
reason = 'Dependency %s raised %r' % (culprit.id, exc) reason = 'Dependency {0.id} raised {1!r}'.format(
culprit, exc,
)
except StopIteration: except StopIteration:
reason = repr(exc) reason = repr(exc)
app._tasks[callback.task].backend.fail_from_current_stack( app._tasks[callback.task].backend.fail_from_current_stack(
callback.id, exc=ChordError(reason), callback.id, exc=ChordError(reason),
) )
else: else:
try: try:
callback.delay(ret) callback.delay(ret)
except Exception, exc: except Exception as exc:
app._tasks[callback.task].backend.fail_from_current_stack( app._tasks[callback.task].backend.fail_from_current_stack(
callback.id, callback.id,
exc=ChordError('Callback error: %r' % (exc, )), exc=ChordError('Callback error: {0!r}'.format(exc)),
) )
else: else:
return unlock_chord.retry(countdown=interval, return unlock_chord.retry(countdown=interval,
@@ -121,23 +124,23 @@ def add_unlock_chord_task(app):
@shared_task @shared_task
def add_map_task(app): def add_map_task(app):
from celery.canvas import subtask from celery.canvas import signature
@app.task(name='celery.map', _force_evaluate=True) @app.task(name='celery.map', shared=False, _force_evaluate=True)
def xmap(task, it): def xmap(task, it):
task = subtask(task).type task = signature(task, app=app).type
return [task(value) for value in it] return [task(item) for item in it]
return xmap return xmap
@shared_task @shared_task
def add_starmap_task(app): def add_starmap_task(app):
from celery.canvas import subtask from celery.canvas import signature
@app.task(name='celery.starmap', _force_evaluate=True) @app.task(name='celery.starmap', shared=False, _force_evaluate=True)
def xstarmap(task, it): def xstarmap(task, it):
task = subtask(task).type task = signature(task, app=app).type
return [task(*args) for args in it] return [task(*item) for item in it]
return xstarmap return xstarmap
@@ -145,7 +148,7 @@ def add_starmap_task(app):
def add_chunk_task(app): def add_chunk_task(app):
from celery.canvas import chunks as _chunks from celery.canvas import chunks as _chunks
@app.task(name='celery.chunks', _force_evaluate=True) @app.task(name='celery.chunks', shared=False, _force_evaluate=True)
def chunks(task, it, n): def chunks(task, it, n):
return _chunks.apply_chunks(task, it, n) return _chunks.apply_chunks(task, it, n)
return chunks return chunks
@@ -154,19 +157,20 @@ def add_chunk_task(app):
@shared_task @shared_task
def add_group_task(app): def add_group_task(app):
_app = app _app = app
from celery.canvas import maybe_subtask, subtask from celery.canvas import maybe_signature, signature
from celery.result import from_serializable from celery.result import result_from_tuple
class Group(app.Task): class Group(app.Task):
app = _app app = _app
name = 'celery.group' name = 'celery.group'
accept_magic_kwargs = False accept_magic_kwargs = False
_decorated = True
def run(self, tasks, result, group_id, partial_args): def run(self, tasks, result, group_id, partial_args):
app = self.app app = self.app
result = from_serializable(result, app) result = result_from_tuple(result, app)
# any partial args are added to all tasks in the group # any partial args are added to all tasks in the group
taskit = (subtask(task).clone(partial_args) taskit = (signature(task, app=app).clone(partial_args)
for i, task in enumerate(tasks)) for i, task in enumerate(tasks))
if self.request.is_eager or app.conf.CELERY_ALWAYS_EAGER: if self.request.is_eager or app.conf.CELERY_ALWAYS_EAGER:
return app.GroupResult( return app.GroupResult(
@@ -178,30 +182,25 @@ def add_group_task(app):
add_to_parent=False) for stask in taskit] add_to_parent=False) for stask in taskit]
parent = get_current_worker_task() parent = get_current_worker_task()
if parent: if parent:
parent.request.children.append(result) parent.add_trail(result)
return result return result
def prepare(self, options, tasks, args, **kwargs): def prepare(self, options, tasks, args, **kwargs):
AsyncResult = self.AsyncResult
options['group_id'] = group_id = ( options['group_id'] = group_id = (
options.setdefault('task_id', uuid())) options.setdefault('task_id', uuid()))
def prepare_member(task): def prepare_member(task):
task = maybe_subtask(task) task = maybe_signature(task, app=self.app)
opts = task.options task.options['group_id'] = group_id
opts['group_id'] = group_id return task, task.freeze()
try:
tid = opts['task_id']
except KeyError:
tid = opts['task_id'] = uuid()
return task, AsyncResult(tid)
try: try:
tasks, results = zip(*[prepare_member(task) for task in tasks]) tasks, res = list(zip(
*[prepare_member(task) for task in tasks]
))
except ValueError: # tasks empty except ValueError: # tasks empty
tasks, results = [], [] tasks, res = [], []
return (tasks, self.app.GroupResult(group_id, results), return (tasks, self.app.GroupResult(group_id, res), group_id, args)
group_id, args)
def apply_async(self, partial_args=(), kwargs={}, **options): def apply_async(self, partial_args=(), kwargs={}, **options):
if self.app.conf.CELERY_ALWAYS_EAGER: if self.app.conf.CELERY_ALWAYS_EAGER:
@@ -210,7 +209,7 @@ def add_group_task(app):
options, args=partial_args, **kwargs options, args=partial_args, **kwargs
) )
super(Group, self).apply_async(( super(Group, self).apply_async((
list(tasks), result.serializable(), gid, args), **options list(tasks), result.as_tuple(), gid, args), **options
) )
return result return result
@@ -223,50 +222,55 @@ def add_group_task(app):
@shared_task @shared_task
def add_chain_task(app): def add_chain_task(app):
from celery.canvas import Signature, chord, group, maybe_subtask from celery.canvas import Signature, chord, group, maybe_signature
_app = app _app = app
class Chain(app.Task): class Chain(app.Task):
app = _app app = _app
name = 'celery.chain' name = 'celery.chain'
accept_magic_kwargs = False accept_magic_kwargs = False
_decorated = True
def prepare_steps(self, args, tasks): def prepare_steps(self, args, tasks):
app = self.app
steps = deque(tasks) steps = deque(tasks)
next_step = prev_task = prev_res = None next_step = prev_task = prev_res = None
tasks, results = [], [] tasks, results = [], []
i = 0 i = 0
while steps: while steps:
# First task get partial args from chain. # First task get partial args from chain.
task = maybe_subtask(steps.popleft()) task = maybe_signature(steps.popleft(), app=app)
task = task.clone() if i else task.clone(args) task = task.clone() if i else task.clone(args)
res = task._freeze() res = task.freeze()
i += 1 i += 1
if isinstance(task, group): if isinstance(task, group) and steps and \
not isinstance(steps[0], group):
# automatically upgrade group(..) | s to chord(group, s) # automatically upgrade group(..) | s to chord(group, s)
try: try:
next_step = steps.popleft() next_step = steps.popleft()
# for chords we freeze by pretending it's a normal # for chords we freeze by pretending it's a normal
# task instead of a group. # task instead of a group.
res = Signature._freeze(task) res = Signature.freeze(next_step)
task = chord(task, body=next_step, task_id=res.task_id) task = chord(task, body=next_step, task_id=res.task_id)
except IndexError: except IndexError:
pass pass # no callback, so keep as group
if prev_task: if prev_task:
# link previous task to this task. # link previous task to this task.
prev_task.link(task) prev_task.link(task)
# set the results parent attribute. # set the results parent attribute.
res.parent = prev_res if not res.parent:
res.parent = prev_res
results.append(res) if not isinstance(prev_task, chord):
tasks.append(task) results.append(res)
tasks.append(task)
prev_task, prev_res = task, res prev_task, prev_res = task, res
return tasks, results return tasks, results
def apply_async(self, args=(), kwargs={}, group_id=None, chord=None, def apply_async(self, args=(), kwargs={}, group_id=None, chord=None,
task_id=None, **options): task_id=None, link=None, link_error=None, **options):
if self.app.conf.CELERY_ALWAYS_EAGER: if self.app.conf.CELERY_ALWAYS_EAGER:
return self.apply(args, kwargs, **options) return self.apply(args, kwargs, **options)
options.pop('publisher', None) options.pop('publisher', None)
@@ -279,13 +283,24 @@ def add_chain_task(app):
if task_id: if task_id:
tasks[-1].set(task_id=task_id) tasks[-1].set(task_id=task_id)
result = tasks[-1].type.AsyncResult(task_id) result = tasks[-1].type.AsyncResult(task_id)
# make sure we can do a link() and link_error() on a chain object.
if link:
tasks[-1].set(link=link)
# and if any task in the chain fails, call the errbacks
if link_error:
for task in tasks:
task.set(link_error=link_error)
tasks[0].apply_async() tasks[0].apply_async()
return result return result
def apply(self, args=(), kwargs={}, subtask=maybe_subtask, **options): def apply(self, args=(), kwargs={}, signature=maybe_signature,
**options):
app = self.app
last, fargs = None, args # fargs passed to first task only last, fargs = None, args # fargs passed to first task only
for task in kwargs['tasks']: for task in kwargs['tasks']:
res = subtask(task).clone(fargs).apply(last and (last.get(), )) res = signature(task, app=app).clone(fargs).apply(
last and (last.get(), ),
)
res.parent, last, fargs = last, res, None res.parent, last, fargs = last, res, None
return last return last
return Chain return Chain
@@ -294,10 +309,10 @@ def add_chain_task(app):
@shared_task @shared_task
def add_chord_task(app): def add_chord_task(app):
"""Every chord is executed in a dedicated task, so that the chord """Every chord is executed in a dedicated task, so that the chord
can be used as a subtask, and this generates the task can be used as a signature, and this generates the task
responsible for that.""" responsible for that."""
from celery import group from celery import group
from celery.canvas import maybe_subtask from celery.canvas import maybe_signature
_app = app _app = app
default_propagate = app.conf.CELERY_CHORD_PROPAGATES default_propagate = app.conf.CELERY_CHORD_PROPAGATES
@@ -306,18 +321,22 @@ def add_chord_task(app):
name = 'celery.chord' name = 'celery.chord'
accept_magic_kwargs = False accept_magic_kwargs = False
ignore_result = False ignore_result = False
_decorated = True
def run(self, header, body, partial_args=(), interval=None, def run(self, header, body, partial_args=(), interval=None,
countdown=1, max_retries=None, propagate=None, countdown=1, max_retries=None, propagate=None,
eager=False, **kwargs): eager=False, **kwargs):
app = self.app
propagate = default_propagate if propagate is None else propagate propagate = default_propagate if propagate is None else propagate
group_id = uuid() group_id = uuid()
AsyncResult = self.app.AsyncResult AsyncResult = app.AsyncResult
prepare_member = self._prepare_member prepare_member = self._prepare_member
# - convert back to group if serialized # - convert back to group if serialized
tasks = header.tasks if isinstance(header, group) else header tasks = header.tasks if isinstance(header, group) else header
header = group([maybe_subtask(s).clone() for s in tasks]) header = group([
maybe_signature(s, app=app).clone() for s in tasks
])
# - eager applies the group inline # - eager applies the group inline
if eager: if eager:
return header.apply(args=partial_args, task_id=group_id) return header.apply(args=partial_args, task_id=group_id)
@@ -333,8 +352,9 @@ def add_chord_task(app):
propagate=propagate, propagate=propagate,
result=results) result=results)
# - call the header group, returning the GroupResult. # - call the header group, returning the GroupResult.
# XXX Python 2.5 doesn't allow kwargs after star-args. final_res = header(*partial_args, task_id=group_id)
return header(*partial_args, **{'task_id': group_id})
return final_res
def _prepare_member(self, task, body, group_id): def _prepare_member(self, task, body, group_id):
opts = task.options opts = task.options
@@ -346,23 +366,25 @@ def add_chord_task(app):
opts.update(chord=body, group_id=group_id) opts.update(chord=body, group_id=group_id)
return task_id return task_id
def apply_async(self, args=(), kwargs={}, task_id=None, **options): def apply_async(self, args=(), kwargs={}, task_id=None,
if self.app.conf.CELERY_ALWAYS_EAGER: group_id=None, chord=None, **options):
app = self.app
if app.conf.CELERY_ALWAYS_EAGER:
return self.apply(args, kwargs, **options) return self.apply(args, kwargs, **options)
group_id = options.pop('group_id', None)
chord = options.pop('chord', None)
header = kwargs.pop('header') header = kwargs.pop('header')
body = kwargs.pop('body') body = kwargs.pop('body')
header, body = (list(maybe_subtask(header)), header, body = (list(maybe_signature(header, app=app)),
maybe_subtask(body)) maybe_signature(body, app=app))
if group_id: # forward certain options to body
body.set(group_id=group_id) if chord is not None:
if chord: body.options['chord'] = chord
body.set(chord=chord) if group_id is not None:
callback_id = body.options.setdefault('task_id', task_id or uuid()) body.options['group_id'] = group_id
[body.link(s) for s in options.pop('link', [])]
[body.link_error(s) for s in options.pop('link_error', [])]
body_result = body.freeze(task_id)
parent = super(Chord, self).apply_async((header, body, args), parent = super(Chord, self).apply_async((header, body, args),
kwargs, **options) kwargs, **options)
body_result = self.AsyncResult(callback_id)
body_result.parent = parent body_result.parent = parent
return body_result return body_result
@@ -370,6 +392,6 @@ def add_chord_task(app):
body = kwargs['body'] body = kwargs['body']
res = super(Chord, self).apply(args, dict(kwargs, eager=True), res = super(Chord, self).apply(args, dict(kwargs, eager=True),
**options) **options)
return maybe_subtask(body).apply( return maybe_signature(body, app=self.app).apply(
args=(res.get(propagate=propagate).get(), )) args=(res.get(propagate=propagate).get(), ))
return Chord return Chord

View File

@@ -8,17 +8,32 @@
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement
import warnings
from kombu.pidbox import Mailbox from kombu.pidbox import Mailbox
from kombu.utils import cached_property from kombu.utils import cached_property
from . import app_or_default from celery.exceptions import DuplicateNodenameWarning
__all__ = ['Inspect', 'Control', 'flatten_reply']
W_DUPNODE = """\
Received multiple replies from node name {0!r}.
Please make sure you give each node a unique nodename using the `-n` option.\
"""
def flatten_reply(reply): def flatten_reply(reply):
nodes = {} nodes = {}
seen = set()
for item in reply: for item in reply:
dup = next((nodename in seen for nodename in item), None)
if dup:
warnings.warn(DuplicateNodenameWarning(
W_DUPNODE.format(dup),
))
seen.update(item)
nodes.update(item) nodes.update(item)
return nodes return nodes
@@ -58,6 +73,9 @@ class Inspect(object):
def report(self): def report(self):
return self._request('report') return self._request('report')
def clock(self):
return self._request('clock')
def active(self, safe=False): def active(self, safe=False):
return self._request('dump_active', safe=safe) return self._request('dump_active', safe=safe)
@@ -83,15 +101,30 @@ class Inspect(object):
def active_queues(self): def active_queues(self):
return self._request('active_queues') return self._request('active_queues')
def conf(self): def query_task(self, ids):
return self._request('dump_conf') return self._request('query_task', ids=ids)
def conf(self, with_defaults=False):
return self._request('dump_conf', with_defaults=with_defaults)
def hello(self, from_node, revoked=None):
return self._request('hello', from_node=from_node, revoked=revoked)
def memsample(self):
return self._request('memsample')
def memdump(self, samples=10):
return self._request('memdump', samples=samples)
def objgraph(self, type='Request', n=200, max_depth=10):
return self._request('objgraph', num=n, max_depth=max_depth, type=type)
class Control(object): class Control(object):
Mailbox = Mailbox Mailbox = Mailbox
def __init__(self, app=None): def __init__(self, app=None):
self.app = app_or_default(app) self.app = app
self.mailbox = self.Mailbox('celery', type='fanout', self.mailbox = self.Mailbox('celery', type='fanout',
accept=self.app.conf.CELERY_ACCEPT_CONTENT) accept=self.app.conf.CELERY_ACCEPT_CONTENT)
@@ -112,6 +145,11 @@ class Control(object):
return self.app.amqp.TaskConsumer(conn).purge() return self.app.amqp.TaskConsumer(conn).purge()
discard_all = purge discard_all = purge
def election(self, id, topic, action=None, connection=None):
self.broadcast('election', connection=connection, arguments={
'id': id, 'topic': topic, 'action': action,
})
def revoke(self, task_id, destination=None, terminate=False, def revoke(self, task_id, destination=None, terminate=False,
signal='SIGTERM', **kwargs): signal='SIGTERM', **kwargs):
"""Tell all (or specific) workers to revoke a task by id. """Tell all (or specific) workers to revoke a task by id.
@@ -136,7 +174,7 @@ class Control(object):
def ping(self, destination=None, timeout=1, **kwargs): def ping(self, destination=None, timeout=1, **kwargs):
"""Ping all (or specific) workers. """Ping all (or specific) workers.
Returns answer from alive workers. Will return the list of answers.
See :meth:`broadcast` for supported keyword arguments. See :meth:`broadcast` for supported keyword arguments.
@@ -234,7 +272,7 @@ class Control(object):
Supports the same arguments as :meth:`broadcast`. Supports the same arguments as :meth:`broadcast`.
""" """
return self.broadcast('pool_grow', {}, destination, **kwargs) return self.broadcast('pool_grow', {'n': n}, destination, **kwargs)
def pool_shrink(self, n=1, destination=None, **kwargs): def pool_shrink(self, n=1, destination=None, **kwargs):
"""Tell all (or specific) workers to shrink the pool by ``n``. """Tell all (or specific) workers to shrink the pool by ``n``.
@@ -242,7 +280,7 @@ class Control(object):
Supports the same arguments as :meth:`broadcast`. Supports the same arguments as :meth:`broadcast`.
""" """
return self.broadcast('pool_shrink', {}, destination, **kwargs) return self.broadcast('pool_shrink', {'n': n}, destination, **kwargs)
def broadcast(self, command, arguments=None, destination=None, def broadcast(self, command, arguments=None, destination=None,
connection=None, reply=False, timeout=1, limit=None, connection=None, reply=False, timeout=1, limit=None,

View File

@@ -10,25 +10,28 @@ from __future__ import absolute_import
import sys import sys
from collections import deque from collections import deque, namedtuple
from datetime import timedelta from datetime import timedelta
from celery.five import items
from celery.utils import strtobool from celery.utils import strtobool
from celery.utils.functional import memoize from celery.utils.functional import memoize
__all__ = ['Option', 'NAMESPACES', 'flatten', 'find']
is_jython = sys.platform.startswith('java') is_jython = sys.platform.startswith('java')
is_pypy = hasattr(sys, 'pypy_version_info') is_pypy = hasattr(sys, 'pypy_version_info')
DEFAULT_POOL = 'processes' DEFAULT_POOL = 'prefork'
if is_jython: if is_jython:
DEFAULT_POOL = 'threads' DEFAULT_POOL = 'threads'
elif is_pypy: elif is_pypy:
if sys.pypy_version_info[0:3] < (1, 5, 0): if sys.pypy_version_info[0:3] < (1, 5, 0):
DEFAULT_POOL = 'solo' DEFAULT_POOL = 'solo'
else: else:
DEFAULT_POOL = 'processes' DEFAULT_POOL = 'prefork'
DEFAULT_ACCEPT_CONTENT = ['json', 'pickle', 'msgpack', 'yaml']
DEFAULT_PROCESS_LOG_FMT = """ DEFAULT_PROCESS_LOG_FMT = """
[%(asctime)s: %(levelname)s/%(processName)s] %(message)s [%(asctime)s: %(levelname)s/%(processName)s] %(message)s
""".strip() """.strip()
@@ -36,10 +39,13 @@ DEFAULT_LOG_FMT = '[%(asctime)s: %(levelname)s] %(message)s'
DEFAULT_TASK_LOG_FMT = """[%(asctime)s: %(levelname)s/%(processName)s] \ DEFAULT_TASK_LOG_FMT = """[%(asctime)s: %(levelname)s/%(processName)s] \
%(task_name)s[%(task_id)s]: %(message)s""" %(task_name)s[%(task_id)s]: %(message)s"""
_BROKER_OLD = {'deprecate_by': '2.5', 'remove_by': '4.0', 'alt': 'BROKER_URL'} _BROKER_OLD = {'deprecate_by': '2.5', 'remove_by': '4.0',
'alt': 'BROKER_URL setting'}
_REDIS_OLD = {'deprecate_by': '2.5', 'remove_by': '4.0', _REDIS_OLD = {'deprecate_by': '2.5', 'remove_by': '4.0',
'alt': 'URL form of CELERY_RESULT_BACKEND'} 'alt': 'URL form of CELERY_RESULT_BACKEND'}
searchresult = namedtuple('searchresult', ('namespace', 'key', 'type'))
class Option(object): class Option(object):
alt = None alt = None
@@ -51,15 +57,15 @@ class Option(object):
def __init__(self, default=None, *args, **kwargs): def __init__(self, default=None, *args, **kwargs):
self.default = default self.default = default
self.type = kwargs.get('type') or 'string' self.type = kwargs.get('type') or 'string'
for attr, value in kwargs.iteritems(): for attr, value in items(kwargs):
setattr(self, attr, value) setattr(self, attr, value)
def to_python(self, value): def to_python(self, value):
return self.typemap[self.type](value) return self.typemap[self.type](value)
def __repr__(self): def __repr__(self):
return '<Option: type->%s default->%r>' % (self.type, self.default) return '<Option: type->{0} default->{1!r}>'.format(self.type,
self.default)
NAMESPACES = { NAMESPACES = {
'BROKER': { 'BROKER': {
@@ -67,11 +73,11 @@ NAMESPACES = {
'CONNECTION_TIMEOUT': Option(4, type='float'), 'CONNECTION_TIMEOUT': Option(4, type='float'),
'CONNECTION_RETRY': Option(True, type='bool'), 'CONNECTION_RETRY': Option(True, type='bool'),
'CONNECTION_MAX_RETRIES': Option(100, type='int'), 'CONNECTION_MAX_RETRIES': Option(100, type='int'),
'FAILOVER_STRATEGY': Option(None, type='string'),
'HEARTBEAT': Option(None, type='int'), 'HEARTBEAT': Option(None, type='int'),
'HEARTBEAT_CHECKRATE': Option(3.0, type='int'), 'HEARTBEAT_CHECKRATE': Option(3.0, type='int'),
'LOGIN_METHOD': Option(None, type='string'),
'POOL_LIMIT': Option(10, type='int'), 'POOL_LIMIT': Option(10, type='int'),
'INSIST': Option(False, type='bool',
deprecate_by='2.4', remove_by='4.0'),
'USE_SSL': Option(False, type='bool'), 'USE_SSL': Option(False, type='bool'),
'TRANSPORT': Option(type='string'), 'TRANSPORT': Option(type='string'),
'TRANSPORT_OPTIONS': Option({}, type='dict'), 'TRANSPORT_OPTIONS': Option({}, type='dict'),
@@ -90,24 +96,18 @@ NAMESPACES = {
'WRITE_CONSISTENCY': Option(type='string'), 'WRITE_CONSISTENCY': Option(type='string'),
}, },
'CELERY': { 'CELERY': {
'ACCEPT_CONTENT': Option(None, type='any'), 'ACCEPT_CONTENT': Option(DEFAULT_ACCEPT_CONTENT, type='list'),
'ACKS_LATE': Option(False, type='bool'), 'ACKS_LATE': Option(False, type='bool'),
'ALWAYS_EAGER': Option(False, type='bool'), 'ALWAYS_EAGER': Option(False, type='bool'),
'AMQP_TASK_RESULT_EXPIRES': Option(
type='float', deprecate_by='2.5', remove_by='4.0',
alt='CELERY_TASK_RESULT_EXPIRES'
),
'AMQP_TASK_RESULT_CONNECTION_MAX': Option(
1, type='int', remove_by='2.5', alt='BROKER_POOL_LIMIT',
),
'ANNOTATIONS': Option(type='any'), 'ANNOTATIONS': Option(type='any'),
'FORCE_BILLIARD_LOGGING': Option(True, type='bool'),
'BROADCAST_QUEUE': Option('celeryctl'), 'BROADCAST_QUEUE': Option('celeryctl'),
'BROADCAST_EXCHANGE': Option('celeryctl'), 'BROADCAST_EXCHANGE': Option('celeryctl'),
'BROADCAST_EXCHANGE_TYPE': Option('fanout'), 'BROADCAST_EXCHANGE_TYPE': Option('fanout'),
'CACHE_BACKEND': Option(), 'CACHE_BACKEND': Option(),
'CACHE_BACKEND_OPTIONS': Option({}, type='dict'), 'CACHE_BACKEND_OPTIONS': Option({}, type='dict'),
# chord propagate will be True from v3.1 'CHORD_PROPAGATES': Option(True, type='bool'),
'CHORD_PROPAGATES': Option(False, type='bool'), 'COUCHBASE_BACKEND_SETTINGS': Option(None, type='dict'),
'CREATE_MISSING_QUEUES': Option(True, type='bool'), 'CREATE_MISSING_QUEUES': Option(True, type='bool'),
'DEFAULT_RATE_LIMIT': Option(type='string'), 'DEFAULT_RATE_LIMIT': Option(type='string'),
'DISABLE_RATE_LIMITS': Option(False, type='bool'), 'DISABLE_RATE_LIMITS': Option(False, type='bool'),
@@ -118,7 +118,10 @@ NAMESPACES = {
'DEFAULT_DELIVERY_MODE': Option(2, type='string'), 'DEFAULT_DELIVERY_MODE': Option(2, type='string'),
'EAGER_PROPAGATES_EXCEPTIONS': Option(False, type='bool'), 'EAGER_PROPAGATES_EXCEPTIONS': Option(False, type='bool'),
'ENABLE_UTC': Option(True, type='bool'), 'ENABLE_UTC': Option(True, type='bool'),
'ENABLE_REMOTE_CONTROL': Option(True, type='bool'),
'EVENT_SERIALIZER': Option('json'), 'EVENT_SERIALIZER': Option('json'),
'EVENT_QUEUE_EXPIRES': Option(None, type='float'),
'EVENT_QUEUE_TTL': Option(None, type='float'),
'IMPORTS': Option((), type='tuple'), 'IMPORTS': Option((), type='tuple'),
'INCLUDE': Option((), type='tuple'), 'INCLUDE': Option((), type='tuple'),
'IGNORE_RESULT': Option(False, type='bool'), 'IGNORE_RESULT': Option(False, type='bool'),
@@ -132,20 +135,18 @@ NAMESPACES = {
'REDIS_MAX_CONNECTIONS': Option(type='int'), 'REDIS_MAX_CONNECTIONS': Option(type='int'),
'RESULT_BACKEND': Option(type='string'), 'RESULT_BACKEND': Option(type='string'),
'RESULT_DB_SHORT_LIVED_SESSIONS': Option(False, type='bool'), 'RESULT_DB_SHORT_LIVED_SESSIONS': Option(False, type='bool'),
'RESULT_DB_TABLENAMES': Option(type='dict'),
'RESULT_DBURI': Option(), 'RESULT_DBURI': Option(),
'RESULT_ENGINE_OPTIONS': Option(type='dict'), 'RESULT_ENGINE_OPTIONS': Option(type='dict'),
'RESULT_EXCHANGE': Option('celeryresults'), 'RESULT_EXCHANGE': Option('celeryresults'),
'RESULT_EXCHANGE_TYPE': Option('direct'), 'RESULT_EXCHANGE_TYPE': Option('direct'),
'RESULT_SERIALIZER': Option('pickle'), 'RESULT_SERIALIZER': Option('pickle'),
'RESULT_PERSISTENT': Option(False, type='bool'), 'RESULT_PERSISTENT': Option(None, type='bool'),
'ROUTES': Option(type='any'), 'ROUTES': Option(type='any'),
'SEND_EVENTS': Option(False, type='bool'), 'SEND_EVENTS': Option(False, type='bool'),
'SEND_TASK_ERROR_EMAILS': Option(False, type='bool'), 'SEND_TASK_ERROR_EMAILS': Option(False, type='bool'),
'SEND_TASK_SENT_EVENT': Option(False, type='bool'), 'SEND_TASK_SENT_EVENT': Option(False, type='bool'),
'STORE_ERRORS_EVEN_IF_IGNORED': Option(False, type='bool'), 'STORE_ERRORS_EVEN_IF_IGNORED': Option(False, type='bool'),
'TASK_ERROR_WHITELIST': Option(
(), type='tuple', deprecate_by='2.5', remove_by='4.0',
),
'TASK_PUBLISH_RETRY': Option(True, type='bool'), 'TASK_PUBLISH_RETRY': Option(True, type='bool'),
'TASK_PUBLISH_RETRY_POLICY': Option({ 'TASK_PUBLISH_RETRY_POLICY': Option({
'max_retries': 3, 'max_retries': 3,
@@ -166,22 +167,21 @@ NAMESPACES = {
'WORKER_DIRECT': Option(False, type='bool'), 'WORKER_DIRECT': Option(False, type='bool'),
}, },
'CELERYD': { 'CELERYD': {
'AUTOSCALER': Option('celery.worker.autoscale.Autoscaler'), 'AGENT': Option(None, type='string'),
'AUTORELOADER': Option('celery.worker.autoreload.Autoreloader'), 'AUTOSCALER': Option('celery.worker.autoscale:Autoscaler'),
'BOOT_STEPS': Option((), type='tuple'), 'AUTORELOADER': Option('celery.worker.autoreload:Autoreloader'),
'CONCURRENCY': Option(0, type='int'), 'CONCURRENCY': Option(0, type='int'),
'TIMER': Option(type='string'), 'TIMER': Option(type='string'),
'TIMER_PRECISION': Option(1.0, type='float'), 'TIMER_PRECISION': Option(1.0, type='float'),
'FORCE_EXECV': Option(False, type='bool'), 'FORCE_EXECV': Option(False, type='bool'),
'HIJACK_ROOT_LOGGER': Option(True, type='bool'), 'HIJACK_ROOT_LOGGER': Option(True, type='bool'),
'CONSUMER': Option(type='string'), 'CONSUMER': Option('celery.worker.consumer:Consumer', type='string'),
'LOG_FORMAT': Option(DEFAULT_PROCESS_LOG_FMT), 'LOG_FORMAT': Option(DEFAULT_PROCESS_LOG_FMT),
'LOG_COLOR': Option(type='bool'), 'LOG_COLOR': Option(type='bool'),
'LOG_LEVEL': Option('WARN', deprecate_by='2.4', remove_by='4.0', 'LOG_LEVEL': Option('WARN', deprecate_by='2.4', remove_by='4.0',
alt='--loglevel argument'), alt='--loglevel argument'),
'LOG_FILE': Option(deprecate_by='2.4', remove_by='4.0', 'LOG_FILE': Option(deprecate_by='2.4', remove_by='4.0',
alt='--logfile argument'), alt='--logfile argument'),
'MEDIATOR': Option('celery.worker.mediator.Mediator'),
'MAX_TASKS_PER_CHILD': Option(type='int'), 'MAX_TASKS_PER_CHILD': Option(type='int'),
'POOL': Option(DEFAULT_POOL), 'POOL': Option(DEFAULT_POOL),
'POOL_PUTLOCKS': Option(True, type='bool'), 'POOL_PUTLOCKS': Option(True, type='bool'),
@@ -195,7 +195,7 @@ NAMESPACES = {
}, },
'CELERYBEAT': { 'CELERYBEAT': {
'SCHEDULE': Option({}, type='dict'), 'SCHEDULE': Option({}, type='dict'),
'SCHEDULER': Option('celery.beat.PersistentScheduler'), 'SCHEDULER': Option('celery.beat:PersistentScheduler'),
'SCHEDULE_FILENAME': Option('celerybeat-schedule'), 'SCHEDULE_FILENAME': Option('celerybeat-schedule'),
'MAX_LOOP_INTERVAL': Option(0, type='float'), 'MAX_LOOP_INTERVAL': Option(0, type='float'),
'LOG_LEVEL': Option('INFO', deprecate_by='2.4', remove_by='4.0', 'LOG_LEVEL': Option('INFO', deprecate_by='2.4', remove_by='4.0',
@@ -228,7 +228,7 @@ def flatten(d, ns=''):
stack = deque([(ns, d)]) stack = deque([(ns, d)])
while stack: while stack:
name, space = stack.popleft() name, space = stack.popleft()
for key, value in space.iteritems(): for key, value in items(space):
if isinstance(value, dict): if isinstance(value, dict):
stack.append((name + key + '_', value)) stack.append((name + key + '_', value))
else: else:
@@ -240,10 +240,10 @@ def find_deprecated_settings(source):
from celery.utils import warn_deprecated from celery.utils import warn_deprecated
for name, opt in flatten(NAMESPACES): for name, opt in flatten(NAMESPACES):
if (opt.deprecate_by or opt.remove_by) and getattr(source, name, None): if (opt.deprecate_by or opt.remove_by) and getattr(source, name, None):
warn_deprecated(description='The %r setting' % (name, ), warn_deprecated(description='The {0!r} setting'.format(name),
deprecation=opt.deprecate_by, deprecation=opt.deprecate_by,
removal=opt.remove_by, removal=opt.remove_by,
alternative='Use %s instead' % (opt.alt, )) alternative='Use the {0.alt} instead'.format(opt))
return source return source
@@ -252,16 +252,18 @@ def find(name, namespace='celery'):
# - Try specified namespace first. # - Try specified namespace first.
namespace = namespace.upper() namespace = namespace.upper()
try: try:
return namespace, name.upper(), NAMESPACES[namespace][name.upper()] return searchresult(
namespace, name.upper(), NAMESPACES[namespace][name.upper()],
)
except KeyError: except KeyError:
# - Try all the other namespaces. # - Try all the other namespaces.
for ns, keys in NAMESPACES.iteritems(): for ns, keys in items(NAMESPACES):
if ns.upper() == name.upper(): if ns.upper() == name.upper():
return None, ns, keys return searchresult(None, ns, keys)
elif isinstance(keys, dict): elif isinstance(keys, dict):
try: try:
return ns, name.upper(), keys[name.upper()] return searchresult(ns, name.upper(), keys[name.upper()])
except KeyError: except KeyError:
pass pass
# - See if name is a qualname last. # - See if name is a qualname last.
return None, name.upper(), DEFAULTS[name.upper()] return searchresult(None, name.upper(), DEFAULTS[name.upper()])

View File

@@ -16,12 +16,15 @@ import logging
import os import os
import sys import sys
from logging.handlers import WatchedFileHandler
from kombu.log import NullHandler from kombu.log import NullHandler
from kombu.utils.encoding import set_default_encoding_file
from celery import signals from celery import signals
from celery._state import get_current_task from celery._state import get_current_task
from celery.five import class_property, string_t
from celery.utils import isatty from celery.utils import isatty
from celery.utils.compat import WatchedFileHandler
from celery.utils.log import ( from celery.utils.log import (
get_logger, mlevel, get_logger, mlevel,
ColorFormatter, ensure_process_aware_logger, ColorFormatter, ensure_process_aware_logger,
@@ -30,7 +33,7 @@ from celery.utils.log import (
) )
from celery.utils.term import colored from celery.utils.term import colored
is_py3k = sys.version_info[0] == 3 __all__ = ['TaskFormatter', 'Logging']
MP_LOG = os.environ.get('MP_LOG', False) MP_LOG = os.environ.get('MP_LOG', False)
@@ -67,28 +70,33 @@ class Logging(object):
loglevel, logfile, colorize=colorize, loglevel, logfile, colorize=colorize,
) )
if not handled: if not handled:
logger = get_logger('celery.redirected')
if redirect_stdouts: if redirect_stdouts:
self.redirect_stdouts_to_logger(logger, self.redirect_stdouts(redirect_level)
loglevel=redirect_level)
os.environ.update( os.environ.update(
CELERY_LOG_LEVEL=str(loglevel) if loglevel else '', CELERY_LOG_LEVEL=str(loglevel) if loglevel else '',
CELERY_LOG_FILE=str(logfile) if logfile else '', CELERY_LOG_FILE=str(logfile) if logfile else '',
CELERY_LOG_REDIRECT='1' if redirect_stdouts else '', )
CELERY_LOG_REDIRECT_LEVEL=str(redirect_level), return handled
def redirect_stdouts(self, loglevel=None, name='celery.redirected'):
self.redirect_stdouts_to_logger(
get_logger(name), loglevel=loglevel
)
os.environ.update(
CELERY_LOG_REDIRECT='1',
CELERY_LOG_REDIRECT_LEVEL=str(loglevel or ''),
) )
def setup_logging_subsystem(self, loglevel=None, logfile=None, def setup_logging_subsystem(self, loglevel=None, logfile=None,
format=None, colorize=None, **kwargs): format=None, colorize=None, **kwargs):
if Logging._setup: if self.already_setup:
return return
Logging._setup = True self.already_setup = True
loglevel = mlevel(loglevel or self.loglevel) loglevel = mlevel(loglevel or self.loglevel)
format = format or self.format format = format or self.format
colorize = self.supports_color(colorize, logfile) colorize = self.supports_color(colorize, logfile)
reset_multiprocessing_logger() reset_multiprocessing_logger()
if not is_py3k: ensure_process_aware_logger()
ensure_process_aware_logger()
receivers = signals.setup_logging.send( receivers = signals.setup_logging.send(
sender=None, loglevel=loglevel, logfile=logfile, sender=None, loglevel=loglevel, logfile=logfile,
format=format, colorize=colorize, format=format, colorize=colorize,
@@ -121,14 +129,19 @@ class Logging(object):
# then setup the root task logger. # then setup the root task logger.
self.setup_task_loggers(loglevel, logfile, colorize=colorize) self.setup_task_loggers(loglevel, logfile, colorize=colorize)
try:
stream = logging.getLogger().handlers[0].stream
except (AttributeError, IndexError):
pass
else:
set_default_encoding_file(stream)
# This is a hack for multiprocessing's fork+exec, so that # This is a hack for multiprocessing's fork+exec, so that
# logging before Process.run works. # logging before Process.run works.
logfile_name = logfile if isinstance(logfile, basestring) else '' logfile_name = logfile if isinstance(logfile, string_t) else ''
os.environ.update( os.environ.update(_MP_FORK_LOGLEVEL_=str(loglevel),
_MP_FORK_LOGLEVEL_=str(loglevel), _MP_FORK_LOGFILE_=logfile_name,
_MP_FORK_LOGFILE_=logfile_name, _MP_FORK_LOGFORMAT_=format)
_MP_FORK_LOGFORMAT_=format,
)
return receivers return receivers
def _configure_logger(self, logger, logfile, loglevel, def _configure_logger(self, logger, logfile, loglevel,
@@ -145,7 +158,7 @@ class Logging(object):
If `logfile` is not specified, then `sys.stderr` is used. If `logfile` is not specified, then `sys.stderr` is used.
Returns logger object. Will return the base task logger object.
""" """
loglevel = mlevel(loglevel or self.loglevel) loglevel = mlevel(loglevel or self.loglevel)
@@ -229,3 +242,11 @@ class Logging(object):
def get_default_logger(self, name='celery', **kwargs): def get_default_logger(self, name='celery', **kwargs):
return get_logger(name) return get_logger(name)
@class_property
def already_setup(cls):
return cls._setup
@already_setup.setter # noqa
def already_setup(cls, was_setup):
cls._setup = was_setup

View File

@@ -10,7 +10,13 @@ from __future__ import absolute_import
import inspect import inspect
from importlib import import_module
from celery._state import get_current_app
from celery.exceptions import NotRegistered from celery.exceptions import NotRegistered
from celery.five import items
__all__ = ['TaskRegistry']
class TaskRegistry(dict): class TaskRegistry(dict):
@@ -51,10 +57,15 @@ class TaskRegistry(dict):
return self.filter_types('periodic') return self.filter_types('periodic')
def filter_types(self, type): def filter_types(self, type):
return dict((name, task) for name, task in self.iteritems() return dict((name, task) for name, task in items(self)
if getattr(task, 'type', 'regular') == type) if getattr(task, 'type', 'regular') == type)
def _unpickle_task(name): def _unpickle_task(name):
from celery import current_app return get_current_app().tasks[name]
return current_app.tasks[name]
def _unpickle_task_v2(name, module=None):
if module:
import_module(module)
return get_current_app().tasks[name]

View File

@@ -10,10 +10,13 @@
from __future__ import absolute_import from __future__ import absolute_import
from celery.exceptions import QueueNotFound from celery.exceptions import QueueNotFound
from celery.five import string_t
from celery.utils import lpmerge from celery.utils import lpmerge
from celery.utils.functional import firstmethod, mpromise from celery.utils.functional import firstmethod, mlazy
from celery.utils.imports import instantiate from celery.utils.imports import instantiate
__all__ = ['MapRoute', 'Router', 'prepare']
_first_route = firstmethod('route_for_task') _first_route = firstmethod('route_for_task')
@@ -24,9 +27,10 @@ class MapRoute(object):
self.map = map self.map = map
def route_for_task(self, task, *args, **kwargs): def route_for_task(self, task, *args, **kwargs):
route = self.map.get(task) try:
if route: return dict(self.map[task])
return dict(route) except KeyError:
pass
class Router(object): class Router(object):
@@ -51,7 +55,7 @@ class Router(object):
def expand_destination(self, route): def expand_destination(self, route):
# Route can be a queue name: convenient for direct exchanges. # Route can be a queue name: convenient for direct exchanges.
if isinstance(route, basestring): if isinstance(route, string_t):
queue, route = route, {} queue, route = route, {}
else: else:
# can use defaults from configured queue, but override specific # can use defaults from configured queue, but override specific
@@ -62,13 +66,8 @@ class Router(object):
try: try:
Q = self.queues[queue] # noqa Q = self.queues[queue] # noqa
except KeyError: except KeyError:
if not self.create_missing: raise QueueNotFound(
raise QueueNotFound( 'Queue {0!r} missing from CELERY_QUEUES'.format(queue))
'Queue %r is not defined in CELERY_QUEUES' % queue)
for key in 'exchange', 'routing_key':
if route.get(key) is None:
route[key] = queue
Q = self.app.amqp.queues.add(queue, **route)
# needs to be declared by publisher # needs to be declared by publisher
route['queue'] = Q route['queue'] = Q
return route return route
@@ -83,8 +82,8 @@ def prepare(routes):
def expand_route(route): def expand_route(route):
if isinstance(route, dict): if isinstance(route, dict):
return MapRoute(route) return MapRoute(route)
if isinstance(route, basestring): if isinstance(route, string_t):
return mpromise(instantiate, route) return mlazy(instantiate, route)
return route return route
if routes is None: if routes is None:

View File

@@ -7,15 +7,18 @@
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import with_statement
import sys
from billiard.einfo import ExceptionInfo
from celery import current_app from celery import current_app
from celery import states from celery import states
from celery.__compat__ import class_property from celery._state import _task_stack
from celery._state import get_current_worker_task, _task_stack from celery.canvas import signature
from celery.canvas import subtask from celery.exceptions import MaxRetriesExceededError, Reject, Retry
from celery.datastructures import ExceptionInfo from celery.five import class_property, items, with_metaclass
from celery.exceptions import MaxRetriesExceededError, RetryTaskError from celery.local import Proxy
from celery.result import EagerResult from celery.result import EagerResult
from celery.utils import gen_task_name, fun_takes_kwargs, uuid, maybe_reraise from celery.utils import gen_task_name, fun_takes_kwargs, uuid, maybe_reraise
from celery.utils.functional import mattrgetter, maybe_list from celery.utils.functional import mattrgetter, maybe_list
@@ -23,15 +26,57 @@ from celery.utils.imports import instantiate
from celery.utils.mail import ErrorMail from celery.utils.mail import ErrorMail
from .annotations import resolve_all as resolve_all_annotations from .annotations import resolve_all as resolve_all_annotations
from .registry import _unpickle_task from .registry import _unpickle_task_v2
from .utils import appstr
__all__ = ['Context', 'Task']
#: extracts attributes related to publishing a message from an object. #: extracts attributes related to publishing a message from an object.
extract_exec_options = mattrgetter( extract_exec_options = mattrgetter(
'queue', 'routing_key', 'exchange', 'queue', 'routing_key', 'exchange', 'priority', 'expires',
'immediate', 'mandatory', 'priority', 'expires', 'serializer', 'delivery_mode', 'compression', 'time_limit',
'serializer', 'delivery_mode', 'compression', 'soft_time_limit', 'immediate', 'mandatory', # imm+man is deprecated
) )
# We take __repr__ very seriously around here ;)
R_BOUND_TASK = '<class {0.__name__} of {app}{flags}>'
R_UNBOUND_TASK = '<unbound {0.__name__}{flags}>'
R_SELF_TASK = '<@task {0.name} bound to other {0.__self__}>'
R_INSTANCE = '<@task: {0.name} of {app}{flags}>'
class _CompatShared(object):
def __init__(self, name, cons):
self.name = name
self.cons = cons
def __hash__(self):
return hash(self.name)
def __repr__(self):
return '<OldTask: %r>' % (self.name, )
def __call__(self, app):
return self.cons(app)
def _strflags(flags, default=''):
if flags:
return ' ({0})'.format(', '.join(flags))
return default
def _reprtask(task, fmt=None, flags=None):
flags = list(flags) if flags is not None else []
flags.append('v2 compatible') if task.__v2_compat__ else None
if not fmt:
fmt = R_BOUND_TASK if task._app else R_UNBOUND_TASK
return fmt.format(
task, flags=_strflags(flags),
app=appstr(task._app) if task._app else None,
)
class Context(object): class Context(object):
# Default context # Default context
@@ -45,7 +90,10 @@ class Context(object):
eta = None eta = None
expires = None expires = None
is_eager = False is_eager = False
headers = None
delivery_info = None delivery_info = None
reply_to = None
correlation_id = None
taskset = None # compat alias to group taskset = None # compat alias to group
group = None group = None
chord = None chord = None
@@ -53,7 +101,7 @@ class Context(object):
called_directly = True called_directly = True
callbacks = None callbacks = None
errbacks = None errbacks = None
timeouts = None timelimit = None
_children = None # see property _children = None # see property
_protected = 0 _protected = 0
@@ -61,19 +109,16 @@ class Context(object):
self.update(*args, **kwargs) self.update(*args, **kwargs)
def update(self, *args, **kwargs): def update(self, *args, **kwargs):
self.__dict__.update(*args, **kwargs) return self.__dict__.update(*args, **kwargs)
def clear(self): def clear(self):
self.__dict__.clear() return self.__dict__.clear()
def get(self, key, default=None): def get(self, key, default=None):
try: return getattr(self, key, default)
return getattr(self, key)
except AttributeError:
return default
def __repr__(self): def __repr__(self):
return '<Context: %r>' % (vars(self, )) return '<Context: {0!r}>'.format(vars(self))
@property @property
def children(self): def children(self):
@@ -86,33 +131,62 @@ class Context(object):
class TaskType(type): class TaskType(type):
"""Meta class for tasks. """Meta class for tasks.
Automatically registers the task in the task registry, except Automatically registers the task in the task registry (except
if the `abstract` attribute is set. if the :attr:`Task.abstract`` attribute is set).
If no `name` attribute is provided, then no name is automatically If no :attr:`Task.name` attribute is provided, then the name is generated
set to the name of the module it was defined in, and the class name. from the module and class name.
""" """
_creation_count = {} # used by old non-abstract task classes
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
new = super(TaskType, cls).__new__ new = super(TaskType, cls).__new__
task_module = attrs.get('__module__') or '__main__' task_module = attrs.get('__module__') or '__main__'
# - Abstract class: abstract attribute should not be inherited. # - Abstract class: abstract attribute should not be inherited.
if attrs.pop('abstract', None) or not attrs.get('autoregister', True): abstract = attrs.pop('abstract', None)
if abstract or not attrs.get('autoregister', True):
return new(cls, name, bases, attrs) return new(cls, name, bases, attrs)
# The 'app' attribute is now a property, with the real app located # The 'app' attribute is now a property, with the real app located
# in the '_app' attribute. Previously this was a regular attribute, # in the '_app' attribute. Previously this was a regular attribute,
# so we should support classes defining it. # so we should support classes defining it.
_app1, _app2 = attrs.pop('_app', None), attrs.pop('app', None) app = attrs.pop('_app', None) or attrs.pop('app', None)
app = attrs['_app'] = _app1 or _app2 or current_app if not isinstance(app, Proxy) and app is None:
for base in bases:
if base._app:
app = base._app
break
else:
app = current_app._get_current_object()
attrs['_app'] = app
# - Automatically generate missing/empty name. # - Automatically generate missing/empty name.
task_name = attrs.get('name') task_name = attrs.get('name')
if not task_name: if not task_name:
attrs['name'] = task_name = gen_task_name(app, name, task_module) attrs['name'] = task_name = gen_task_name(app, name, task_module)
if not attrs.get('_decorated'):
# non decorated tasks must also be shared in case
# an app is created multiple times due to modules
# imported under multiple names.
# Hairy stuff, here to be compatible with 2.x.
# People should not use non-abstract task classes anymore,
# use the task decorator.
from celery.app.builtins import shared_task
unique_name = '.'.join([task_module, name])
if unique_name not in cls._creation_count:
# the creation count is used as a safety
# so that the same task is not added recursively
# to the set of constructors.
cls._creation_count[unique_name] = 1
shared_task(_CompatShared(
unique_name,
lambda app: TaskType.__new__(cls, name, bases,
dict(attrs, _app=app)),
))
# - Create and register class. # - Create and register class.
# Because of the way import happens (recursively) # Because of the way import happens (recursively)
# we may or may not be the first time the task tries to register # we may or may not be the first time the task tries to register
@@ -126,13 +200,10 @@ class TaskType(type):
return instance.__class__ return instance.__class__
def __repr__(cls): def __repr__(cls):
if cls._app: return _reprtask(cls)
return '<class %s of %s>' % (cls.__name__, cls._app, )
if cls.__v2_compat__:
return '<unbound %s (v2 compatible)>' % (cls.__name__, )
return '<unbound %s>' % (cls.__name__, )
@with_metaclass(TaskType)
class Task(object): class Task(object):
"""Task base class. """Task base class.
@@ -141,7 +212,6 @@ class Task(object):
is overridden). is overridden).
""" """
__metaclass__ = TaskType
__trace__ = None __trace__ = None
__v2_compat__ = False # set by old base in celery.task.base __v2_compat__ = False # set by old base in celery.task.base
@@ -185,6 +255,11 @@ class Task(object):
#: setting. #: setting.
ignore_result = None ignore_result = None
#: If enabled the request will keep track of subtasks started by
#: this task, and this information will be sent with the result
#: (``result.children``).
trail = True
#: When enabled errors will be stored even if the task is otherwise #: When enabled errors will be stored even if the task is otherwise
#: configured to ignore results. #: configured to ignore results.
store_errors_even_if_ignored = None store_errors_even_if_ignored = None
@@ -243,6 +318,8 @@ class Task(object):
#: called. This should probably be deprecated. #: called. This should probably be deprecated.
_default_request = None _default_request = None
_exec_options = None
__bound__ = False __bound__ = False
from_config = ( from_config = (
@@ -266,14 +343,14 @@ class Task(object):
was_bound, self.__bound__ = self.__bound__, True was_bound, self.__bound__ = self.__bound__, True
self._app = app self._app = app
conf = app.conf conf = app.conf
self._exec_options = None # clear option cache
for attr_name, config_name in self.from_config: for attr_name, config_name in self.from_config:
if getattr(self, attr_name, None) is None: if getattr(self, attr_name, None) is None:
setattr(self, attr_name, conf[config_name]) setattr(self, attr_name, conf[config_name])
if self.accept_magic_kwargs is None: if self.accept_magic_kwargs is None:
self.accept_magic_kwargs = app.accept_magic_kwargs self.accept_magic_kwargs = app.accept_magic_kwargs
if self.backend is None: self.backend = app.backend
self.backend = app.backend
# decorate with annotations from config. # decorate with annotations from config.
if not was_bound: if not was_bound:
@@ -295,17 +372,19 @@ class Task(object):
@classmethod @classmethod
def _get_app(self): def _get_app(self):
if not self.__bound__ or self._app is None: if self._app is None:
self._app = current_app
if not self.__bound__:
# The app property's __set__ method is not called # The app property's __set__ method is not called
# if Task.app is set (on the class), so must bind on use. # if Task.app is set (on the class), so must bind on use.
self.bind(current_app) self.bind(self._app)
return self._app return self._app
app = class_property(_get_app, bind) app = class_property(_get_app, bind)
@classmethod @classmethod
def annotate(self): def annotate(self):
for d in resolve_all_annotations(self.app.annotations, self): for d in resolve_all_annotations(self.app.annotations, self):
for key, value in d.iteritems(): for key, value in items(d):
if key.startswith('@'): if key.startswith('@'):
self.add_around(key[1:], value) self.add_around(key[1:], value)
else: else:
@@ -332,17 +411,22 @@ class Task(object):
self.pop_request() self.pop_request()
_task_stack.pop() _task_stack.pop()
# - tasks are pickled into the name of the task only, and the reciever
# - simply grabs it from the local registry.
def __reduce__(self): def __reduce__(self):
return (_unpickle_task, (self.name, ), None) # - tasks are pickled into the name of the task only, and the reciever
# - simply grabs it from the local registry.
# - in later versions the module of the task is also included,
# - and the receiving side tries to import that module so that
# - it will work even if the task has not been registered.
mod = type(self).__module__
mod = mod if mod and mod in sys.modules else None
return (_unpickle_task_v2, (self.name, mod), None)
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
"""The body of the task executed by workers.""" """The body of the task executed by workers."""
raise NotImplementedError('Tasks must define the run method.') raise NotImplementedError('Tasks must define the run method.')
def start_strategy(self, app, consumer): def start_strategy(self, app, consumer, **kwargs):
return instantiate(self.Strategy, self, app, consumer) return instantiate(self.Strategy, self, app, consumer, **kwargs)
def delay(self, *args, **kwargs): def delay(self, *args, **kwargs):
"""Star argument version of :meth:`apply_async`. """Star argument version of :meth:`apply_async`.
@@ -357,10 +441,8 @@ class Task(object):
""" """
return self.apply_async(args, kwargs) return self.apply_async(args, kwargs)
def apply_async(self, args=None, kwargs=None, def apply_async(self, args=None, kwargs=None, task_id=None, producer=None,
task_id=None, producer=None, connection=None, router=None, link=None, link_error=None, **options):
link=None, link_error=None, publisher=None,
add_to_parent=True, **options):
"""Apply tasks asynchronously by sending a message. """Apply tasks asynchronously by sending a message.
:keyword args: The positional arguments to pass on to the :keyword args: The positional arguments to pass on to the
@@ -371,14 +453,12 @@ class Task(object):
:keyword countdown: Number of seconds into the future that the :keyword countdown: Number of seconds into the future that the
task should execute. Defaults to immediate task should execute. Defaults to immediate
execution (do not confuse with the execution.
`immediate` flag, as they are unrelated).
:keyword eta: A :class:`~datetime.datetime` object describing :keyword eta: A :class:`~datetime.datetime` object describing
the absolute time and date of when the task should the absolute time and date of when the task should
be executed. May not be specified if `countdown` be executed. May not be specified if `countdown`
is also supplied. (Do not confuse this with the is also supplied.
`immediate` flag, as they are unrelated).
:keyword expires: Either a :class:`int`, describing the number of :keyword expires: Either a :class:`int`, describing the number of
seconds, or a :class:`~datetime.datetime` object seconds, or a :class:`~datetime.datetime` object
@@ -429,70 +509,56 @@ class Task(object):
:func:`kombu.compression.register`. Defaults to :func:`kombu.compression.register`. Defaults to
the :setting:`CELERY_MESSAGE_COMPRESSION` the :setting:`CELERY_MESSAGE_COMPRESSION`
setting. setting.
:keyword link: A single, or a list of subtasks to apply if the :keyword link: A single, or a list of tasks to apply if the
task exits successfully. task exits successfully.
:keyword link_error: A single, or a list of subtasks to apply :keyword link_error: A single, or a list of tasks to apply
if an error occurs while executing the task. if an error occurs while executing the task.
:keyword producer: :class:~@amqp.TaskProducer` instance to use. :keyword producer: :class:~@amqp.TaskProducer` instance to use.
:keyword add_to_parent: If set to True (default) and the task :keyword add_to_parent: If set to True (default) and the task
is applied while executing another task, then the result is applied while executing another task, then the result
will be appended to the parent tasks ``request.children`` will be appended to the parent tasks ``request.children``
attribute. attribute. Trailing can also be disabled by default using the
:attr:`trail` attribute
:keyword publisher: Deprecated alias to ``producer``. :keyword publisher: Deprecated alias to ``producer``.
Also supports all keyword arguments supported by Also supports all keyword arguments supported by
:meth:`kombu.messaging.Producer.publish`. :meth:`kombu.Producer.publish`.
.. note:: .. note::
If the :setting:`CELERY_ALWAYS_EAGER` setting is set, it will If the :setting:`CELERY_ALWAYS_EAGER` setting is set, it will
be replaced by a local :func:`apply` call instead. be replaced by a local :func:`apply` call instead.
""" """
producer = producer or publisher
app = self._get_app() app = self._get_app()
router = router or self.app.amqp.router if app.conf.CELERY_ALWAYS_EAGER:
conf = app.conf return self.apply(args, kwargs, task_id=task_id or uuid(),
# add 'self' if this is a bound method.
if self.__self__ is not None:
args = (self.__self__, ) + tuple(args)
if conf.CELERY_ALWAYS_EAGER:
return self.apply(args, kwargs, task_id=task_id,
link=link, link_error=link_error, **options) link=link, link_error=link_error, **options)
options = dict(extract_exec_options(self), **options) # add 'self' if this is a "task_method".
options = router.route(options, self.name, args, kwargs) if self.__self__ is not None:
args = args if isinstance(args, tuple) else tuple(args or ())
if connection: args = (self.__self__, ) + args
producer = app.amqp.TaskProducer(connection) return app.send_task(
with app.producer_or_acquire(producer) as P: self.name, args, kwargs, task_id=task_id, producer=producer,
task_id = P.publish_task(self.name, args, kwargs, link=link, link_error=link_error, result_cls=self.AsyncResult,
task_id=task_id, **dict(self._get_exec_options(), **options)
callbacks=maybe_list(link), )
errbacks=maybe_list(link_error),
**options)
result = self.AsyncResult(task_id)
if add_to_parent:
parent = get_current_worker_task()
if parent:
parent.request.children.append(result)
return result
def subtask_from_request(self, request=None, args=None, kwargs=None, def subtask_from_request(self, request=None, args=None, kwargs=None,
**extra_options): **extra_options):
request = self.request if request is None else request request = self.request if request is None else request
args = request.args if args is None else args args = request.args if args is None else args
kwargs = request.kwargs if kwargs is None else kwargs kwargs = request.kwargs if kwargs is None else kwargs
delivery_info = request.delivery_info or {} limit_hard, limit_soft = request.timelimit or (None, None)
options = { options = dict({
'task_id': request.id, 'task_id': request.id,
'link': request.callbacks, 'link': request.callbacks,
'link_error': request.errbacks, 'link_error': request.errbacks,
'exchange': delivery_info.get('exchange'), 'group_id': request.group,
'routing_key': delivery_info.get('routing_key') 'chord': request.chord,
} 'soft_time_limit': limit_soft,
'time_limit': limit_hard,
}, **request.delivery_info or {})
return self.subtask(args, kwargs, options, type=self, **extra_options) return self.subtask(args, kwargs, options, type=self, **extra_options)
def retry(self, args=None, kwargs=None, exc=None, throw=True, def retry(self, args=None, kwargs=None, exc=None, throw=True,
@@ -503,7 +569,7 @@ class Task(object):
:param kwargs: Keyword arguments to retry with. :param kwargs: Keyword arguments to retry with.
:keyword exc: Custom exception to report when the max restart :keyword exc: Custom exception to report when the max restart
limit has been exceeded (default: limit has been exceeded (default:
:exc:`~celery.exceptions.MaxRetriesExceededError`). :exc:`~@MaxRetriesExceededError`).
If this argument is set and retry is called while If this argument is set and retry is called while
an exception was raised (``sys.exc_info()`` is set) an exception was raised (``sys.exc_info()`` is set)
@@ -515,16 +581,19 @@ class Task(object):
:keyword eta: Explicit time and date to run the retry at :keyword eta: Explicit time and date to run the retry at
(must be a :class:`~datetime.datetime` instance). (must be a :class:`~datetime.datetime` instance).
:keyword max_retries: If set, overrides the default retry limit. :keyword max_retries: If set, overrides the default retry limit.
:keyword time_limit: If set, overrides the default time limit.
:keyword soft_time_limit: If set, overrides the default soft
time limit.
:keyword \*\*options: Any extra options to pass on to :keyword \*\*options: Any extra options to pass on to
meth:`apply_async`. meth:`apply_async`.
:keyword throw: If this is :const:`False`, do not raise the :keyword throw: If this is :const:`False`, do not raise the
:exc:`~celery.exceptions.RetryTaskError` exception, :exc:`~@Retry` exception,
that tells the worker to mark the task as being that tells the worker to mark the task as being
retried. Note that this means the task will be retried. Note that this means the task will be
marked as failed if the task raises an exception, marked as failed if the task raises an exception,
or successful if it returns. or successful if it returns.
:raises celery.exceptions.RetryTaskError: To tell the worker that :raises celery.exceptions.Retry: To tell the worker that
the task has been re-sent for retry. This always happens, the task has been re-sent for retry. This always happens,
unless the `throw` keyword argument has been explicitly set unless the `throw` keyword argument has been explicitly set
to :const:`False`, and is considered normal operation. to :const:`False`, and is considered normal operation.
@@ -533,17 +602,20 @@ class Task(object):
.. code-block:: python .. code-block:: python
>>> @task() >>> from imaginary_twitter_lib import Twitter
>>> def tweet(auth, message): >>> from proj.celery import app
>>> @app.task()
... def tweet(auth, message):
... twitter = Twitter(oauth=auth) ... twitter = Twitter(oauth=auth)
... try: ... try:
... twitter.post_status_update(message) ... twitter.post_status_update(message)
... except twitter.FailWhale, exc: ... except twitter.FailWhale as exc:
... # Retry in 5 minutes. ... # Retry in 5 minutes.
... raise tweet.retry(countdown=60 * 5, exc=exc) ... raise tweet.retry(countdown=60 * 5, exc=exc)
Although the task will never return above as `retry` raises an Although the task will never return above as `retry` raises an
exception to notify the worker, we use `return` in front of the retry exception to notify the worker, we use `raise` in front of the retry
to convey that the rest of the block will not be executed. to convey that the rest of the block will not be executed.
""" """
@@ -555,11 +627,12 @@ class Task(object):
# so just raise the original exception. # so just raise the original exception.
if request.called_directly: if request.called_directly:
maybe_reraise() # raise orig stack if PyErr_Occurred maybe_reraise() # raise orig stack if PyErr_Occurred
raise exc or RetryTaskError('Task can be retried', None) raise exc or Retry('Task can be retried', None)
if not eta and countdown is None: if not eta and countdown is None:
countdown = self.default_retry_delay countdown = self.default_retry_delay
is_eager = request.is_eager
S = self.subtask_from_request( S = self.subtask_from_request(
request, args, kwargs, request, args, kwargs,
countdown=countdown, eta=eta, retries=retries, countdown=countdown, eta=eta, retries=retries,
@@ -570,13 +643,18 @@ class Task(object):
if exc: if exc:
maybe_reraise() maybe_reraise()
raise self.MaxRetriesExceededError( raise self.MaxRetriesExceededError(
"""Can't retry %s[%s] args:%s kwargs:%s""" % ( "Can't retry {0}[{1}] args:{2} kwargs:{3}".format(
self.name, request.id, S.args, S.kwargs)) self.name, request.id, S.args, S.kwargs))
# If task was executed eagerly using apply(), # If task was executed eagerly using apply(),
# then the retry must also be executed eagerly. # then the retry must also be executed eagerly.
S.apply().get() if request.is_eager else S.apply_async() try:
ret = RetryTaskError(exc=exc, when=eta or countdown) S.apply().get() if is_eager else S.apply_async()
except Exception as exc:
if is_eager:
raise
raise Reject(exc, requeue=True)
ret = Retry(exc=exc, when=eta or countdown)
if throw: if throw:
raise ret raise ret
return ret return ret
@@ -595,7 +673,7 @@ class Task(object):
""" """
# trace imports Task, so need to import inline. # trace imports Task, so need to import inline.
from celery.task.trace import eager_trace_task from celery.app.trace import eager_trace_task
app = self._get_app() app = self._get_app()
args = args or () args = args or ()
@@ -629,12 +707,13 @@ class Task(object):
'delivery_info': {'is_eager': True}} 'delivery_info': {'is_eager': True}}
supported_keys = fun_takes_kwargs(task.run, default_kwargs) supported_keys = fun_takes_kwargs(task.run, default_kwargs)
extend_with = dict((key, val) extend_with = dict((key, val)
for key, val in default_kwargs.items() for key, val in items(default_kwargs)
if key in supported_keys) if key in supported_keys)
kwargs.update(extend_with) kwargs.update(extend_with)
tb = None tb = None
retval, info = eager_trace_task(task, task_id, args, kwargs, retval, info = eager_trace_task(task, task_id, args, kwargs,
app=self._get_app(),
request=request, propagate=throw) request=request, propagate=throw)
if isinstance(retval, ExceptionInfo): if isinstance(retval, ExceptionInfo):
retval, tb = retval.exception, retval.traceback retval, tb = retval.exception, retval.traceback
@@ -651,10 +730,11 @@ class Task(object):
task_name=self.name, **kwargs) task_name=self.name, **kwargs)
def subtask(self, args=None, *starargs, **starkwargs): def subtask(self, args=None, *starargs, **starkwargs):
"""Returns :class:`~celery.subtask` object for """Return :class:`~celery.signature` object for
this task, wrapping arguments and execution options this task, wrapping arguments and execution options
for a single task invocation.""" for a single task invocation."""
return subtask(self, args, *starargs, **starkwargs) starkwargs.setdefault('app', self.app)
return signature(self, args, *starargs, **starkwargs)
def s(self, *args, **kwargs): def s(self, *args, **kwargs):
"""``.s(*a, **k) -> .subtask(a, k)``""" """``.s(*a, **k) -> .subtask(a, k)``"""
@@ -667,17 +747,17 @@ class Task(object):
def chunks(self, it, n): def chunks(self, it, n):
"""Creates a :class:`~celery.canvas.chunks` task for this task.""" """Creates a :class:`~celery.canvas.chunks` task for this task."""
from celery import chunks from celery import chunks
return chunks(self.s(), it, n) return chunks(self.s(), it, n, app=self.app)
def map(self, it): def map(self, it):
"""Creates a :class:`~celery.canvas.xmap` task from ``it``.""" """Creates a :class:`~celery.canvas.xmap` task from ``it``."""
from celery import xmap from celery import xmap
return xmap(self.s(), it) return xmap(self.s(), it, app=self.app)
def starmap(self, it): def starmap(self, it):
"""Creates a :class:`~celery.canvas.xstarmap` task from ``it``.""" """Creates a :class:`~celery.canvas.xstarmap` task from ``it``."""
from celery import xstarmap from celery import xstarmap
return xstarmap(self.s(), it) return xstarmap(self.s(), it, app=self.app)
def update_state(self, task_id=None, state=None, meta=None): def update_state(self, task_id=None, state=None, meta=None):
"""Update task state. """Update task state.
@@ -719,7 +799,7 @@ class Task(object):
:param args: Original arguments for the retried task. :param args: Original arguments for the retried task.
:param kwargs: Original keyword arguments for the retried task. :param kwargs: Original keyword arguments for the retried task.
:keyword einfo: :class:`~celery.datastructures.ExceptionInfo` :keyword einfo: :class:`~billiard.einfo.ExceptionInfo`
instance, containing the traceback. instance, containing the traceback.
The return value of this handler is ignored. The return value of this handler is ignored.
@@ -738,7 +818,7 @@ class Task(object):
:param kwargs: Original keyword arguments for the task :param kwargs: Original keyword arguments for the task
that failed. that failed.
:keyword einfo: :class:`~celery.datastructures.ExceptionInfo` :keyword einfo: :class:`~billiard.einfo.ExceptionInfo`
instance, containing the traceback. instance, containing the traceback.
The return value of this handler is ignored. The return value of this handler is ignored.
@@ -756,7 +836,7 @@ class Task(object):
:param kwargs: Original keyword arguments for the task :param kwargs: Original keyword arguments for the task
that failed. that failed.
:keyword einfo: :class:`~celery.datastructures.ExceptionInfo` :keyword einfo: :class:`~billiard.einfo.ExceptionInfo`
instance, containing the traceback (if any). instance, containing the traceback (if any).
The return value of this handler is ignored. The return value of this handler is ignored.
@@ -769,6 +849,11 @@ class Task(object):
not getattr(self, 'disable_error_emails', None): not getattr(self, 'disable_error_emails', None):
self.ErrorMail(self, **kwargs).send(context, exc) self.ErrorMail(self, **kwargs).send(context, exc)
def add_trail(self, result):
if self.trail:
self.request.children.append(result)
return result
def push_request(self, *args, **kwargs): def push_request(self, *args, **kwargs):
self.request_stack.push(Context(*args, **kwargs)) self.request_stack.push(Context(*args, **kwargs))
@@ -777,9 +862,7 @@ class Task(object):
def __repr__(self): def __repr__(self):
"""`repr(task)`""" """`repr(task)`"""
if self.__self__: return _reprtask(self, R_SELF_TASK if self.__self__ else R_INSTANCE)
return '<bound task %s of %r>' % (self.name, self.__self__)
return '<@task: %s>' % (self.name, )
def _get_request(self): def _get_request(self):
"""Get current request object.""" """Get current request object."""
@@ -793,6 +876,11 @@ class Task(object):
return req return req
request = property(_get_request) request = property(_get_request)
def _get_exec_options(self):
if self._exec_options is None:
self._exec_options = extract_exec_options(self)
return self._exec_options
@property @property
def __name__(self): def __name__(self):
return self.__class__.__name__ return self.__class__.__name__

Some files were not shown because too many files have changed in this diff Show More