Merge remote-tracking branch 'origin/master' into auditlog

* origin/master:
  AC-637 Credential now requires scm_key_unlock when saving encrypted ssh_key_data.
  AC-626 Removed support for prompting for password and ssh_key_unlock for scm/cloud credentials.
  AC-613 Change rackspace to rax for inventory source field value.
  AC-613 Change rackspace to rax for inventory source field value.
  AC-624 Fix options docs for project update view.
  AC-632 Fix escaping for ansible-playbook command line when also using ssh-agent.
  Update CONTRIBUTING.md
  AC-630 Expose cloud_credentials field for job template and job.
  AC-641 Added pattern to respond to key unlock prompt for project update.
  Updated all vendored third-party packages.
  AC-636 Fix existing projects with scm_type=null to always use empty string. Update validation and tests to ensure None gets automatically coerced to an empty string on saving a project.
  AC-633 js error fixed.
  AC-633 fixed a sort of unrelated js error. The capitalize filter directive attempted to act on a 'null' input error. Added a test to ignore empty/null input.
  AC-633 Fixed 'hast' typo.
  AC-617 changed callback generation icon to a magic wand, which will hopefully satiate jlaska.
  AC-627 Fixed password/ssh_password collision in Credentials.js form. This was also fixed in auditlog branch.
  AC-628 applied credential changes made in add controller to edit controller so that credential/cloud_credential lookups display context-aware  credential lists.
  Moved credentials in tab order. It now follows teams and precedes projects. Based on a suggestion from jlaska.
  AC-609 Fixed issue with help button not displaying correctly on 'select' pages where user can pick an existing object (i.e. users, credentials, etc) to add to a parent object.

Conflicts:
	awx/api/serializers.py
	awx/main/migrations/0025_v14_changes.py
This commit is contained in:
Matthew Jones 2013-11-18 09:18:37 -05:00
commit ba1a113ec3
797 changed files with 46286 additions and 28659 deletions

View File

@ -96,6 +96,7 @@ class ChoiceField(fields.ChoiceField):
# ModelSerializer.
serializers.ChoiceField = ChoiceField
class BaseSerializer(serializers.ModelSerializer):
# add the URL and related resources
@ -187,10 +188,6 @@ class BaseSerializer(serializers.ModelSerializer):
else:
return obj.active
def validate_description(self, attrs, source):
# Description should always be empty string, never null.
attrs[source] = attrs.get(source, None) or ''
return attrs
class UserSerializer(BaseSerializer):
@ -278,6 +275,7 @@ class UserSerializer(BaseSerializer):
def validate_is_superuser(self, attrs, source):
return self._validate_ldap_managed_field(attrs, source)
class OrganizationSerializer(BaseSerializer):
class Meta:
@ -299,6 +297,7 @@ class OrganizationSerializer(BaseSerializer):
))
return res
class ProjectSerializer(BaseSerializer):
playbooks = serializers.Field(source='playbooks', help_text='Array of playbooks available within this project.')
@ -334,27 +333,32 @@ class ProjectSerializer(BaseSerializer):
args=(obj.last_update.pk,))
return res
def _get_scm_type(self, attrs, source=None):
if self.object:
return attrs.get(source or 'scm_type', self.object.scm_type) or u''
else:
return attrs.get(source or 'scm_type', u'') or u''
def validate_local_path(self, attrs, source):
# Don't allow assigning a local_path used by another project.
# Don't allow assigning a local_path when scm_type is set.
valid_local_paths = Project.get_local_path_choices()
if self.object:
scm_type = attrs.get('scm_type', self.object.scm_type)
if not scm_type:
valid_local_paths.append(self.object.local_path)
else:
scm_type = attrs.get('scm_type', '')
scm_type = self._get_scm_type(attrs)
if self.object and not scm_type:
valid_local_paths.append(self.object.local_path)
if scm_type:
attrs.pop(source, None)
if source in attrs and attrs[source] not in valid_local_paths:
raise serializers.ValidationError('Invalid path choice')
return attrs
def validate_scm_type(self, attrs, source):
scm_type = self._get_scm_type(attrs, source)
attrs[source] = scm_type
return attrs
def validate_scm_url(self, attrs, source):
if self.object:
scm_type = attrs.get('scm_type', self.object.scm_type) or ''
else:
scm_type = attrs.get('scm_type', '') or ''
scm_type = self._get_scm_type(attrs)
scm_url = unicode(attrs.get(source, None) or '')
if not scm_type:
return attrs
@ -413,6 +417,7 @@ class ProjectSerializer(BaseSerializer):
# FIXME: Validate combination of SCM URL and credential!
class ProjectPlaybooksSerializer(ProjectSerializer):
class Meta:
@ -423,6 +428,7 @@ class ProjectPlaybooksSerializer(ProjectSerializer):
ret = super(ProjectPlaybooksSerializer, self).to_native(obj)
return ret.get('playbooks', [])
class ProjectUpdateSerializer(BaseSerializer):
class Meta:
@ -441,6 +447,7 @@ class ProjectUpdateSerializer(BaseSerializer):
))
return res
class BaseSerializerWithVariables(BaseSerializer):
def validate_variables(self, attrs, source):
@ -453,6 +460,7 @@ class BaseSerializerWithVariables(BaseSerializer):
raise serializers.ValidationError('Must be valid JSON or YAML')
return attrs
class InventorySerializer(BaseSerializerWithVariables):
class Meta:
@ -481,6 +489,7 @@ class InventorySerializer(BaseSerializerWithVariables):
))
return res
class HostSerializer(BaseSerializerWithVariables):
class Meta:
@ -611,6 +620,7 @@ class GroupSerializer(BaseSerializerWithVariables):
raise serializers.ValidationError('Invalid group name')
return attrs
class GroupTreeSerializer(GroupSerializer):
children = serializers.SerializerMethodField('get_children')
@ -628,6 +638,7 @@ class GroupTreeSerializer(GroupSerializer):
children_qs = obj.children.filter(active=True)
return GroupTreeSerializer(children_qs, many=True).data
class BaseVariableDataSerializer(BaseSerializer):
def to_native(self, obj):
@ -643,24 +654,28 @@ class BaseVariableDataSerializer(BaseSerializer):
data = {'variables': json.dumps(data)}
return super(BaseVariableDataSerializer, self).from_native(data, files)
class InventoryVariableDataSerializer(BaseVariableDataSerializer):
class Meta:
model = Inventory
fields = ('variables',)
class HostVariableDataSerializer(BaseVariableDataSerializer):
class Meta:
model = Host
fields = ('variables',)
class GroupVariableDataSerializer(BaseVariableDataSerializer):
class Meta:
model = Group
fields = ('variables',)
class InventorySourceSerializer(BaseSerializer):
#source_password = serializers.WritableField(required=False, default='')
@ -730,6 +745,7 @@ class InventorySourceSerializer(BaseSerializer):
# FIXME
return attrs
class InventoryUpdateSerializer(BaseSerializer):
class Meta:
@ -749,6 +765,7 @@ class InventoryUpdateSerializer(BaseSerializer):
))
return res
class TeamSerializer(BaseSerializer):
class Meta:
@ -768,6 +785,7 @@ class TeamSerializer(BaseSerializer):
))
return res
class PermissionSerializer(BaseSerializer):
class Meta:
@ -789,6 +807,7 @@ class PermissionSerializer(BaseSerializer):
res['inventory'] = reverse('api:inventory_detail', args=(obj.inventory.pk,))
return res
def validate(self, attrs):
# Can only set either user or team.
if attrs['user'] and attrs['team']:
@ -804,6 +823,7 @@ class PermissionSerializer(BaseSerializer):
'assigning deployment permissions')
return attrs
class CredentialSerializer(BaseSerializer):
# FIXME: may want to make some of these filtered based on user accessing
@ -845,13 +865,15 @@ class CredentialSerializer(BaseSerializer):
res['team'] = reverse('api:team_detail', args=(obj.team.pk,))
return res
class JobTemplateSerializer(BaseSerializer):
class Meta:
model = JobTemplate
fields = BASE_FIELDS + ('job_type', 'inventory', 'project', 'playbook',
'credential', 'forks', 'limit', 'verbosity',
'extra_vars', 'job_tags', 'host_config_key')
'credential', 'cloud_credential', 'forks',
'limit', 'verbosity', 'extra_vars', 'job_tags',
'host_config_key')
def get_related(self, obj):
if obj is None:
@ -864,6 +886,9 @@ class JobTemplateSerializer(BaseSerializer):
))
if obj.credential:
res['credential'] = reverse('api:credential_detail', args=(obj.credential.pk,))
if obj.cloud_credential:
res['cloud_credential'] = reverse('api:credential_detail',
args=(obj.cloud_credential.pk,))
if obj.host_config_key:
res['callback'] = reverse('api:job_template_callback', args=(obj.pk,))
return res
@ -875,6 +900,7 @@ class JobTemplateSerializer(BaseSerializer):
raise serializers.ValidationError('Playbook not found for project')
return attrs
class JobSerializer(BaseSerializer):
passwords_needed_to_start = serializers.Field(source='passwords_needed_to_start')
@ -883,7 +909,7 @@ class JobSerializer(BaseSerializer):
model = Job
fields = ('id', 'url', 'related', 'summary_fields', 'created',
'modified', 'job_template', 'job_type', 'inventory',
'project', 'playbook', 'credential',
'project', 'playbook', 'credential', 'cloud_credential',
'forks', 'limit', 'verbosity', 'extra_vars',
'job_tags', 'launch_type', 'status', 'failed',
'result_stdout', 'result_traceback',
@ -903,6 +929,9 @@ class JobSerializer(BaseSerializer):
))
if obj.job_template:
res['job_template'] = reverse('api:job_template_detail', args=(obj.job_template.pk,))
if obj.cloud_credential:
res['cloud_credential'] = reverse('api:credential_detail',
args=(obj.cloud_credential.pk,))
if obj.can_start or True:
res['start'] = reverse('api:job_start', args=(obj.pk,))
if obj.can_cancel or True:
@ -925,6 +954,8 @@ class JobSerializer(BaseSerializer):
data.setdefault('playbook', job_template.playbook)
if job_template.credential:
data.setdefault('credential', job_template.credential.pk)
if job_template.cloud_credential:
data.setdefault('cloud_credential', job_template.cloud_credential.pk)
data.setdefault('forks', job_template.forks)
data.setdefault('limit', job_template.limit)
data.setdefault('verbosity', job_template.verbosity)
@ -932,6 +963,7 @@ class JobSerializer(BaseSerializer):
data.setdefault('job_tags', job_template.job_tags)
return super(JobSerializer, self).from_native(data, files)
class JobHostSummarySerializer(BaseSerializer):
class Meta:
@ -961,6 +993,7 @@ class JobHostSummarySerializer(BaseSerializer):
pass
return d
class JobEventSerializer(BaseSerializer):
event_display = serializers.Field(source='get_event_display2')
@ -1055,7 +1088,6 @@ class ActivityStreamSerializer(BaseSerializer):
pass
return d
class AuthTokenSerializer(serializers.Serializer):
username = serializers.CharField()

View File

@ -1,18 +1,13 @@
# Update Inventory Source
Make a GET request to this resource to determine if the group can be updated
from its inventory source and whether any passwords are required for the
update. The response will include the following fields:
from its inventory source. The response will include the following field:
* `can_start`: Flag indicating if this job can be started (boolean, read-only)
* `passwords_needed_to_update`: Password names required to update from the
inventory source (array, read-only)
* `can_update`: Flag indicating if this inventory source can be updated
(boolean, read-only)
Make a POST request to this resource to update the inventory source. If any
passwords are required, they must be passed via POST data.
If successful, the response status code will be 202. If any required passwords
are not provided, a 400 status code will be returned. If the inventory source
is not defined or cannot be updated, a 405 status code will be returned.
Make a POST request to this resource to update the inventory source. If
successful, the response status code will be 202. If the inventory source is
not defined or cannot be updated, a 405 status code will be returned.
{% include "api/_new_in_awx.md" %}

View File

@ -5,7 +5,8 @@ whether any passwords are required to start the job. The response will include
the following fields:
* `can_start`: Flag indicating if this job can be started (boolean, read-only)
* `passwords_needed_to_start`: Password names required to start the job (array, read-only)
* `passwords_needed_to_start`: Password names required to start the job (array,
read-only)
Make a POST request to this resource to start the job. If any passwords are
required, they must be passed via POST data.

View File

@ -1,18 +1,12 @@
# Update Project
Make a GET request to this resource to determine if the project can be updated
from its SCM source and whether any passwords are required for the update. The
response will include the following fields:
from its SCM source. The response will include the following field:
* `can_start`: Flag indicating if this job can be started (boolean, read-only)
* `passwords_needed_to_update`: Password names required to update the project
(array, read-only)
* `can_update`: Flag indicating if this project can be updated (boolean,
read-only)
Make a POST request to this resource to update the project. If any passwords
are required, they must be passed via POST data.
If successful, the response status code will be 202. If any required passwords
are not provided, a 400 status code will be returned. If the project cannot be
updated, a 405 status code will be returned.
Make a POST request to this resource to update the project. If the project
cannot be updated, a 405 status code will be returned.
{% include "api/_new_in_awx.md" %}

View File

@ -300,8 +300,6 @@ class ProjectUpdateView(GenericAPIView):
data = dict(
can_update=obj.can_update,
)
if obj.scm_type:
data['passwords_needed_to_update'] = obj.scm_passwords_needed
return Response(data)
def post(self, request, *args, **kwargs):
@ -309,8 +307,7 @@ class ProjectUpdateView(GenericAPIView):
if obj.can_update:
project_update = obj.update(**request.DATA)
if not project_update:
data = dict(passwords_needed_to_update=obj.scm_passwords_needed)
return Response(data, status=status.HTTP_400_BAD_REQUEST)
return Response({}, status=status.HTTP_400_BAD_REQUEST)
else:
headers = {'Location': project_update.get_absolute_url()}
return Response(status=status.HTTP_202_ACCEPTED, headers=headers)

View File

@ -1,50 +1,51 @@
Local versions of third-party packages required by AWX. Package names and
versions are listed below, along with notes on which files are included.
amqp==1.2.1 (amqp/*)
amqp==1.3.3 (amqp/*)
anyjson==0.3.3 (anyjson/*)
argparse==1.2.1 (argparse.py, needed for Python 2.6 support)
Babel==1.3 (babel/*, excluded bin/pybabel)
billiard==2.7.3.32 (billiard/*, funtests/*, excluded _billiard.so)
boto==2.13.3 (boto/*, excluded bin/asadmin, bin/bundle_image, bin/cfadmin,
billiard==3.3.0.6 (billiard/*, funtests/*, excluded _billiard.so)
boto==2.17.0 (boto/*, excluded bin/asadmin, bin/bundle_image, bin/cfadmin,
bin/cq, bin/cwutil, bin/dynamodb_dump, bin/dynamodb_load, bin/elbadmin,
bin/fetch_file, bin/glacier, bin/instance_events, bin/kill_instance,
bin/launch_instance, bin/list_instances, bin/lss3, bin/mturk,
bin/pyami_sendmail, bin/route53, bin/s3put, bin/sdbadmin, bin/taskadmin)
celery==3.0.23 (celery/*, excluded bin/celery* and bin/camqadm)
celery==3.1.3 (celery/*, excluded bin/celery*)
d2to1==0.2.11 (d2to1/*)
distribute==0.7.3 (no files)
django-auth-ldap==1.1.4 (django_auth_ldap/*)
django-celery==3.0.23 (djcelery/*, excluded bin/djcelerymon)
django-extensions==1.2.2 (django_extensions/*)
django-jsonfield==0.9.10 (jsonfield/*)
django-auth-ldap==1.1.6 (django_auth_ldap/*)
django-celery==3.1.1 (djcelery/*)
django-extensions==1.2.5 (django_extensions/*)
django-jsonfield==0.9.11 (jsonfield/*)
django-taggit==0.10 (taggit/*)
djangorestframework==2.3.8 (rest_framework/*)
httplib2==0.8 (httplib2/*)
importlib==1.0.2 (importlib/*, needed for Python 2.6 support)
iso8601==0.1.4 (iso8601/*)
keyring==3.0.5 (keyring/*, excluded bin/keyring)
kombu==2.5.14 (kombu/*)
iso8601==0.1.8 (iso8601/*)
keyring==3.2 (keyring/*, excluded bin/keyring)
kombu==3.0.4 (kombu/*)
Markdown==2.3.1 (markdown/*, excluded bin/markdown_py)
mock==1.0.1 (mock.py)
ordereddict==1.1 (ordereddict.py, needed for Python 2.6 support)
os-diskconfig-python-novaclient-ext==0.1.1 (os_diskconfig_python_novaclient_ext/*)
os-networksv2-python-novaclient-ext==0.21 (os_networksv2_python_novaclient_ext.py)
pbr==0.5.21 (pbr/*)
pexpect==2.4 (pexpect.py, pxssh.py, fdpexpect.py, FSM.py, screen.py, ANSI.py)
pbr==0.5.23 (pbr/*)
pexpect==3.0 (pexpect/*, excluded pxssh.py, fdpexpect.py, FSM.py, screen.py,
ANSI.py)
pip==1.4.1 (pip/*, excluded bin/pip*)
prettytable==0.7.2 (prettytable.py)
pyrax==1.5.0 (pyrax/*)
python-dateutil==2.1 (dateutil/*)
pyrax==1.6.2 (pyrax/*)
python-dateutil==2.2 (dateutil/*)
python-novaclient==2.15.0 (novaclient/*, excluded bin/nova)
python-swiftclient==1.6.0 (swiftclient/*, excluded bin/swift)
pytz==2013d (pytz/*)
rackspace-auth-openstack==1.0 (rackspace_auth_openstack/*)
python-swiftclient==1.8.0 (swiftclient/*, excluded bin/swift)
pytz==2013.8 (pytz/*)
rackspace-auth-openstack==1.1 (rackspace_auth_openstack/*)
rackspace-novaclient==1.3 (no files)
rax-default-network-flags-python-novaclient-ext==0.1.3 (rax_default_network_flags_python_novaclient_ext/*)
rax-scheduled-images-python-novaclient-ext==0.2.1 (rax_scheduled_images_python_novaclient_ext/*)
requests==2.0.0 (requests/*)
setuptools==1.1.6 (setuptools/*, _markerlib/*, pkg_resources.py, easy_install.py, excluded bin/easy_install*)
simplejson==3.3.0 (simplejson/*, excluded simplejson/_speedups.so)
requests==2.0.1 (requests/*)
setuptools==1.3.2 (setuptools/*, _markerlib/*, pkg_resources.py, easy_install.py, excluded bin/easy_install*)
simplejson==3.3.1 (simplejson/*, excluded simplejson/_speedups.so)
six==1.4.1 (six.py)
South==0.8.2 (south/*)
South==0.8.3 (south/*)

View File

@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301
from __future__ import absolute_import
VERSION = (1, 2, 1)
VERSION = (1, 3, 3)
__version__ = '.'.join(map(str, VERSION[0:3])) + ''.join(VERSION[3:])
__author__ = 'Barry Pederson'
__maintainer__ = 'Ask Solem'
@ -61,6 +61,7 @@ from .exceptions import ( # noqa
error_for_code,
__all__ as _all_exceptions,
)
from .utils import promise # noqa
__all__ = [
'Connection',

View File

@ -24,6 +24,7 @@ from warnings import warn
from .abstract_channel import AbstractChannel
from .exceptions import ChannelError, ConsumerCancelled, error_for_code
from .five import Queue
from .protocol import basic_return_t, queue_declare_ok_t
from .serialization import AMQPWriter
__all__ = ['Channel']
@ -80,6 +81,12 @@ class Channel(AbstractChannel):
self.events = defaultdict(set)
self.no_ack_consumers = set()
# set first time basic_publish_confirm is called
# and publisher confirms are enabled for this channel.
self._confirm_selected = False
if self.connection.confirm_publish:
self.basic_publish = self.basic_publish_confirm
self._x_open()
def _do_close(self):
@ -1272,10 +1279,11 @@ class Channel(AbstractChannel):
this count.
"""
queue = args.read_shortstr()
message_count = args.read_long()
consumer_count = args.read_long()
return queue, message_count, consumer_count
return queue_declare_ok_t(
args.read_shortstr(),
args.read_long(),
args.read_long(),
)
def queue_delete(self, queue='',
if_unused=False, if_empty=False, nowait=False):
@ -1875,6 +1883,7 @@ class Channel(AbstractChannel):
exchange = args.read_shortstr()
routing_key = args.read_shortstr()
msg.channel = self
msg.delivery_info = {
'consumer_tag': consumer_tag,
'delivery_tag': delivery_tag,
@ -1883,8 +1892,11 @@ class Channel(AbstractChannel):
'routing_key': routing_key,
}
fun = self.callbacks.get(consumer_tag, None)
if fun is not None:
try:
fun = self.callbacks[consumer_tag]
except KeyError:
pass
else:
fun(msg)
def basic_get(self, queue='', no_ack=False):
@ -2015,6 +2027,7 @@ class Channel(AbstractChannel):
routing_key = args.read_shortstr()
message_count = args.read_long()
msg.channel = self
msg.delivery_info = {
'delivery_tag': delivery_tag,
'redelivered': redelivered,
@ -2024,8 +2037,8 @@ class Channel(AbstractChannel):
}
return msg
def basic_publish(self, msg, exchange='', routing_key='',
mandatory=False, immediate=False):
def _basic_publish(self, msg, exchange='', routing_key='',
mandatory=False, immediate=False):
"""Publish a message
This method publishes a message to a specific exchange. The
@ -2099,6 +2112,15 @@ class Channel(AbstractChannel):
args.write_bit(immediate)
self._send_method((60, 40), args, msg)
basic_publish = _basic_publish
def basic_publish_confirm(self, *args, **kwargs):
if not self._confirm_selected:
self._confirm_selected = True
self.confirm_select()
ret = self._basic_publish(*args, **kwargs)
self.wait([(60, 80)])
return ret
def basic_qos(self, prefetch_size, prefetch_count, a_global):
"""Specify quality of service
@ -2334,14 +2356,13 @@ class Channel(AbstractChannel):
message was published.
"""
reply_code = args.read_short()
reply_text = args.read_shortstr()
exchange = args.read_shortstr()
routing_key = args.read_shortstr()
self.returned_messages.put(
(reply_code, reply_text, exchange, routing_key, msg)
)
self.returned_messages.put(basic_return_t(
args.read_short(),
args.read_shortstr(),
args.read_shortstr(),
args.read_shortstr(),
msg,
))
#############
#

View File

@ -89,7 +89,7 @@ class Connection(AbstractChannel):
virtual_host='/', locale='en_US', client_properties=None,
ssl=False, connect_timeout=None, channel_max=None,
frame_max=None, heartbeat=0, on_blocked=None,
on_unblocked=None, **kwargs):
on_unblocked=None, confirm_publish=False, **kwargs):
"""Create a connection to the specified host, which should be
a 'host[:port]', such as 'localhost', or '1.2.3.4:5672'
(defaults to 'localhost', if a port is not specified then
@ -127,6 +127,8 @@ class Connection(AbstractChannel):
self.frame_max = frame_max
self.heartbeat = heartbeat
self.confirm_publish = confirm_publish
# Callbacks
self.on_blocked = on_blocked
self.on_unblocked = on_unblocked
@ -163,6 +165,10 @@ class Connection(AbstractChannel):
return self._x_open(virtual_host)
@property
def connected(self):
return self.transport and self.transport.connected
def _do_close(self):
try:
self.transport.close()

View File

@ -47,7 +47,9 @@ class AMQPError(Exception):
reply_text, method_sig, self.method_name)
def __str__(self):
return '{0.method}: ({0.reply_code}) {0.reply_text}'.format(self)
if self.method:
return '{0.method}: ({0.reply_code}) {0.reply_text}'.format(self)
return self.reply_text or '<AMQPError: unknown error>'
@property
def method(self):

View File

@ -46,7 +46,7 @@ _CONTENT_METHODS = [
class _PartialMessage(object):
"""Helper class to build up a multi-frame method."""
def __init__(self, method_sig, args):
def __init__(self, method_sig, args, channel):
self.method_sig = method_sig
self.args = args
self.msg = Message()
@ -147,7 +147,9 @@ class MethodReader(object):
#
# Save what we've got so far and wait for the content-header
#
self.partial_messages[channel] = _PartialMessage(method_sig, args)
self.partial_messages[channel] = _PartialMessage(
method_sig, args, channel,
)
self.expected_types[channel] = 2
else:
self._quick_put((channel, method_sig, args, None))

View File

@ -0,0 +1,13 @@
from __future__ import absolute_import
from collections import namedtuple
queue_declare_ok_t = namedtuple(
'queue_declare_ok_t', ('queue', 'message_count', 'consumer_count'),
)
basic_return_t = namedtuple(
'basic_return_t',
('reply_code', 'reply_text', 'exchange', 'routing_key', 'message'),
)

View File

@ -49,6 +49,9 @@ except:
from struct import pack, unpack
from .exceptions import UnexpectedFrame
from .utils import get_errno, set_cloexec
_UNAVAIL = errno.EAGAIN, errno.EINTR
AMQP_PORT = 5672
@ -63,8 +66,10 @@ IPV6_LITERAL = re.compile(r'\[([\.0-9a-f:]+)\](?::(\d+))?')
class _AbstractTransport(object):
"""Common superclass for TCP and SSL transports"""
connected = False
def __init__(self, host, connect_timeout):
self.connected = True
msg = None
port = AMQP_PORT
@ -85,6 +90,10 @@ class _AbstractTransport(object):
af, socktype, proto, canonname, sa = res
try:
self.sock = socket.socket(af, socktype, proto)
try:
set_cloexec(self.sock, True)
except NotImplementedError:
pass
self.sock.settimeout(connect_timeout)
self.sock.connect(sa)
except socket.error as exc:
@ -99,13 +108,18 @@ class _AbstractTransport(object):
# Didn't connect, return the most recent error message
raise socket.error(last_err)
self.sock.settimeout(None)
self.sock.setsockopt(SOL_TCP, socket.TCP_NODELAY, 1)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
try:
self.sock.settimeout(None)
self.sock.setsockopt(SOL_TCP, socket.TCP_NODELAY, 1)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self._setup_transport()
self._setup_transport()
self._write(AMQP_PROTOCOL_HEADER)
self._write(AMQP_PROTOCOL_HEADER)
except (OSError, IOError, socket.error) as exc:
if get_errno(exc) not in _UNAVAIL:
self.connected = False
raise
def __del__(self):
try:
@ -141,12 +155,20 @@ class _AbstractTransport(object):
self.sock.shutdown(socket.SHUT_RDWR)
self.sock.close()
self.sock = None
self.connected = False
def read_frame(self, unpack=unpack):
read = self._read
frame_type, channel, size = unpack('>BHI', read(7, True))
payload = read(size)
ch = ord(read(1))
try:
frame_type, channel, size = unpack('>BHI', read(7, True))
payload = read(size)
ch = ord(read(1))
except socket.timeout:
raise
except (OSError, IOError, socket.error) as exc:
if get_errno(exc) not in _UNAVAIL:
self.connected = False
raise
if ch == 206: # '\xce'
return frame_type, channel, payload
else:
@ -155,10 +177,17 @@ class _AbstractTransport(object):
def write_frame(self, frame_type, channel, payload):
size = len(payload)
self._write(pack(
'>BHI%dsB' % size,
frame_type, channel, size, payload, 0xce,
))
try:
self._write(pack(
'>BHI%dsB' % size,
frame_type, channel, size, payload, 0xce,
))
except socket.timeout:
raise
except (OSError, IOError, socket.error) as exc:
if get_errno(exc) not in _UNAVAIL:
self.connected = False
raise
class SSLTransport(_AbstractTransport):
@ -200,19 +229,22 @@ class SSLTransport(_AbstractTransport):
# to get the exact number of bytes wanted.
recv = self._quick_recv
rbuf = self._read_buffer
while len(rbuf) < n:
try:
s = recv(131072) # see note above
except socket.error as exc:
# ssl.sock.read may cause ENOENT if the
# operation couldn't be performed (Issue celery#1414).
if not initial and exc.errno in _errnos:
continue
raise exc
if not s:
raise IOError('Socket closed')
rbuf += s
try:
while len(rbuf) < n:
try:
s = recv(131072) # see note above
except socket.error as exc:
# ssl.sock.read may cause ENOENT if the
# operation couldn't be performed (Issue celery#1414).
if not initial and exc.errno in _errnos:
continue
raise
if not s:
raise IOError('Socket closed')
rbuf += s
except:
self._read_buffer = rbuf
raise
result, self._read_buffer = rbuf[:n], rbuf[n:]
return result
@ -240,16 +272,20 @@ class TCPTransport(_AbstractTransport):
"""Read exactly n bytes from the socket"""
recv = self._quick_recv
rbuf = self._read_buffer
while len(rbuf) < n:
try:
s = recv(131072)
except socket.error as exc:
if not initial and exc.errno in _errnos:
continue
raise
if not s:
raise IOError('Socket closed')
rbuf += s
try:
while len(rbuf) < n:
try:
s = recv(131072)
except socket.error as exc:
if not initial and exc.errno in _errnos:
continue
raise
if not s:
raise IOError('Socket closed')
rbuf += s
except:
self._read_buffer = rbuf
raise
result, self._read_buffer = rbuf[:n], rbuf[n:]
return result

View File

@ -2,6 +2,11 @@ from __future__ import absolute_import
import sys
try:
import fcntl
except ImportError:
fcntl = None # noqa
class promise(object):
if not hasattr(sys, 'pypy_version_info'):
@ -59,3 +64,36 @@ class promise(object):
def noop():
return promise(lambda *a, **k: None)
try:
from os import set_cloexec # Python 3.4?
except ImportError:
def set_cloexec(fd, cloexec): # noqa
try:
FD_CLOEXEC = fcntl.FD_CLOEXEC
except AttributeError:
raise NotImplementedError(
'close-on-exec flag not supported on this platform',
)
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
if cloexec:
flags |= FD_CLOEXEC
else:
flags &= ~FD_CLOEXEC
return fcntl.fcntl(fd, fcntl.F_SETFD, flags)
def get_errno(exc):
""":exc:`socket.error` and :exc:`IOError` first got
the ``.errno`` attribute in Py2.7"""
try:
return exc.errno
except AttributeError:
try:
# e.args = (errno, reason)
if isinstance(exc.args, tuple) and len(exc.args) == 2:
return exc.args[0]
except AttributeError:
pass
return 0

View File

@ -18,9 +18,8 @@
#
from __future__ import absolute_import
from __future__ import with_statement
VERSION = (2, 7, 3, 32)
VERSION = (3, 3, 0, 6)
__version__ = ".".join(map(str, VERSION[0:4])) + "".join(VERSION[4:])
__author__ = 'R Oudkerk / Python Software Foundation'
__author_email__ = 'python-dev@python.org'
@ -90,15 +89,12 @@ def Manager():
return m
def Pipe(duplex=True):
def Pipe(duplex=True, rnonblock=False, wnonblock=False):
'''
Returns two connection object connected by a pipe
'''
if sys.version_info[0] == 3:
from multiprocessing.connection import Pipe
else:
from billiard._connection import Pipe
return Pipe(duplex)
from billiard.connection import Pipe
return Pipe(duplex, rnonblock, wnonblock)
def cpu_count():
@ -241,7 +237,11 @@ def Pool(processes=None, initializer=None, initargs=(), maxtasksperchild=None,
Returns a process pool object
'''
from .pool import Pool
return Pool(processes, initializer, initargs, maxtasksperchild)
return Pool(processes, initializer, initargs, maxtasksperchild,
timeout, soft_timeout, lost_worker_timeout,
max_restarts, max_restart_freq, on_process_up,
on_process_down, on_timeout_set, on_timeout_cancel,
threads, semaphore, putlocks, allow_restart)
def RawValue(typecode_or_type, *args):

View File

@ -8,7 +8,6 @@
#
from __future__ import absolute_import
from __future__ import with_statement
__all__ = ['Client', 'Listener', 'Pipe']
@ -21,11 +20,13 @@ import tempfile
import itertools
from . import AuthenticationError
from . import reduction
from ._ext import _billiard, win32
from .compat import get_errno
from .util import get_temp_dir, Finalize, sub_debug, debug
from .compat import get_errno, bytes, setblocking
from .five import monotonic
from .forking import duplicate, close
from .compat import bytes
from .reduction import ForkingPickler
from .util import get_temp_dir, Finalize, sub_debug, debug
try:
WindowsError = WindowsError # noqa
@ -36,6 +37,9 @@ except NameError:
# global set later
xmlrpclib = None
Connection = getattr(_billiard, 'Connection', None)
PipeConnection = getattr(_billiard, 'PipeConnection', None)
#
#
@ -60,11 +64,11 @@ if sys.platform == 'win32':
def _init_timeout(timeout=CONNECTION_TIMEOUT):
return time.time() + timeout
return monotonic() + timeout
def _check_timeout(t):
return time.time() > t
return monotonic() > t
#
#
@ -81,7 +85,7 @@ def arbitrary_address(family):
return tempfile.mktemp(prefix='listener-', dir=get_temp_dir())
elif family == 'AF_PIPE':
return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' %
(os.getpid(), _mmap_counter.next()))
(os.getpid(), next(_mmap_counter)))
else:
raise ValueError('unrecognized family')
@ -183,26 +187,32 @@ def Client(address, family=None, authkey=None):
if sys.platform != 'win32':
def Pipe(duplex=True):
def Pipe(duplex=True, rnonblock=False, wnonblock=False):
'''
Returns pair of connection objects at either end of a pipe
'''
if duplex:
s1, s2 = socket.socketpair()
c1 = _billiard.Connection(os.dup(s1.fileno()))
c2 = _billiard.Connection(os.dup(s2.fileno()))
s1.setblocking(not rnonblock)
s2.setblocking(not wnonblock)
c1 = Connection(os.dup(s1.fileno()))
c2 = Connection(os.dup(s2.fileno()))
s1.close()
s2.close()
else:
fd1, fd2 = os.pipe()
c1 = _billiard.Connection(fd1, writable=False)
c2 = _billiard.Connection(fd2, readable=False)
if rnonblock:
setblocking(fd1, 0)
if wnonblock:
setblocking(fd2, 0)
c1 = Connection(fd1, writable=False)
c2 = Connection(fd2, readable=False)
return c1, c2
else:
def Pipe(duplex=True): # noqa
def Pipe(duplex=True, rnonblock=False, wnonblock=False): # noqa
'''
Returns pair of connection objects at either end of a pipe
'''
@ -231,12 +241,12 @@ else:
try:
win32.ConnectNamedPipe(h1, win32.NULL)
except WindowsError, e:
if e.args[0] != win32.ERROR_PIPE_CONNECTED:
except WindowsError as exc:
if exc.args[0] != win32.ERROR_PIPE_CONNECTED:
raise
c1 = _billiard.PipeConnection(h1, writable=duplex)
c2 = _billiard.PipeConnection(h2, readable=duplex)
c1 = PipeConnection(h1, writable=duplex)
c2 = PipeConnection(h2, readable=duplex)
return c1, c2
@ -275,7 +285,7 @@ class SocketListener(object):
def accept(self):
s, self._last_accepted = self._socket.accept()
fd = duplicate(s.fileno())
conn = _billiard.Connection(fd)
conn = Connection(fd)
s.close()
return conn
@ -296,7 +306,7 @@ def SocketClient(address):
while 1:
try:
s.connect(address)
except socket.error, exc:
except socket.error as exc:
if get_errno(exc) != errno.ECONNREFUSED or _check_timeout(t):
debug('failed to connect to address %s', address)
raise
@ -307,7 +317,7 @@ def SocketClient(address):
raise
fd = duplicate(s.fileno())
conn = _billiard.Connection(fd)
conn = Connection(fd)
s.close()
return conn
@ -352,10 +362,10 @@ if sys.platform == 'win32':
handle = self._handle_queue.pop(0)
try:
win32.ConnectNamedPipe(handle, win32.NULL)
except WindowsError, e:
if e.args[0] != win32.ERROR_PIPE_CONNECTED:
except WindowsError as exc:
if exc.args[0] != win32.ERROR_PIPE_CONNECTED:
raise
return _billiard.PipeConnection(handle)
return PipeConnection(handle)
@staticmethod
def _finalize_pipe_listener(queue, address):
@ -375,8 +385,8 @@ if sys.platform == 'win32':
address, win32.GENERIC_READ | win32.GENERIC_WRITE,
0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL,
)
except WindowsError, e:
if e.args[0] not in (
except WindowsError as exc:
if exc.args[0] not in (
win32.ERROR_SEM_TIMEOUT,
win32.ERROR_PIPE_BUSY) or _check_timeout(t):
raise
@ -388,7 +398,7 @@ if sys.platform == 'win32':
win32.SetNamedPipeHandleState(
h, win32.PIPE_READMODE_MESSAGE, None, None
)
return _billiard.PipeConnection(h)
return PipeConnection(h)
#
# Authentication stuff
@ -471,3 +481,12 @@ def XmlClient(*args, **kwds):
global xmlrpclib
import xmlrpclib # noqa
return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
if sys.platform == 'win32':
ForkingPickler.register(socket.socket, reduction.reduce_socket)
ForkingPickler.register(Connection, reduction.reduce_connection)
ForkingPickler.register(PipeConnection, reduction.reduce_pipe_connection)
else:
ForkingPickler.register(socket.socket, reduction.reduce_socket)
ForkingPickler.register(Connection, reduction.reduce_connection)

View File

@ -0,0 +1,955 @@
#
# A higher level module for using sockets (or Windows named pipes)
#
# multiprocessing/connection.py
#
# Copyright (c) 2006-2008, R Oudkerk
# Licensed to PSF under a Contributor Agreement.
#
from __future__ import absolute_import
__all__ = ['Client', 'Listener', 'Pipe', 'wait']
import io
import os
import sys
import select
import socket
import struct
import errno
import tempfile
import itertools
import _multiprocessing
from .compat import setblocking
from .exceptions import AuthenticationError, BufferTooShort
from .five import monotonic
from .util import get_temp_dir, Finalize, sub_debug
from .reduction import ForkingPickler
try:
import _winapi
from _winapi import (
WAIT_OBJECT_0,
WAIT_ABANDONED_0,
WAIT_TIMEOUT,
INFINITE,
)
except ImportError:
if sys.platform == 'win32':
raise
_winapi = None
#
#
#
BUFSIZE = 8192
# A very generous timeout when it comes to local connections...
CONNECTION_TIMEOUT = 20.
_mmap_counter = itertools.count()
default_family = 'AF_INET'
families = ['AF_INET']
if hasattr(socket, 'AF_UNIX'):
default_family = 'AF_UNIX'
families += ['AF_UNIX']
if sys.platform == 'win32':
default_family = 'AF_PIPE'
families += ['AF_PIPE']
def _init_timeout(timeout=CONNECTION_TIMEOUT):
return monotonic() + timeout
def _check_timeout(t):
return monotonic() > t
def arbitrary_address(family):
'''
Return an arbitrary free address for the given family
'''
if family == 'AF_INET':
return ('localhost', 0)
elif family == 'AF_UNIX':
return tempfile.mktemp(prefix='listener-', dir=get_temp_dir())
elif family == 'AF_PIPE':
return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' %
(os.getpid(), next(_mmap_counter)))
else:
raise ValueError('unrecognized family')
def _validate_family(family):
'''
Checks if the family is valid for the current environment.
'''
if sys.platform != 'win32' and family == 'AF_PIPE':
raise ValueError('Family %s is not recognized.' % family)
if sys.platform == 'win32' and family == 'AF_UNIX':
# double check
if not hasattr(socket, family):
raise ValueError('Family %s is not recognized.' % family)
def address_type(address):
'''
Return the types of the address
This can be 'AF_INET', 'AF_UNIX', or 'AF_PIPE'
'''
if type(address) == tuple:
return 'AF_INET'
elif type(address) is str and address.startswith('\\\\'):
return 'AF_PIPE'
elif type(address) is str:
return 'AF_UNIX'
else:
raise ValueError('address type of %r unrecognized' % address)
#
# Connection classes
#
class _ConnectionBase:
_handle = None
def __init__(self, handle, readable=True, writable=True):
handle = handle.__index__()
if handle < 0:
raise ValueError("invalid handle")
if not readable and not writable:
raise ValueError(
"at least one of `readable` and `writable` must be True")
self._handle = handle
self._readable = readable
self._writable = writable
# XXX should we use util.Finalize instead of a __del__?
def __del__(self):
if self._handle is not None:
self._close()
def _check_closed(self):
if self._handle is None:
raise OSError("handle is closed")
def _check_readable(self):
if not self._readable:
raise OSError("connection is write-only")
def _check_writable(self):
if not self._writable:
raise OSError("connection is read-only")
def _bad_message_length(self):
if self._writable:
self._readable = False
else:
self.close()
raise OSError("bad message length")
@property
def closed(self):
"""True if the connection is closed"""
return self._handle is None
@property
def readable(self):
"""True if the connection is readable"""
return self._readable
@property
def writable(self):
"""True if the connection is writable"""
return self._writable
def fileno(self):
"""File descriptor or handle of the connection"""
self._check_closed()
return self._handle
def close(self):
"""Close the connection"""
if self._handle is not None:
try:
self._close()
finally:
self._handle = None
def send_bytes(self, buf, offset=0, size=None):
"""Send the bytes data from a bytes-like object"""
self._check_closed()
self._check_writable()
m = memoryview(buf)
# HACK for byte-indexing of non-bytewise buffers (e.g. array.array)
if m.itemsize > 1:
m = memoryview(bytes(m))
n = len(m)
if offset < 0:
raise ValueError("offset is negative")
if n < offset:
raise ValueError("buffer length < offset")
if size is None:
size = n - offset
elif size < 0:
raise ValueError("size is negative")
elif offset + size > n:
raise ValueError("buffer length < offset + size")
self._send_bytes(m[offset:offset + size])
def send(self, obj):
"""Send a (picklable) object"""
self._check_closed()
self._check_writable()
self._send_bytes(ForkingPickler.dumps(obj))
def recv_bytes(self, maxlength=None):
"""
Receive bytes data as a bytes object.
"""
self._check_closed()
self._check_readable()
if maxlength is not None and maxlength < 0:
raise ValueError("negative maxlength")
buf = self._recv_bytes(maxlength)
if buf is None:
self._bad_message_length()
return buf.getvalue()
def recv_bytes_into(self, buf, offset=0):
"""
Receive bytes data into a writeable buffer-like object.
Return the number of bytes read.
"""
self._check_closed()
self._check_readable()
with memoryview(buf) as m:
# Get bytesize of arbitrary buffer
itemsize = m.itemsize
bytesize = itemsize * len(m)
if offset < 0:
raise ValueError("negative offset")
elif offset > bytesize:
raise ValueError("offset too large")
result = self._recv_bytes()
size = result.tell()
if bytesize < offset + size:
raise BufferTooShort(result.getvalue())
# Message can fit in dest
result.seek(0)
result.readinto(
m[offset // itemsize:(offset + size) // itemsize]
)
return size
def recv_payload(self):
return self._recv_bytes().getbuffer()
def recv(self):
"""Receive a (picklable) object"""
self._check_closed()
self._check_readable()
buf = self._recv_bytes()
return ForkingPickler.loads(buf.getbuffer())
def poll(self, timeout=0.0):
"""Whether there is any input available to be read"""
self._check_closed()
self._check_readable()
return self._poll(timeout)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
self.close()
if _winapi:
class PipeConnection(_ConnectionBase):
"""
Connection class based on a Windows named pipe.
Overlapped I/O is used, so the handles must have been created
with FILE_FLAG_OVERLAPPED.
"""
_got_empty_message = False
def _close(self, _CloseHandle=_winapi.CloseHandle):
_CloseHandle(self._handle)
def _send_bytes(self, buf):
ov, err = _winapi.WriteFile(self._handle, buf, overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
waitres = _winapi.WaitForMultipleObjects(
[ov.event], False, INFINITE)
assert waitres == WAIT_OBJECT_0
except:
ov.cancel()
raise
finally:
nwritten, err = ov.GetOverlappedResult(True)
assert err == 0
assert nwritten == len(buf)
def _recv_bytes(self, maxsize=None):
if self._got_empty_message:
self._got_empty_message = False
return io.BytesIO()
else:
bsize = 128 if maxsize is None else min(maxsize, 128)
try:
ov, err = _winapi.ReadFile(self._handle, bsize,
overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
waitres = _winapi.WaitForMultipleObjects(
[ov.event], False, INFINITE)
assert waitres == WAIT_OBJECT_0
except:
ov.cancel()
raise
finally:
nread, err = ov.GetOverlappedResult(True)
if err == 0:
f = io.BytesIO()
f.write(ov.getbuffer())
return f
elif err == _winapi.ERROR_MORE_DATA:
return self._get_more_data(ov, maxsize)
except OSError as e:
if e.winerror == _winapi.ERROR_BROKEN_PIPE:
raise EOFError
else:
raise
raise RuntimeError(
"shouldn't get here; expected KeyboardInterrupt"
)
def _poll(self, timeout):
if (self._got_empty_message or
_winapi.PeekNamedPipe(self._handle)[0] != 0):
return True
return bool(wait([self], timeout))
def _get_more_data(self, ov, maxsize):
buf = ov.getbuffer()
f = io.BytesIO()
f.write(buf)
left = _winapi.PeekNamedPipe(self._handle)[1]
assert left > 0
if maxsize is not None and len(buf) + left > maxsize:
self._bad_message_length()
ov, err = _winapi.ReadFile(self._handle, left, overlapped=True)
rbytes, err = ov.GetOverlappedResult(True)
assert err == 0
assert rbytes == left
f.write(ov.getbuffer())
return f
class Connection(_ConnectionBase):
"""
Connection class based on an arbitrary file descriptor (Unix only), or
a socket handle (Windows).
"""
if _winapi:
def _close(self, _close=_multiprocessing.closesocket):
_close(self._handle)
_write = _multiprocessing.send
_read = _multiprocessing.recv
else:
def _close(self, _close=os.close): # noqa
_close(self._handle)
_write = os.write
_read = os.read
def send_offset(self, buf, offset, write=_write):
return write(self._handle, buf[offset:])
def _send(self, buf, write=_write):
remaining = len(buf)
while True:
try:
n = write(self._handle, buf)
except OSError as exc:
if exc.errno == errno.EINTR:
continue
raise
remaining -= n
if remaining == 0:
break
buf = buf[n:]
def setblocking(self, blocking):
setblocking(self._handle, blocking)
def _recv(self, size, read=_read):
buf = io.BytesIO()
handle = self._handle
remaining = size
while remaining > 0:
try:
chunk = read(handle, remaining)
except OSError as exc:
if exc.errno == errno.EINTR:
continue
raise
n = len(chunk)
if n == 0:
if remaining == size:
raise EOFError
else:
raise OSError("got end of file during message")
buf.write(chunk)
remaining -= n
return buf
def _send_bytes(self, buf):
# For wire compatibility with 3.2 and lower
n = len(buf)
self._send(struct.pack("!i", n))
# The condition is necessary to avoid "broken pipe" errors
# when sending a 0-length buffer if the other end closed the pipe.
if n > 0:
self._send(buf)
def _recv_bytes(self, maxsize=None):
buf = self._recv(4)
size, = struct.unpack("!i", buf.getvalue())
if maxsize is not None and size > maxsize:
return None
return self._recv(size)
def _poll(self, timeout):
r = wait([self], timeout)
return bool(r)
#
# Public functions
#
class Listener(object):
'''
Returns a listener object.
This is a wrapper for a bound socket which is 'listening' for
connections, or for a Windows named pipe.
'''
def __init__(self, address=None, family=None, backlog=1, authkey=None):
family = (family or (address and address_type(address))
or default_family)
address = address or arbitrary_address(family)
_validate_family(family)
if family == 'AF_PIPE':
self._listener = PipeListener(address, backlog)
else:
self._listener = SocketListener(address, family, backlog)
if authkey is not None and not isinstance(authkey, bytes):
raise TypeError('authkey should be a byte string')
self._authkey = authkey
def accept(self):
'''
Accept a connection on the bound socket or named pipe of `self`.
Returns a `Connection` object.
'''
if self._listener is None:
raise OSError('listener is closed')
c = self._listener.accept()
if self._authkey:
deliver_challenge(c, self._authkey)
answer_challenge(c, self._authkey)
return c
def close(self):
'''
Close the bound socket or named pipe of `self`.
'''
if self._listener is not None:
self._listener.close()
self._listener = None
address = property(lambda self: self._listener._address)
last_accepted = property(lambda self: self._listener._last_accepted)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
self.close()
def Client(address, family=None, authkey=None):
'''
Returns a connection to the address of a `Listener`
'''
family = family or address_type(address)
_validate_family(family)
if family == 'AF_PIPE':
c = PipeClient(address)
else:
c = SocketClient(address)
if authkey is not None and not isinstance(authkey, bytes):
raise TypeError('authkey should be a byte string')
if authkey is not None:
answer_challenge(c, authkey)
deliver_challenge(c, authkey)
return c
if sys.platform != 'win32':
def Pipe(duplex=True, rnonblock=False, wnonblock=False):
'''
Returns pair of connection objects at either end of a pipe
'''
if duplex:
s1, s2 = socket.socketpair()
s1.setblocking(not rnonblock)
s2.setblocking(not wnonblock)
c1 = Connection(s1.detach())
c2 = Connection(s2.detach())
else:
fd1, fd2 = os.pipe()
if rnonblock:
setblocking(fd1, 0)
if wnonblock:
setblocking(fd2, 0)
c1 = Connection(fd1, writable=False)
c2 = Connection(fd2, readable=False)
return c1, c2
else:
def Pipe(duplex=True, rnonblock=False, wnonblock=False): # noqa
'''
Returns pair of connection objects at either end of a pipe
'''
address = arbitrary_address('AF_PIPE')
if duplex:
openmode = _winapi.PIPE_ACCESS_DUPLEX
access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE
obsize, ibsize = BUFSIZE, BUFSIZE
else:
openmode = _winapi.PIPE_ACCESS_INBOUND
access = _winapi.GENERIC_WRITE
obsize, ibsize = 0, BUFSIZE
h1 = _winapi.CreateNamedPipe(
address, openmode | _winapi.FILE_FLAG_OVERLAPPED |
_winapi.FILE_FLAG_FIRST_PIPE_INSTANCE,
_winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE |
_winapi.PIPE_WAIT,
1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL
)
h2 = _winapi.CreateFile(
address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING,
_winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL
)
_winapi.SetNamedPipeHandleState(
h2, _winapi.PIPE_READMODE_MESSAGE, None, None
)
overlapped = _winapi.ConnectNamedPipe(h1, overlapped=True)
_, err = overlapped.GetOverlappedResult(True)
assert err == 0
c1 = PipeConnection(h1, writable=duplex)
c2 = PipeConnection(h2, readable=duplex)
return c1, c2
#
# Definitions for connections based on sockets
#
class SocketListener(object):
'''
Representation of a socket which is bound to an address and listening
'''
def __init__(self, address, family, backlog=1):
self._socket = socket.socket(getattr(socket, family))
try:
# SO_REUSEADDR has different semantics on Windows (issue #2550).
if os.name == 'posix':
self._socket.setsockopt(socket.SOL_SOCKET,
socket.SO_REUSEADDR, 1)
self._socket.setblocking(True)
self._socket.bind(address)
self._socket.listen(backlog)
self._address = self._socket.getsockname()
except OSError:
self._socket.close()
raise
self._family = family
self._last_accepted = None
if family == 'AF_UNIX':
self._unlink = Finalize(
self, os.unlink, args=(address, ), exitpriority=0
)
else:
self._unlink = None
def accept(self):
while True:
try:
s, self._last_accepted = self._socket.accept()
except OSError as exc:
if exc.errno == errno.EINTR:
continue
raise
else:
break
s.setblocking(True)
return Connection(s.detach())
def close(self):
self._socket.close()
if self._unlink is not None:
self._unlink()
def SocketClient(address):
'''
Return a connection object connected to the socket given by `address`
'''
family = address_type(address)
with socket.socket(getattr(socket, family)) as s:
s.setblocking(True)
s.connect(address)
return Connection(s.detach())
#
# Definitions for connections based on named pipes
#
if sys.platform == 'win32':
class PipeListener(object):
'''
Representation of a named pipe
'''
def __init__(self, address, backlog=None):
self._address = address
self._handle_queue = [self._new_handle(first=True)]
self._last_accepted = None
sub_debug('listener created with address=%r', self._address)
self.close = Finalize(
self, PipeListener._finalize_pipe_listener,
args=(self._handle_queue, self._address), exitpriority=0
)
def _new_handle(self, first=False):
flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED
if first:
flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
return _winapi.CreateNamedPipe(
self._address, flags,
_winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE |
_winapi.PIPE_WAIT,
_winapi.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
_winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL
)
def accept(self):
self._handle_queue.append(self._new_handle())
handle = self._handle_queue.pop(0)
try:
ov = _winapi.ConnectNamedPipe(handle, overlapped=True)
except OSError as e:
if e.winerror != _winapi.ERROR_NO_DATA:
raise
# ERROR_NO_DATA can occur if a client has already connected,
# written data and then disconnected -- see Issue 14725.
else:
try:
_winapi.WaitForMultipleObjects([ov.event], False, INFINITE)
except:
ov.cancel()
_winapi.CloseHandle(handle)
raise
finally:
_, err = ov.GetOverlappedResult(True)
assert err == 0
return PipeConnection(handle)
@staticmethod
def _finalize_pipe_listener(queue, address):
sub_debug('closing listener with address=%r', address)
for handle in queue:
_winapi.CloseHandle(handle)
def PipeClient(address,
errors=(_winapi.ERROR_SEM_TIMEOUT,
_winapi.ERROR_PIPE_BUSY)):
'''
Return a connection object connected to the pipe given by `address`
'''
t = _init_timeout()
while 1:
try:
_winapi.WaitNamedPipe(address, 1000)
h = _winapi.CreateFile(
address, _winapi.GENERIC_READ | _winapi.GENERIC_WRITE,
0, _winapi.NULL, _winapi.OPEN_EXISTING,
_winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL
)
except OSError as e:
if e.winerror not in errors or _check_timeout(t):
raise
else:
break
else:
raise
_winapi.SetNamedPipeHandleState(
h, _winapi.PIPE_READMODE_MESSAGE, None, None
)
return PipeConnection(h)
#
# Authentication stuff
#
MESSAGE_LENGTH = 20
CHALLENGE = b'#CHALLENGE#'
WELCOME = b'#WELCOME#'
FAILURE = b'#FAILURE#'
def deliver_challenge(connection, authkey):
import hmac
assert isinstance(authkey, bytes)
message = os.urandom(MESSAGE_LENGTH)
connection.send_bytes(CHALLENGE + message)
digest = hmac.new(authkey, message).digest()
response = connection.recv_bytes(256) # reject large message
if response == digest:
connection.send_bytes(WELCOME)
else:
connection.send_bytes(FAILURE)
raise AuthenticationError('digest received was wrong')
def answer_challenge(connection, authkey):
import hmac
assert isinstance(authkey, bytes)
message = connection.recv_bytes(256) # reject large message
assert message[:len(CHALLENGE)] == CHALLENGE, 'message = %r' % message
message = message[len(CHALLENGE):]
digest = hmac.new(authkey, message).digest()
connection.send_bytes(digest)
response = connection.recv_bytes(256) # reject large message
if response != WELCOME:
raise AuthenticationError('digest sent was rejected')
#
# Support for using xmlrpclib for serialization
#
class ConnectionWrapper(object):
def __init__(self, conn, dumps, loads):
self._conn = conn
self._dumps = dumps
self._loads = loads
for attr in ('fileno', 'close', 'poll', 'recv_bytes', 'send_bytes'):
obj = getattr(conn, attr)
setattr(self, attr, obj)
def send(self, obj):
s = self._dumps(obj)
self._conn.send_bytes(s)
def recv(self):
s = self._conn.recv_bytes()
return self._loads(s)
def _xml_dumps(obj):
return xmlrpclib.dumps((obj,), None, None, None, 1).encode('utf-8') # noqa
def _xml_loads(s):
(obj,), method = xmlrpclib.loads(s.decode('utf-8')) # noqa
return obj
class XmlListener(Listener):
def accept(self):
global xmlrpclib
import xmlrpc.client as xmlrpclib # noqa
obj = Listener.accept(self)
return ConnectionWrapper(obj, _xml_dumps, _xml_loads)
def XmlClient(*args, **kwds):
global xmlrpclib
import xmlrpc.client as xmlrpclib # noqa
return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
#
# Wait
#
if sys.platform == 'win32':
def _exhaustive_wait(handles, timeout):
# Return ALL handles which are currently signalled. (Only
# returning the first signalled might create starvation issues.)
L = list(handles)
ready = []
while L:
res = _winapi.WaitForMultipleObjects(L, False, timeout)
if res == WAIT_TIMEOUT:
break
elif WAIT_OBJECT_0 <= res < WAIT_OBJECT_0 + len(L):
res -= WAIT_OBJECT_0
elif WAIT_ABANDONED_0 <= res < WAIT_ABANDONED_0 + len(L):
res -= WAIT_ABANDONED_0
else:
raise RuntimeError('Should not get here')
ready.append(L[res])
L = L[res+1:]
timeout = 0
return ready
_ready_errors = {_winapi.ERROR_BROKEN_PIPE, _winapi.ERROR_NETNAME_DELETED}
def wait(object_list, timeout=None):
'''
Wait till an object in object_list is ready/readable.
Returns list of those objects in object_list which are ready/readable.
'''
if timeout is None:
timeout = INFINITE
elif timeout < 0:
timeout = 0
else:
timeout = int(timeout * 1000 + 0.5)
object_list = list(object_list)
waithandle_to_obj = {}
ov_list = []
ready_objects = set()
ready_handles = set()
try:
for o in object_list:
try:
fileno = getattr(o, 'fileno')
except AttributeError:
waithandle_to_obj[o.__index__()] = o
else:
# start an overlapped read of length zero
try:
ov, err = _winapi.ReadFile(fileno(), 0, True)
except OSError as e:
err = e.winerror
if err not in _ready_errors:
raise
if err == _winapi.ERROR_IO_PENDING:
ov_list.append(ov)
waithandle_to_obj[ov.event] = o
else:
# If o.fileno() is an overlapped pipe handle and
# err == 0 then there is a zero length message
# in the pipe, but it HAS NOT been consumed.
ready_objects.add(o)
timeout = 0
ready_handles = _exhaustive_wait(waithandle_to_obj.keys(), timeout)
finally:
# request that overlapped reads stop
for ov in ov_list:
ov.cancel()
# wait for all overlapped reads to stop
for ov in ov_list:
try:
_, err = ov.GetOverlappedResult(True)
except OSError as e:
err = e.winerror
if err not in _ready_errors:
raise
if err != _winapi.ERROR_OPERATION_ABORTED:
o = waithandle_to_obj[ov.event]
ready_objects.add(o)
if err == 0:
# If o.fileno() is an overlapped pipe handle then
# a zero length message HAS been consumed.
if hasattr(o, '_got_empty_message'):
o._got_empty_message = True
ready_objects.update(waithandle_to_obj[h] for h in ready_handles)
return [o for o in object_list if o in ready_objects]
else:
if hasattr(select, 'poll'):
def _poll(fds, timeout):
if timeout is not None:
timeout = int(timeout * 1000) # timeout is in milliseconds
fd_map = {}
pollster = select.poll()
for fd in fds:
pollster.register(fd, select.POLLIN)
if hasattr(fd, 'fileno'):
fd_map[fd.fileno()] = fd
else:
fd_map[fd] = fd
ls = []
for fd, event in pollster.poll(timeout):
if event & select.POLLNVAL:
raise ValueError('invalid file descriptor %i' % fd)
ls.append(fd_map[fd])
return ls
else:
def _poll(fds, timeout): # noqa
return select.select(fds, [], [], timeout)[0]
def wait(object_list, timeout=None): # noqa
'''
Wait till an object in object_list is ready/readable.
Returns list of those objects in object_list which are ready/readable.
'''
if timeout is not None:
if timeout <= 0:
return _poll(object_list, 0)
else:
deadline = monotonic() + timeout
while True:
try:
return _poll(object_list, timeout)
except OSError as e:
if e.errno != errno.EINTR:
raise
if timeout is not None:
timeout = deadline - monotonic()

View File

@ -4,10 +4,7 @@ import sys
supports_exec = True
try:
import _winapi as win32
except ImportError: # pragma: no cover
win32 = None
from .compat import _winapi as win32 # noqa
if sys.platform.startswith("java"):
_billiard = None
@ -20,11 +17,9 @@ else:
try:
Connection = _billiard.Connection
except AttributeError: # Py3
from multiprocessing.connection import Connection # noqa
from billiard.connection import Connection # noqa
PipeConnection = getattr(_billiard, "PipeConnection", None)
if win32 is None:
win32 = getattr(_billiard, "win32", None) # noqa
def ensure_multiprocessing():

View File

@ -0,0 +1,244 @@
#
# Module to allow connection and socket objects to be transferred
# between processes
#
# multiprocessing/reduction.py
#
# Copyright (c) 2006-2008, R Oudkerk
# Licensed to PSF under a Contributor Agreement.
#
from __future__ import absolute_import
__all__ = []
import os
import sys
import socket
import threading
from pickle import Pickler
from . import current_process
from ._ext import _billiard, win32
from .util import register_after_fork, debug, sub_debug
if not(sys.platform == 'win32' or hasattr(_billiard, 'recvfd')):
raise ImportError('pickling of connections not supported')
close = win32.CloseHandle if sys.platform == 'win32' else os.close
# globals set later
_listener = None
_lock = None
_cache = set()
#
# ForkingPickler
#
class ForkingPickler(Pickler): # noqa
dispatch = Pickler.dispatch.copy()
@classmethod
def register(cls, type, reduce):
def dispatcher(self, obj):
rv = reduce(obj)
self.save_reduce(obj=obj, *rv)
cls.dispatch[type] = dispatcher
def _reduce_method(m): # noqa
if m.__self__ is None:
return getattr, (m.__self__.__class__, m.__func__.__name__)
else:
return getattr, (m.__self__, m.__func__.__name__)
ForkingPickler.register(type(ForkingPickler.save), _reduce_method)
def _reduce_method_descriptor(m):
return getattr, (m.__objclass__, m.__name__)
ForkingPickler.register(type(list.append), _reduce_method_descriptor)
ForkingPickler.register(type(int.__add__), _reduce_method_descriptor)
try:
from functools import partial
except ImportError:
pass
else:
def _reduce_partial(p):
return _rebuild_partial, (p.func, p.args, p.keywords or {})
def _rebuild_partial(func, args, keywords):
return partial(func, *args, **keywords)
ForkingPickler.register(partial, _reduce_partial)
def dump(obj, file, protocol=None):
ForkingPickler(file, protocol).dump(obj)
#
# Platform specific definitions
#
if sys.platform == 'win32':
# XXX Should this subprocess import be here?
import _subprocess # noqa
def send_handle(conn, handle, destination_pid):
from .forking import duplicate
process_handle = win32.OpenProcess(
win32.PROCESS_ALL_ACCESS, False, destination_pid
)
try:
new_handle = duplicate(handle, process_handle)
conn.send(new_handle)
finally:
close(process_handle)
def recv_handle(conn):
return conn.recv()
else:
def send_handle(conn, handle, destination_pid): # noqa
_billiard.sendfd(conn.fileno(), handle)
def recv_handle(conn): # noqa
return _billiard.recvfd(conn.fileno())
#
# Support for a per-process server thread which caches pickled handles
#
def _reset(obj):
global _lock, _listener, _cache
for h in _cache:
close(h)
_cache.clear()
_lock = threading.Lock()
_listener = None
_reset(None)
register_after_fork(_reset, _reset)
def _get_listener():
global _listener
if _listener is None:
_lock.acquire()
try:
if _listener is None:
from .connection import Listener
debug('starting listener and thread for sending handles')
_listener = Listener(authkey=current_process().authkey)
t = threading.Thread(target=_serve)
t.daemon = True
t.start()
finally:
_lock.release()
return _listener
def _serve():
from .util import is_exiting, sub_warning
while 1:
try:
conn = _listener.accept()
handle_wanted, destination_pid = conn.recv()
_cache.remove(handle_wanted)
send_handle(conn, handle_wanted, destination_pid)
close(handle_wanted)
conn.close()
except:
if not is_exiting():
sub_warning('thread for sharing handles raised exception',
exc_info=True)
#
# Functions to be used for pickling/unpickling objects with handles
#
def reduce_handle(handle):
from .forking import Popen, duplicate
if Popen.thread_is_spawning():
return (None, Popen.duplicate_for_child(handle), True)
dup_handle = duplicate(handle)
_cache.add(dup_handle)
sub_debug('reducing handle %d', handle)
return (_get_listener().address, dup_handle, False)
def rebuild_handle(pickled_data):
from .connection import Client
address, handle, inherited = pickled_data
if inherited:
return handle
sub_debug('rebuilding handle %d', handle)
conn = Client(address, authkey=current_process().authkey)
conn.send((handle, os.getpid()))
new_handle = recv_handle(conn)
conn.close()
return new_handle
#
# Register `_billiard.Connection` with `ForkingPickler`
#
def reduce_connection(conn):
rh = reduce_handle(conn.fileno())
return rebuild_connection, (rh, conn.readable, conn.writable)
def rebuild_connection(reduced_handle, readable, writable):
handle = rebuild_handle(reduced_handle)
return _billiard.Connection(
handle, readable=readable, writable=writable
)
# Register `socket.socket` with `ForkingPickler`
#
def fromfd(fd, family, type_, proto=0):
s = socket.fromfd(fd, family, type_, proto)
if s.__class__ is not socket.socket:
s = socket.socket(_sock=s)
return s
def reduce_socket(s):
reduced_handle = reduce_handle(s.fileno())
return rebuild_socket, (reduced_handle, s.family, s.type, s.proto)
def rebuild_socket(reduced_handle, family, type_, proto):
fd = rebuild_handle(reduced_handle)
_sock = fromfd(fd, family, type_, proto)
close(fd)
return _sock
ForkingPickler.register(socket.socket, reduce_socket)
#
# Register `_billiard.PipeConnection` with `ForkingPickler`
#
if sys.platform == 'win32':
def reduce_pipe_connection(conn):
rh = reduce_handle(conn.fileno())
return rebuild_pipe_connection, (rh, conn.readable, conn.writable)
def rebuild_pipe_connection(reduced_handle, readable, writable):
handle = rebuild_handle(reduced_handle)
return _billiard.PipeConnection(
handle, readable=readable, writable=writable
)

View File

@ -0,0 +1,249 @@
#
# Module which deals with pickling of objects.
#
# multiprocessing/reduction.py
#
# Copyright (c) 2006-2008, R Oudkerk
# Licensed to PSF under a Contributor Agreement.
#
from __future__ import absolute_import
import copyreg
import functools
import io
import os
import pickle
import socket
import sys
__all__ = ['send_handle', 'recv_handle', 'ForkingPickler', 'register', 'dump']
HAVE_SEND_HANDLE = (sys.platform == 'win32' or
(hasattr(socket, 'CMSG_LEN') and
hasattr(socket, 'SCM_RIGHTS') and
hasattr(socket.socket, 'sendmsg')))
#
# Pickler subclass
#
class ForkingPickler(pickle.Pickler):
'''Pickler subclass used by multiprocessing.'''
_extra_reducers = {}
_copyreg_dispatch_table = copyreg.dispatch_table
def __init__(self, *args):
super().__init__(*args)
self.dispatch_table = self._copyreg_dispatch_table.copy()
self.dispatch_table.update(self._extra_reducers)
@classmethod
def register(cls, type, reduce):
'''Register a reduce function for a type.'''
cls._extra_reducers[type] = reduce
@classmethod
def dumps(cls, obj, protocol=None):
buf = io.BytesIO()
cls(buf, protocol).dump(obj)
return buf.getbuffer()
loads = pickle.loads
register = ForkingPickler.register
def dump(obj, file, protocol=None):
'''Replacement for pickle.dump() using ForkingPickler.'''
ForkingPickler(file, protocol).dump(obj)
#
# Platform specific definitions
#
if sys.platform == 'win32':
# Windows
__all__ += ['DupHandle', 'duplicate', 'steal_handle']
import _winapi
def duplicate(handle, target_process=None, inheritable=False):
'''Duplicate a handle. (target_process is a handle not a pid!)'''
if target_process is None:
target_process = _winapi.GetCurrentProcess()
return _winapi.DuplicateHandle(
_winapi.GetCurrentProcess(), handle, target_process,
0, inheritable, _winapi.DUPLICATE_SAME_ACCESS)
def steal_handle(source_pid, handle):
'''Steal a handle from process identified by source_pid.'''
source_process_handle = _winapi.OpenProcess(
_winapi.PROCESS_DUP_HANDLE, False, source_pid)
try:
return _winapi.DuplicateHandle(
source_process_handle, handle,
_winapi.GetCurrentProcess(), 0, False,
_winapi.DUPLICATE_SAME_ACCESS | _winapi.DUPLICATE_CLOSE_SOURCE)
finally:
_winapi.CloseHandle(source_process_handle)
def send_handle(conn, handle, destination_pid):
'''Send a handle over a local connection.'''
dh = DupHandle(handle, _winapi.DUPLICATE_SAME_ACCESS, destination_pid)
conn.send(dh)
def recv_handle(conn):
'''Receive a handle over a local connection.'''
return conn.recv().detach()
class DupHandle(object):
'''Picklable wrapper for a handle.'''
def __init__(self, handle, access, pid=None):
if pid is None:
# We just duplicate the handle in the current process and
# let the receiving process steal the handle.
pid = os.getpid()
proc = _winapi.OpenProcess(_winapi.PROCESS_DUP_HANDLE, False, pid)
try:
self._handle = _winapi.DuplicateHandle(
_winapi.GetCurrentProcess(),
handle, proc, access, False, 0)
finally:
_winapi.CloseHandle(proc)
self._access = access
self._pid = pid
def detach(self):
'''Get the handle. This should only be called once.'''
# retrieve handle from process which currently owns it
if self._pid == os.getpid():
# The handle has already been duplicated for this process.
return self._handle
# We must steal the handle from the process whose pid is self._pid.
proc = _winapi.OpenProcess(_winapi.PROCESS_DUP_HANDLE, False,
self._pid)
try:
return _winapi.DuplicateHandle(
proc, self._handle, _winapi.GetCurrentProcess(),
self._access, False, _winapi.DUPLICATE_CLOSE_SOURCE)
finally:
_winapi.CloseHandle(proc)
else:
# Unix
__all__ += ['DupFd', 'sendfds', 'recvfds']
import array
# On MacOSX we should acknowledge receipt of fds -- see Issue14669
ACKNOWLEDGE = sys.platform == 'darwin'
def sendfds(sock, fds):
'''Send an array of fds over an AF_UNIX socket.'''
fds = array.array('i', fds)
msg = bytes([len(fds) % 256])
sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)])
if ACKNOWLEDGE and sock.recv(1) != b'A':
raise RuntimeError('did not receive acknowledgement of fd')
def recvfds(sock, size):
'''Receive an array of fds over an AF_UNIX socket.'''
a = array.array('i')
bytes_size = a.itemsize * size
msg, ancdata, flags, addr = sock.recvmsg(
1, socket.CMSG_LEN(bytes_size),
)
if not msg and not ancdata:
raise EOFError
try:
if ACKNOWLEDGE:
sock.send(b'A')
if len(ancdata) != 1:
raise RuntimeError(
'received %d items of ancdata' % len(ancdata),
)
cmsg_level, cmsg_type, cmsg_data = ancdata[0]
if (cmsg_level == socket.SOL_SOCKET and
cmsg_type == socket.SCM_RIGHTS):
if len(cmsg_data) % a.itemsize != 0:
raise ValueError
a.frombytes(cmsg_data)
assert len(a) % 256 == msg[0]
return list(a)
except (ValueError, IndexError):
pass
raise RuntimeError('Invalid data received')
def send_handle(conn, handle, destination_pid): # noqa
'''Send a handle over a local connection.'''
fd = conn.fileno()
with socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) as s:
sendfds(s, [handle])
def recv_handle(conn): # noqa
'''Receive a handle over a local connection.'''
fd = conn.fileno()
with socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) as s:
return recvfds(s, 1)[0]
def DupFd(fd):
'''Return a wrapper for an fd.'''
from .forking import Popen
return Popen.duplicate_for_child(fd)
#
# Try making some callable types picklable
#
def _reduce_method(m):
if m.__self__ is None:
return getattr, (m.__class__, m.__func__.__name__)
else:
return getattr, (m.__self__, m.__func__.__name__)
class _C:
def f(self):
pass
register(type(_C().f), _reduce_method)
def _reduce_method_descriptor(m):
return getattr, (m.__objclass__, m.__name__)
register(type(list.append), _reduce_method_descriptor)
register(type(int.__add__), _reduce_method_descriptor)
def _reduce_partial(p):
return _rebuild_partial, (p.func, p.args, p.keywords or {})
def _rebuild_partial(func, args, keywords):
return functools.partial(func, *args, **keywords)
register(functools.partial, _reduce_partial)
#
# Make sockets picklable
#
if sys.platform == 'win32':
def _reduce_socket(s):
from .resource_sharer import DupSocket
return _rebuild_socket, (DupSocket(s),)
def _rebuild_socket(ds):
return ds.detach()
register(socket.socket, _reduce_socket)
else:
def _reduce_socket(s): # noqa
df = DupFd(s.fileno())
return _rebuild_socket, (df, s.family, s.type, s.proto)
def _rebuild_socket(df, family, type, proto): # noqa
fd = df.detach()
return socket.socket(family, type, proto, fileno=fd)
register(socket.socket, _reduce_socket)

View File

@ -88,7 +88,7 @@ def get_all_processes_pids():
def get_processtree_pids(pid, include_parent=True):
"""Return a list with all the pids of a process tree"""
parents = get_all_processes_pids()
all_pids = parents.keys()
all_pids = list(parents.keys())
pids = set([pid])
while 1:
pids_new = pids.copy()

View File

@ -4,10 +4,10 @@ This module contains utilities added by billiard, to keep
"non-core" functionality out of ``.util``."""
from __future__ import absolute_import
import os
import signal
import sys
from time import time
import pickle as pypickle
try:
import cPickle as cpickle
@ -15,6 +15,7 @@ except ImportError: # pragma: no cover
cpickle = None # noqa
from .exceptions import RestartFreqExceeded
from .five import monotonic
if sys.version_info < (2, 6): # pragma: no cover
# cPickle does not use absolute_imports
@ -36,16 +37,15 @@ else:
except ImportError:
from StringIO import StringIO as BytesIO # noqa
EX_SOFTWARE = 70
TERMSIGS = (
'SIGHUP',
'SIGQUIT',
'SIGILL',
'SIGTRAP',
'SIGABRT',
'SIGEMT',
'SIGFPE',
'SIGBUS',
'SIGSEGV',
'SIGSYS',
'SIGPIPE',
'SIGALRM',
@ -58,13 +58,33 @@ TERMSIGS = (
'SIGUSR2',
)
#: set by signal handlers just before calling exit.
#: if this is true after the sighandler returns it means that something
#: went wrong while terminating the process, and :func:`os._exit`
#: must be called ASAP.
_should_have_exited = [False]
def pickle_loads(s, load=pickle_load):
# used to support buffer objects
return load(BytesIO(s))
def maybe_setsignal(signum, handler):
try:
signal.signal(signum, handler)
except (OSError, AttributeError, ValueError, RuntimeError):
pass
def _shutdown_cleanup(signum, frame):
# we will exit here so if the signal is received a second time
# we can be sure that something is very wrong and we may be in
# a crashing loop.
if _should_have_exited[0]:
os._exit(EX_SOFTWARE)
maybe_setsignal(signum, signal.SIG_DFL)
_should_have_exited[0] = True
sys.exit(-(256 - signum))
@ -72,11 +92,12 @@ def reset_signals(handler=_shutdown_cleanup):
for sig in TERMSIGS:
try:
signum = getattr(signal, sig)
except AttributeError:
pass
else:
current = signal.getsignal(signum)
if current is not None and current != signal.SIG_IGN:
signal.signal(signum, handler)
except (OSError, AttributeError, ValueError, RuntimeError):
pass
maybe_setsignal(signum, handler)
class restart_state(object):
@ -87,7 +108,7 @@ class restart_state(object):
self.R, self.T = 0, None
def step(self, now=None):
now = time() if now is None else now
now = monotonic() if now is None else now
R = self.R
if self.T and now - self.T >= self.maxT:
# maxT passed, reset counter and time passed.
@ -98,9 +119,8 @@ class restart_state(object):
# the startup probably went fine (startup restart burst
# protection)
if self.R: # pragma: no cover
pass
self.R = 0 # reset in case someone catches the error
raise self.RestartFreqExceeded("%r in %rs" % (R, self.maxT))
self.R = 0 # reset in case someone catches the error
raise self.RestartFreqExceeded("%r in %rs" % (R, self.maxT))
# first run sets T
if self.T is None:
self.T = now

View File

@ -3,13 +3,50 @@ from __future__ import absolute_import
import errno
import os
import sys
import __builtin__
from .five import builtins, range
if sys.platform == 'win32':
try:
import _winapi # noqa
except ImportError: # pragma: no cover
try:
from _billiard import win32 as _winapi # noqa
except (ImportError, AttributeError):
from _multiprocessing import win32 as _winapi # noqa
else:
_winapi = None # noqa
try:
buf_t, is_new_buffer = memoryview, True # noqa
except NameError: # Py2.6
buf_t, is_new_buffer = buffer, False # noqa
if hasattr(os, 'write'):
__write__ = os.write
if is_new_buffer:
def send_offset(fd, buf, offset):
return __write__(fd, buf[offset:])
else: # Py2.6
def send_offset(fd, buf, offset): # noqa
return __write__(fd, buf_t(buf, offset))
else: # non-posix platform
def send_offset(fd, buf, offset): # noqa
raise NotImplementedError('send_offset')
if sys.version_info[0] == 3:
bytes = bytes
else:
try:
_bytes = __builtin__.bytes
_bytes = builtins.bytes
except AttributeError:
_bytes = str
@ -25,10 +62,10 @@ try:
except AttributeError:
def closerange(fd_low, fd_high): # noqa
for fd in reversed(xrange(fd_low, fd_high)):
for fd in reversed(range(fd_low, fd_high)):
try:
os.close(fd)
except OSError, exc:
except OSError as exc:
if exc.errno != errno.EBADF:
raise
@ -46,3 +83,26 @@ def get_errno(exc):
except AttributeError:
pass
return 0
if sys.platform == 'win32':
def setblocking(handle, blocking):
raise NotImplementedError('setblocking not implemented on win32')
def isblocking(handle):
raise NotImplementedError('isblocking not implemented on win32')
else:
from os import O_NONBLOCK
from fcntl import fcntl, F_GETFL, F_SETFL
def isblocking(handle): # noqa
return not (fcntl(handle, F_GETFL) & O_NONBLOCK)
def setblocking(handle, blocking): # noqa
flags = fcntl(handle, F_GETFL, 0)
fcntl(
handle, F_SETFL,
flags & (~O_NONBLOCK) if blocking else flags | O_NONBLOCK,
)

View File

@ -1,11 +1,27 @@
from __future__ import absolute_import
import sys
is_pypy = hasattr(sys, 'pypy_version_info')
if sys.version_info[0] == 3:
from multiprocessing import connection
from . import _connection3 as connection
else:
from billiard import _connection as connection # noqa
from . import _connection as connection # noqa
if is_pypy:
import _multiprocessing
from .compat import setblocking, send_offset
class Connection(_multiprocessing.Connection):
def send_offset(self, buf, offset):
return send_offset(self.fileno(), buf, offset)
def setblocking(self, blocking):
setblocking(self.fileno(), blocking)
_multiprocessing.Connection = Connection
sys.modules[__name__] = connection

View File

@ -50,12 +50,10 @@ import array
from threading import Lock, RLock, Semaphore, BoundedSemaphore
from threading import Event
from Queue import Queue
if sys.version_info[0] == 3:
from multiprocessing.connection import Pipe
else:
from billiard._connection import Pipe
from billiard.five import Queue
from billiard.connection import Pipe
class DummyProcess(threading.Thread):
@ -91,7 +89,7 @@ class Condition(_Condition):
if sys.version_info[0] == 3:
notify_all = _Condition.notifyAll
else:
notify_all = _Condition.notifyAll.im_func
notify_all = _Condition.notifyAll.__func__
Process = DummyProcess
@ -117,7 +115,7 @@ class Namespace(object):
self.__dict__.update(kwds)
def __repr__(self):
items = self.__dict__.items()
items = list(self.__dict__.items())
temp = []
for name, value in items:
if not name.startswith('_'):

View File

@ -35,7 +35,7 @@ from __future__ import absolute_import
__all__ = ['Client', 'Listener', 'Pipe']
from Queue import Queue
from billiard.five import Queue
families = [None]

View File

@ -32,7 +32,7 @@ class _Frame(object):
class _Object(object):
def __init__(self, **kw):
[setattr(self, k, v) for k, v in kw.iteritems()]
[setattr(self, k, v) for k, v in kw.items()]
class _Truncated(object):

View File

@ -0,0 +1,189 @@
# -*- coding: utf-8 -*-
"""
celery.five
~~~~~~~~~~~
Compatibility implementations of features
only available in newer Python versions.
"""
from __future__ import absolute_import
############## py3k #########################################################
import sys
PY3 = sys.version_info[0] == 3
try:
reload = reload # noqa
except NameError: # pragma: no cover
from imp import reload # noqa
try:
from UserList import UserList # noqa
except ImportError: # pragma: no cover
from collections import UserList # noqa
try:
from UserDict import UserDict # noqa
except ImportError: # pragma: no cover
from collections import UserDict # noqa
############## time.monotonic ################################################
if sys.version_info < (3, 3):
import platform
SYSTEM = platform.system()
if SYSTEM == 'Darwin':
import ctypes
libSystem = ctypes.CDLL('libSystem.dylib')
CoreServices = ctypes.CDLL(
'/System/Library/Frameworks/CoreServices.framework/CoreServices',
use_errno=True,
)
mach_absolute_time = libSystem.mach_absolute_time
mach_absolute_time.restype = ctypes.c_uint64
absolute_to_nanoseconds = CoreServices.AbsoluteToNanoseconds
absolute_to_nanoseconds.restype = ctypes.c_uint64
absolute_to_nanoseconds.argtypes = [ctypes.c_uint64]
def _monotonic():
return absolute_to_nanoseconds(mach_absolute_time()) * 1e-9
elif SYSTEM == 'Linux':
# from stackoverflow:
# questions/1205722/how-do-i-get-monotonic-time-durations-in-python
import ctypes
import os
CLOCK_MONOTONIC = 1 # see <linux/time.h>
class timespec(ctypes.Structure):
_fields_ = [
('tv_sec', ctypes.c_long),
('tv_nsec', ctypes.c_long),
]
librt = ctypes.CDLL('librt.so.1', use_errno=True)
clock_gettime = librt.clock_gettime
clock_gettime.argtypes = [
ctypes.c_int, ctypes.POINTER(timespec),
]
def _monotonic(): # noqa
t = timespec()
if clock_gettime(CLOCK_MONOTONIC, ctypes.pointer(t)) != 0:
errno_ = ctypes.get_errno()
raise OSError(errno_, os.strerror(errno_))
return t.tv_sec + t.tv_nsec * 1e-9
else:
from time import time as _monotonic
try:
from time import monotonic
except ImportError:
monotonic = _monotonic # noqa
if PY3:
import builtins
from queue import Queue, Empty, Full
from itertools import zip_longest
from io import StringIO, BytesIO
map = map
string = str
string_t = str
long_t = int
text_t = str
range = range
int_types = (int, )
open_fqdn = 'builtins.open'
def items(d):
return d.items()
def keys(d):
return d.keys()
def values(d):
return d.values()
def nextfun(it):
return it.__next__
exec_ = getattr(builtins, 'exec')
def reraise(tp, value, tb=None):
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
class WhateverIO(StringIO):
def write(self, data):
if isinstance(data, bytes):
data = data.encode()
StringIO.write(self, data)
else:
import __builtin__ as builtins # noqa
from Queue import Queue, Empty, Full # noqa
from itertools import imap as map, izip_longest as zip_longest # noqa
from StringIO import StringIO # noqa
string = unicode # noqa
string_t = basestring # noqa
text_t = unicode
long_t = long # noqa
range = xrange
int_types = (int, long)
open_fqdn = '__builtin__.open'
def items(d): # noqa
return d.iteritems()
def keys(d): # noqa
return d.iterkeys()
def values(d): # noqa
return d.itervalues()
def nextfun(it): # noqa
return it.next
def exec_(code, globs=None, locs=None):
"""Execute code in a namespace."""
if globs is None:
frame = sys._getframe(1)
globs = frame.f_globals
if locs is None:
locs = frame.f_locals
del frame
elif locs is None:
locs = globs
exec("""exec code in globs, locs""")
exec_("""def reraise(tp, value, tb=None): raise tp, value, tb""")
BytesIO = WhateverIO = StringIO # noqa
def with_metaclass(Type, skip_attrs=set(['__dict__', '__weakref__'])):
"""Class decorator to set metaclass.
Works with both Python 3 and Python 3 and it does not add
an extra class in the lookup order like ``six.with_metaclass`` does
(that is -- it copies the original class instead of using inheritance).
"""
def _clone_with_metaclass(Class):
attrs = dict((key, value) for key, value in items(vars(Class))
if key not in skip_attrs)
return Type(Class.__name__, Class.__bases__, attrs)
return _clone_with_metaclass

View File

@ -14,12 +14,15 @@ import sys
import signal
import warnings
from ._ext import Connection, PipeConnection, win32
from pickle import load, HIGHEST_PROTOCOL
from billiard import util, process
from billiard import util
from billiard import process
from billiard.five import int_types
from .reduction import dump
from .compat import _winapi as win32
__all__ = ['Popen', 'assert_spawning', 'exit',
'duplicate', 'close', 'ForkingPickler']
'duplicate', 'close']
try:
WindowsError = WindowsError # noqa
@ -53,105 +56,16 @@ def assert_spawning(self):
' through inheritance' % type(self).__name__
)
#
# Try making some callable types picklable
#
from pickle import Pickler
if sys.version_info[0] == 3:
from copyreg import dispatch_table
class ForkingPickler(Pickler):
_extra_reducers = {}
def __init__(self, *args, **kwargs):
Pickler.__init__(self, *args, **kwargs)
self.dispatch_table = dispatch_table.copy()
self.dispatch_table.update(self._extra_reducers)
@classmethod
def register(cls, type, reduce):
cls._extra_reducers[type] = reduce
def _reduce_method(m):
if m.__self__ is None:
return getattr, (m.__class__, m.__func__.__name__)
else:
return getattr, (m.__self__, m.__func__.__name__)
class _C:
def f(self):
pass
ForkingPickler.register(type(_C().f), _reduce_method)
else:
class ForkingPickler(Pickler): # noqa
dispatch = Pickler.dispatch.copy()
@classmethod
def register(cls, type, reduce):
def dispatcher(self, obj):
rv = reduce(obj)
self.save_reduce(obj=obj, *rv)
cls.dispatch[type] = dispatcher
def _reduce_method(m): # noqa
if m.im_self is None:
return getattr, (m.im_class, m.im_func.func_name)
else:
return getattr, (m.im_self, m.im_func.func_name)
ForkingPickler.register(type(ForkingPickler.save), _reduce_method)
def _reduce_method_descriptor(m):
return getattr, (m.__objclass__, m.__name__)
ForkingPickler.register(type(list.append), _reduce_method_descriptor)
ForkingPickler.register(type(int.__add__), _reduce_method_descriptor)
try:
from functools import partial
except ImportError:
pass
else:
def _reduce_partial(p):
return _rebuild_partial, (p.func, p.args, p.keywords or {})
def _rebuild_partial(func, args, keywords):
return partial(func, *args, **keywords)
ForkingPickler.register(partial, _reduce_partial)
def dump(obj, file, protocol=None):
ForkingPickler(file, protocol).dump(obj)
#
# Make (Pipe)Connection picklable
#
def reduce_connection(conn):
# XXX check not necessary since only registered with ForkingPickler
if not Popen.thread_is_spawning():
raise RuntimeError(
'By default %s objects can only be shared between processes\n'
'using inheritance' % type(conn).__name__
)
return type(conn), (Popen.duplicate_for_child(conn.fileno()),
conn.readable, conn.writable)
ForkingPickler.register(Connection, reduce_connection)
if PipeConnection:
ForkingPickler.register(PipeConnection, reduce_connection)
#
# Unix
#
if sys.platform != 'win32':
import thread
try:
import thread
except ImportError:
import _thread as thread # noqa
import select
WINEXE = False
@ -172,6 +86,8 @@ if sys.platform != 'win32':
_tls = thread._local()
def __init__(self, process_obj):
# register reducers
from billiard import connection # noqa
_Django_old_layout_hack__save()
sys.stdout.flush()
sys.stderr.flush()
@ -265,9 +181,15 @@ if sys.platform != 'win32':
#
else:
import thread
try:
import thread
except ImportError:
import _thread as thread # noqa
import msvcrt
import _subprocess
try:
import _subprocess
except ImportError:
import _winapi as _subprocess # noqa
#
#
@ -287,10 +209,14 @@ else:
def duplicate(handle, target_process=None, inheritable=False):
if target_process is None:
target_process = _subprocess.GetCurrentProcess()
return _subprocess.DuplicateHandle(
h = _subprocess.DuplicateHandle(
_subprocess.GetCurrentProcess(), handle, target_process,
0, inheritable, _subprocess.DUPLICATE_SAME_ACCESS
).Detach()
)
if sys.version_info[0] < 3 or (
sys.version_info[0] == 3 and sys.version_info[1] < 3):
h = h.Detach()
return h
#
# We define a Popen class similar to the one from subprocess, but
@ -318,8 +244,9 @@ else:
hp, ht, pid, tid = _subprocess.CreateProcess(
_python_exe, cmd, None, None, 1, 0, None, None, None
)
ht.Close()
close(rhandle)
close(ht) if isinstance(ht, int_types) else ht.Close()
(close(rhandle) if isinstance(rhandle, int_types)
else rhandle.Close())
# set attributes of self
self.pid = pid
@ -566,22 +493,6 @@ def get_preparation_data(name):
return d
#
# Make (Pipe)Connection picklable
#
def reduce_connection(conn):
if not Popen.thread_is_spawning():
raise RuntimeError(
'By default %s objects can only be shared between processes\n'
'using inheritance' % type(conn).__name__
)
return type(conn), (Popen.duplicate_for_child(conn.fileno()),
conn.readable, conn.writable)
ForkingPickler.register(Connection, reduce_connection)
ForkingPickler.register(PipeConnection, reduce_connection)
#
# Prepare current process
#
@ -659,7 +570,7 @@ def prepare(data):
# Try to make the potentially picklable objects in
# sys.modules['__main__'] realize they are in the main
# module -- somewhat ugly.
for obj in main_module.__dict__.values():
for obj in list(main_module.__dict__.values()):
try:
if obj.__module__ == '__parents_main__':
obj.__module__ = '__main__'

View File

@ -17,7 +17,8 @@ import itertools
from ._ext import _billiard, win32
from .util import Finalize, info, get_temp_dir
from .forking import assert_spawning, ForkingPickler
from .forking import assert_spawning
from .reduction import ForkingPickler
__all__ = ['BufferWrapper']
@ -38,7 +39,7 @@ if sys.platform == 'win32':
def __init__(self, size):
self.size = size
self.name = 'pym-%d-%d' % (os.getpid(), Arena._counter.next())
self.name = 'pym-%d-%d' % (os.getpid(), next(Arena._counter))
self.buffer = mmap.mmap(-1, self.size, tagname=self.name)
assert win32.GetLastError() == 0, 'tagname already in use'
self._state = (self.size, self.name)
@ -65,9 +66,9 @@ else:
if fileno == -1 and not _forking_is_enabled:
name = os.path.join(
get_temp_dir(),
'pym-%d-%d' % (os.getpid(), self._counter.next()))
'pym-%d-%d' % (os.getpid(), next(self._counter)))
self.fileno = os.open(
name, os.O_RDWR | os.O_CREAT | os.O_EXCL, 0600)
name, os.O_RDWR | os.O_CREAT | os.O_EXCL, 0o600)
os.unlink(name)
os.ftruncate(self.fileno, size)
self.buffer = mmap.mmap(self.fileno, self.size)

View File

@ -8,7 +8,6 @@
# Licensed to PSF under a Contributor Agreement.
#
from __future__ import absolute_import
from __future__ import with_statement
__all__ = ['BaseManager', 'SyncManager', 'BaseProxy', 'Token']
@ -19,14 +18,15 @@ __all__ = ['BaseManager', 'SyncManager', 'BaseProxy', 'Token']
import sys
import threading
import array
import Queue
from collections import Callable
from traceback import format_exc
from time import time as _time
from . import Process, current_process, active_children, Pool, util, connection
from .five import Queue, items, monotonic
from .process import AuthenticationString
from .forking import exit, Popen, ForkingPickler
from .forking import exit, Popen
from .reduction import ForkingPickler
from .util import Finalize, error, info
#
@ -123,7 +123,7 @@ def all_methods(obj):
temp = []
for name in dir(obj):
func = getattr(obj, name)
if callable(func):
if isinstance(func, Callable):
temp.append(name)
return temp
@ -205,14 +205,14 @@ class Server(object):
msg = ('#RETURN', result)
try:
c.send(msg)
except Exception, e:
except Exception as exc:
try:
c.send(('#TRACEBACK', format_exc()))
except Exception:
pass
info('Failure to send message: %r', msg)
info(' ... request was %r', request)
info(' ... exception was %r', e)
info(' ... exception was %r', exc)
c.close()
@ -245,8 +245,8 @@ class Server(object):
try:
res = function(*args, **kwds)
except Exception, e:
msg = ('#ERROR', e)
except Exception as exc:
msg = ('#ERROR', exc)
else:
typeid = gettypeid and gettypeid.get(methodname, None)
if typeid:
@ -280,13 +280,13 @@ class Server(object):
try:
try:
send(msg)
except Exception, e:
except Exception:
send(('#UNSERIALIZABLE', repr(msg)))
except Exception, e:
except Exception as exc:
info('exception in thread serving %r',
threading.currentThread().name)
info(' ... message was %r', msg)
info(' ... exception was %r', e)
info(' ... exception was %r', exc)
conn.close()
sys.exit(1)
@ -314,7 +314,7 @@ class Server(object):
'''
with self.mutex:
result = []
keys = self.id_to_obj.keys()
keys = list(self.id_to_obj.keys())
keys.sort()
for ident in keys:
if ident != '0':
@ -492,7 +492,8 @@ class BaseManager(object):
'''
assert self._state.value == State.INITIAL
if initializer is not None and not callable(initializer):
if initializer is not None and \
not isinstance(initializer, Callable):
raise TypeError('initializer must be a callable')
# pipe over which we will retrieve address of server
@ -641,7 +642,7 @@ class BaseManager(object):
)
if method_to_typeid:
for key, value in method_to_typeid.items():
for key, value in items(method_to_typeid):
assert type(key) is str, '%r is not a string' % key
assert type(value) is str, '%r is not a string' % value
@ -797,8 +798,8 @@ class BaseProxy(object):
util.debug('DECREF %r', token.id)
conn = _Client(token.address, authkey=authkey)
dispatch(conn, None, 'decref', (token.id,))
except Exception, e:
util.debug('... decref failed %s', e)
except Exception as exc:
util.debug('... decref failed %s', exc)
else:
util.debug('DECREF %r -- manager already shutdown', token.id)
@ -815,9 +816,9 @@ class BaseProxy(object):
self._manager = None
try:
self._incref()
except Exception, e:
except Exception as exc:
# the proxy may just be for a manager which has shutdown
info('incref failed: %s', e)
info('incref failed: %s', exc)
def __reduce__(self):
kwds = {}
@ -933,7 +934,7 @@ class Namespace(object):
self.__dict__.update(kwds)
def __repr__(self):
items = self.__dict__.items()
items = list(self.__dict__.items())
temp = []
for name, value in items:
if not name.startswith('_'):
@ -1026,13 +1027,13 @@ class ConditionProxy(AcquirerProxy):
if result:
return result
if timeout is not None:
endtime = _time() + timeout
endtime = monotonic() + timeout
else:
endtime = None
waittime = None
while not result:
if endtime is not None:
waittime = endtime - _time()
waittime = endtime - monotonic()
if waittime <= 0:
break
self.wait(waittime)
@ -1149,8 +1150,8 @@ class SyncManager(BaseManager):
this class.
'''
SyncManager.register('Queue', Queue.Queue)
SyncManager.register('JoinableQueue', Queue.Queue)
SyncManager.register('Queue', Queue)
SyncManager.register('JoinableQueue', Queue)
SyncManager.register('Event', threading.Event, EventProxy)
SyncManager.register('Lock', threading.Lock, AcquirerProxy)
SyncManager.register('RLock', threading.RLock, AcquirerProxy)

File diff suppressed because it is too large Load Diff

View File

@ -27,6 +27,7 @@ try:
from _weakrefset import WeakSet
except ImportError:
WeakSet = None # noqa
from .five import items, string_t
try:
ORIGINAL_DIR = os.path.abspath(os.getcwd())
@ -85,7 +86,7 @@ class Process(object):
def __init__(self, group=None, target=None, name=None,
args=(), kwargs={}, daemon=None, **_kw):
assert group is None, 'group argument must be None for now'
count = _current_process._counter.next()
count = next(_current_process._counter)
self._identity = _current_process._identity + (count,)
self._authkey = _current_process._authkey
if daemon is not None:
@ -164,7 +165,7 @@ class Process(object):
return self._name
def _set_name(self, value):
assert isinstance(name, basestring), 'name must be a string'
assert isinstance(name, string_t), 'name must be a string'
self._name = value
name = property(_get_name, _set_name)
@ -256,14 +257,17 @@ class Process(object):
_current_process = self
# Re-init logging system.
# Workaround for http://bugs.python.org/issue6721#msg140215
# Python logging module uses RLock() objects which are broken after
# fork. This can result in a deadlock (Celery Issue #496).
logger_names = logging.Logger.manager.loggerDict.keys()
# Workaround for http://bugs.python.org/issue6721/#msg140215
# Python logging module uses RLock() objects which are broken
# after fork. This can result in a deadlock (Celery Issue #496).
loggerDict = logging.Logger.manager.loggerDict
logger_names = list(loggerDict.keys())
logger_names.append(None) # for root logger
for name in logger_names:
for handler in logging.getLogger(name).handlers:
handler.createLock()
if not name or not isinstance(loggerDict[name],
logging.PlaceHolder):
for handler in logging.getLogger(name).handlers:
handler.createLock()
logging._lock = threading.RLock()
try:
@ -279,15 +283,15 @@ class Process(object):
exitcode = 0
finally:
util._exit_function()
except SystemExit, e:
if not e.args:
except SystemExit as exc:
if not exc.args:
exitcode = 1
elif isinstance(e.args[0], int):
exitcode = e.args[0]
elif isinstance(exc.args[0], int):
exitcode = exc.args[0]
else:
sys.stderr.write(str(e.args[0]) + '\n')
sys.stderr.write(str(exc.args[0]) + '\n')
_maybe_flush(sys.stderr)
exitcode = 0 if isinstance(e.args[0], str) else 1
exitcode = 0 if isinstance(exc.args[0], str) else 1
except:
exitcode = 1
if not util.error('Process %s', self.name, exc_info=True):
@ -347,7 +351,7 @@ del _MainProcess
_exitcode_to_name = {}
for name, signum in signal.__dict__.items():
for name, signum in items(signal.__dict__):
if name[:3] == 'SIG' and '_' not in name:
_exitcode_to_name[-signum] = name

View File

@ -7,7 +7,6 @@
# Licensed to PSF under a Contributor Agreement.
#
from __future__ import absolute_import
from __future__ import with_statement
__all__ = ['Queue', 'SimpleQueue', 'JoinableQueue']
@ -15,17 +14,16 @@ import sys
import os
import threading
import collections
import time
import weakref
import errno
from Queue import Empty, Full
from . import Pipe
from ._ext import _billiard
from .compat import get_errno
from .five import monotonic
from .synchronize import Lock, BoundedSemaphore, Semaphore, Condition
from .util import debug, error, info, Finalize, register_after_fork
from .five import Empty, Full
from .forking import assert_spawning
@ -96,12 +94,12 @@ class Queue(object):
else:
if block:
deadline = time.time() + timeout
deadline = monotonic() + timeout
if not self._rlock.acquire(block, timeout):
raise Empty
try:
if block:
timeout = deadline - time.time()
timeout = deadline - monotonic()
if timeout < 0 or not self._poll(timeout):
raise Empty
elif not self._poll():
@ -238,7 +236,7 @@ class Queue(object):
send(obj)
except IndexError:
pass
except Exception, exc:
except Exception as exc:
if ignore_epipe and get_errno(exc) == errno.EPIPE:
return
# Since this runs in a daemon thread the resources it uses
@ -306,19 +304,17 @@ class JoinableQueue(Queue):
self._cond.wait()
class SimpleQueue(object):
class _SimpleQueue(object):
'''
Simplified Queue type -- really just a locked pipe
'''
def __init__(self):
self._reader, self._writer = Pipe(duplex=False)
self._rlock = Lock()
def __init__(self, rnonblock=False, wnonblock=False):
self._reader, self._writer = Pipe(
duplex=False, rnonblock=rnonblock, wnonblock=wnonblock,
)
self._poll = self._reader.poll
if sys.platform == 'win32':
self._wlock = None
else:
self._wlock = Lock()
self._rlock = self._wlock = None
self._make_methods()
def empty(self):
@ -337,19 +333,22 @@ class SimpleQueue(object):
try:
recv_payload = self._reader.recv_payload
except AttributeError:
recv_payload = None # C extension not installed
recv_payload = self._reader.recv_bytes
rlock = self._rlock
def get():
with rlock:
return recv()
self.get = get
if rlock is not None:
def get():
with rlock:
return recv()
self.get = get
if recv_payload is not None:
def get_payload():
with rlock:
return recv_payload()
self.get_payload = get_payload
else:
self.get = recv
self.get_payload = recv_payload
if self._wlock is None:
# writes to a message oriented win32 pipe are atomic
@ -362,3 +361,12 @@ class SimpleQueue(object):
with wlock:
return send(obj)
self.put = put
class SimpleQueue(_SimpleQueue):
def __init__(self):
self._reader, self._writer = Pipe(duplex=False)
self._rlock = Lock()
self._wlock = Lock() if sys.platform != 'win32' else None
self._make_methods()

View File

@ -1,200 +1,10 @@
#
# Module to allow connection and socket objects to be transferred
# between processes
#
# multiprocessing/reduction.py
#
# Copyright (c) 2006-2008, R Oudkerk
# Licensed to PSF under a Contributor Agreement.
#
from __future__ import absolute_import
__all__ = []
import os
import sys
import socket
import threading
if sys.version_info[0] == 3:
from multiprocessing.connection import Client, Listener
from . import _reduction3 as reduction
else:
from billiard._connection import Client, Listener # noqa
from . import _reduction as reduction # noqa
from . import current_process
from ._ext import _billiard, win32
from .forking import Popen, duplicate, close, ForkingPickler
from .util import register_after_fork, debug, sub_debug
if not(sys.platform == 'win32' or hasattr(_billiard, 'recvfd')):
raise ImportError('pickling of connections not supported')
# globals set later
_listener = None
_lock = None
_cache = set()
#
# Platform specific definitions
#
if sys.platform == 'win32':
# XXX Should this subprocess import be here?
import _subprocess # noqa
def send_handle(conn, handle, destination_pid):
process_handle = win32.OpenProcess(
win32.PROCESS_ALL_ACCESS, False, destination_pid
)
try:
new_handle = duplicate(handle, process_handle)
conn.send(new_handle)
finally:
close(process_handle)
def recv_handle(conn):
return conn.recv()
else:
def send_handle(conn, handle, destination_pid): # noqa
_billiard.sendfd(conn.fileno(), handle)
def recv_handle(conn): # noqa
return _billiard.recvfd(conn.fileno())
#
# Support for a per-process server thread which caches pickled handles
#
def _reset(obj):
global _lock, _listener, _cache
for h in _cache:
close(h)
_cache.clear()
_lock = threading.Lock()
_listener = None
_reset(None)
register_after_fork(_reset, _reset)
def _get_listener():
global _listener
if _listener is None:
_lock.acquire()
try:
if _listener is None:
debug('starting listener and thread for sending handles')
_listener = Listener(authkey=current_process().authkey)
t = threading.Thread(target=_serve)
t.daemon = True
t.start()
finally:
_lock.release()
return _listener
def _serve():
from .util import is_exiting, sub_warning
while 1:
try:
conn = _listener.accept()
handle_wanted, destination_pid = conn.recv()
_cache.remove(handle_wanted)
send_handle(conn, handle_wanted, destination_pid)
close(handle_wanted)
conn.close()
except:
if not is_exiting():
sub_warning('thread for sharing handles raised exception',
exc_info=True)
#
# Functions to be used for pickling/unpickling objects with handles
#
def reduce_handle(handle):
if Popen.thread_is_spawning():
return (None, Popen.duplicate_for_child(handle), True)
dup_handle = duplicate(handle)
_cache.add(dup_handle)
sub_debug('reducing handle %d', handle)
return (_get_listener().address, dup_handle, False)
def rebuild_handle(pickled_data):
address, handle, inherited = pickled_data
if inherited:
return handle
sub_debug('rebuilding handle %d', handle)
conn = Client(address, authkey=current_process().authkey)
conn.send((handle, os.getpid()))
new_handle = recv_handle(conn)
conn.close()
return new_handle
#
# Register `_billiard.Connection` with `ForkingPickler`
#
def reduce_connection(conn):
rh = reduce_handle(conn.fileno())
return rebuild_connection, (rh, conn.readable, conn.writable)
def rebuild_connection(reduced_handle, readable, writable):
handle = rebuild_handle(reduced_handle)
return _billiard.Connection(
handle, readable=readable, writable=writable
)
ForkingPickler.register(_billiard.Connection, reduce_connection)
#
# Register `socket.socket` with `ForkingPickler`
#
def fromfd(fd, family, type_, proto=0):
s = socket.fromfd(fd, family, type_, proto)
if s.__class__ is not socket.socket:
s = socket.socket(_sock=s)
return s
def reduce_socket(s):
reduced_handle = reduce_handle(s.fileno())
return rebuild_socket, (reduced_handle, s.family, s.type, s.proto)
def rebuild_socket(reduced_handle, family, type_, proto):
fd = rebuild_handle(reduced_handle)
_sock = fromfd(fd, family, type_, proto)
close(fd)
return _sock
ForkingPickler.register(socket.socket, reduce_socket)
#
# Register `_billiard.PipeConnection` with `ForkingPickler`
#
if sys.platform == 'win32':
def reduce_pipe_connection(conn):
rh = reduce_handle(conn.fileno())
return rebuild_pipe_connection, (rh, conn.readable, conn.writable)
def rebuild_pipe_connection(reduced_handle, readable, writable):
handle = rebuild_handle(reduced_handle)
return _billiard.PipeConnection(
handle, readable=readable, writable=writable
)
ForkingPickler.register(_billiard.PipeConnection, reduce_pipe_connection)
sys.modules[__name__] = reduction

View File

@ -12,7 +12,9 @@ import ctypes
import weakref
from . import heap, RLock
from .forking import assert_spawning, ForkingPickler
from .five import int_types
from .forking import assert_spawning
from .reduction import ForkingPickler
__all__ = ['RawValue', 'RawArray', 'Value', 'Array', 'copy', 'synchronized']
@ -48,7 +50,7 @@ def RawArray(typecode_or_type, size_or_initializer):
Returns a ctypes array allocated from shared memory
'''
type_ = typecode_to_type.get(typecode_or_type, typecode_or_type)
if isinstance(size_or_initializer, (int, long)):
if isinstance(size_or_initializer, int_types):
type_ = type_ * size_or_initializer
obj = _new_value(type_)
ctypes.memset(ctypes.addressof(obj), 0, ctypes.sizeof(obj))
@ -66,7 +68,8 @@ def Value(typecode_or_type, *args, **kwds):
'''
lock = kwds.pop('lock', None)
if kwds:
raise ValueError('unrecognized keyword argument(s): %s' % kwds.keys())
raise ValueError(
'unrecognized keyword argument(s): %s' % list(kwds.keys()))
obj = RawValue(typecode_or_type, *args)
if lock is False:
return obj
@ -83,7 +86,8 @@ def Array(typecode_or_type, size_or_initializer, **kwds):
'''
lock = kwds.pop('lock', None)
if kwds:
raise ValueError('unrecognized keyword argument(s): %s' % kwds.keys())
raise ValueError(
'unrecognized keyword argument(s): %s' % list(kwds.keys()))
obj = RawArray(typecode_or_type, size_or_initializer)
if lock is False:
return obj

View File

@ -19,9 +19,8 @@ import sys
import threading
from time import time as _time
from ._ext import _billiard, ensure_SemLock
from .five import range, monotonic
from .process import current_process
from .util import Finalize, register_after_fork, debug
from .forking import assert_spawning, Popen
@ -36,7 +35,7 @@ ensure_SemLock()
# Constants
#
RECURSIVE_MUTEX, SEMAPHORE = range(2)
RECURSIVE_MUTEX, SEMAPHORE = list(range(2))
SEM_VALUE_MAX = _billiard.SemLock.SEM_VALUE_MAX
try:
@ -115,7 +114,7 @@ class SemLock(object):
@staticmethod
def _make_name():
return '/%s-%s-%s' % (current_process()._semprefix,
os.getpid(), SemLock._counter.next())
os.getpid(), next(SemLock._counter))
class Semaphore(SemLock):
@ -248,7 +247,7 @@ class Condition(object):
# release lock
count = self._lock._semlock._count()
for i in xrange(count):
for i in range(count):
self._lock.release()
try:
@ -259,7 +258,7 @@ class Condition(object):
self._woken_count.release()
# reacquire lock
for i in xrange(count):
for i in range(count):
self._lock.acquire()
return ret
@ -296,7 +295,7 @@ class Condition(object):
sleepers += 1
if sleepers:
for i in xrange(sleepers):
for i in range(sleepers):
self._woken_count.acquire() # wait for a sleeper to wake
# rezero wait_semaphore in case some timeouts just happened
@ -308,13 +307,13 @@ class Condition(object):
if result:
return result
if timeout is not None:
endtime = _time() + timeout
endtime = monotonic() + timeout
else:
endtime = None
waittime = None
while not result:
if endtime is not None:
waittime = endtime - _time()
waittime = endtime - monotonic()
if waittime <= 0:
break
self.wait(waittime)

View File

@ -13,6 +13,9 @@ def teardown():
except (AttributeError, ImportError):
pass
atexit._exithandlers[:] = [
e for e in atexit._exithandlers if e[0] not in cancelled
]
try:
atexit._exithandlers[:] = [
e for e in atexit._exithandlers if e[0] not in cancelled
]
except AttributeError:
pass

View File

@ -1,5 +1,4 @@
from __future__ import absolute_import
from __future__ import with_statement
import os
import signal

View File

@ -1,5 +1,4 @@
from __future__ import absolute_import
from __future__ import with_statement
import re
import sys
@ -13,6 +12,8 @@ except AttributeError:
import unittest2 as unittest # noqa
from unittest2.util import safe_repr, unorderable_list_difference # noqa
from billiard.five import string_t, items, values
from .compat import catch_warnings
# -- adds assertWarns from recent unittest2, not in Python 2.7.
@ -25,7 +26,7 @@ class _AssertRaisesBaseContext(object):
self.expected = expected
self.failureException = test_case.failureException
self.obj_name = None
if isinstance(expected_regex, basestring):
if isinstance(expected_regex, string_t):
expected_regex = re.compile(expected_regex)
self.expected_regex = expected_regex
@ -37,7 +38,7 @@ class _AssertWarnsContext(_AssertRaisesBaseContext):
# The __warningregistry__'s need to be in a pristine state for tests
# to work properly.
warnings.resetwarnings()
for v in sys.modules.values():
for v in values(sys.modules):
if getattr(v, '__warningregistry__', None):
v.__warningregistry__ = {}
self.warnings_manager = catch_warnings(record=True)
@ -93,7 +94,7 @@ class Case(unittest.TestCase):
def assertDictContainsSubset(self, expected, actual, msg=None):
missing, mismatched = [], []
for key, value in expected.iteritems():
for key, value in items(expected):
if key not in actual:
missing.append(key)
elif value != actual[key]:

View File

@ -10,16 +10,25 @@ from __future__ import absolute_import
import errno
import functools
import itertools
import weakref
import atexit
import shutil
import tempfile
import threading # we want threading to install its
# cleanup function before multiprocessing does
from multiprocessing.util import ( # noqa
_afterfork_registry,
_afterfork_counter,
_exit_function,
_finalizer_registry,
_finalizer_counter,
Finalize,
ForkAwareLocal,
ForkAwareThreadLock,
get_temp_dir,
is_exiting,
register_after_fork,
_run_after_forkers,
_run_finalizers,
)
from .compat import get_errno
from .process import current_process, active_children
__all__ = [
'sub_debug', 'debug', 'info', 'sub_warning', 'get_logger',
@ -45,17 +54,6 @@ DEFAULT_LOGGING_FORMAT = '[%(levelname)s/%(processName)s] %(message)s'
_logger = None
_log_to_stderr = False
#: Support for reinitialization of objects when bootstrapping a child process
_afterfork_registry = weakref.WeakValueDictionary()
_afterfork_counter = itertools.count()
#: Finalization using weakrefs
_finalizer_registry = {}
_finalizer_counter = itertools.count()
#: set to true if the process is shutting down.
_exiting = False
def sub_debug(msg, *args, **kwargs):
if _logger:
@ -138,195 +136,6 @@ def log_to_stderr(level=None):
return _logger
def get_temp_dir():
'''
Function returning a temp directory which will be removed on exit
'''
# get name of a temp directory which will be automatically cleaned up
if current_process()._tempdir is None:
tempdir = tempfile.mkdtemp(prefix='pymp-')
info('created temp directory %s', tempdir)
Finalize(None, shutil.rmtree, args=[tempdir], exitpriority=-100)
current_process()._tempdir = tempdir
return current_process()._tempdir
def _run_after_forkers():
items = list(_afterfork_registry.items())
items.sort()
for (index, ident, func), obj in items:
try:
func(obj)
except Exception, e:
info('after forker raised exception %s', e)
def register_after_fork(obj, func):
_afterfork_registry[(_afterfork_counter.next(), id(obj), func)] = obj
class Finalize(object):
'''
Class which supports object finalization using weakrefs
'''
def __init__(self, obj, callback, args=(), kwargs=None, exitpriority=None):
assert exitpriority is None or type(exitpriority) is int
if obj is not None:
self._weakref = weakref.ref(obj, self)
else:
assert exitpriority is not None
self._callback = callback
self._args = args
self._kwargs = kwargs or {}
self._key = (exitpriority, _finalizer_counter.next())
_finalizer_registry[self._key] = self
def __call__(self, wr=None,
# Need to bind these locally because the globals
# could've been cleared at shutdown
_finalizer_registry=_finalizer_registry,
sub_debug=sub_debug):
'''
Run the callback unless it has already been called or cancelled
'''
try:
del _finalizer_registry[self._key]
except KeyError:
sub_debug('finalizer no longer registered')
else:
sub_debug(
'finalizer calling %s with args %s and kwargs %s',
self._callback, self._args, self._kwargs,
)
res = self._callback(*self._args, **self._kwargs)
self._weakref = self._callback = self._args = \
self._kwargs = self._key = None
return res
def cancel(self):
'''
Cancel finalization of the object
'''
try:
del _finalizer_registry[self._key]
except KeyError:
pass
else:
self._weakref = self._callback = self._args = \
self._kwargs = self._key = None
def still_active(self):
'''
Return whether this finalizer is still waiting to invoke callback
'''
return self._key in _finalizer_registry
def __repr__(self):
try:
obj = self._weakref()
except (AttributeError, TypeError):
obj = None
if obj is None:
return '<Finalize object, dead>'
x = '<Finalize object, callback=%s' % \
getattr(self._callback, '__name__', self._callback)
if self._args:
x += ', args=' + str(self._args)
if self._kwargs:
x += ', kwargs=' + str(self._kwargs)
if self._key[0] is not None:
x += ', exitprority=' + str(self._key[0])
return x + '>'
def _run_finalizers(minpriority=None,
_finalizer_registry=_finalizer_registry,
sub_debug=sub_debug, error=error):
'''
Run all finalizers whose exit priority is not None and at least minpriority
Finalizers with highest priority are called first; finalizers with
the same priority will be called in reverse order of creation.
'''
if minpriority is None:
f = lambda p: p[0][0] is not None
else:
f = lambda p: p[0][0] is not None and p[0][0] >= minpriority
items = [x for x in _finalizer_registry.items() if f(x)]
items.sort(reverse=True)
for key, finalizer in items:
sub_debug('calling %s', finalizer)
try:
finalizer()
except Exception:
if not error("Error calling finalizer %r", finalizer,
exc_info=True):
import traceback
traceback.print_exc()
if minpriority is None:
_finalizer_registry.clear()
def is_exiting():
'''
Returns true if the process is shutting down
'''
return _exiting or _exiting is None
def _exit_function(info=info, debug=debug,
active_children=active_children,
_run_finalizers=_run_finalizers):
'''
Clean up on exit
'''
global _exiting
info('process shutting down')
debug('running all "atexit" finalizers with priority >= 0')
_run_finalizers(0)
for p in active_children():
if p._daemonic:
info('calling terminate() for daemon %s', p.name)
p._popen.terminate()
for p in active_children():
info('calling join() for process %s', p.name)
p.join()
debug('running the remaining "atexit" finalizers')
_run_finalizers()
atexit.register(_exit_function)
class ForkAwareThreadLock(object):
def __init__(self):
self._lock = threading.Lock()
self.acquire = self._lock.acquire
self.release = self._lock.release
register_after_fork(self, ForkAwareThreadLock.__init__)
class ForkAwareLocal(threading.local):
def __init__(self):
register_after_fork(self, lambda obj: obj.__dict__.clear())
def __reduce__(self):
return type(self), ()
def _eintr_retry(func):
'''
Automatic retry after EINTR.
@ -337,7 +146,7 @@ def _eintr_retry(func):
while 1:
try:
return func(*args, **kwargs)
except OSError, exc:
except OSError as exc:
if get_errno(exc) != errno.EINTR:
raise
return wrapped

View File

@ -36,7 +36,7 @@ import logging.config
import urlparse
from boto.exception import InvalidUriError
__version__ = '2.13.3'
__version__ = '2.17.0'
Version = __version__ # for backware compatibility
UserAgent = 'Boto/%s Python/%s %s/%s' % (
@ -721,6 +721,29 @@ def connect_support(aws_access_key_id=None,
)
def connect_cloudtrail(aws_access_key_id=None,
aws_secret_access_key=None,
**kwargs):
"""
Connect to AWS CloudTrail
:type aws_access_key_id: string
:param aws_access_key_id: Your AWS Access Key ID
:type aws_secret_access_key: string
:param aws_secret_access_key: Your AWS Secret Access Key
:rtype: :class:`boto.cloudtrail.layer1.CloudtrailConnection`
:return: A connection to the AWS Cloudtrail service
"""
from boto.cloudtrail.layer1 import CloudTrailConnection
return CloudTrailConnection(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
**kwargs
)
def storage_uri(uri_str, default_scheme='file', debug=0, validate=True,
bucket_storage_uri_class=BucketStorageUri,
suppress_consec_slashes=True, is_latest=False):

View File

@ -431,13 +431,17 @@ class HmacAuthV4Handler(AuthHandler, HmacKeys):
parts = http_request.host.split('.')
if self.region_name is not None:
region_name = self.region_name
elif parts[1] == 'us-gov':
region_name = 'us-gov-west-1'
else:
if len(parts) == 3:
region_name = 'us-east-1'
elif len(parts) > 1:
if parts[1] == 'us-gov':
region_name = 'us-gov-west-1'
else:
region_name = parts[1]
if len(parts) == 3:
region_name = 'us-east-1'
else:
region_name = parts[1]
else:
region_name = parts[0]
if self.service_name is not None:
service_name = self.service_name
else:

View File

@ -191,12 +191,9 @@ class DocumentServiceConnection(object):
session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
pool_connections=20,
pool_maxsize=50
pool_maxsize=50,
max_retries=5
)
# Now kludge in the right number of retries.
# Once we're requiring ``requests>=1.2.1``, this can become an
# initialization parameter above.
adapter.max_retries = 5
session.mount('http://', adapter)
session.mount('https://', adapter)
r = session.post(url, data=sdf, headers={'Content-Type': 'application/json'})

View File

@ -79,7 +79,7 @@ class SearchResults(object):
class Query(object):
RESULTS_PER_PAGE = 500
def __init__(self, q=None, bq=None, rank=None,
@ -147,7 +147,7 @@ class Query(object):
class SearchConnection(object):
def __init__(self, domain=None, endpoint=None):
self.domain = domain
self.endpoint = endpoint
@ -209,7 +209,7 @@ class SearchConnection(object):
:param facet_sort: Rules used to specify the order in which facet
values should be returned. Allowed values are *alpha*, *count*,
*max*, *sum*. Use *alpha* to sort alphabetical, and *count* to sort
the facet by number of available result.
the facet by number of available result.
``{'color': 'alpha', 'size': 'count'}``
:type facet_top_n: dict
@ -243,10 +243,10 @@ class SearchConnection(object):
the search string.
>>> search(bq="'Tim*'") # Return documents with words like Tim or Timothy)
Search terms can also be combined. Allowed operators are "and", "or",
"not", "field", "optional", "token", "phrase", or "filter"
>>> search(bq="(and 'Tim' (field author 'John Smith'))")
Facets allow you to show classification information about the search
@ -258,12 +258,12 @@ class SearchConnection(object):
With facet_constraints, facet_top_n and facet_sort more complicated
constraints can be specified such as returning the top author out of
John Smith and Mark Smith who have a document with the word Tim in it.
>>> search(q='Tim',
... facet=['Author'],
... facet_constraints={'author': "'John Smith','Mark Smith'"},
... facet=['author'],
... facet_top_n={'author': 1},
>>> search(q='Tim',
... facet=['Author'],
... facet_constraints={'author': "'John Smith','Mark Smith'"},
... facet=['author'],
... facet_top_n={'author': 1},
... facet_sort={'author': 'count'})
"""
@ -300,9 +300,7 @@ class SearchConnection(object):
except AttributeError:
pass
raise SearchServiceException('Authentication error from Amazon%s' % msg)
raise SearchServiceException("Got non-json response from Amazon")
data['query'] = query
data['search_service'] = self
raise SearchServiceException("Got non-json response from Amazon. %s" % r.content, query)
if 'messages' in data and 'error' in data:
for m in data['messages']:
@ -311,7 +309,10 @@ class SearchConnection(object):
"=> %s" % (params, m['message']), query)
elif 'error' in data:
raise SearchServiceException("Unknown error processing search %s"
% (params), query)
% json.dumps(data), query)
data['query'] = query
data['search_service'] = self
return SearchResults(**data)

View File

@ -0,0 +1,48 @@
# Copyright (c) 2013 Amazon.com, Inc. or its affiliates.
# All Rights Reserved
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish, dis-
# tribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the fol-
# lowing conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
from boto.regioninfo import RegionInfo
def regions():
"""
Get all available regions for the AWS Cloudtrail service.
:rtype: list
:return: A list of :class:`boto.regioninfo.RegionInfo`
"""
from boto.cloudtrail.layer1 import CloudTrailConnection
return [RegionInfo(name='us-east-1',
endpoint='cloudtrail.us-east-1.amazonaws.com',
connection_cls=CloudTrailConnection),
RegionInfo(name='us-west-2',
endpoint='cloudtrail.us-west-2.amazonaws.com',
connection_cls=CloudTrailConnection),
]
def connect_to_region(region_name, **kw_params):
for region in regions():
if region.name == region_name:
return region.connect(**kw_params)
return None

View File

@ -0,0 +1,86 @@
"""
Exceptions that are specific to the cloudtrail module.
"""
from boto.exception import BotoServerError
class InvalidSnsTopicNameException(BotoServerError):
"""
Raised when an invalid SNS topic name is passed to Cloudtrail.
"""
pass
class InvalidS3BucketNameException(BotoServerError):
"""
Raised when an invalid S3 bucket name is passed to Cloudtrail.
"""
pass
class TrailAlreadyExistsException(BotoServerError):
"""
Raised when the given trail name already exists.
"""
pass
class InsufficientSnsTopicPolicyException(BotoServerError):
"""
Raised when the SNS topic does not allow Cloudtrail to post
messages.
"""
pass
class InvalidTrailNameException(BotoServerError):
"""
Raised when the trail name is invalid.
"""
pass
class InternalErrorException(BotoServerError):
"""
Raised when there was an internal Cloudtrail error.
"""
pass
class TrailNotFoundException(BotoServerError):
"""
Raised when the given trail name is not found.
"""
pass
class S3BucketDoesNotExistException(BotoServerError):
"""
Raised when the given S3 bucket does not exist.
"""
pass
class TrailNotProvidedException(BotoServerError):
"""
Raised when no trail name was provided.
"""
pass
class InvalidS3PrefixException(BotoServerError):
"""
Raised when an invalid key prefix is given.
"""
pass
class MaximumNumberOfTrailsExceededException(BotoServerError):
"""
Raised when no more trails can be created.
"""
pass
class InsufficientS3BucketPolicyException(BotoServerError):
"""
Raised when the S3 bucket does not allow Cloudtrail to
write files into the prefix.
"""
pass

View File

@ -0,0 +1,309 @@
# Copyright (c) 2013 Amazon.com, Inc. or its affiliates. All Rights Reserved
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish, dis-
# tribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the fol-
# lowing conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
try:
import json
except ImportError:
import simplejson as json
import boto
from boto.connection import AWSQueryConnection
from boto.regioninfo import RegionInfo
from boto.exception import JSONResponseError
from boto.cloudtrail import exceptions
class CloudTrailConnection(AWSQueryConnection):
"""
AWS Cloud Trail
This is the CloudTrail API Reference. It provides descriptions of
actions, data types, common parameters, and common errors for
CloudTrail.
CloudTrail is a web service that records AWS API calls for your
AWS account and delivers log files to an Amazon S3 bucket. The
recorded information includes the identity of the user, the start
time of the event, the source IP address, the request parameters,
and the response elements returned by the service.
As an alternative to using the API, you can use one of the AWS
SDKs, which consist of libraries and sample code for various
programming languages and platforms (Java, Ruby, .NET, iOS,
Android, etc.). The SDKs provide a convenient way to create
programmatic access to AWSCloudTrail. For example, the SDKs take
care of cryptographically signing requests, managing errors, and
retrying requests automatically. For information about the AWS
SDKs, including how to download and install them, see the Tools
for Amazon Web Services page.
See the CloudTrail User Guide for information about the data that
is included with each event listed in the log files.
"""
APIVersion = "2013-11-01"
DefaultRegionName = "us-east-1"
DefaultRegionEndpoint = "cloudtrail.us-east-1.amazonaws.com"
ServiceName = "CloudTrail"
TargetPrefix = "com.amazonaws.cloudtrail.v20131101.CloudTrail_20131101"
ResponseError = JSONResponseError
_faults = {
"InvalidSnsTopicNameException": exceptions.InvalidSnsTopicNameException,
"InvalidS3BucketNameException": exceptions.InvalidS3BucketNameException,
"TrailAlreadyExistsException": exceptions.TrailAlreadyExistsException,
"InsufficientSnsTopicPolicyException": exceptions.InsufficientSnsTopicPolicyException,
"InvalidTrailNameException": exceptions.InvalidTrailNameException,
"InternalErrorException": exceptions.InternalErrorException,
"TrailNotFoundException": exceptions.TrailNotFoundException,
"S3BucketDoesNotExistException": exceptions.S3BucketDoesNotExistException,
"TrailNotProvidedException": exceptions.TrailNotProvidedException,
"InvalidS3PrefixException": exceptions.InvalidS3PrefixException,
"MaximumNumberOfTrailsExceededException": exceptions.MaximumNumberOfTrailsExceededException,
"InsufficientS3BucketPolicyException": exceptions.InsufficientS3BucketPolicyException,
}
def __init__(self, **kwargs):
region = kwargs.pop('region', None)
if not region:
region = RegionInfo(self, self.DefaultRegionName,
self.DefaultRegionEndpoint)
if 'host' not in kwargs:
kwargs['host'] = region.endpoint
AWSQueryConnection.__init__(self, **kwargs)
self.region = region
def _required_auth_capability(self):
return ['hmac-v4']
def create_trail(self, trail=None):
"""
From the command line, use create-subscription.
Creates a trail that specifies the settings for delivery of
log data to an Amazon S3 bucket. The request includes a Trail
structure that specifies the following:
+ Trail name.
+ The name of the Amazon S3 bucket to which CloudTrail
delivers your log files.
+ The name of the Amazon S3 key prefix that precedes each log
file.
+ The name of the Amazon SNS topic that notifies you that a
new file is available in your bucket.
+ Whether the log file should include events from global
services. Currently, the only events included in CloudTrail
log files are from IAM and AWS STS.
Returns the appropriate HTTP status code if successful. If
not, it returns either one of the CommonErrors or a
FrontEndException with one of the following error codes:
**MaximumNumberOfTrailsExceeded**
An attempt was made to create more trails than allowed. You
can only create one trail for each account in each region.
**TrailAlreadyExists**
At attempt was made to create a trail with a name that already
exists.
**S3BucketDoesNotExist**
Specified Amazon S3 bucket does not exist.
**InsufficientS3BucketPolicy**
Policy on Amazon S3 bucket does not permit CloudTrail to write
to your bucket. See the AWS AWS CloudTrail User Guide for the
required bucket policy.
**InsufficientSnsTopicPolicy**
The policy on Amazon SNS topic does not permit CloudTrail to
write to it. Can also occur when an Amazon SNS topic does not
exist.
:type trail: dict
:param trail: Contains the Trail structure that specifies the settings
for each trail.
"""
params = {}
if trail is not None:
params['trail'] = trail
return self.make_request(action='CreateTrail',
body=json.dumps(params))
def delete_trail(self, name=None):
"""
Deletes a trail.
:type name: string
:param name: The name of a trail to be deleted.
"""
params = {}
if name is not None:
params['Name'] = name
return self.make_request(action='DeleteTrail',
body=json.dumps(params))
def describe_trails(self, trail_name_list=None):
"""
Retrieves the settings for some or all trails associated with
an account. Returns a list of Trail structures in JSON format.
:type trail_name_list: list
:param trail_name_list: The list of Trail object names.
"""
params = {}
if trail_name_list is not None:
params['trailNameList'] = trail_name_list
return self.make_request(action='DescribeTrails',
body=json.dumps(params))
def get_trail_status(self, name=None):
"""
Returns GetTrailStatusResult, which contains a JSON-formatted
list of information about the trail specified in the request.
JSON fields include information such as delivery errors,
Amazon SNS and Amazon S3 errors, and times that logging
started and stopped for each trail.
:type name: string
:param name: The name of the trail for which you are requesting the
current status.
"""
params = {}
if name is not None:
params['Name'] = name
return self.make_request(action='GetTrailStatus',
body=json.dumps(params))
def start_logging(self, name=None):
"""
Starts the processing of recording user activity events and
log file delivery for a trail.
:type name: string
:param name: The name of the Trail for which CloudTrail logs events.
"""
params = {}
if name is not None:
params['Name'] = name
return self.make_request(action='StartLogging',
body=json.dumps(params))
def stop_logging(self, name=None):
"""
Suspends the recording of user activity events and log file
delivery for the specified trail. Under most circumstances,
there is no need to use this action. You can update a trail
without stopping it first. This action is the only way to stop
logging activity.
:type name: string
:param name: Communicates to CloudTrail the name of the Trail for which
to stop logging events.
"""
params = {}
if name is not None:
params['Name'] = name
return self.make_request(action='StopLogging',
body=json.dumps(params))
def update_trail(self, trail=None):
"""
From the command line, use update-subscription.
Updates the settings that specify delivery of log files.
Changes to a trail do not require stopping the CloudTrail
service. You can use this action to designate an existing
bucket for log delivery, or to create a new bucket and prefix.
If the existing bucket has previously been a target for
CloudTrail log files, an IAM policy exists for the bucket. If
you create a new bucket using UpdateTrail, you need to apply
the policy to the bucket using one of the means provided by
the Amazon S3 service.
The request includes a Trail structure that specifies the
following:
+ Trail name.
+ The name of the Amazon S3 bucket to which CloudTrail
delivers your log files.
+ The name of the Amazon S3 key prefix that precedes each log
file.
+ The name of the Amazon SNS topic that notifies you that a
new file is available in your bucket.
+ Whether the log file should include events from global
services, such as IAM or AWS STS.
**CreateTrail** returns the appropriate HTTP status code if
successful. If not, it returns either one of the common errors
or one of the exceptions listed at the end of this page.
:type trail: dict
:param trail: Represents the Trail structure that contains the
CloudTrail setting for an account.
"""
params = {}
if trail is not None:
params['trail'] = trail
return self.make_request(action='UpdateTrail',
body=json.dumps(params))
def make_request(self, action, body):
headers = {
'X-Amz-Target': '%s.%s' % (self.TargetPrefix, action),
'Host': self.region.endpoint,
'Content-Type': 'application/x-amz-json-1.1',
'Content-Length': str(len(body)),
}
http_request = self.build_base_http_request(
method='POST', path='/', auth_path='/', params={},
headers=headers, data=body)
response = self._mexe(http_request, sender=None,
override_num_retries=10)
response_body = response.read()
boto.log.debug(response_body)
if response.status == 200:
if response_body:
return json.loads(response_body)
else:
json_body = json.loads(response_body)
fault_name = json_body.get('__type', None)
exception_class = self._faults.get(fault_name, self.ResponseError)
raise exception_class(response.status, response.reason,
body=json_body)

View File

@ -101,7 +101,7 @@ DEFAULT_CA_CERTS_FILE = os.path.join(os.path.dirname(os.path.abspath(boto.cacert
class HostConnectionPool(object):
"""
A pool of connections for one remote (host,is_secure).
A pool of connections for one remote (host,port,is_secure).
When connections are added to the pool, they are put into a
pending queue. The _mexe method returns connections to the pool
@ -145,7 +145,7 @@ class HostConnectionPool(object):
def get(self):
"""
Returns the next connection in this pool that is ready to be
reused. Returns None of there aren't any.
reused. Returns None if there aren't any.
"""
# Discard ready connections that are too old.
self.clean()
@ -234,7 +234,7 @@ class ConnectionPool(object):
STALE_DURATION = 60.0
def __init__(self):
# Mapping from (host,is_secure) to HostConnectionPool.
# Mapping from (host,port,is_secure) to HostConnectionPool.
# If a pool becomes empty, it is removed.
self.host_to_pool = {}
# The last time the pool was cleaned.
@ -259,7 +259,7 @@ class ConnectionPool(object):
"""
return sum(pool.size() for pool in self.host_to_pool.values())
def get_http_connection(self, host, is_secure):
def get_http_connection(self, host, port, is_secure):
"""
Gets a connection from the pool for the named host. Returns
None if there is no connection that can be reused. It's the caller's
@ -268,18 +268,18 @@ class ConnectionPool(object):
"""
self.clean()
with self.mutex:
key = (host, is_secure)
key = (host, port, is_secure)
if key not in self.host_to_pool:
return None
return self.host_to_pool[key].get()
def put_http_connection(self, host, is_secure, conn):
def put_http_connection(self, host, port, is_secure, conn):
"""
Adds a connection to the pool of connections that can be
reused for the named host.
"""
with self.mutex:
key = (host, is_secure)
key = (host, port, is_secure)
if key not in self.host_to_pool:
self.host_to_pool[key] = HostConnectionPool()
self.host_to_pool[key].put(conn)
@ -486,6 +486,11 @@ class AWSAuthConnection(object):
"2.6 or later.")
self.ca_certificates_file = config.get_value(
'Boto', 'ca_certificates_file', DEFAULT_CA_CERTS_FILE)
if port:
self.port = port
else:
self.port = PORTS_BY_SECURITY[is_secure]
self.handle_proxy(proxy, proxy_port, proxy_user, proxy_pass)
# define exceptions from httplib that we want to catch and retry
self.http_exceptions = (httplib.HTTPException, socket.error,
@ -513,10 +518,6 @@ class AWSAuthConnection(object):
if not isinstance(debug, (int, long)):
debug = 0
self.debug = config.getint('Boto', 'debug', debug)
if port:
self.port = port
else:
self.port = PORTS_BY_SECURITY[is_secure]
self.host_header = None
# Timeout used to tell httplib how long to wait for socket timeouts.
@ -551,7 +552,7 @@ class AWSAuthConnection(object):
self.host_header = self.provider.host_header
self._pool = ConnectionPool()
self._connection = (self.server_name(), self.is_secure)
self._connection = (self.host, self.port, self.is_secure)
self._last_rs = None
self._auth_handler = auth.get_auth_handler(
host, config, self.provider, self._required_auth_capability())
@ -652,7 +653,7 @@ class AWSAuthConnection(object):
if 'http_proxy' in os.environ and not self.proxy:
pattern = re.compile(
'(?:http://)?' \
'(?:(?P<user>\w+):(?P<pass>.*)@)?' \
'(?:(?P<user>[\w\-\.]+):(?P<pass>.*)@)?' \
'(?P<host>[\w\-\.]+)' \
'(?::(?P<port>\d+))?'
)
@ -680,12 +681,12 @@ class AWSAuthConnection(object):
self.no_proxy = os.environ.get('no_proxy', '') or os.environ.get('NO_PROXY', '')
self.use_proxy = (self.proxy != None)
def get_http_connection(self, host, is_secure):
conn = self._pool.get_http_connection(host, is_secure)
def get_http_connection(self, host, port, is_secure):
conn = self._pool.get_http_connection(host, port, is_secure)
if conn is not None:
return conn
else:
return self.new_http_connection(host, is_secure)
return self.new_http_connection(host, port, is_secure)
def skip_proxy(self, host):
if not self.no_proxy:
@ -703,16 +704,29 @@ class AWSAuthConnection(object):
return False
def new_http_connection(self, host, is_secure):
if self.use_proxy and not is_secure and \
not self.skip_proxy(host):
host = '%s:%d' % (self.proxy, int(self.proxy_port))
def new_http_connection(self, host, port, is_secure):
if host is None:
host = self.server_name()
# Make sure the host is really just the host, not including
# the port number
host = host.split(':', 1)[0]
http_connection_kwargs = self.http_connection_kwargs.copy()
# Connection factories below expect a port keyword argument
http_connection_kwargs['port'] = port
# Override host with proxy settings if needed
if self.use_proxy and not is_secure and \
not self.skip_proxy(host):
host = self.proxy
http_connection_kwargs['port'] = int(self.proxy_port)
if is_secure:
boto.log.debug(
'establishing HTTPS connection: host=%s, kwargs=%s',
host, self.http_connection_kwargs)
host, http_connection_kwargs)
if self.use_proxy and not self.skip_proxy(host):
connection = self.proxy_ssl(host, is_secure and 443 or 80)
elif self.https_connection_factory:
@ -720,35 +734,35 @@ class AWSAuthConnection(object):
elif self.https_validate_certificates and HAVE_HTTPS_CONNECTION:
connection = https_connection.CertValidatingHTTPSConnection(
host, ca_certs=self.ca_certificates_file,
**self.http_connection_kwargs)
**http_connection_kwargs)
else:
connection = httplib.HTTPSConnection(host,
**self.http_connection_kwargs)
**http_connection_kwargs)
else:
boto.log.debug('establishing HTTP connection: kwargs=%s' %
self.http_connection_kwargs)
http_connection_kwargs)
if self.https_connection_factory:
# even though the factory says https, this is too handy
# to not be able to allow overriding for http also.
connection = self.https_connection_factory(host,
**self.http_connection_kwargs)
**http_connection_kwargs)
else:
connection = httplib.HTTPConnection(host,
**self.http_connection_kwargs)
**http_connection_kwargs)
if self.debug > 1:
connection.set_debuglevel(self.debug)
# self.connection must be maintained for backwards-compatibility
# however, it must be dynamically pulled from the connection pool
# set a private variable which will enable that
if host.split(':')[0] == self.host and is_secure == self.is_secure:
self._connection = (host, is_secure)
self._connection = (host, port, is_secure)
# Set the response class of the http connection to use our custom
# class.
connection.response_class = HTTPResponse
return connection
def put_http_connection(self, host, is_secure, connection):
self._pool.put_http_connection(host, is_secure, connection)
def put_http_connection(self, host, port, is_secure, connection):
self._pool.put_http_connection(host, port, is_secure, connection)
def proxy_ssl(self, host=None, port=None):
if host and port:
@ -841,6 +855,7 @@ class AWSAuthConnection(object):
boto.log.debug('Data: %s' % request.body)
boto.log.debug('Headers: %s' % request.headers)
boto.log.debug('Host: %s' % request.host)
boto.log.debug('Port: %s' % request.port)
boto.log.debug('Params: %s' % request.params)
response = None
body = None
@ -850,7 +865,8 @@ class AWSAuthConnection(object):
else:
num_retries = override_num_retries
i = 0
connection = self.get_http_connection(request.host, self.is_secure)
connection = self.get_http_connection(request.host, request.port,
self.is_secure)
while i <= num_retries:
# Use binary exponential backoff to desynchronize client requests.
next_sleep = random.random() * (2 ** i)
@ -858,6 +874,12 @@ class AWSAuthConnection(object):
# we now re-sign each request before it is retried
boto.log.debug('Token: %s' % self.provider.security_token)
request.authorize(connection=self)
# Only force header for non-s3 connections, because s3 uses
# an older signing method + bucket resource URLs that include
# the port info. All others should be now be up to date and
# not include the port.
if 's3' not in self._required_auth_capability():
request.headers['Host'] = self.host.split(':', 1)[0]
if callable(sender):
response = sender(connection, request.method, request.path,
request.body, request.headers)
@ -880,31 +902,45 @@ class AWSAuthConnection(object):
boto.log.debug(msg)
time.sleep(next_sleep)
continue
if response.status == 500 or response.status == 503:
if response.status in [500, 502, 503, 504]:
msg = 'Received %d response. ' % response.status
msg += 'Retrying in %3.1f seconds' % next_sleep
boto.log.debug(msg)
body = response.read()
elif response.status < 300 or response.status >= 400 or \
not location:
self.put_http_connection(request.host, self.is_secure,
connection)
# don't return connection to the pool if response contains
# Connection:close header, because the connection has been
# closed and default reconnect behavior may do something
# different than new_http_connection. Also, it's probably
# less efficient to try to reuse a closed connection.
conn_header_value = response.getheader('connection')
if conn_header_value == 'close':
connection.close()
else:
self.put_http_connection(request.host, request.port,
self.is_secure, connection)
return response
else:
scheme, request.host, request.path, \
params, query, fragment = urlparse.urlparse(location)
if query:
request.path += '?' + query
# urlparse can return both host and port in netloc, so if
# that's the case we need to split them up properly
if ':' in request.host:
request.host, request.port = request.host.split(':', 1)
msg = 'Redirecting: %s' % scheme + '://'
msg += request.host + request.path
boto.log.debug(msg)
connection = self.get_http_connection(request.host,
request.port,
scheme == 'https')
response = None
continue
except PleaseRetryException, e:
boto.log.debug('encountered a retry exception: %s' % e)
connection = self.new_http_connection(request.host,
connection = self.new_http_connection(request.host, request.port,
self.is_secure)
response = e.response
except self.http_exceptions, e:
@ -913,10 +949,10 @@ class AWSAuthConnection(object):
boto.log.debug(
'encountered unretryable %s exception, re-raising' %
e.__class__.__name__)
raise e
raise
boto.log.debug('encountered %s exception, reconnecting' % \
e.__class__.__name__)
connection = self.new_http_connection(request.host,
connection = self.new_http_connection(request.host, request.port,
self.is_secure)
time.sleep(next_sleep)
i += 1
@ -927,7 +963,7 @@ class AWSAuthConnection(object):
if response:
raise BotoServerError(response.status, response.reason, body)
elif e:
raise e
raise
else:
msg = 'Please report this exception as a Boto Issue!'
raise BotoClientError(msg)
@ -1006,7 +1042,7 @@ class AWSQueryConnection(AWSAuthConnection):
def make_request(self, action, params=None, path='/', verb='GET'):
http_request = self.build_base_http_request(verb, path, None,
params, {}, '',
self.server_name())
self.host)
if action:
http_request.params['Action'] = action
if self.APIVersion:

View File

@ -50,11 +50,11 @@ class Item(dict):
if range_key == None:
range_key = attrs.get(self._range_key_name, None)
self[self._range_key_name] = range_key
self._updates = {}
for key, value in attrs.items():
if key != self._hash_key_name and key != self._range_key_name:
self[key] = value
self.consumed_units = 0
self._updates = {}
@property
def hash_key(self):

View File

@ -277,6 +277,10 @@ class Dynamizer(object):
if len(attr) > 1 or not attr:
return attr
dynamodb_type = attr.keys()[0]
if dynamodb_type.lower() == dynamodb_type:
# It's not an actual type, just a single character attr that
# overlaps with the DDB types. Return it.
return attr
try:
decoder = getattr(self, '_decode_%s' % dynamodb_type.lower())
except AttributeError:

View File

@ -21,7 +21,11 @@
#
from binascii import crc32
import json
try:
import json
except ImportError:
import simplejson as json
import boto
from boto.connection import AWSQueryConnection
from boto.regioninfo import RegionInfo
@ -67,7 +71,11 @@ class DynamoDBConnection(AWSQueryConnection):
if reg.name == region_name:
region = reg
break
kwargs['host'] = region.endpoint
# Only set host if it isn't manually overwritten
if 'host' not in kwargs:
kwargs['host'] = region.endpoint
AWSQueryConnection.__init__(self, **kwargs)
self.region = region
self._validate_checksums = boto.config.getbool(
@ -1467,13 +1475,13 @@ class DynamoDBConnection(AWSQueryConnection):
def make_request(self, action, body):
headers = {
'X-Amz-Target': '%s.%s' % (self.TargetPrefix, action),
'Host': self.region.endpoint,
'Host': self.host,
'Content-Type': 'application/x-amz-json-1.0',
'Content-Length': str(len(body)),
}
http_request = self.build_base_http_request(
method='POST', path='/', auth_path='/', params={},
headers=headers, data=body)
headers=headers, data=body, host=self.host)
response = self._mexe(http_request, sender=None,
override_num_retries=self.NumberRetries,
retry_handler=self._retry_handler)

View File

@ -418,6 +418,45 @@ class Table(object):
item.load(item_data)
return item
def lookup(self, *args, **kwargs):
"""
Look up an entry in DynamoDB. This is mostly backwards compatible
with boto.dynamodb. Unlike get_item, it takes hash_key and range_key first,
although you may still specify keyword arguments instead.
Also unlike the get_item command, if the returned item has no keys
(i.e., it does not exist in DynamoDB), a None result is returned, instead
of an empty key object.
Example::
>>> user = users.lookup(username)
>>> user = users.lookup(username, consistent=True)
>>> app = apps.lookup('my_customer_id', 'my_app_id')
"""
if not self.schema:
self.describe()
for x, arg in enumerate(args):
kwargs[self.schema[x].name] = arg
ret = self.get_item(**kwargs)
if not ret.keys():
return None
return ret
def new_item(self, *args):
"""
Returns a new, blank item
This is mostly for consistency with boto.dynamodb
"""
if not self.schema:
self.describe()
data = {}
for x, arg in enumerate(args):
data[self.schema[x].name] = arg
return Item(self, data=data)
def put_item(self, data, overwrite=False):
"""
Saves an entire item to DynamoDB.
@ -1164,4 +1203,4 @@ class BatchTable(object):
self.handle_unprocessed(resp)
boto.log.info(
"%s unprocessed items left" % len(self._unprocessed)
)
)

View File

@ -241,6 +241,10 @@ class AutoScaleConnection(AWSQueryConnection):
params['EbsOptimized'] = 'true'
else:
params['EbsOptimized'] = 'false'
if launch_config.associate_public_ip_address is True:
params['AssociatePublicIpAddress'] = 'true'
elif launch_config.associate_public_ip_address is False:
params['AssociatePublicIpAddress'] = 'false'
return self.get_object('CreateLaunchConfiguration', params,
Request, verb='POST')
@ -492,15 +496,19 @@ class AutoScaleConnection(AWSQueryConnection):
If no group name or list of policy names are provided, all
available policies are returned.
:type as_name: str
:param as_name: The name of the
:type as_group: str
:param as_group: The name of the
:class:`boto.ec2.autoscale.group.AutoScalingGroup` to filter for.
:type names: list
:param names: List of policy names which should be searched for.
:type policy_names: list
:param policy_names: List of policy names which should be searched for.
:type max_records: int
:param max_records: Maximum amount of groups to return.
:type next_token: str
:param next_token: If you have more results than can be returned
at once, pass in this parameter to page through all results.
"""
params = {}
if as_group:
@ -681,9 +689,9 @@ class AutoScaleConnection(AWSQueryConnection):
Configures an Auto Scaling group to send notifications when
specified events take place.
:type as_group: str or
:type autoscale_group: str or
:class:`boto.ec2.autoscale.group.AutoScalingGroup` object
:param as_group: The Auto Scaling group to put notification
:param autoscale_group: The Auto Scaling group to put notification
configuration on.
:type topic: str
@ -692,7 +700,12 @@ class AutoScaleConnection(AWSQueryConnection):
:type notification_types: list
:param notification_types: The type of events that will trigger
the notification.
the notification. Valid types are:
'autoscaling:EC2_INSTANCE_LAUNCH',
'autoscaling:EC2_INSTANCE_LAUNCH_ERROR',
'autoscaling:EC2_INSTANCE_TERMINATE',
'autoscaling:EC2_INSTANCE_TERMINATE_ERROR',
'autoscaling:TEST_NOTIFICATION'
"""
name = autoscale_group
@ -704,6 +717,29 @@ class AutoScaleConnection(AWSQueryConnection):
self.build_list_params(params, notification_types, 'NotificationTypes')
return self.get_status('PutNotificationConfiguration', params)
def delete_notification_configuration(self, autoscale_group, topic):
"""
Deletes notifications created by put_notification_configuration.
:type autoscale_group: str or
:class:`boto.ec2.autoscale.group.AutoScalingGroup` object
:param autoscale_group: The Auto Scaling group to put notification
configuration on.
:type topic: str
:param topic: The Amazon Resource Name (ARN) of the Amazon Simple
Notification Service (SNS) topic.
"""
name = autoscale_group
if isinstance(autoscale_group, AutoScalingGroup):
name = autoscale_group.name
params = {'AutoScalingGroupName': name,
'TopicARN': topic}
return self.get_status('DeleteNotificationConfiguration', params)
def set_instance_health(self, instance_id, health_status,
should_respect_grace_period=True):
"""

View File

@ -148,6 +148,9 @@ class AutoScalingGroup(object):
:type vpc_zone_identifier: str
:param vpc_zone_identifier: The subnet identifier of the Virtual
Private Cloud.
:type tags: list
:param tags: List of :class:`boto.ec2.autoscale.tag.Tag`s
:type termination_policies: list
:param termination_policies: A list of termination policies. Valid values
@ -296,12 +299,23 @@ class AutoScalingGroup(object):
def put_notification_configuration(self, topic, notification_types):
"""
Configures an Auto Scaling group to send notifications when
specified events take place.
specified events take place. Valid notification types are:
'autoscaling:EC2_INSTANCE_LAUNCH',
'autoscaling:EC2_INSTANCE_LAUNCH_ERROR',
'autoscaling:EC2_INSTANCE_TERMINATE',
'autoscaling:EC2_INSTANCE_TERMINATE_ERROR',
'autoscaling:TEST_NOTIFICATION'
"""
return self.connection.put_notification_configuration(self,
topic,
notification_types)
def delete_notification_configuration(self, topic):
"""
Deletes notifications created by put_notification_configuration.
"""
return self.connection.delete_notification_configuration(self, topic)
def suspend_processes(self, scaling_processes=None):
"""
Suspends Auto Scaling processes for an Auto Scaling group.

View File

@ -94,7 +94,8 @@ class LaunchConfiguration(object):
instance_type='m1.small', kernel_id=None,
ramdisk_id=None, block_device_mappings=None,
instance_monitoring=False, spot_price=None,
instance_profile_name=None, ebs_optimized=False):
instance_profile_name=None, ebs_optimized=False,
associate_public_ip_address=None):
"""
A launch configuration.
@ -109,8 +110,9 @@ class LaunchConfiguration(object):
:param key_name: The name of the EC2 key pair.
:type security_groups: list
:param security_groups: Names of the security groups with which to
associate the EC2 instances.
:param security_groups: Names or security group id's of the security
groups with which to associate the EC2 instances or VPC instances,
respectively.
:type user_data: str
:param user_data: The user data available to launched EC2 instances.
@ -144,6 +146,10 @@ class LaunchConfiguration(object):
:type ebs_optimized: bool
:param ebs_optimized: Specifies whether the instance is optimized
for EBS I/O (true) or not (false).
:type associate_public_ip_address: bool
:param associate_public_ip_address: Used for Auto Scaling groups that launch instances into an Amazon Virtual Private Cloud.
Specifies whether to assign a public IP address to each instance launched in a Amazon VPC.
"""
self.connection = connection
self.name = name
@ -163,6 +169,7 @@ class LaunchConfiguration(object):
self.instance_profile_name = instance_profile_name
self.launch_configuration_arn = None
self.ebs_optimized = ebs_optimized
self.associate_public_ip_address = associate_public_ip_address
def __repr__(self):
return 'LaunchConfiguration:%s' % self.name

View File

@ -55,11 +55,11 @@ class Tag(object):
self.key = value
elif name == 'Value':
self.value = value
elif name == 'PropogateAtLaunch':
elif name == 'PropagateAtLaunch':
if value.lower() == 'true':
self.propogate_at_launch = True
self.propagate_at_launch = True
else:
self.propogate_at_launch = False
self.propagate_at_launch = False
elif name == 'ResourceId':
self.resource_id = value
elif name == 'ResourceType':

View File

@ -95,7 +95,7 @@ class MetricAlarm(object):
statistic is applied.
:type evaluation_periods: int
:param evaluation_period: The number of periods over which data is
:param evaluation_periods: The number of periods over which data is
compared to the specified threshold.
:type unit: str
@ -112,9 +112,16 @@ class MetricAlarm(object):
:type description: str
:param description: Description of MetricAlarm
:type dimensions: list of dicts
:param description: Dimensions of alarm, such as:
[{'InstanceId':['i-0123456,i-0123457']}]
:type dimensions: dict
:param dimensions: A dictionary of dimension key/values where
the key is the dimension name and the value
is either a scalar value or an iterator
of values to be associated with that
dimension.
Example: {
'InstanceId': ['i-0123456', 'i-0123457'],
'LoadBalancerName': 'test-lb'
}
:type alarm_actions: list of strs
:param alarm_actions: A list of the ARNs of the actions to take in

View File

@ -69,7 +69,7 @@ from boto.exception import EC2ResponseError
class EC2Connection(AWSQueryConnection):
APIVersion = boto.config.get('Boto', 'ec2_version', '2013-07-15')
APIVersion = boto.config.get('Boto', 'ec2_version', '2013-10-01')
DefaultRegionName = boto.config.get('Boto', 'ec2_region_name', 'us-east-1')
DefaultRegionEndpoint = boto.config.get('Boto', 'ec2_region_endpoint',
'ec2.us-east-1.amazonaws.com')
@ -260,7 +260,7 @@ class EC2Connection(AWSQueryConnection):
def register_image(self, name=None, description=None, image_location=None,
architecture=None, kernel_id=None, ramdisk_id=None,
root_device_name=None, block_device_map=None,
dry_run=False):
dry_run=False, virtualization_type=None):
"""
Register an image.
@ -293,6 +293,12 @@ class EC2Connection(AWSQueryConnection):
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
:type virtualization_type: string
:param virtualization_type: The virutalization_type of the image.
Valid choices are:
* paravirtual
* hvm
:rtype: string
:return: The new image id
"""
@ -315,6 +321,9 @@ class EC2Connection(AWSQueryConnection):
block_device_map.ec2_build_list_params(params)
if dry_run:
params['DryRun'] = 'true'
if virtualization_type:
params['VirtualizationType'] = virtualization_type
rs = self.get_object('RegisterImage', params, ResultSet, verb='POST')
image_id = getattr(rs, 'imageId', None)
return image_id
@ -355,7 +364,8 @@ class EC2Connection(AWSQueryConnection):
return result
def create_image(self, instance_id, name,
description=None, no_reboot=False, dry_run=False):
description=None, no_reboot=False,
block_device_mapping=None, dry_run=False):
"""
Will create an AMI from the instance in the running or stopped
state.
@ -377,6 +387,10 @@ class EC2Connection(AWSQueryConnection):
responsibility of maintaining file system integrity is
left to the owner of the instance.
:type block_device_mapping: :class:`boto.ec2.blockdevicemapping.BlockDeviceMapping`
:param block_device_mapping: A BlockDeviceMapping data structure
describing the EBS volumes associated with the Image.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
@ -389,6 +403,8 @@ class EC2Connection(AWSQueryConnection):
params['Description'] = description
if no_reboot:
params['NoReboot'] = 'true'
if block_device_mapping:
block_device_mapping.ec2_build_list_params(params)
if dry_run:
params['DryRun'] = 'true'
img = self.get_object('CreateImage', params, Image, verb='POST')
@ -1500,7 +1516,7 @@ class EC2Connection(AWSQueryConnection):
if dry_run:
params['DryRun'] = 'true'
return self.get_list('CancelSpotInstanceRequests', params,
[('item', Instance)], verb='POST')
[('item', SpotInstanceRequest)], verb='POST')
def get_spot_datafeed_subscription(self, dry_run=False):
"""
@ -2189,17 +2205,17 @@ class EC2Connection(AWSQueryConnection):
present, only the Snapshots associated with
these snapshot ids will be returned.
:type owner: str
:param owner: If present, only the snapshots owned by the specified user
:type owner: str or list
:param owner: If present, only the snapshots owned by the specified user(s)
will be returned. Valid values are:
* self
* amazon
* AWS Account ID
:type restorable_by: str
:type restorable_by: str or list
:param restorable_by: If present, only the snapshots that are restorable
by the specified account id will be returned.
by the specified account id(s) will be returned.
:type filters: dict
:param filters: Optional filters that can be used to limit
@ -2220,10 +2236,11 @@ class EC2Connection(AWSQueryConnection):
params = {}
if snapshot_ids:
self.build_list_params(params, snapshot_ids, 'SnapshotId')
if owner:
params['Owner'] = owner
self.build_list_params(params, owner, 'Owner')
if restorable_by:
params['RestorableBy'] = restorable_by
self.build_list_params(params, restorable_by, 'RestorableBy')
if filters:
self.build_filter_params(params, filters)
if dry_run:

View File

@ -188,13 +188,13 @@ class ELBConnection(AWSQueryConnection):
(LoadBalancerPortNumber, InstancePortNumber, Protocol, InstanceProtocol,
SSLCertificateId).
Where;
- LoadBalancerPortNumber and InstancePortNumber are integer
values between 1 and 65535.
- Protocol and InstanceProtocol is a string containing either 'TCP',
'SSL', 'HTTP', or 'HTTPS'
- SSLCertificateId is the ARN of an SSL certificate loaded into
AWS IAM
Where:
- LoadBalancerPortNumber and InstancePortNumber are integer
values between 1 and 65535
- Protocol and InstanceProtocol is a string containing either 'TCP',
'SSL', 'HTTP', or 'HTTPS'
- SSLCertificateId is the ARN of an SSL certificate loaded into
AWS IAM
:rtype: :class:`boto.ec2.elb.loadbalancer.LoadBalancer`
:return: The newly created
@ -272,13 +272,13 @@ class ELBConnection(AWSQueryConnection):
(LoadBalancerPortNumber, InstancePortNumber, Protocol, InstanceProtocol,
SSLCertificateId).
Where;
- LoadBalancerPortNumber and InstancePortNumber are integer
values between 1 and 65535.
- Protocol and InstanceProtocol is a string containing either 'TCP',
'SSL', 'HTTP', or 'HTTPS'
- SSLCertificateId is the ARN of an SSL certificate loaded into
AWS IAM
Where:
- LoadBalancerPortNumber and InstancePortNumber are integer
values between 1 and 65535
- Protocol and InstanceProtocol is a string containing either 'TCP',
'SSL', 'HTTP', or 'HTTPS'
- SSLCertificateId is the ARN of an SSL certificate loaded into
AWS IAM
:return: The status of the request
"""

View File

@ -342,7 +342,7 @@ class LoadBalancer(object):
"""
if isinstance(subnets, str) or isinstance(subnets, unicode):
subnets = [subnets]
new_subnets = self.connection.detach_lb_to_subnets(self.name, subnets)
new_subnets = self.connection.detach_lb_from_subnets(self.name, subnets)
self.subnets = new_subnets
def apply_security_groups(self, security_groups):

View File

@ -340,14 +340,6 @@ class Instance(TaggedEC2Object):
self.ami_launch_index = value
elif name == 'previousState':
self.previous_state = value
elif name == 'name':
self.state = value
elif name == 'code':
try:
self.state_code = int(value)
except ValueError:
boto.log.warning('Error converting code (%s) to int' % value)
self.state_code = value
elif name == 'instanceType':
self.instance_type = value
elif name == 'rootDeviceName':

View File

@ -234,11 +234,12 @@ class PriceSchedule(object):
class ReservedInstancesConfiguration(object):
def __init__(self, connection=None, availability_zone=None, platform=None,
instance_count=None):
instance_count=None, instance_type=None):
self.connection = connection
self.availability_zone = availability_zone
self.platform = platform
self.instance_count = instance_count
self.instance_type = instance_type
def startElement(self, name, attrs, connection):
return None
@ -250,6 +251,8 @@ class ReservedInstancesConfiguration(object):
self.platform = value
elif name == 'instanceCount':
self.instance_count = int(value)
elif name == 'instanceType':
self.instance_type = value
else:
setattr(self, name, value)
@ -271,12 +274,14 @@ class ModifyReservedInstancesResult(object):
class ModificationResult(object):
def __init__(self, connection=None, modification_id=None,
availability_zone=None, platform=None, instance_count=None):
availability_zone=None, platform=None, instance_count=None,
instance_type=None):
self.connection = connection
self.modification_id = modification_id
self.availability_zone = availability_zone
self.platform = platform
self.instance_count = instance_count
self.instance_type = instance_type
def startElement(self, name, attrs, connection):
return None
@ -290,6 +295,8 @@ class ModificationResult(object):
self.platform = value
elif name == 'instanceCount':
self.instance_count = int(value)
elif name == 'instanceType':
self.instance_type = value
else:
setattr(self, name, value)

View File

@ -123,6 +123,9 @@ class SecurityGroup(TaggedEC2Object):
only changes the local version of the object. No information
is sent to EC2.
"""
if not self.rules:
raise ValueError("The security group has no rules")
target_rule = None
for rule in self.rules:
if rule.ip_protocol == ip_protocol:
@ -136,9 +139,9 @@ class SecurityGroup(TaggedEC2Object):
if grant.cidr_ip == cidr_ip:
target_grant = grant
if target_grant:
rule.grants.remove(target_grant, dry_run=dry_run)
if len(rule.grants) == 0:
self.rules.remove(target_rule, dry_run=dry_run)
rule.grants.remove(target_grant)
if len(rule.grants) == 0:
self.rules.remove(target_rule)
def authorize(self, ip_protocol=None, from_port=None, to_port=None,
cidr_ip=None, src_group=None, dry_run=False):

View File

@ -387,8 +387,8 @@ class ElasticTranscoderConnection(AWSAuthConnection):
:param description: A description of the preset.
:type container: string
:param container: The container type for the output file. This value
must be `mp4`.
:param container: The container type for the output file. Valid values
include `mp3`, `mp4`, `ogg`, `ts`, and `webm`.
:type video: dict
:param video: A section of the request body that specifies the video

View File

@ -43,25 +43,25 @@ def regions():
endpoint='elasticmapreduce.us-east-1.amazonaws.com',
connection_cls=EmrConnection),
RegionInfo(name='us-west-1',
endpoint='elasticmapreduce.us-west-1.amazonaws.com',
endpoint='us-west-1.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection),
RegionInfo(name='us-west-2',
endpoint='elasticmapreduce.us-west-2.amazonaws.com',
endpoint='us-west-2.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection),
RegionInfo(name='ap-northeast-1',
endpoint='elasticmapreduce.ap-northeast-1.amazonaws.com',
endpoint='ap-northeast-1.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection),
RegionInfo(name='ap-southeast-1',
endpoint='elasticmapreduce.ap-southeast-1.amazonaws.com',
endpoint='ap-southeast-1.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection),
RegionInfo(name='ap-southeast-2',
endpoint='elasticmapreduce.ap-southeast-2.amazonaws.com',
endpoint='ap-southeast-2.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection),
RegionInfo(name='eu-west-1',
endpoint='elasticmapreduce.eu-west-1.amazonaws.com',
endpoint='eu-west-1.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection),
RegionInfo(name='sa-east-1',
endpoint='elasticmapreduce.sa-east-1.amazonaws.com',
endpoint='sa-east-1.elasticmapreduce.amazonaws.com',
connection_cls=EmrConnection),
]

View File

@ -28,9 +28,12 @@ import types
import boto
import boto.utils
from boto.ec2.regioninfo import RegionInfo
from boto.emr.emrobject import JobFlow, RunJobFlowResponse
from boto.emr.emrobject import AddInstanceGroupsResponse
from boto.emr.emrobject import ModifyInstanceGroupsResponse
from boto.emr.emrobject import AddInstanceGroupsResponse, BootstrapActionList, \
Cluster, ClusterSummaryList, HadoopStep, \
InstanceGroupList, InstanceList, JobFlow, \
JobFlowStepList, \
ModifyInstanceGroupsResponse, \
RunJobFlowResponse, StepSummaryList
from boto.emr.step import JarStep
from boto.connection import AWSQueryConnection
from boto.exception import EmrResponseError
@ -65,10 +68,30 @@ class EmrConnection(AWSQueryConnection):
https_connection_factory, path,
security_token,
validate_certs=validate_certs)
# Many of the EMR hostnames are of the form:
# <region>.<service_name>.amazonaws.com
# rather than the more common:
# <service_name>.<region>.amazonaws.com
# so we need to explicitly set the region_name and service_name
# for the SigV4 signing.
self.auth_region_name = self.region.name
self.auth_service_name = 'elasticmapreduce'
def _required_auth_capability(self):
return ['hmac-v4']
def describe_cluster(self, cluster_id):
"""
Describes an Elastic MapReduce cluster
:type cluster_id: str
:param cluster_id: The cluster id of interest
"""
params = {
'ClusterId': cluster_id
}
return self.get_object('DescribeCluster', params, Cluster)
def describe_jobflow(self, jobflow_id):
"""
Describes a single Elastic MapReduce job flow
@ -111,6 +134,139 @@ class EmrConnection(AWSQueryConnection):
return self.get_list('DescribeJobFlows', params, [('member', JobFlow)])
def describe_step(self, cluster_id, step_id):
"""
Describe an Elastic MapReduce step
:type cluster_id: str
:param cluster_id: The cluster id of interest
:type step_id: str
:param step_id: The step id of interest
"""
params = {
'ClusterId': cluster_id,
'StepId': step_id
}
return self.get_object('DescribeStep', params, HadoopStep)
def list_bootstrap_actions(self, cluster_id, marker=None):
"""
Get a list of bootstrap actions for an Elastic MapReduce cluster
:type cluster_id: str
:param cluster_id: The cluster id of interest
:type marker: str
:param marker: Pagination marker
"""
params = {
'ClusterId': cluster_id
}
if marker:
params['Marker'] = marker
return self.get_object('ListBootstrapActions', params, BootstrapActionList)
def list_clusters(self, created_after=None, created_before=None,
cluster_states=None, marker=None):
"""
List Elastic MapReduce clusters with optional filtering
:type created_after: datetime
:param created_after: Bound on cluster creation time
:type created_before: datetime
:param created_before: Bound on cluster creation time
:type cluster_states: list
:param cluster_states: Bound on cluster states
:type marker: str
:param marker: Pagination marker
"""
params = {}
if created_after:
params['CreatedAfter'] = created_after.strftime(
boto.utils.ISO8601)
if created_before:
params['CreatedBefore'] = created_before.strftime(
boto.utils.ISO8601)
if marker:
params['Marker'] = marker
if cluster_states:
self.build_list_params(params, cluster_states, 'ClusterStates.member')
return self.get_object('ListClusters', params, ClusterSummaryList)
def list_instance_groups(self, cluster_id, marker=None):
"""
List EC2 instance groups in a cluster
:type cluster_id: str
:param cluster_id: The cluster id of interest
:type marker: str
:param marker: Pagination marker
"""
params = {
'ClusterId': cluster_id
}
if marker:
params['Marker'] = marker
return self.get_object('ListInstanceGroups', params, InstanceGroupList)
def list_instances(self, cluster_id, instance_group_id=None,
instance_group_types=None, marker=None):
"""
List EC2 instances in a cluster
:type cluster_id: str
:param cluster_id: The cluster id of interest
:type instance_group_id: str
:param instance_group_id: The EC2 instance group id of interest
:type instance_group_types: list
:param instance_group_types: Filter by EC2 instance group type
:type marker: str
:param marker: Pagination marker
"""
params = {
'ClusterId': cluster_id
}
if instance_group_id:
params['InstanceGroupId'] = instance_group_id
if marker:
params['Marker'] = marker
if instance_group_types:
self.build_list_params(params, instance_group_types,
'InstanceGroupTypeList.member')
return self.get_object('ListInstances', params, InstanceList)
def list_steps(self, cluster_id, step_states=None, marker=None):
"""
List cluster steps
:type cluster_id: str
:param cluster_id: The cluster id of interest
:type step_states: list
:param step_states: Filter by step states
:type marker: str
:param marker: Pagination marker
"""
params = {
'ClusterId': cluster_id
}
if marker:
params['Marker'] = marker
if step_states:
self.build_list_params(params, step_states, 'StepStateList.member')
self.get_object('ListSteps', params, StepSummaryList)
def terminate_jobflow(self, jobflow_id):
"""
Terminate an Elastic MapReduce job flow
@ -150,7 +306,7 @@ class EmrConnection(AWSQueryConnection):
params.update(self._build_step_list(step_args))
return self.get_object(
'AddJobFlowSteps', params, RunJobFlowResponse, verb='POST')
'AddJobFlowSteps', params, JobFlowStepList, verb='POST')
def add_instance_groups(self, jobflow_id, instance_groups):
"""

View File

@ -60,11 +60,29 @@ class Arg(EmrObject):
self.value = value
class StepId(Arg):
pass
class JobFlowStepList(EmrObject):
def __ini__(self, connection=None):
self.connection = connection
self.stepids = None
def startElement(self, name, attrs, connection):
if name == 'StepIds':
self.stepids = ResultSet([('member', StepId)])
return self.stepids
else:
return None
class BootstrapAction(EmrObject):
Fields = set([
'Args',
'Name',
'Path',
'ScriptPath',
])
def startElement(self, name, attrs, connection):
@ -174,3 +192,281 @@ class JobFlow(EmrObject):
return self.bootstrapactions
else:
return None
class ClusterTimeline(EmrObject):
Fields = set([
'CreationDateTime',
'ReadyDateTime',
'EndDateTime'
])
class ClusterStatus(EmrObject):
Fields = set([
'State',
'StateChangeReason',
'Timeline'
])
def __init__(self, connection=None):
self.connection = connection
self.timeline = None
def startElement(self, name, attrs, connection):
if name == 'Timeline':
self.timeline = ClusterTimeline()
return self.timeline
else:
return None
class Ec2InstanceAttributes(EmrObject):
Fields = set([
'Ec2KeyName',
'Ec2SubnetId',
'Ec2AvailabilityZone',
'IamInstanceProfile'
])
class Application(EmrObject):
Fields = set([
'Name',
'Version',
'Args',
'AdditionalInfo'
])
class Cluster(EmrObject):
Fields = set([
'Id',
'Name',
'LogUri',
'RequestedAmiVersion',
'RunningAmiVersion',
'AutoTerminate',
'TerminationProtected',
'VisibleToAllUsers'
])
def __init__(self, connection=None):
self.connection = connection
self.status = None
self.ec2instanceattributes = None
self.applications = None
def startElement(self, name, attrs, connection):
if name == 'Status':
self.status = ClusterStatus()
return self.status
elif name == 'EC2InstanceAttributes':
self.ec2instanceattributes = Ec2InstanceAttributes()
return self.ec2instanceattributes
elif name == 'Applications':
self.applications = ResultSet([('member', Application)])
else:
return None
class ClusterSummary(Cluster):
Fields = set([
'Id',
'Name'
])
class ClusterSummaryList(EmrObject):
Fields = set([
'Marker'
])
def __init__(self, connection):
self.connection = connection
self.clusters = None
def startElement(self, name, attrs, connection):
if name == 'Clusters':
self.clusters = ResultSet([('member', ClusterSummary)])
return self.clusters
else:
return None
class StepConfig(EmrObject):
Fields = set([
'Jar'
'MainClass'
])
def __init__(self, connection=None):
self.connection = connection
self.properties = None
self.args = None
def startElement(self, name, attrs, connection):
if name == 'Properties':
self.properties = ResultSet([('member', KeyValue)])
return self.properties
elif name == 'Args':
self.args = ResultSet([('member', Arg)])
return self.args
else:
return None
class HadoopStep(EmrObject):
Fields = set([
'Id',
'Name',
'ActionOnFailure'
])
def __init__(self, connection=None):
self.connection = connection
self.config = None
self.status = None
def startElement(self, name, attrs, connection):
if name == 'Config':
self.config = StepConfig()
return self.config
elif name == 'Status':
self.status = ClusterStatus()
return self.status
else:
return None
class InstanceGroupInfo(EmrObject):
Fields = set([
'Id',
'Name',
'Market',
'InstanceGroupType',
'BidPrice',
'InstanceType',
'RequestedInstanceCount',
'RunningInstanceCount'
])
def __init__(self, connection=None):
self.connection = connection
self.status = None
def startElement(self, name, attrs, connection):
if name == 'Status':
self.status = ClusterStatus()
return self.status
else:
return None
class InstanceGroupList(EmrObject):
Fields = set([
'Marker'
])
def __init__(self, connection=None):
self.connection = connection
self.instancegroups = None
def startElement(self, name, attrs, connection):
if name == 'InstanceGroups':
self.instancegroups = ResultSet([('member', InstanceGroupInfo)])
return self.instancegroups
else:
return None
class InstanceInfo(EmrObject):
Fields = set([
'Id',
'Ec2InstanceId',
'PublicDnsName',
'PublicIpAddress',
'PrivateDnsName',
'PrivateIpAddress'
])
def __init__(self, connection=None):
self.connection = connection
self.status = None
def startElement(self, name, attrs, connection):
if name == 'Status':
self.status = ClusterStatus()
return self.status
else:
return None
class InstanceList(EmrObject):
Fields = set([
'Marker'
])
def __init__(self, connection=None):
self.connection = connection
self.instances = None
def startElement(self, name, attrs, connection):
if name == 'Instances':
self.instances = ResultSet([('member', InstanceInfo)])
return self.instances
else:
return None
class StepSummary(EmrObject):
Fields = set([
'Id',
'Name'
])
def __init__(self, connection=None):
self.connection = connection
self.status = None
def startElement(self, name, attrs, connection):
if name == 'Status':
self.status = ClusterStatus()
return self.status
else:
return None
class StepSummaryList(EmrObject):
Fields = set([
'Marker'
])
def __init__(self, connection=None):
self.connection = connection
self.steps = None
def startElement(self, name, attrs, connection):
if name == 'Steps':
self.steps = ResultSet([('member', StepSummary)])
return self.steps
else:
return None
class BootstrapActionList(EmrObject):
Fields = set([
'Marker'
])
def __init__(self, connection=None):
self.connection = connection
self.actions = None
def startElement(self, name, attrs, connection):
if name == 'BootstrapActions':
self.actions = ResultSet([('member', BootstrapAction)])
return self.actions
else:
return None

View File

@ -47,6 +47,9 @@ def regions():
RegionInfo(name='eu-west-1',
endpoint='glacier.eu-west-1.amazonaws.com',
connection_cls=Layer2),
RegionInfo(name='ap-southeast-2',
endpoint='glacier.ap-southeast-2.amazonaws.com',
connection_cls=Layer2),
]

View File

@ -19,12 +19,14 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
import re
import urllib
import xml.sax
import boto
from boto import handler
from boto.resultset import ResultSet
from boto.exception import GSResponseError
from boto.exception import InvalidAclError
from boto.gs.acl import ACL, CannedACLStrings
from boto.gs.acl import SupportedPermissions as GSPermissions
@ -41,6 +43,7 @@ DEF_OBJ_ACL = 'defaultObjectAcl'
STANDARD_ACL = 'acl'
CORS_ARG = 'cors'
LIFECYCLE_ARG = 'lifecycle'
ERROR_DETAILS_REGEX = re.compile(r'<Details>(?P<details>.*)</Details>')
class Bucket(S3Bucket):
"""Represents a Google Cloud Storage bucket."""
@ -99,9 +102,16 @@ class Bucket(S3Bucket):
if response_headers:
for rk, rv in response_headers.iteritems():
query_args_l.append('%s=%s' % (rk, urllib.quote(rv)))
key, resp = self._get_key_internal(key_name, headers,
query_args_l=query_args_l)
try:
key, resp = self._get_key_internal(key_name, headers,
query_args_l=query_args_l)
except GSResponseError, e:
if e.status == 403 and 'Forbidden' in e.reason:
# If we failed getting an object, let the user know which object
# failed rather than just returning a generic 403.
e.reason = ("Access denied to 'gs://%s/%s'." %
(self.name, key_name))
raise
return key
def copy_key(self, new_key_name, src_bucket_name, src_key_name,
@ -312,6 +322,14 @@ class Bucket(S3Bucket):
headers=headers)
body = response.read()
if response.status != 200:
if response.status == 403:
match = ERROR_DETAILS_REGEX.search(body)
details = match.group('details') if match else None
if details:
details = (('<Details>%s. Note that Full Control access'
' is required to access ACLs.</Details>') %
details)
body = re.sub(ERROR_DETAILS_REGEX, details, body)
raise self.connection.provider.storage_response_error(
response.status, response.reason, body)
return body

View File

@ -482,7 +482,7 @@ class ResumableUploadHandler(object):
# pool connections) because httplib requires a new HTTP connection per
# transaction. (Without this, calling http_conn.getresponse() would get
# "ResponseNotReady".)
http_conn = conn.new_http_connection(self.tracker_uri_host,
http_conn = conn.new_http_connection(self.tracker_uri_host, conn.port,
conn.is_secure)
http_conn.set_debuglevel(conn.debug)

View File

@ -38,6 +38,8 @@ class XmlHandler(xml.sax.ContentHandler):
def endElement(self, name):
self.nodes[-1][1].endElement(name, self.current_text, self.connection)
if self.nodes[-1][0] == name:
if hasattr(self.nodes[-1][1], 'endNode'):
self.nodes[-1][1].endNode(self.connection)
self.nodes.pop()
self.current_text = ''

View File

@ -836,7 +836,7 @@ class IAMConnection(AWSQueryConnection):
:param user_name: The username of the user
:type serial_number: string
:param seriasl_number: The serial number which uniquely identifies
:param serial_number: The serial number which uniquely identifies
the MFA device.
:type auth_code_1: string
@ -862,7 +862,7 @@ class IAMConnection(AWSQueryConnection):
:param user_name: The username of the user
:type serial_number: string
:param seriasl_number: The serial number which uniquely identifies
:param serial_number: The serial number which uniquely identifies
the MFA device.
"""
@ -879,7 +879,7 @@ class IAMConnection(AWSQueryConnection):
:param user_name: The username of the user
:type serial_number: string
:param seriasl_number: The serial number which uniquely identifies
:param serial_number: The serial number which uniquely identifies
the MFA device.
:type auth_code_1: string

View File

@ -34,10 +34,11 @@ class SSHClient(object):
def __init__(self, server,
host_key_file='~/.ssh/known_hosts',
uname='root', ssh_pwd=None):
uname='root', timeout=None, ssh_pwd=None):
self.server = server
self.host_key_file = host_key_file
self.uname = uname
self._timeout = timeout
self._pkey = paramiko.RSAKey.from_private_key_file(server.ssh_key_file,
password=ssh_pwd)
self._ssh_client = paramiko.SSHClient()
@ -52,7 +53,8 @@ class SSHClient(object):
try:
self._ssh_client.connect(self.server.hostname,
username=self.uname,
pkey=self._pkey)
pkey=self._pkey,
timeout=self._timeout)
return
except socket.error, (value, message):
if value in (51, 61, 111):

View File

@ -37,15 +37,16 @@ api_version_path = {
'Products': ('2011-10-01', 'SellerId', '/Products/2011-10-01'),
'Sellers': ('2011-07-01', 'SellerId', '/Sellers/2011-07-01'),
'Inbound': ('2010-10-01', 'SellerId',
'/FulfillmentInboundShipment/2010-10-01'),
'/FulfillmentInboundShipment/2010-10-01'),
'Outbound': ('2010-10-01', 'SellerId',
'/FulfillmentOutboundShipment/2010-10-01'),
'/FulfillmentOutboundShipment/2010-10-01'),
'Inventory': ('2010-10-01', 'SellerId',
'/FulfillmentInventory/2010-10-01'),
'/FulfillmentInventory/2010-10-01'),
}
content_md5 = lambda c: base64.encodestring(hashlib.md5(c).digest()).strip()
decorated_attrs = ('action', 'response', 'section',
'quota', 'restore', 'version')
api_call_map = {}
def add_attrs_from(func, to):
@ -67,7 +68,7 @@ def structured_lists(*fields):
kw.pop(key)
return func(self, *args, **kw)
wrapper.__doc__ = "{0}\nLists: {1}".format(func.__doc__,
', '.join(fields))
', '.join(fields))
return add_attrs_from(func, to=wrapper)
return decorator
@ -101,7 +102,7 @@ def destructure_object(value, into={}, prefix=''):
destructure_object(attr, into=into, prefix=prefix + '.' + name)
elif filter(lambda x: isinstance(value, x), (list, set, tuple)):
for index, element in [(prefix + '.' + str(i + 1), value[i])
for i in range(len(value))]:
for i in range(len(value))]:
destructure_object(element, into=into, prefix=index)
elif isinstance(value, bool):
into[prefix] = str(value).lower()
@ -118,7 +119,7 @@ def structured_objects(*fields):
destructure_object(kw.pop(field), into=kw, prefix=field)
return func(*args, **kw)
wrapper.__doc__ = "{0}\nObjects: {1}".format(func.__doc__,
', '.join(fields))
', '.join(fields))
return add_attrs_from(func, to=wrapper)
return decorator
@ -137,7 +138,7 @@ def requires(*groups):
return func(*args, **kw)
message = ' OR '.join(['+'.join(g) for g in groups])
wrapper.__doc__ = "{0}\nRequired: {1}".format(func.__doc__,
message)
message)
return add_attrs_from(func, to=wrapper)
return decorator
@ -156,7 +157,7 @@ def exclusive(*groups):
return func(*args, **kw)
message = ' OR '.join(['+'.join(g) for g in groups])
wrapper.__doc__ = "{0}\nEither: {1}".format(func.__doc__,
message)
message)
return add_attrs_from(func, to=wrapper)
return decorator
@ -175,8 +176,8 @@ def dependent(field, *groups):
return func(*args, **kw)
message = ' OR '.join(['+'.join(g) for g in groups])
wrapper.__doc__ = "{0}\n{1} requires: {2}".format(func.__doc__,
field,
message)
field,
message)
return add_attrs_from(func, to=wrapper)
return decorator
@ -192,7 +193,7 @@ def requires_some_of(*fields):
raise KeyError(message)
return func(*args, **kw)
wrapper.__doc__ = "{0}\nSome Required: {1}".format(func.__doc__,
', '.join(fields))
', '.join(fields))
return add_attrs_from(func, to=wrapper)
return decorator
@ -206,7 +207,7 @@ def boolean_arguments(*fields):
kw[field] = str(kw[field]).lower()
return func(*args, **kw)
wrapper.__doc__ = "{0}\nBooleans: {1}".format(func.__doc__,
', '.join(fields))
', '.join(fields))
return add_attrs_from(func, to=wrapper)
return decorator
@ -237,6 +238,7 @@ def api_action(section, quota, restore, *api):
wrapper.__doc__ = "MWS {0}/{1} API call; quota={2} restore={3:.2f}\n" \
"{4}".format(action, version, quota, restore,
func.__doc__)
api_call_map[action] = func.func_name
return wrapper
return decorator
@ -260,7 +262,8 @@ class MWSConnection(AWSQueryConnection):
Modelled off of the inherited get_object/make_request flow.
"""
request = self.build_base_http_request('POST', path, None, data=body,
params=params, headers=headers, host=self.server_name())
params=params, headers=headers,
host=self.host)
response = self._mexe(request, override_num_retries=None)
body = response.read()
boto.log.debug(body)
@ -275,6 +278,9 @@ class MWSConnection(AWSQueryConnection):
digest = response.getheader('Content-MD5')
assert content_md5(body) == digest
return body
return self._parse_response(cls, body)
def _parse_response(self, cls, body):
obj = cls(self)
h = XmlHandler(obj, self)
xml.sax.parseString(body, h)
@ -285,13 +291,10 @@ class MWSConnection(AWSQueryConnection):
The named method can be in CamelCase or underlined_lower_case.
This is the complement to MWSConnection.any_call.action
"""
# this looks ridiculous but it should be better than regex
action = '_' in name and string.capwords(name, '_') or name
attribs = [getattr(self, m) for m in dir(self)]
ismethod = lambda m: type(m) is type(self.method_for)
ismatch = lambda m: getattr(m, 'action', None) == action
method = filter(ismatch, filter(ismethod, attribs))
return method and method[0] or None
if action in api_call_map:
return getattr(self, api_call_map[action])
return None
def iter_call(self, call, *args, **kw):
"""Pass a call name as the first argument and a generator
@ -322,7 +325,7 @@ class MWSConnection(AWSQueryConnection):
"""Uploads a feed for processing by Amazon MWS.
"""
return self.post_request(path, kw, response, body=body,
headers=headers)
headers=headers)
@structured_lists('FeedSubmissionIdList.Id', 'FeedTypeList.Type',
'FeedProcessingStatusList.Status')
@ -365,10 +368,10 @@ class MWSConnection(AWSQueryConnection):
def get_service_status(self, **kw):
"""Instruct the user on how to get service status.
"""
sections = ', '.join(map(str.lower, api_version_path.keys()))
message = "Use {0}.get_(section)_service_status(), " \
"where (section) is one of the following: " \
"{1}".format(self.__class__.__name__,
', '.join(map(str.lower, api_version_path.keys())))
"{1}".format(self.__class__.__name__, sections)
raise AttributeError(message)
@structured_lists('MarketplaceIdList.Id')
@ -583,6 +586,14 @@ class MWSConnection(AWSQueryConnection):
"""
return self.post_request(path, kw, response)
@requires(['PackageNumber'])
@api_action('Outbound', 30, 0.5)
def get_package_tracking_details(self, path, response, **kw):
"""Returns delivery tracking information for a package in
an outbound shipment for a Multi-Channel Fulfillment order.
"""
return self.post_request(path, kw, response)
@structured_objects('Address', 'Items')
@requires(['Address', 'Items'])
@api_action('Outbound', 30, 0.5)
@ -659,8 +670,8 @@ class MWSConnection(AWSQueryConnection):
frame that you specify.
"""
toggle = set(('FulfillmentChannel.Channel.1',
'OrderStatus.Status.1', 'PaymentMethod.1',
'LastUpdatedAfter', 'LastUpdatedBefore'))
'OrderStatus.Status.1', 'PaymentMethod.1',
'LastUpdatedAfter', 'LastUpdatedBefore'))
for do, dont in {
'BuyerEmail': toggle.union(['SellerOrderId']),
'SellerOrderId': toggle.union(['BuyerEmail']),
@ -804,7 +815,7 @@ class MWSConnection(AWSQueryConnection):
@requires(['NextToken'])
@api_action('Sellers', 15, 60)
def list_marketplace_participations_by_next_token(self, path, response,
**kw):
**kw):
"""Returns the next page of marketplaces and participations
using the NextToken value that was returned by your
previous request to either ListMarketplaceParticipations

View File

@ -33,20 +33,30 @@ class ComplexType(dict):
class DeclarativeType(object):
def __init__(self, _hint=None, **kw):
self._value = None
if _hint is not None:
self._hint = _hint
else:
class JITResponse(ResponseElement):
pass
self._hint = JITResponse
for name, value in kw.items():
setattr(self._hint, name, value)
self._value = None
return
class JITResponse(ResponseElement):
pass
self._hint = JITResponse
self._hint.__name__ = 'JIT_{0}/{1}'.format(self.__class__.__name__,
hex(id(self._hint))[2:])
for name, value in kw.items():
setattr(self._hint, name, value)
def __repr__(self):
parent = getattr(self, '_parent', None)
return '<{0}_{1}/{2}_{3}>'.format(self.__class__.__name__,
parent and parent._name or '?',
getattr(self, '_name', '?'),
hex(id(self.__class__)))
def setup(self, parent, name, *args, **kw):
self._parent = parent
self._name = name
self._clone = self.__class__(self._hint)
self._clone = self.__class__(_hint=self._hint)
self._clone._parent = parent
self._clone._name = name
setattr(self._parent, self._name, self._clone)
@ -58,10 +68,7 @@ class DeclarativeType(object):
raise NotImplemented
def teardown(self, *args, **kw):
if self._value is None:
delattr(self._parent, self._name)
else:
setattr(self._parent, self._name, self._value)
setattr(self._parent, self._name, self._value)
class Element(DeclarativeType):
@ -78,11 +85,6 @@ class SimpleList(DeclarativeType):
DeclarativeType.__init__(self, *args, **kw)
self._value = []
def teardown(self, *args, **kw):
if self._value == []:
self._value = None
DeclarativeType.teardown(self, *args, **kw)
def start(self, *args, **kw):
return None
@ -93,35 +95,46 @@ class SimpleList(DeclarativeType):
class ElementList(SimpleList):
def start(self, *args, **kw):
value = self._hint(parent=self._parent, **kw)
self._value += [value]
return self._value[-1]
self._value.append(value)
return value
def end(self, *args, **kw):
pass
class MemberList(ElementList):
def __init__(self, *args, **kw):
self._this = kw.get('this')
ElementList.__init__(self, *args, **kw)
def start(self, attrs={}, **kw):
Class = self._this or self._parent._type_for(self._name, attrs)
if issubclass(self._hint, ResponseElement):
ListClass = ElementList
class MemberList(Element):
def __init__(self, _member=None, _hint=None, *args, **kw):
message = 'Invalid `member` specification in {0}'.format(self.__class__.__name__)
assert 'member' not in kw, message
if _member is None:
if _hint is None:
Element.__init__(self, *args, member=ElementList(**kw))
else:
Element.__init__(self, _hint=_hint)
else:
ListClass = SimpleList
setattr(Class, Class._member, ListClass(self._hint))
self._value = Class(attrs=attrs, parent=self._parent, **kw)
return self._value
if _hint is None:
if issubclass(_member, DeclarativeType):
member = _member(**kw)
else:
member = ElementList(_member, **kw)
Element.__init__(self, *args, member=member)
else:
message = 'Nonsensical {0} hint {1!r}'.format(self.__class__.__name__,
_hint)
raise AssertionError(message)
def end(self, *args, **kw):
self._value = getattr(self._value, self._value._member)
ElementList.end(self, *args, **kw)
def teardown(self, *args, **kw):
if self._value is None:
self._value = []
else:
if isinstance(self._value.member, DeclarativeType):
self._value.member = []
self._value = self._value.member
Element.teardown(self, *args, **kw)
def ResponseFactory(action):
result = globals().get(action + 'Result', ResponseElement)
def ResponseFactory(action, force=None):
result = force or globals().get(action + 'Result', ResponseElement)
class MWSResponse(Response):
_name = action + 'Response'
@ -141,18 +154,17 @@ def strip_namespace(func):
class ResponseElement(dict):
_override = {}
_member = 'member'
_name = None
_namespace = None
def __init__(self, connection=None, name=None, parent=None, attrs={}):
def __init__(self, connection=None, name=None, parent=None, attrs=None):
if parent is not None and self._namespace is None:
self._namespace = parent._namespace
if connection is not None:
self._connection = connection
self._name = name or self._name or self.__class__.__name__
self._declared('setup', attrs=attrs)
dict.__init__(self, attrs.copy())
dict.__init__(self, attrs and attrs.copy() or {})
def _declared(self, op, **kw):
def inherit(obj):
@ -177,7 +189,7 @@ class ResponseElement(dict):
do_show = lambda pair: not pair[0].startswith('_')
attrs = filter(do_show, self.__dict__.items())
name = self.__class__.__name__
if name == 'JITResponse':
if name.startswith('JIT_'):
name = '^{0}^'.format(self._name or '')
elif name == 'MWSResponse':
name = '^{0}^'.format(self._name or name)
@ -192,7 +204,7 @@ class ResponseElement(dict):
attribute = getattr(self, name, None)
if isinstance(attribute, DeclarativeType):
return attribute.start(name=name, attrs=attrs,
connection=connection)
connection=connection)
elif attrs.getLength():
setattr(self, name, ComplexType(attrs.copy()))
else:
@ -316,7 +328,7 @@ class CreateInboundShipmentPlanResult(ResponseElement):
class ListInboundShipmentsResult(ResponseElement):
ShipmentData = MemberList(Element(ShipFromAddress=Element()))
ShipmentData = MemberList(ShipFromAddress=Element())
class ListInboundShipmentsByNextTokenResult(ListInboundShipmentsResult):
@ -334,8 +346,8 @@ class ListInboundShipmentItemsByNextTokenResult(ListInboundShipmentItemsResult):
class ListInventorySupplyResult(ResponseElement):
InventorySupplyList = MemberList(
EarliestAvailability=Element(),
SupplyDetail=MemberList(\
EarliestAvailabileToPick=Element(),
SupplyDetail=MemberList(
EarliestAvailableToPick=Element(),
LatestAvailableToPick=Element(),
)
)
@ -431,13 +443,9 @@ class FulfillmentPreviewItem(ResponseElement):
class FulfillmentPreview(ResponseElement):
EstimatedShippingWeight = Element(ComplexWeight)
EstimatedFees = MemberList(\
Element(\
Amount=Element(ComplexAmount),
),
)
EstimatedFees = MemberList(Amount=Element(ComplexAmount))
UnfulfillablePreviewItems = MemberList(FulfillmentPreviewItem)
FulfillmentPreviewShipments = MemberList(\
FulfillmentPreviewShipments = MemberList(
FulfillmentPreviewItems=MemberList(FulfillmentPreviewItem),
)
@ -448,15 +456,14 @@ class GetFulfillmentPreviewResult(ResponseElement):
class FulfillmentOrder(ResponseElement):
DestinationAddress = Element()
NotificationEmailList = MemberList(str)
NotificationEmailList = MemberList(SimpleList)
class GetFulfillmentOrderResult(ResponseElement):
FulfillmentOrder = Element(FulfillmentOrder)
FulfillmentShipment = MemberList(Element(\
FulfillmentShipmentItem=MemberList(),
FulfillmentShipmentPackage=MemberList(),
)
FulfillmentShipment = MemberList(
FulfillmentShipmentItem=MemberList(),
FulfillmentShipmentPackage=MemberList(),
)
FulfillmentOrderItem = MemberList()
@ -469,6 +476,11 @@ class ListAllFulfillmentOrdersByNextTokenResult(ListAllFulfillmentOrdersResult):
pass
class GetPackageTrackingDetailsResult(ResponseElement):
ShipToAddress = Element()
TrackingEvents = MemberList(EventAddress=Element())
class Image(ResponseElement):
pass
@ -533,17 +545,17 @@ class Product(ResponseElement):
_namespace = 'ns2'
Identifiers = Element(MarketplaceASIN=Element(),
SKUIdentifier=Element())
AttributeSets = Element(\
AttributeSets = Element(
ItemAttributes=ElementList(ItemAttributes),
)
Relationships = Element(\
Relationships = Element(
VariationParent=ElementList(VariationRelationship),
)
CompetitivePricing = ElementList(CompetitivePricing)
SalesRankings = Element(\
SalesRankings = Element(
SalesRank=ElementList(SalesRank),
)
LowestOfferListings = Element(\
LowestOfferListings = Element(
LowestOfferListing=ElementList(LowestOfferListing),
)
@ -569,6 +581,10 @@ class GetMatchingProductForIdResult(ListMatchingProductsResult):
pass
class GetMatchingProductForIdResponse(ResponseResultList):
_ResultClass = GetMatchingProductForIdResult
class GetCompetitivePricingForSKUResponse(ProductsBulkOperationResponse):
pass
@ -607,9 +623,9 @@ class GetProductCategoriesForASINResult(GetProductCategoriesResult):
class Order(ResponseElement):
OrderTotal = Element(ComplexMoney)
ShippingAddress = Element()
PaymentExecutionDetail = Element(\
PaymentExecutionDetailItem=ElementList(\
PaymentExecutionDetailItem=Element(\
PaymentExecutionDetail = Element(
PaymentExecutionDetailItem=ElementList(
PaymentExecutionDetailItem=Element(
Payment=Element(ComplexMoney)
)
)

View File

@ -80,11 +80,51 @@ class OpsWorksConnection(AWSQueryConnection):
def _required_auth_capability(self):
return ['hmac-v4']
def assign_volume(self, volume_id, instance_id=None):
"""
Assigns one of the stack's registered Amazon EBS volumes to a
specified instance. The volume must first be registered with
the stack by calling RegisterVolume. For more information, see
``_.
:type volume_id: string
:param volume_id: The volume ID.
:type instance_id: string
:param instance_id: The instance ID.
"""
params = {'VolumeId': volume_id, }
if instance_id is not None:
params['InstanceId'] = instance_id
return self.make_request(action='AssignVolume',
body=json.dumps(params))
def associate_elastic_ip(self, elastic_ip, instance_id=None):
"""
Associates one of the stack's registered Elastic IP addresses
with a specified instance. The address must first be
registered with the stack by calling RegisterElasticIp. For
more information, see ``_.
:type elastic_ip: string
:param elastic_ip: The Elastic IP address.
:type instance_id: string
:param instance_id: The instance ID.
"""
params = {'ElasticIp': elastic_ip, }
if instance_id is not None:
params['InstanceId'] = instance_id
return self.make_request(action='AssociateElasticIp',
body=json.dumps(params))
def attach_elastic_load_balancer(self, elastic_load_balancer_name,
layer_id):
"""
Attaches an Elastic Load Balancing instance to a specified
layer.
Attaches an Elastic Load Balancing load balancer to a
specified layer.
You must create the Elastic Load Balancing instance
separately, by using the Elastic Load Balancing console, API,
@ -136,8 +176,8 @@ class OpsWorksConnection(AWSQueryConnection):
will be launched into this VPC, and you cannot change the ID later.
+ If your account supports EC2 Classic, the default value is no VPC.
+ If you account does not support EC2 Classic, the default value is the
default VPC for the specified region.
+ If your account does not support EC2 Classic, the default value is
the default VPC for the specified region.
If the VPC ID corresponds to a default VPC and you have specified
@ -559,7 +599,8 @@ class OpsWorksConnection(AWSQueryConnection):
custom_instance_profile_arn=None,
custom_security_group_ids=None, packages=None,
volume_configurations=None, enable_auto_healing=None,
auto_assign_elastic_ips=None, custom_recipes=None,
auto_assign_elastic_ips=None,
auto_assign_public_ips=None, custom_recipes=None,
install_updates_on_boot=None):
"""
Creates a layer. For more information, see `How to Create a
@ -629,7 +670,13 @@ class OpsWorksConnection(AWSQueryConnection):
:type auto_assign_elastic_ips: boolean
:param auto_assign_elastic_ips: Whether to automatically assign an
`Elastic IP address`_ to the layer.
`Elastic IP address`_ to the layer's instances. For more
information, see `How to Edit a Layer`_.
:type auto_assign_public_ips: boolean
:param auto_assign_public_ips: For stacks that are running in a VPC,
whether to automatically assign a public IP address to the layer's
instances. For more information, see `How to Edit a Layer`_.
:type custom_recipes: dict
:param custom_recipes: A `LayerCustomRecipes` object that specifies the
@ -668,6 +715,8 @@ class OpsWorksConnection(AWSQueryConnection):
params['EnableAutoHealing'] = enable_auto_healing
if auto_assign_elastic_ips is not None:
params['AutoAssignElasticIps'] = auto_assign_elastic_ips
if auto_assign_public_ips is not None:
params['AutoAssignPublicIps'] = auto_assign_public_ips
if custom_recipes is not None:
params['CustomRecipes'] = custom_recipes
if install_updates_on_boot is not None:
@ -700,8 +749,8 @@ class OpsWorksConnection(AWSQueryConnection):
into this VPC, and you cannot change the ID later.
+ If your account supports EC2 Classic, the default value is no VPC.
+ If you account does not support EC2 Classic, the default value is the
default VPC for the specified region.
+ If your account does not support EC2 Classic, the default value is
the default VPC for the specified region.
If the VPC ID corresponds to a default VPC and you have specified
@ -954,6 +1003,33 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='DeleteUserProfile',
body=json.dumps(params))
def deregister_elastic_ip(self, elastic_ip):
"""
Deregisters a specified Elastic IP address. The address can
then be registered by another stack. For more information, see
``_.
:type elastic_ip: string
:param elastic_ip: The Elastic IP address.
"""
params = {'ElasticIp': elastic_ip, }
return self.make_request(action='DeregisterElasticIp',
body=json.dumps(params))
def deregister_volume(self, volume_id):
"""
Deregisters an Amazon EBS volume. The volume can then be
registered by another stack. For more information, see ``_.
:type volume_id: string
:param volume_id: The volume ID.
"""
params = {'VolumeId': volume_id, }
return self.make_request(action='DeregisterVolume',
body=json.dumps(params))
def describe_apps(self, stack_id=None, app_ids=None):
"""
Requests a description of a specified set of apps.
@ -1047,7 +1123,7 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='DescribeDeployments',
body=json.dumps(params))
def describe_elastic_ips(self, instance_id=None, ips=None):
def describe_elastic_ips(self, instance_id=None, stack_id=None, ips=None):
"""
Describes `Elastic IP addresses`_.
@ -1058,6 +1134,11 @@ class OpsWorksConnection(AWSQueryConnection):
`DescribeElasticIps` returns a description of the Elastic IP
addresses associated with the specified instance.
:type stack_id: string
:param stack_id: A stack ID. If you include this parameter,
`DescribeElasticIps` returns a description of the Elastic IP
addresses that are registered with the specified stack.
:type ips: list
:param ips: An array of Elastic IP addresses to be described. If you
include this parameter, `DescribeElasticIps` returns a description
@ -1068,6 +1149,8 @@ class OpsWorksConnection(AWSQueryConnection):
params = {}
if instance_id is not None:
params['InstanceId'] = instance_id
if stack_id is not None:
params['StackId'] = stack_id
if ips is not None:
params['Ips'] = ips
return self.make_request(action='DescribeElasticIps',
@ -1080,8 +1163,8 @@ class OpsWorksConnection(AWSQueryConnection):
You must specify at least one of the parameters.
:type stack_id: string
:param stack_id: A stack ID. The action describes the Elastic Load
Balancing instances for the stack.
:param stack_id: A stack ID. The action describes the stack's Elastic
Load Balancing instances.
:type layer_ids: list
:param layer_ids: A list of layer IDs. The action describes the Elastic
@ -1130,7 +1213,7 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='DescribeInstances',
body=json.dumps(params))
def describe_layers(self, stack_id, layer_ids=None):
def describe_layers(self, stack_id=None, layer_ids=None):
"""
Requests a description of one or more layers in a specified
stack.
@ -1146,7 +1229,9 @@ class OpsWorksConnection(AWSQueryConnection):
description of every layer in the specified stack.
"""
params = {'StackId': stack_id, }
params = {}
if stack_id is not None:
params['StackId'] = stack_id
if layer_ids is not None:
params['LayerIds'] = layer_ids
return self.make_request(action='DescribeLayers',
@ -1285,8 +1370,8 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='DescribeUserProfiles',
body=json.dumps(params))
def describe_volumes(self, instance_id=None, raid_array_id=None,
volume_ids=None):
def describe_volumes(self, instance_id=None, stack_id=None,
raid_array_id=None, volume_ids=None):
"""
Describes an instance's Amazon EBS volumes.
@ -1297,6 +1382,10 @@ class OpsWorksConnection(AWSQueryConnection):
`DescribeVolumes` returns descriptions of the volumes associated
with the specified instance.
:type stack_id: string
:param stack_id: A stack ID. The action describes the stack's
registered Amazon EBS volumes.
:type raid_array_id: string
:param raid_array_id: The RAID array ID. If you use this parameter,
`DescribeVolumes` returns descriptions of the volumes associated
@ -1311,6 +1400,8 @@ class OpsWorksConnection(AWSQueryConnection):
params = {}
if instance_id is not None:
params['InstanceId'] = instance_id
if stack_id is not None:
params['StackId'] = stack_id
if raid_array_id is not None:
params['RaidArrayId'] = raid_array_id
if volume_ids is not None:
@ -1321,7 +1412,7 @@ class OpsWorksConnection(AWSQueryConnection):
def detach_elastic_load_balancer(self, elastic_load_balancer_name,
layer_id):
"""
Detaches a specified Elastic Load Balancing instance from it's
Detaches a specified Elastic Load Balancing instance from its
layer.
:type elastic_load_balancer_name: string
@ -1340,6 +1431,20 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='DetachElasticLoadBalancer',
body=json.dumps(params))
def disassociate_elastic_ip(self, elastic_ip):
"""
Disassociates an Elastic IP address from its instance. The
address remains registered with the stack. For more
information, see ``_.
:type elastic_ip: string
:param elastic_ip: The Elastic IP address.
"""
params = {'ElasticIp': elastic_ip, }
return self.make_request(action='DisassociateElasticIp',
body=json.dumps(params))
def get_hostname_suggestion(self, layer_id):
"""
Gets a generated host name for the specified layer, based on
@ -1366,6 +1471,45 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='RebootInstance',
body=json.dumps(params))
def register_elastic_ip(self, elastic_ip, stack_id):
"""
Registers an Elastic IP address with a specified stack. An
address can be registered with only one stack at a time. If
the address is already registered, you must first deregister
it by calling DeregisterElasticIp. For more information, see
``_.
:type elastic_ip: string
:param elastic_ip: The Elastic IP address.
:type stack_id: string
:param stack_id: The stack ID.
"""
params = {'ElasticIp': elastic_ip, 'StackId': stack_id, }
return self.make_request(action='RegisterElasticIp',
body=json.dumps(params))
def register_volume(self, stack_id, ec_2_volume_id=None):
"""
Registers an Amazon EBS volume with a specified stack. A
volume can be registered with only one stack at a time. If the
volume is already registered, you must first deregister it by
calling DeregisterVolume. For more information, see ``_.
:type ec_2_volume_id: string
:param ec_2_volume_id: The Amazon EBS volume ID.
:type stack_id: string
:param stack_id: The stack ID.
"""
params = {'StackId': stack_id, }
if ec_2_volume_id is not None:
params['Ec2VolumeId'] = ec_2_volume_id
return self.make_request(action='RegisterVolume',
body=json.dumps(params))
def set_load_based_auto_scaling(self, layer_id, enable=None,
up_scaling=None, down_scaling=None):
"""
@ -1511,6 +1655,19 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='StopStack',
body=json.dumps(params))
def unassign_volume(self, volume_id):
"""
Unassigns an assigned Amazon EBS volume. The volume remains
registered with the stack. For more information, see ``_.
:type volume_id: string
:param volume_id: The volume ID.
"""
params = {'VolumeId': volume_id, }
return self.make_request(action='UnassignVolume',
body=json.dumps(params))
def update_app(self, app_id, name=None, description=None, type=None,
app_source=None, domains=None, enable_ssl=None,
ssl_configuration=None, attributes=None):
@ -1568,6 +1725,24 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='UpdateApp',
body=json.dumps(params))
def update_elastic_ip(self, elastic_ip, name=None):
"""
Updates a registered Elastic IP address's name. For more
information, see ``_.
:type elastic_ip: string
:param elastic_ip: The address.
:type name: string
:param name: The new name.
"""
params = {'ElasticIp': elastic_ip, }
if name is not None:
params['Name'] = name
return self.make_request(action='UpdateElasticIp',
body=json.dumps(params))
def update_instance(self, instance_id, layer_ids=None,
instance_type=None, auto_scaling_type=None,
hostname=None, os=None, ami_id=None,
@ -1673,7 +1848,8 @@ class OpsWorksConnection(AWSQueryConnection):
attributes=None, custom_instance_profile_arn=None,
custom_security_group_ids=None, packages=None,
volume_configurations=None, enable_auto_healing=None,
auto_assign_elastic_ips=None, custom_recipes=None,
auto_assign_elastic_ips=None,
auto_assign_public_ips=None, custom_recipes=None,
install_updates_on_boot=None):
"""
Updates a specified layer.
@ -1718,7 +1894,13 @@ class OpsWorksConnection(AWSQueryConnection):
:type auto_assign_elastic_ips: boolean
:param auto_assign_elastic_ips: Whether to automatically assign an
`Elastic IP address`_ to the layer.
`Elastic IP address`_ to the layer's instances. For more
information, see `How to Edit a Layer`_.
:type auto_assign_public_ips: boolean
:param auto_assign_public_ips: For stacks that are running in a VPC,
whether to automatically assign a public IP address to the layer's
instances. For more information, see `How to Edit a Layer`_.
:type custom_recipes: dict
:param custom_recipes: A `LayerCustomRecipes` object that specifies the
@ -1756,6 +1938,8 @@ class OpsWorksConnection(AWSQueryConnection):
params['EnableAutoHealing'] = enable_auto_healing
if auto_assign_elastic_ips is not None:
params['AutoAssignElasticIps'] = auto_assign_elastic_ips
if auto_assign_public_ips is not None:
params['AutoAssignPublicIps'] = auto_assign_public_ips
if custom_recipes is not None:
params['CustomRecipes'] = custom_recipes
if install_updates_on_boot is not None:
@ -1934,6 +2118,29 @@ class OpsWorksConnection(AWSQueryConnection):
return self.make_request(action='UpdateUserProfile',
body=json.dumps(params))
def update_volume(self, volume_id, name=None, mount_point=None):
"""
Updates an Amazon EBS volume's name or mount point. For more
information, see ``_.
:type volume_id: string
:param volume_id: The volume ID.
:type name: string
:param name: The new name.
:type mount_point: string
:param mount_point: The new mount point.
"""
params = {'VolumeId': volume_id, }
if name is not None:
params['Name'] = name
if mount_point is not None:
params['MountPoint'] = mount_point
return self.make_request(action='UpdateVolume',
body=json.dumps(params))
def make_request(self, action, body):
headers = {
'X-Amz-Target': '%s.%s' % (self.TargetPrefix, action),

View File

@ -45,6 +45,12 @@ def regions():
RegionInfo(name='ap-northeast-1',
endpoint='redshift.ap-northeast-1.amazonaws.com',
connection_cls=cls),
RegionInfo(name='ap-southeast-1',
endpoint='redshift.ap-southeast-1.amazonaws.com',
connection_cls=cls),
RegionInfo(name='ap-southeast-2',
endpoint='redshift.ap-southeast-2.amazonaws.com',
connection_cls=cls),
]

View File

@ -188,3 +188,272 @@ class AccessToSnapshotDeniedFault(JSONResponseError):
class UnauthorizedOperationFault(JSONResponseError):
pass
class SnapshotCopyAlreadyDisabled(JSONResponseError):
pass
class ClusterNotFound(JSONResponseError):
pass
class UnknownSnapshotCopyRegion(JSONResponseError):
pass
class InvalidClusterSubnetState(JSONResponseError):
pass
class ReservedNodeQuotaExceeded(JSONResponseError):
pass
class InvalidClusterState(JSONResponseError):
pass
class HsmClientCertificateQuotaExceeded(JSONResponseError):
pass
class SubscriptionCategoryNotFound(JSONResponseError):
pass
class HsmClientCertificateNotFound(JSONResponseError):
pass
class SubscriptionEventIdNotFound(JSONResponseError):
pass
class ClusterSecurityGroupAlreadyExists(JSONResponseError):
pass
class HsmConfigurationAlreadyExists(JSONResponseError):
pass
class NumberOfNodesQuotaExceeded(JSONResponseError):
pass
class ReservedNodeOfferingNotFound(JSONResponseError):
pass
class BucketNotFound(JSONResponseError):
pass
class InsufficientClusterCapacity(JSONResponseError):
pass
class InvalidRestore(JSONResponseError):
pass
class UnauthorizedOperation(JSONResponseError):
pass
class ClusterQuotaExceeded(JSONResponseError):
pass
class InvalidVPCNetworkState(JSONResponseError):
pass
class ClusterSnapshotNotFound(JSONResponseError):
pass
class AuthorizationQuotaExceeded(JSONResponseError):
pass
class InvalidHsmClientCertificateState(JSONResponseError):
pass
class SNSTopicArnNotFound(JSONResponseError):
pass
class ResizeNotFound(JSONResponseError):
pass
class ClusterSubnetGroupNotFound(JSONResponseError):
pass
class SNSNoAuthorization(JSONResponseError):
pass
class ClusterSnapshotQuotaExceeded(JSONResponseError):
pass
class AccessToSnapshotDenied(JSONResponseError):
pass
class InvalidClusterSecurityGroupState(JSONResponseError):
pass
class NumberOfNodesPerClusterLimitExceeded(JSONResponseError):
pass
class ClusterSubnetQuotaExceeded(JSONResponseError):
pass
class SNSInvalidTopic(JSONResponseError):
pass
class ClusterSecurityGroupNotFound(JSONResponseError):
pass
class InvalidElasticIp(JSONResponseError):
pass
class InvalidClusterParameterGroupState(JSONResponseError):
pass
class InvalidHsmConfigurationState(JSONResponseError):
pass
class ClusterAlreadyExists(JSONResponseError):
pass
class HsmConfigurationQuotaExceeded(JSONResponseError):
pass
class ClusterSnapshotAlreadyExists(JSONResponseError):
pass
class SubscriptionSeverityNotFound(JSONResponseError):
pass
class SourceNotFound(JSONResponseError):
pass
class ReservedNodeAlreadyExists(JSONResponseError):
pass
class ClusterSubnetGroupQuotaExceeded(JSONResponseError):
pass
class ClusterParameterGroupNotFound(JSONResponseError):
pass
class InvalidS3BucketName(JSONResponseError):
pass
class InvalidS3KeyPrefix(JSONResponseError):
pass
class SubscriptionAlreadyExist(JSONResponseError):
pass
class HsmConfigurationNotFound(JSONResponseError):
pass
class AuthorizationNotFound(JSONResponseError):
pass
class ClusterSecurityGroupQuotaExceeded(JSONResponseError):
pass
class EventSubscriptionQuotaExceeded(JSONResponseError):
pass
class AuthorizationAlreadyExists(JSONResponseError):
pass
class InvalidClusterSnapshotState(JSONResponseError):
pass
class ClusterParameterGroupQuotaExceeded(JSONResponseError):
pass
class SnapshotCopyDisabled(JSONResponseError):
pass
class ClusterSubnetGroupAlreadyExists(JSONResponseError):
pass
class ReservedNodeNotFound(JSONResponseError):
pass
class HsmClientCertificateAlreadyExists(JSONResponseError):
pass
class InvalidClusterSubnetGroupState(JSONResponseError):
pass
class SubscriptionNotFound(JSONResponseError):
pass
class InsufficientS3BucketPolicy(JSONResponseError):
pass
class ClusterParameterGroupAlreadyExists(JSONResponseError):
pass
class UnsupportedOption(JSONResponseError):
pass
class CopyToRegionDisabled(JSONResponseError):
pass
class SnapshotCopyAlreadyEnabled(JSONResponseError):
pass
class IncompatibleOrderableOptions(JSONResponseError):
pass

File diff suppressed because it is too large Load Diff

View File

@ -18,22 +18,25 @@
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
import xml.sax
import uuid
import exception
import random
import urllib
import uuid
import xml.sax
import boto
from boto.connection import AWSAuthConnection
from boto import handler
import boto.jsonresponse
from boto.route53.record import ResourceRecordSets
from boto.route53.zone import Zone
import boto.jsonresponse
import exception
HZXML = """<?xml version="1.0" encoding="UTF-8"?>
<CreateHostedZoneRequest xmlns="%(xmlns)s">
@ -43,7 +46,7 @@ HZXML = """<?xml version="1.0" encoding="UTF-8"?>
<Comment>%(comment)s</Comment>
</HostedZoneConfig>
</CreateHostedZoneRequest>"""
#boto.set_stream_logger('dns')
@ -60,12 +63,13 @@ class Route53Connection(AWSAuthConnection):
def __init__(self, aws_access_key_id=None, aws_secret_access_key=None,
port=None, proxy=None, proxy_port=None,
host=DefaultHost, debug=0, security_token=None,
validate_certs=True):
validate_certs=True, https_connection_factory=None):
AWSAuthConnection.__init__(self, host,
aws_access_key_id, aws_secret_access_key,
True, port, proxy, proxy_port, debug=debug,
security_token=security_token,
validate_certs=validate_certs)
validate_certs=validate_certs,
https_connection_factory=https_connection_factory)
def _required_auth_capability(self):
return ['route53']
@ -79,7 +83,8 @@ class Route53Connection(AWSAuthConnection):
pairs.append(key + '=' + urllib.quote(str(val)))
path += '?' + '&'.join(pairs)
return AWSAuthConnection.make_request(self, action, path,
headers, data)
headers, data,
retry_handler=self._retry_handler)
# Hosted Zones
@ -118,7 +123,7 @@ class Route53Connection(AWSAuthConnection):
def get_hosted_zone(self, hosted_zone_id):
"""
Get detailed information about a particular Hosted Zone.
:type hosted_zone_id: str
:param hosted_zone_id: The unique identifier for the Hosted Zone
@ -158,7 +163,7 @@ class Route53Connection(AWSAuthConnection):
"""
Create a new Hosted Zone. Returns a Python data structure with
information about the newly created Hosted Zone.
:type domain_name: str
:param domain_name: The name of the domain. This should be a
fully-specified domain, and should end with a final period
@ -178,7 +183,7 @@ class Route53Connection(AWSAuthConnection):
use that.
:type comment: str
:param comment: Any comments you want to include about the hosted
:param comment: Any comments you want to include about the hosted
zone.
"""
@ -204,7 +209,7 @@ class Route53Connection(AWSAuthConnection):
raise exception.DNSServerError(response.status,
response.reason,
body)
def delete_hosted_zone(self, hosted_zone_id):
uri = '/%s/hostedzone/%s' % (self.Version, hosted_zone_id)
response = self.make_request('DELETE', uri)
@ -226,7 +231,7 @@ class Route53Connection(AWSAuthConnection):
"""
Retrieve the Resource Record Sets defined for this Hosted Zone.
Returns the raw XML data returned by the Route53 call.
:type hosted_zone_id: str
:param hosted_zone_id: The unique identifier for the Hosted Zone
@ -401,3 +406,24 @@ class Route53Connection(AWSAuthConnection):
if value and not value[-1] == '.':
value = "%s." % value
return value
def _retry_handler(self, response, i, next_sleep):
status = None
boto.log.debug("Saw HTTP status: %s" % response.status)
if response.status == 400:
code = response.getheader('Code')
if code and 'PriorRequestNotComplete' in code:
# This is a case where we need to ignore a 400 error, as
# Route53 returns this. See
# http://docs.aws.amazon.com/Route53/latest/DeveloperGuide/DNSLimitations.html
msg = "%s, retry attempt %s" % (
'PriorRequestNotComplete',
i
)
next_sleep = random.random() * (2 ** i)
i += 1
status = (msg, i, next_sleep)
return status

View File

@ -63,6 +63,7 @@ class S3WebsiteEndpointTranslate:
trans_region['sa-east-1'] = 's3-website-sa-east-1'
trans_region['ap-northeast-1'] = 's3-website-ap-northeast-1'
trans_region['ap-southeast-1'] = 's3-website-ap-southeast-1'
trans_region['ap-southeast-2'] = 's3-website-ap-southeast-2'
@classmethod
def translate_region(self, reg):
@ -341,6 +342,11 @@ class Bucket(object):
raise self.connection.provider.storage_response_error(
response.status, response.reason, body)
def _validate_kwarg_names(self, kwargs, names):
for kwarg in kwargs:
if kwarg not in names:
raise TypeError('Invalid argument %s!' % kwarg)
def get_all_keys(self, headers=None, **params):
"""
A lower-level method for listing contents of a bucket. This
@ -370,6 +376,8 @@ class Bucket(object):
:return: The result from S3 listing the keys requested
"""
self._validate_kwarg_names(params, ['maxkeys', 'max_keys', 'prefix',
'marker', 'delimiter'])
return self._get_all([('Contents', self.key_class),
('CommonPrefixes', Prefix)],
'', headers, **params)
@ -407,6 +415,9 @@ class Bucket(object):
:rtype: ResultSet
:return: The result from S3 listing the keys requested
"""
self._validate_kwarg_names(params, ['maxkeys', 'max_keys', 'prefix',
'key_marker', 'version_id_marker',
'delimiter'])
return self._get_all([('Version', self.key_class),
('CommonPrefixes', Prefix),
('DeleteMarker', DeleteMarker)],
@ -450,6 +461,8 @@ class Bucket(object):
:return: The result from S3 listing the uploads requested
"""
self._validate_kwarg_names(params, ['max_uploads', 'key_marker',
'upload_id_marker'])
return self._get_all([('Upload', MultiPartUpload),
('CommonPrefixes', Prefix)],
'uploads', headers, **params)
@ -693,7 +706,8 @@ class Bucket(object):
if self.name == src_bucket_name:
src_bucket = self
else:
src_bucket = self.connection.get_bucket(src_bucket_name)
src_bucket = self.connection.get_bucket(
src_bucket_name, validate=False)
acl = src_bucket.get_xml_acl(src_key_name)
if encrypt_key:
headers[provider.server_side_encryption_header] = 'AES256'
@ -1300,6 +1314,7 @@ class Bucket(object):
* ErrorDocument
* Key : name of object to serve when an error occurs
"""
return self.get_website_configuration_with_xml(headers)[0]
@ -1320,15 +1335,24 @@ class Bucket(object):
:rtype: 2-Tuple
:returns: 2-tuple containing:
1) A dictionary containing a Python representation
of the XML response. The overall structure is:
* WebsiteConfiguration
* IndexDocument
* Suffix : suffix that is appended to request that
is for a "directory" on the website endpoint
* ErrorDocument
* Key : name of object to serve when an error occurs
2) unparsed XML describing the bucket's website configuration.
1) A dictionary containing a Python representation \
of the XML response. The overall structure is:
* WebsiteConfiguration
* IndexDocument
* Suffix : suffix that is appended to request that \
is for a "directory" on the website endpoint
* ErrorDocument
* Key : name of object to serve when an error occurs
2) unparsed XML describing the bucket's website configuration
"""
body = self.get_website_configuration_xml(headers=headers)

View File

@ -264,7 +264,7 @@ class SNSConnection(AWSQueryConnection):
:type protocol: string
:param protocol: The protocol used to communicate with
the subscriber. Current choices are:
email|email-json|http|https|sqs
email|email-json|http|https|sqs|sms
:type endpoint: string
:param endpoint: The location of the endpoint for
@ -274,6 +274,7 @@ class SNSConnection(AWSQueryConnection):
* For http, this would be a URL beginning with http
* For https, this would be a URL beginning with https
* For sqs, this would be the ARN of an SQS Queue
* For sms, this would be a phone number of an SMS-enabled device
"""
params = {'TopicArn': topic,
'Protocol': protocol,

View File

@ -286,8 +286,8 @@ class SQSConnection(AWSQueryConnection):
:param queue: The Queue from which messages are read.
:type receipt_handle: str
:param queue: The receipt handle associated with the message whose
visibility timeout will be changed.
:param receipt_handle: The receipt handle associated with the message
whose visibility timeout will be changed.
:type visibility_timeout: int
:param visibility_timeout: The new value of the message's visibility
@ -337,16 +337,19 @@ class SQSConnection(AWSQueryConnection):
params['QueueNamePrefix'] = prefix
return self.get_list('ListQueues', params, [('QueueUrl', Queue)])
def get_queue(self, queue_name):
def get_queue(self, queue_name, owner_acct_id=None):
"""
Retrieves the queue with the given name, or ``None`` if no match
was found.
:param str queue_name: The name of the queue to retrieve.
:param str owner_acct_id: Optionally, the AWS account ID of the account that created the queue.
:rtype: :py:class:`boto.sqs.queue.Queue` or ``None``
:returns: The requested queue, or ``None`` if no match was found.
"""
params = {'QueueName': queue_name}
if owner_acct_id:
params['QueueOwnerAWSAccountId']=owner_acct_id
try:
return self.get_object('GetQueueUrl', params, Queue)
except SQSError:

View File

@ -95,7 +95,7 @@ class RawMessage:
def endElement(self, name, value, connection):
if name == 'Body':
self.set_body(self.decode(value))
self.set_body(value)
elif name == 'MessageId':
self.id = value
elif name == 'ReceiptHandle':
@ -105,6 +105,9 @@ class RawMessage:
else:
setattr(self, name, value)
def endNode(self, connection):
self.set_body(self.decode(self.get_body()))
def encode(self, value):
"""Transform body object into serialized byte array format."""
return value

View File

@ -188,7 +188,11 @@ class ActivityWorker(Actor):
@wraps(Layer1.poll_for_activity_task)
def poll(self, **kwargs):
"""PollForActivityTask."""
task = self._swf.poll_for_activity_task(self.domain, self.task_list,
task_list = self.task_list
if 'task_list' in kwargs:
task_list = kwargs.get('task_list')
del kwargs['task_list']
task = self._swf.poll_for_activity_task(self.domain, task_list,
**kwargs)
self.last_tasktoken = task.get('taskToken')
return task
@ -211,12 +215,14 @@ class Decider(Actor):
@wraps(Layer1.poll_for_decision_task)
def poll(self, **kwargs):
"""PollForDecisionTask."""
result = self._swf.poll_for_decision_task(self.domain, self.task_list,
task_list = self.task_list
if 'task_list' in kwargs:
task_list = kwargs.get('task_list')
del kwargs['task_list']
decision_task = self._swf.poll_for_decision_task(self.domain, task_list,
**kwargs)
# Record task token.
self.last_tasktoken = result.get('taskToken')
# Record the last event.
return result
self.last_tasktoken = decision_task.get('taskToken')
return decision_task
class WorkflowType(SWFBase):

View File

@ -27,6 +27,7 @@ from boto.ec2.connection import EC2Connection
from boto.resultset import ResultSet
from boto.vpc.vpc import VPC
from boto.vpc.customergateway import CustomerGateway
from boto.vpc.networkacl import NetworkAcl
from boto.vpc.routetable import RouteTable
from boto.vpc.internetgateway import InternetGateway
from boto.vpc.vpngateway import VpnGateway, Attachment
@ -36,6 +37,7 @@ from boto.vpc.vpnconnection import VpnConnection
from boto.ec2 import RegionData
from boto.regioninfo import RegionInfo
def regions(**kw_params):
"""
Get all available regions for the EC2 service.
@ -53,9 +55,8 @@ def regions(**kw_params):
connection_cls=VPCConnection)
regions.append(region)
regions.append(RegionInfo(name='us-gov-west-1',
endpoint=RegionData[region_name],
connection_cls=VPCConnection)
)
endpoint=RegionData[region_name],
connection_cls=VPCConnection))
return regions
@ -117,20 +118,26 @@ class VPCConnection(EC2Connection):
params['DryRun'] = 'true'
return self.get_list('DescribeVpcs', params, [('item', VPC)])
def create_vpc(self, cidr_block, dry_run=False):
def create_vpc(self, cidr_block, instance_tenancy=None, dry_run=False):
"""
Create a new Virtual Private Cloud.
:type cidr_block: str
:param cidr_block: A valid CIDR block
:type instance_tenancy: str
:param instance_tenancy: The supported tenancy options for instances
launched into the VPC. Valid values are 'default' and 'dedicated'.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
:rtype: The newly created VPC
:return: A :class:`boto.vpc.vpc.VPC` object
"""
params = {'CidrBlock' : cidr_block}
params = {'CidrBlock': cidr_block}
if instance_tenancy:
params['InstanceTenancy'] = instance_tenancy
if dry_run:
params['DryRun'] = 'true'
return self.get_object('CreateVpc', params, VPC)
@ -266,7 +273,7 @@ class VPCConnection(EC2Connection):
:rtype: bool
:return: True if successful
"""
params = { 'AssociationId': association_id }
params = {'AssociationId': association_id}
if dry_run:
params['DryRun'] = 'true'
return self.get_status('DisassociateRouteTable', params)
@ -284,7 +291,7 @@ class VPCConnection(EC2Connection):
:rtype: The newly created route table
:return: A :class:`boto.vpc.routetable.RouteTable` object
"""
params = { 'VpcId': vpc_id }
params = {'VpcId': vpc_id}
if dry_run:
params['DryRun'] = 'true'
return self.get_object('CreateRouteTable', params, RouteTable)
@ -302,13 +309,96 @@ class VPCConnection(EC2Connection):
:rtype: bool
:return: True if successful
"""
params = { 'RouteTableId': route_table_id }
params = {'RouteTableId': route_table_id}
if dry_run:
params['DryRun'] = 'true'
return self.get_status('DeleteRouteTable', params)
def _replace_route_table_association(self, association_id,
route_table_id, dry_run=False):
"""
Helper function for replace_route_table_association and
replace_route_table_association_with_assoc. Should not be used directly.
:type association_id: str
:param association_id: The ID of the existing association to replace.
:type route_table_id: str
:param route_table_id: The route table to ID to be used in the
association.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
:rtype: ResultSet
:return: ResultSet of Amazon resposne
"""
params = {
'AssociationId': association_id,
'RouteTableId': route_table_id
}
if dry_run:
params['DryRun'] = 'true'
return self.get_object('ReplaceRouteTableAssociation', params,
ResultSet)
def replace_route_table_assocation(self, association_id,
route_table_id, dry_run=False):
"""
Replaces a route association with a new route table. This can be
used to replace the 'main' route table by using the main route
table association instead of the more common subnet type
association.
NOTE: It may be better to use replace_route_table_association_with_assoc
instead of this function; this function does not return the new
association ID. This function is retained for backwards compatibility.
:type association_id: str
:param association_id: The ID of the existing association to replace.
:type route_table_id: str
:param route_table_id: The route table to ID to be used in the
association.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
:rtype: bool
:return: True if successful
"""
return self._replace_route_table_association(
association_id, route_table_id, dry_run=dry_run).status
def replace_route_table_association_with_assoc(self, association_id,
route_table_id,
dry_run=False):
"""
Replaces a route association with a new route table. This can be
used to replace the 'main' route table by using the main route
table association instead of the more common subnet type
association. Returns the new association ID.
:type association_id: str
:param association_id: The ID of the existing association to replace.
:type route_table_id: str
:param route_table_id: The route table to ID to be used in the
association.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
:rtype: str
:return: New association ID
"""
return self._replace_route_table_association(
association_id, route_table_id, dry_run=dry_run).newAssociationId
def create_route(self, route_table_id, destination_cidr_block,
gateway_id=None, instance_id=None, dry_run=False):
gateway_id=None, instance_id=None, interface_id=None,
dry_run=False):
"""
Creates a new route in the route table within a VPC. The route's target
can be either a gateway attached to the VPC or a NAT instance in the
@ -327,6 +417,9 @@ class VPCConnection(EC2Connection):
:type instance_id: str
:param instance_id: The ID of a NAT instance in your VPC.
:type interface_id: str
:param interface_id: Allows routing to network interface attachments.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
@ -342,14 +435,16 @@ class VPCConnection(EC2Connection):
params['GatewayId'] = gateway_id
elif instance_id is not None:
params['InstanceId'] = instance_id
elif interface_id is not None:
params['NetworkInterfaceId'] = interface_id
if dry_run:
params['DryRun'] = 'true'
return self.get_status('CreateRoute', params)
def replace_route(self, route_table_id, destination_cidr_block,
gateway_id=None, instance_id=None, interface_id=None,
dry_run=False):
gateway_id=None, instance_id=None, interface_id=None,
dry_run=False):
"""
Replaces an existing route within a route table in a VPC.
@ -417,6 +512,271 @@ class VPCConnection(EC2Connection):
params['DryRun'] = 'true'
return self.get_status('DeleteRoute', params)
#Network ACLs
def get_all_network_acls(self, network_acl_ids=None, filters=None):
"""
Retrieve information about your network acls. You can filter results
to return information only about those network acls that match your
search parameters. Otherwise, all network acls associated with your
account are returned.
:type network_acl_ids: list
:param network_acl_ids: A list of strings with the desired network ACL
IDs.
:type filters: list of tuples
:param filters: A list of tuples containing filters. Each tuple
consists of a filter key and a filter value.
:rtype: list
:return: A list of :class:`boto.vpc.networkacl.NetworkAcl`
"""
params = {}
if network_acl_ids:
self.build_list_params(params, network_acl_ids, "NetworkAclId")
if filters:
self.build_filter_params(params, dict(filters))
return self.get_list('DescribeNetworkAcls', params,
[('item', NetworkAcl)])
def associate_network_acl(self, network_acl_id, subnet_id):
"""
Associates a network acl with a specific subnet.
:type network_acl_id: str
:param network_acl_id: The ID of the network ACL to associate.
:type subnet_id: str
:param subnet_id: The ID of the subnet to associate with.
:rtype: str
:return: The ID of the association created
"""
acl = self.get_all_network_acls(filters=[('association.subnet-id', subnet_id)])[0]
association = [ association for association in acl.associations if association.subnet_id == subnet_id ][0]
params = {
'AssociationId': association.id,
'NetworkAclId': network_acl_id
}
result = self.get_object('ReplaceNetworkAclAssociation', params, ResultSet)
return result.newAssociationId
def disassociate_network_acl(self, subnet_id, vpc_id=None):
"""
Figures out what the default ACL is for the VPC, and associates
current network ACL with the default.
:type subnet_id: str
:param association_id: The ID of the subnet to which the ACL belongs.
:type vpc_id: str
:param vpc_id: The ID of the VPC to which the ACL/subnet belongs. Queries EC2 if omitted.
:rtype: str
:return: The ID of the association created
"""
if not vpc_id:
vpc_id = self.get_all_subnets([subnet_id])[0].vpc_id
acls = self.get_all_network_acls(filters=[('vpc-id', vpc_id), ('default', 'true')])
default_acl_id = acls[0].id
return self.associate_network_acl(default_acl_id, subnet_id)
def create_network_acl(self, vpc_id):
"""
Creates a new network ACL.
:type vpc_id: str
:param vpc_id: The VPC ID to associate this network ACL with.
:rtype: The newly created network ACL
:return: A :class:`boto.vpc.networkacl.NetworkAcl` object
"""
params = {'VpcId': vpc_id}
return self.get_object('CreateNetworkAcl', params, NetworkAcl)
def delete_network_acl(self, network_acl_id):
"""
Delete a network ACL
:type network_acl_id: str
:param network_acl_id: The ID of the network_acl to delete.
:rtype: bool
:return: True if successful
"""
params = {'NetworkAclId': network_acl_id}
return self.get_status('DeleteNetworkAcl', params)
def create_network_acl_entry(self, network_acl_id, rule_number, protocol, rule_action,
cidr_block, egress=None, icmp_code=None, icmp_type=None,
port_range_from=None, port_range_to=None):
"""
Creates a new network ACL entry in a network ACL within a VPC.
:type network_acl_id: str
:param network_acl_id: The ID of the network ACL for this network ACL entry.
:type rule_number: int
:param rule_number: The rule number to assign to the entry (for example, 100).
:type protocol: int
:param protocol: Valid values: -1 or a protocol number
(http://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml)
:type rule_action: str
:param rule_action: Indicates whether to allow or deny traffic that matches the rule.
:type cidr_block: str
:param cidr_block: The CIDR range to allow or deny, in CIDR notation (for example,
172.16.0.0/24).
:type egress: bool
:param egress: Indicates whether this rule applies to egress traffic from the subnet (true)
or ingress traffic to the subnet (false).
:type icmp_type: int
:param icmp_type: For the ICMP protocol, the ICMP type. You can use -1 to specify
all ICMP types.
:type icmp_code: int
:param icmp_code: For the ICMP protocol, the ICMP code. You can use -1 to specify
all ICMP codes for the given ICMP type.
:type port_range_from: int
:param port_range_from: The first port in the range.
:type port_range_to: int
:param port_range_to: The last port in the range.
:rtype: bool
:return: True if successful
"""
params = {
'NetworkAclId': network_acl_id,
'RuleNumber': rule_number,
'Protocol': protocol,
'RuleAction': rule_action,
'CidrBlock': cidr_block
}
if egress is not None:
if isinstance(egress, bool):
egress = str(egress).lower()
params['Egress'] = egress
if icmp_code is not None:
params['Icmp.Code'] = icmp_code
if icmp_type is not None:
params['Icmp.Type'] = icmp_type
if port_range_from is not None:
params['PortRange.From'] = port_range_from
if port_range_to is not None:
params['PortRange.To'] = port_range_to
return self.get_status('CreateNetworkAclEntry', params)
def replace_network_acl_entry(self, network_acl_id, rule_number, protocol, rule_action,
cidr_block, egress=None, icmp_code=None, icmp_type=None,
port_range_from=None, port_range_to=None):
"""
Creates a new network ACL entry in a network ACL within a VPC.
:type network_acl_id: str
:param network_acl_id: The ID of the network ACL for the id you want to replace
:type rule_number: int
:param rule_number: The rule number that you want to replace(for example, 100).
:type protocol: int
:param protocol: Valid values: -1 or a protocol number
(http://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml)
:type rule_action: str
:param rule_action: Indicates whether to allow or deny traffic that matches the rule.
:type cidr_block: str
:param cidr_block: The CIDR range to allow or deny, in CIDR notation (for example,
172.16.0.0/24).
:type egress: bool
:param egress: Indicates whether this rule applies to egress traffic from the subnet (true)
or ingress traffic to the subnet (false).
:type icmp_type: int
:param icmp_type: For the ICMP protocol, the ICMP type. You can use -1 to specify
all ICMP types.
:type icmp_code: int
:param icmp_code: For the ICMP protocol, the ICMP code. You can use -1 to specify
all ICMP codes for the given ICMP type.
:type port_range_from: int
:param port_range_from: The first port in the range.
:type port_range_to: int
:param port_range_to: The last port in the range.
:rtype: bool
:return: True if successful
"""
params = {
'NetworkAclId': network_acl_id,
'RuleNumber': rule_number,
'Protocol': protocol,
'RuleAction': rule_action,
'CidrBlock': cidr_block
}
if egress is not None:
if isinstance(egress, bool):
egress = str(egress).lower()
params['Egress'] = egress
if icmp_code is not None:
params['Icmp.Code'] = icmp_code
if icmp_type is not None:
params['Icmp.Type'] = icmp_type
if port_range_from is not None:
params['PortRange.From'] = port_range_from
if port_range_to is not None:
params['PortRange.To'] = port_range_to
return self.get_status('ReplaceNetworkAclEntry', params)
def delete_network_acl_entry(self, network_acl_id, rule_number, egress=None):
"""
Deletes a network ACL entry from a network ACL within a VPC.
:type network_acl_id: str
:param network_acl_id: The ID of the network ACL with the network ACL entry.
:type rule_number: int
:param rule_number: The rule number for the entry to delete.
:type egress: bool
:param egress: Specifies whether the rule to delete is an egress rule (true)
or ingress rule (false).
:rtype: bool
:return: True if successful
"""
params = {
'NetworkAclId': network_acl_id,
'RuleNumber': rule_number
}
if egress is not None:
if isinstance(egress, bool):
egress = str(egress).lower()
params['Egress'] = egress
return self.get_status('DeleteNetworkAclEntry', params)
# Internet Gateways
def get_all_internet_gateways(self, internet_gateway_ids=None,
@ -476,7 +836,7 @@ class VPCConnection(EC2Connection):
:rtype: Bool
:return: True if successful
"""
params = { 'InternetGatewayId': internet_gateway_id }
params = {'InternetGatewayId': internet_gateway_id}
if dry_run:
params['DryRun'] = 'true'
return self.get_status('DeleteInternetGateway', params)
@ -586,7 +946,7 @@ class VPCConnection(EC2Connection):
:param ip_address: Internet-routable IP address for customer's gateway.
Must be a static address.
:type bgp_asn: str
:type bgp_asn: int
:param bgp_asn: Customer gateway's Border Gateway Protocol (BGP)
Autonomous System Number (ASN)
@ -596,9 +956,9 @@ class VPCConnection(EC2Connection):
:rtype: The newly created CustomerGateway
:return: A :class:`boto.vpc.customergateway.CustomerGateway` object
"""
params = {'Type' : type,
'IpAddress' : ip_address,
'BgpAsn' : bgp_asn}
params = {'Type': type,
'IpAddress': ip_address,
'BgpAsn': bgp_asn}
if dry_run:
params['DryRun'] = 'true'
return self.get_object('CreateCustomerGateway', params, CustomerGateway)
@ -677,7 +1037,7 @@ class VPCConnection(EC2Connection):
:rtype: The newly created VpnGateway
:return: A :class:`boto.vpc.vpngateway.VpnGateway` object
"""
params = {'Type' : type}
params = {'Type': type}
if availability_zone:
params['AvailabilityZone'] = availability_zone
if dry_run:
@ -719,11 +1079,33 @@ class VPCConnection(EC2Connection):
:return: a :class:`boto.vpc.vpngateway.Attachment`
"""
params = {'VpnGatewayId': vpn_gateway_id,
'VpcId' : vpc_id}
'VpcId': vpc_id}
if dry_run:
params['DryRun'] = 'true'
return self.get_object('AttachVpnGateway', params, Attachment)
def detach_vpn_gateway(self, vpn_gateway_id, vpc_id, dry_run=False):
"""
Detaches a VPN gateway from a VPC.
:type vpn_gateway_id: str
:param vpn_gateway_id: The ID of the vpn_gateway to detach
:type vpc_id: str
:param vpc_id: The ID of the VPC you want to detach the gateway from.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
:rtype: bool
:return: True if successful
"""
params = {'VpnGatewayId': vpn_gateway_id,
'VpcId': vpc_id}
if dry_run:
params['DryRun'] = 'true'
return self.get_status('DetachVpnGateway', params)
# Subnets
def get_all_subnets(self, subnet_ids=None, filters=None, dry_run=False):
@ -784,8 +1166,8 @@ class VPCConnection(EC2Connection):
:rtype: The newly created Subnet
:return: A :class:`boto.vpc.customergateway.Subnet` object
"""
params = {'VpcId' : vpc_id,
'CidrBlock' : cidr_block}
params = {'VpcId': vpc_id,
'CidrBlock': cidr_block}
if availability_zone:
params['AvailabilityZone'] = availability_zone
if dry_run:
@ -810,16 +1192,19 @@ class VPCConnection(EC2Connection):
params['DryRun'] = 'true'
return self.get_status('DeleteSubnet', params)
# DHCP Options
def get_all_dhcp_options(self, dhcp_options_ids=None, dry_run=False):
def get_all_dhcp_options(self, dhcp_options_ids=None, filters=None, dry_run=False):
"""
Retrieve information about your DhcpOptions.
:type dhcp_options_ids: list
:param dhcp_options_ids: A list of strings with the desired DhcpOption ID's
:type filters: list of tuples
:param filters: A list of tuples containing filters. Each tuple
consists of a filter key and a filter value.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
@ -829,6 +1214,8 @@ class VPCConnection(EC2Connection):
params = {}
if dhcp_options_ids:
self.build_list_params(params, dhcp_options_ids, 'DhcpOptionsId')
if filters:
self.build_filter_params(params, dict(filters))
if dry_run:
params['DryRun'] = 'true'
return self.get_list('DescribeDhcpOptions', params,
@ -890,19 +1277,19 @@ class VPCConnection(EC2Connection):
if domain_name:
key_counter = insert_option(params,
'domain-name', domain_name)
'domain-name', domain_name)
if domain_name_servers:
key_counter = insert_option(params,
'domain-name-servers', domain_name_servers)
'domain-name-servers', domain_name_servers)
if ntp_servers:
key_counter = insert_option(params,
'ntp-servers', ntp_servers)
'ntp-servers', ntp_servers)
if netbios_name_servers:
key_counter = insert_option(params,
'netbios-name-servers', netbios_name_servers)
'netbios-name-servers', netbios_name_servers)
if netbios_node_type:
key_counter = insert_option(params,
'netbios-node-type', netbios_node_type)
'netbios-node-type', netbios_node_type)
if dry_run:
params['DryRun'] = 'true'
@ -943,7 +1330,7 @@ class VPCConnection(EC2Connection):
:return: True if successful
"""
params = {'DhcpOptionsId': dhcp_options_id,
'VpcId' : vpc_id}
'VpcId': vpc_id}
if dry_run:
params['DryRun'] = 'true'
return self.get_status('AssociateDhcpOptions', params)
@ -983,7 +1370,7 @@ class VPCConnection(EC2Connection):
params = {}
if vpn_connection_ids:
self.build_list_params(params, vpn_connection_ids,
'Vpn_ConnectionId')
'VpnConnectionId')
if filters:
self.build_filter_params(params, dict(filters))
if dry_run:
@ -992,7 +1379,7 @@ class VPCConnection(EC2Connection):
[('item', VpnConnection)])
def create_vpn_connection(self, type, customer_gateway_id, vpn_gateway_id,
dry_run=False):
static_routes_only=None, dry_run=False):
"""
Create a new VPN Connection.
@ -1006,15 +1393,24 @@ class VPCConnection(EC2Connection):
:type vpn_gateway_id: str
:param vpn_gateway_id: The ID of the VPN gateway.
:type static_routes_only: bool
:param static_routes_only: Indicates whether the VPN connection
requires static routes. If you are creating a VPN connection
for a device that does not support BGP, you must specify true.
:type dry_run: bool
:param dry_run: Set to True if the operation should not actually run.
:rtype: The newly created VpnConnection
:return: A :class:`boto.vpc.vpnconnection.VpnConnection` object
"""
params = {'Type' : type,
'CustomerGatewayId' : customer_gateway_id,
'VpnGatewayId' : vpn_gateway_id}
params = {'Type': type,
'CustomerGatewayId': customer_gateway_id,
'VpnGatewayId': vpn_gateway_id}
if static_routes_only is not None:
if isinstance(static_routes_only, bool):
static_routes_only = str(static_routes_only).lower()
params['Options.StaticRoutesOnly'] = static_routes_only
if dry_run:
params['DryRun'] = 'true'
return self.get_object('CreateVpnConnection', params, VpnConnection)

View File

@ -14,7 +14,7 @@
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
@ -25,6 +25,7 @@ Represents a Customer Gateway
from boto.ec2.ec2object import TaggedEC2Object
class CustomerGateway(TaggedEC2Object):
def __init__(self, connection=None):
@ -37,7 +38,7 @@ class CustomerGateway(TaggedEC2Object):
def __repr__(self):
return 'CustomerGateway:%s' % self.id
def endElement(self, name, value, connection):
if name == 'customerGatewayId':
self.id = value
@ -48,7 +49,6 @@ class CustomerGateway(TaggedEC2Object):
elif name == 'state':
self.state = value
elif name == 'bgpAsn':
self.bgp_asn = value
self.bgp_asn = int(value)
else:
setattr(self, name, value)

View File

@ -0,0 +1,164 @@
# Copyright (c) 2009-2010 Mitch Garnaat http://garnaat.org/
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish, dis-
# tribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the fol-
# lowing conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
"""
Represents a Network ACL
"""
from boto.ec2.ec2object import TaggedEC2Object
from boto.resultset import ResultSet
class Icmp(object):
"""
Defines the ICMP code and type.
"""
def __init__(self, connection=None):
self.code = None
self.type = None
def __repr__(self):
return 'Icmp::code:%s, type:%s)' % ( self.code, self.type)
def startElement(self, name, attrs, connection):
pass
def endElement(self, name, value, connection):
if name == 'code':
self.code = value
elif name == 'type':
self.type = value
class NetworkAcl(TaggedEC2Object):
def __init__(self, connection=None):
TaggedEC2Object.__init__(self, connection)
self.id = None
self.vpc_id = None
self.network_acl_entries = []
self.associations = []
def __repr__(self):
return 'NetworkAcl:%s' % self.id
def startElement(self, name, attrs, connection):
result = super(NetworkAcl, self).startElement(name, attrs, connection)
if result is not None:
# Parent found an interested element, just return it
return result
if name == 'entrySet':
self.network_acl_entries = ResultSet([('item', NetworkAclEntry)])
return self.network_acl_entries
elif name == 'associationSet':
self.associations = ResultSet([('item', NetworkAclAssociation)])
return self.associations
else:
return None
def endElement(self, name, value, connection):
if name == 'networkAclId':
self.id = value
elif name == 'vpcId':
self.vpc_id = value
else:
setattr(self, name, value)
class NetworkAclEntry(object):
def __init__(self, connection=None):
self.rule_number = None
self.protocol = None
self.rule_action = None
self.egress = None
self.cidr_block = None
self.port_range = PortRange()
self.icmp = Icmp()
def __repr__(self):
return 'Acl:%s' % self.rule_number
def startElement(self, name, attrs, connection):
if name == 'portRange':
return self.port_range
elif name == 'icmpTypeCode':
return self.icmp
else:
return None
def endElement(self, name, value, connection):
if name == 'cidrBlock':
self.cidr_block = value
elif name == 'egress':
self.egress = value
elif name == 'protocol':
self.protocol = value
elif name == 'ruleAction':
self.rule_action = value
elif name == 'ruleNumber':
self.rule_number = value
class NetworkAclAssociation(object):
def __init__(self, connection=None):
self.id = None
self.subnet_id = None
self.network_acl_id = None
def __repr__(self):
return 'NetworkAclAssociation:%s' % self.id
def startElement(self, name, attrs, connection):
return None
def endElement(self, name, value, connection):
if name == 'networkAclAssociationId':
self.id = value
elif name == 'networkAclId':
self.route_table_id = value
elif name == 'subnetId':
self.subnet_id = value
class PortRange(object):
"""
Define the port range for the ACL entry if it is tcp / udp
"""
def __init__(self, connection=None):
self.from_port = None
self.to_port = None
def __repr__(self):
return 'PortRange:(%s-%s)' % ( self.from_port, self.to_port)
def startElement(self, name, attrs, connection):
pass
def endElement(self, name, value, connection):
if name == 'from':
self.from_port = value
elif name == 'to':
self.to_port = value

View File

@ -1,208 +0,0 @@
# -*- coding: utf-8 -*-
"""
celery.__compat__
~~~~~~~~~~~~~~~~~
This module contains utilities to dynamically
recreate modules, either for lazy loading or
to create old modules at runtime instead of
having them litter the source tree.
"""
from __future__ import absolute_import
import operator
import sys
# import fails in python 2.5. fallback to reduce in stdlib
try:
from functools import reduce
except ImportError:
pass
from importlib import import_module
from types import ModuleType
from .local import Proxy
MODULE_DEPRECATED = """
The module %s is deprecated and will be removed in a future version.
"""
DEFAULT_ATTRS = set(['__file__', '__path__', '__doc__', '__all__'])
# im_func is no longer available in Py3.
# instead the unbound method itself can be used.
if sys.version_info[0] == 3: # pragma: no cover
def fun_of_method(method):
return method
else:
def fun_of_method(method): # noqa
return method.im_func
def getappattr(path):
"""Gets attribute from the current_app recursively,
e.g. getappattr('amqp.get_task_consumer')``."""
from celery import current_app
return current_app._rgetattr(path)
def _compat_task_decorator(*args, **kwargs):
from celery import current_app
kwargs.setdefault('accept_magic_kwargs', True)
return current_app.task(*args, **kwargs)
def _compat_periodic_task_decorator(*args, **kwargs):
from celery.task import periodic_task
kwargs.setdefault('accept_magic_kwargs', True)
return periodic_task(*args, **kwargs)
COMPAT_MODULES = {
'celery': {
'execute': {
'send_task': 'send_task',
},
'decorators': {
'task': _compat_task_decorator,
'periodic_task': _compat_periodic_task_decorator,
},
'log': {
'get_default_logger': 'log.get_default_logger',
'setup_logger': 'log.setup_logger',
'setup_loggig_subsystem': 'log.setup_logging_subsystem',
'redirect_stdouts_to_logger': 'log.redirect_stdouts_to_logger',
},
'messaging': {
'TaskPublisher': 'amqp.TaskPublisher',
'TaskConsumer': 'amqp.TaskConsumer',
'establish_connection': 'connection',
'with_connection': 'with_default_connection',
'get_consumer_set': 'amqp.TaskConsumer',
},
'registry': {
'tasks': 'tasks',
},
},
'celery.task': {
'control': {
'broadcast': 'control.broadcast',
'rate_limit': 'control.rate_limit',
'time_limit': 'control.time_limit',
'ping': 'control.ping',
'revoke': 'control.revoke',
'discard_all': 'control.purge',
'inspect': 'control.inspect',
},
'schedules': 'celery.schedules',
'chords': 'celery.canvas',
}
}
class class_property(object):
def __init__(self, fget=None, fset=None):
assert fget and isinstance(fget, classmethod)
assert isinstance(fset, classmethod) if fset else True
self.__get = fget
self.__set = fset
info = fget.__get__(object) # just need the info attrs.
self.__doc__ = info.__doc__
self.__name__ = info.__name__
self.__module__ = info.__module__
def __get__(self, obj, type=None):
if obj and type is None:
type = obj.__class__
return self.__get.__get__(obj, type)()
def __set__(self, obj, value):
if obj is None:
return self
return self.__set.__get__(obj)(value)
def reclassmethod(method):
return classmethod(fun_of_method(method))
class MagicModule(ModuleType):
_compat_modules = ()
_all_by_module = {}
_direct = {}
_object_origins = {}
def __getattr__(self, name):
if name in self._object_origins:
module = __import__(self._object_origins[name], None, None, [name])
for item in self._all_by_module[module.__name__]:
setattr(self, item, getattr(module, item))
return getattr(module, name)
elif name in self._direct:
module = __import__(self._direct[name], None, None, [name])
setattr(self, name, module)
return module
return ModuleType.__getattribute__(self, name)
def __dir__(self):
return list(set(self.__all__) | DEFAULT_ATTRS)
def create_module(name, attrs, cls_attrs=None, pkg=None,
base=MagicModule, prepare_attr=None):
fqdn = '.'.join([pkg.__name__, name]) if pkg else name
cls_attrs = {} if cls_attrs is None else cls_attrs
attrs = dict((attr_name, prepare_attr(attr) if prepare_attr else attr)
for attr_name, attr in attrs.iteritems())
module = sys.modules[fqdn] = type(name, (base, ), cls_attrs)(fqdn)
module.__dict__.update(attrs)
return module
def recreate_module(name, compat_modules=(), by_module={}, direct={},
base=MagicModule, **attrs):
old_module = sys.modules[name]
origins = get_origins(by_module)
compat_modules = COMPAT_MODULES.get(name, ())
cattrs = dict(
_compat_modules=compat_modules,
_all_by_module=by_module, _direct=direct,
_object_origins=origins,
__all__=tuple(set(reduce(
operator.add,
[tuple(v) for v in [compat_modules, origins, direct, attrs]],
))),
)
new_module = create_module(name, attrs, cls_attrs=cattrs, base=base)
new_module.__dict__.update(dict((mod, get_compat_module(new_module, mod))
for mod in compat_modules))
return old_module, new_module
def get_compat_module(pkg, name):
def prepare(attr):
if isinstance(attr, basestring):
return Proxy(getappattr, (attr, ))
return attr
attrs = COMPAT_MODULES[pkg.__name__][name]
if isinstance(attrs, basestring):
fqdn = '.'.join([pkg.__name__, name])
module = sys.modules[fqdn] = import_module(attrs)
return module
attrs['__all__'] = list(attrs)
return create_module(name, dict(attrs), pkg=pkg, prepare_attr=prepare)
def get_origins(defs):
origins = {}
for module, items in defs.iteritems():
origins.update(dict((item, module) for item in items))
return origins

View File

@ -2,45 +2,126 @@
"""Distributed Task Queue"""
# :copyright: (c) 2009 - 2012 Ask Solem and individual contributors,
# All rights reserved.
# :copyright: (c) 2012 VMware, Inc., All rights reserved.
# :copyright: (c) 2012-2013 GoPivotal, Inc., All rights reserved.
# :license: BSD (3 Clause), see LICENSE for more details.
from __future__ import absolute_import
SERIES = 'Chiastic Slide'
VERSION = (3, 0, 23)
SERIES = 'Cipater'
VERSION = (3, 1, 3)
__version__ = '.'.join(str(p) for p in VERSION[0:3]) + ''.join(VERSION[3:])
__author__ = 'Ask Solem'
__contact__ = 'ask@celeryproject.org'
__homepage__ = 'http://celeryproject.org'
__docformat__ = 'restructuredtext'
__all__ = [
'Celery', 'bugreport', 'shared_task', 'Task',
'current_app', 'current_task',
'chain', 'chord', 'chunks', 'group', 'subtask',
'xmap', 'xstarmap', 'uuid', 'VERSION', '__version__',
'Celery', 'bugreport', 'shared_task', 'task',
'current_app', 'current_task', 'maybe_signature',
'chain', 'chord', 'chunks', 'group', 'signature',
'xmap', 'xstarmap', 'uuid', 'version', '__version__',
]
VERSION_BANNER = '%s (%s)' % (__version__, SERIES)
VERSION_BANNER = '{0} ({1})'.format(__version__, SERIES)
# -eof meta-
import os
import sys
if os.environ.get('C_IMPDEBUG'): # pragma: no cover
from .five import builtins
real_import = builtins.__import__
def debug_import(name, locals=None, globals=None,
fromlist=None, level=-1):
glob = globals or getattr(sys, 'emarfteg_'[::-1])(1).f_globals
importer_name = glob and glob.get('__name__') or 'unknown'
print('-- {0} imports {1}'.format(importer_name, name))
return real_import(name, locals, globals, fromlist, level)
builtins.__import__ = debug_import
# This is never executed, but tricks static analyzers (PyDev, PyCharm,
# pylint, etc.) into knowing the types of these symbols, and what
# they contain.
STATICA_HACK = True
globals()['kcah_acitats'[::-1].upper()] = False
if STATICA_HACK:
# This is never executed, but tricks static analyzers (PyDev, PyCharm,
# pylint, etc.) into knowing the types of these symbols, and what
# they contain.
from celery.app.base import Celery
from celery.app.utils import bugreport
from celery.app.task import Task
from celery._state import current_app, current_task
from celery.canvas import (
chain, chord, chunks, group, subtask, xmap, xstarmap,
if STATICA_HACK: # pragma: no cover
from celery.app import shared_task # noqa
from celery.app.base import Celery # noqa
from celery.app.utils import bugreport # noqa
from celery.app.task import Task # noqa
from celery._state import current_app, current_task # noqa
from celery.canvas import ( # noqa
chain, chord, chunks, group,
signature, maybe_signature, xmap, xstarmap, subtask,
)
from celery.utils import uuid
from celery.utils import uuid # noqa
# Eventlet/gevent patching must happen before importing
# anything else, so these tools must be at top-level.
def _find_option_with_arg(argv, short_opts=None, long_opts=None):
"""Search argv for option specifying its short and longopt
alternatives.
Return the value of the option if found.
"""
for i, arg in enumerate(argv):
if arg.startswith('-'):
if long_opts and arg.startswith('--'):
name, _, val = arg.partition('=')
if name in long_opts:
return val
if short_opts and arg in short_opts:
return argv[i + 1]
raise KeyError('|'.join(short_opts or [] + long_opts or []))
def _patch_eventlet():
import eventlet
import eventlet.debug
eventlet.monkey_patch()
EVENTLET_DBLOCK = int(os.environ.get('EVENTLET_NOBLOCK', 0))
if EVENTLET_DBLOCK:
eventlet.debug.hub_blocking_detection(EVENTLET_DBLOCK)
def _patch_gevent():
from gevent import monkey, version_info
monkey.patch_all()
if version_info[0] == 0: # pragma: no cover
# Signals aren't working in gevent versions <1.0,
# and are not monkey patched by patch_all()
from gevent import signal as _gevent_signal
_signal = __import__('signal')
_signal.signal = _gevent_signal
def maybe_patch_concurrency(argv=sys.argv,
short_opts=['-P'], long_opts=['--pool'],
patches={'eventlet': _patch_eventlet,
'gevent': _patch_gevent}):
"""With short and long opt alternatives that specify the command line
option to set the pool, this makes sure that anything that needs
to be patched is completed as early as possible.
(e.g. eventlet/gevent monkey patches)."""
try:
pool = _find_option_with_arg(argv, short_opts, long_opts)
except KeyError:
pass
else:
try:
patcher = patches[pool]
except KeyError:
pass
else:
patcher()
# set up eventlet/gevent environments ASAP.
from celery import concurrency
concurrency.get_implementation(pool)
# Lazy loading
from .__compat__ import recreate_module
from .five import recreate_module
old_module, new_module = recreate_module( # pragma: no cover
__name__,
@ -49,7 +130,8 @@ old_module, new_module = recreate_module( # pragma: no cover
'celery.app.task': ['Task'],
'celery._state': ['current_app', 'current_task'],
'celery.canvas': ['chain', 'chord', 'chunks', 'group',
'subtask', 'xmap', 'xstarmap'],
'signature', 'maybe_signature', 'subtask',
'xmap', 'xstarmap'],
'celery.utils': ['uuid'],
},
direct={'task': 'celery.task'},
@ -58,4 +140,6 @@ old_module, new_module = recreate_module( # pragma: no cover
__author__=__author__, __contact__=__contact__,
__homepage__=__homepage__, __docformat__=__docformat__,
VERSION=VERSION, SERIES=SERIES, VERSION_BANNER=VERSION_BANNER,
maybe_patch_concurrency=maybe_patch_concurrency,
_find_option_with_arg=_find_option_with_arg,
)

View File

@ -2,10 +2,25 @@ from __future__ import absolute_import
import sys
from os.path import basename
def maybe_patch_concurrency():
from celery.platforms import maybe_patch_concurrency
maybe_patch_concurrency(sys.argv, ['-P'], ['--pool'])
from . import maybe_patch_concurrency
__all__ = ['main']
DEPRECATED_FMT = """
The {old!r} command is deprecated, please use {new!r} instead:
$ {new_argv}
"""
def _warn_deprecated(new):
print(DEPRECATED_FMT.format(
old=basename(sys.argv[0]), new=new,
new_argv=' '.join([new] + sys.argv[1:])),
)
def main():
@ -16,21 +31,24 @@ def main():
def _compat_worker():
maybe_patch_concurrency()
from celery.bin.celeryd import main
_warn_deprecated('celery worker')
from celery.bin.worker import main
main()
def _compat_multi():
maybe_patch_concurrency()
from celery.bin.celeryd_multi import main
_warn_deprecated('celery multi')
from celery.bin.multi import main
main()
def _compat_beat():
maybe_patch_concurrency()
from celery.bin.celerybeat import main
_warn_deprecated('celery beat')
from celery.bin.beat import main
main()
if __name__ == '__main__':
if __name__ == '__main__': # pragma: no cover
main()

View File

@ -9,7 +9,7 @@
This module shouldn't be used directly.
"""
from __future__ import absolute_import
from __future__ import absolute_import, print_function
import os
import sys
@ -19,12 +19,26 @@ import weakref
from celery.local import Proxy
from celery.utils.threads import LocalStack
__all__ = ['set_default_app', 'get_current_app', 'get_current_task',
'get_current_worker_task', 'current_app', 'current_task']
#: Global default app used when no current app.
default_app = None
#: List of all app instances (weakrefs), must not be used directly.
_apps = set()
_task_join_will_block = False
def _set_task_join_will_block(blocks):
global _task_join_will_block
_task_join_will_block = True
def task_join_will_block():
return _task_join_will_block
class _TLS(threading.local):
#: Apps with the :attr:`~celery.app.base.BaseApp.set_as_current` attribute
@ -53,10 +67,11 @@ def _get_current_app():
return _tls.current_app or default_app
C_STRICT_APP = os.environ.get('C_STRICT_APP')
if os.environ.get('C_STRICT_APP'):
if os.environ.get('C_STRICT_APP'): # pragma: no cover
def get_current_app():
raise Exception('USES CURRENT APP')
import traceback
sys.stderr.write('USES CURRENT_APP\n')
print('-- USES CURRENT_APP', file=sys.stderr) # noqa+
traceback.print_stack(file=sys.stderr)
return _get_current_app()
else:

View File

@ -7,22 +7,29 @@
"""
from __future__ import absolute_import
from __future__ import with_statement
import os
from collections import Callable
from celery.local import Proxy
from celery import _state
from celery._state import ( # noqa
from celery._state import (
set_default_app,
get_current_app as current_app,
get_current_task as current_task,
_get_active_apps,
_task_stack,
)
from celery.utils import gen_task_name
from .builtins import shared_task as _shared_task
from .base import Celery, AppPickler # noqa
from .base import Celery, AppPickler
__all__ = ['Celery', 'AppPickler', 'default_app', 'app_or_default',
'bugreport', 'enable_trace', 'disable_trace', 'shared_task',
'set_default_app', 'current_app', 'current_task',
'push_current_task', 'pop_current_task']
#: Proxy always returning the app set as default.
default_app = Proxy(lambda: _state.default_app)
@ -40,8 +47,18 @@ app_or_default = None
default_loader = os.environ.get('CELERY_LOADER') or 'default' # XXX
def bugreport():
return current_app().bugreport()
#: Function used to push a task to the thread local stack
#: keeping track of the currently executing task.
#: You must remember to pop the task after.
push_current_task = _task_stack.push
#: Function used to pop a task from the thread local stack
#: keeping track of the currently executing task.
pop_current_task = _task_stack.pop
def bugreport(app=None):
return (app or current_app()).bugreport()
def _app_or_default(app=None):
@ -84,8 +101,8 @@ App = Celery # XXX Compat
def shared_task(*args, **kwargs):
"""Task decorator that creates shared tasks,
and returns a proxy that always returns the task from the current apps
"""Create shared tasks (decorator).
Will return a proxy that always takes the task from the current apps
task registry.
This can be used by library authors to create tasks that will work
@ -121,7 +138,7 @@ def shared_task(*args, **kwargs):
with app._finalize_mutex:
app._task_from_fun(fun, **options)
# Returns a proxy that always gets the task from the current
# Return a proxy that always gets the task from the current
# apps task registry.
def task_by_cons():
app = current_app()
@ -131,6 +148,6 @@ def shared_task(*args, **kwargs):
return Proxy(task_by_cons)
return __inner
if len(args) == 1 and callable(args[0]):
if len(args) == 1 and isinstance(args[0], Callable):
return create_shared_task(**kwargs)(args[0])
return create_shared_task(*args, **kwargs)

View File

@ -1,63 +0,0 @@
# -*- coding: utf-8 -*-
"""
celery.app.abstract
~~~~~~~~~~~~~~~~~~~
Abstract class that takes default attribute values
from the configuration.
"""
from __future__ import absolute_import
class from_config(object):
def __init__(self, key=None):
self.key = key
def get_key(self, attr):
return attr if self.key is None else self.key
class _configurated(type):
def __new__(cls, name, bases, attrs):
attrs['__confopts__'] = dict((attr, spec.get_key(attr))
for attr, spec in attrs.iteritems()
if isinstance(spec, from_config))
inherit_from = attrs.get('inherit_confopts', ())
for subcls in bases:
try:
attrs['__confopts__'].update(subcls.__confopts__)
except AttributeError:
pass
for subcls in inherit_from:
attrs['__confopts__'].update(subcls.__confopts__)
attrs = dict((k, v if not isinstance(v, from_config) else None)
for k, v in attrs.iteritems())
return super(_configurated, cls).__new__(cls, name, bases, attrs)
class configurated(object):
__metaclass__ = _configurated
def setup_defaults(self, kwargs, namespace='celery'):
confopts = self.__confopts__
app, find = self.app, self.app.conf.find_value_for_key
for attr, keyname in confopts.iteritems():
try:
value = kwargs[attr]
except KeyError:
value = find(keyname, namespace)
else:
if value is None:
value = find(keyname, namespace)
setattr(self, attr, value)
for attr_name, attr_value in kwargs.iteritems():
if attr_name not in confopts and attr_value is not None:
setattr(self, attr_name, attr_value)
def confopts_as_dict(self):
return dict((key, getattr(self, key)) for key in self.__confopts__)

View File

@ -12,20 +12,25 @@ from datetime import timedelta
from weakref import WeakValueDictionary
from kombu import Connection, Consumer, Exchange, Producer, Queue
from kombu.common import entry_to_queue
from kombu.common import Broadcast
from kombu.pools import ProducerPool
from kombu.utils import cached_property, uuid
from kombu.utils.encoding import safe_repr
from kombu.utils.functional import maybe_list
from celery import signals
from celery.five import items, string_t
from celery.utils.text import indent as textindent
from . import app_or_default
from . import routes as _routes
__all__ = ['AMQP', 'Queues', 'TaskProducer', 'TaskConsumer']
#: Human readable queue declaration.
QUEUE_FORMAT = """
.> %(name)s exchange:%(exchange)s(%(exchange_type)s) binding:%(routing_key)s
.> {0.name:<16} exchange={0.exchange.name}({0.exchange.type}) \
key={0.routing_key}
"""
@ -46,15 +51,16 @@ class Queues(dict):
_consume_from = None
def __init__(self, queues=None, default_exchange=None,
create_missing=True, ha_policy=None):
create_missing=True, ha_policy=None, autoexchange=None):
dict.__init__(self)
self.aliases = WeakValueDictionary()
self.default_exchange = default_exchange
self.create_missing = create_missing
self.ha_policy = ha_policy
self.autoexchange = Exchange if autoexchange is None else autoexchange
if isinstance(queues, (tuple, list)):
queues = dict((q.name, q) for q in queues)
for name, q in (queues or {}).iteritems():
for name, q in items(queues or {}):
self.add(q) if isinstance(q, Queue) else self.add_compat(name, **q)
def __getitem__(self, name):
@ -79,11 +85,16 @@ class Queues(dict):
def add(self, queue, **kwargs):
"""Add new queue.
:param queue: Name of the queue.
:keyword exchange: Name of the exchange.
:keyword routing_key: Binding key.
:keyword exchange_type: Type of exchange.
:keyword \*\*options: Additional declaration options.
The first argument can either be a :class:`kombu.Queue` instance,
or the name of a queue. If the former the rest of the keyword
arguments are ignored, and options are simply taken from the queue
instance.
:param queue: :class:`kombu.Queue` instance or name of the queue.
:keyword exchange: (if named) specifies exchange name.
:keyword routing_key: (if named) specifies binding key.
:keyword exchange_type: (if named) specifies type of exchange.
:keyword \*\*options: (if named) Additional declaration options.
"""
if not isinstance(queue, Queue):
@ -102,7 +113,7 @@ class Queues(dict):
options['routing_key'] = name
if self.ha_policy is not None:
self._set_ha_policy(options.setdefault('queue_arguments', {}))
q = self[name] = entry_to_queue(name, **options)
q = self[name] = Queue.from_dict(name, **options)
return q
def _set_ha_policy(self, args):
@ -117,13 +128,8 @@ class Queues(dict):
active = self.consume_from
if not active:
return ''
info = [
QUEUE_FORMAT.strip() % {
'name': (name + ':').ljust(12),
'exchange': q.exchange.name,
'exchange_type': q.exchange.type,
'routing_key': q.routing_key}
for name, q in sorted(active.iteritems())]
info = [QUEUE_FORMAT.strip().format(q)
for _, q in sorted(items(active))]
if indent_first:
return textindent('\n'.join(info), indent)
return info[0] + '\n' + textindent('\n'.join(info[1:]), indent)
@ -136,23 +142,37 @@ class Queues(dict):
self._consume_from[q.name] = q
return q
def select_subset(self, wanted):
def select(self, include):
"""Sets :attr:`consume_from` by selecting a subset of the
currently defined queues.
:param wanted: List of wanted queue names.
:param include: Names of queues to consume from.
Can be iterable or string.
"""
if wanted:
self._consume_from = dict((name, self[name]) for name in wanted)
if include:
self._consume_from = dict((name, self[name])
for name in maybe_list(include))
select_subset = select # XXX compat
def select_remove(self, queue):
if self._consume_from is None:
self.select_subset(k for k in self if k != queue)
else:
self._consume_from.pop(queue, None)
def deselect(self, exclude):
"""Deselect queues so that they will not be consumed from.
:param exclude: Names of queues to avoid consuming from.
Can be iterable or string.
"""
if exclude:
exclude = maybe_list(exclude)
if self._consume_from is None:
# using selection
return self.select(k for k in self if k not in exclude)
# using all queues
for queue in exclude:
self._consume_from.pop(queue, None)
select_remove = deselect # XXX compat
def new_missing(self, name):
return Queue(name, Exchange(name), name)
return Queue(name, self.autoexchange(name), name)
@property
def consume_from(self):
@ -189,20 +209,30 @@ class TaskProducer(Producer):
queue=None, now=None, retries=0, chord=None,
callbacks=None, errbacks=None, routing_key=None,
serializer=None, delivery_mode=None, compression=None,
declare=None, **kwargs):
reply_to=None, time_limit=None, soft_time_limit=None,
declare=None, headers=None,
send_before_publish=signals.before_task_publish.send,
before_receivers=signals.before_task_publish.receivers,
send_after_publish=signals.after_task_publish.send,
after_receivers=signals.after_task_publish.receivers,
send_task_sent=signals.task_sent.send, # XXX deprecated
sent_receivers=signals.task_sent.receivers,
**kwargs):
"""Send task message."""
retry = self.retry if retry is None else retry
qname = queue
if queue is None and exchange is None:
queue = self.default_queue
if queue is not None:
if isinstance(queue, basestring):
if isinstance(queue, string_t):
qname, queue = queue, self.queues[queue]
else:
qname = queue.name
exchange = exchange or queue.exchange.name
routing_key = routing_key or queue.routing_key
declare = declare or ([queue] if queue else [])
if declare is None and queue and not isinstance(queue, Broadcast):
declare = [queue]
# merge default and custom policy
retry = self.retry if retry is None else retry
@ -218,9 +248,13 @@ class TaskProducer(Producer):
if countdown: # Convert countdown to ETA.
now = now or self.app.now()
eta = now + timedelta(seconds=countdown)
if self.utc:
eta = eta.replace(tzinfo=self.app.timezone)
if isinstance(expires, (int, float)):
now = now or self.app.now()
expires = now + timedelta(seconds=expires)
if self.utc:
expires = expires.replace(tzinfo=self.app.timezone)
eta = eta and eta.isoformat()
expires = expires and expires.isoformat()
@ -235,21 +269,44 @@ class TaskProducer(Producer):
'utc': self.utc,
'callbacks': callbacks,
'errbacks': errbacks,
'timelimit': (time_limit, soft_time_limit),
'taskset': group_id or taskset_id,
'chord': chord,
}
if before_receivers:
send_before_publish(
sender=task_name, body=body,
exchange=exchange,
routing_key=routing_key,
declare=declare,
headers=headers,
properties=kwargs,
retry_policy=retry_policy,
)
self.publish(
body,
exchange=exchange, routing_key=routing_key,
serializer=serializer or self.serializer,
compression=compression or self.compression,
headers=headers,
retry=retry, retry_policy=_rp,
reply_to=reply_to,
correlation_id=task_id,
delivery_mode=delivery_mode, declare=declare,
**kwargs
)
signals.task_sent.send(sender=task_name, **body)
if after_receivers:
send_after_publish(sender=task_name, body=body,
exchange=exchange, routing_key=routing_key)
if sent_receivers: # XXX deprecated
send_task_sent(sender=task_name, task_id=task_id,
task=task_name, args=task_args,
kwargs=task_kwargs, eta=eta,
taskset=group_id or taskset_id)
if self.send_sent_event:
evd = event_dispatcher or self.event_dispatcher
exname = exchange or self.exchange
@ -306,7 +363,7 @@ class TaskConsumer(Consumer):
accept = self.app.conf.CELERY_ACCEPT_CONTENT
super(TaskConsumer, self).__init__(
channel,
queues or self.app.amqp.queues.consume_from.values(),
queues or list(self.app.amqp.queues.consume_from.values()),
accept=accept,
**kw
)
@ -329,13 +386,20 @@ class AMQP(object):
#: set by the :attr:`producer_pool`.
_producer_pool = None
# Exchange class/function used when defining automatic queues.
# E.g. you can use ``autoexchange = lambda n: None`` to use the
# amqp default exchange, which is a shortcut to bypass routing
# and instead send directly to the queue named in the routing key.
autoexchange = None
def __init__(self, app):
self.app = app
def flush_routes(self):
self._rtable = _routes.prepare(self.app.conf.CELERY_ROUTES)
def Queues(self, queues, create_missing=None, ha_policy=None):
def Queues(self, queues, create_missing=None, ha_policy=None,
autoexchange=None):
"""Create new :class:`Queues` instance, using queue defaults
from the current configuration."""
conf = self.app.conf
@ -347,10 +411,15 @@ class AMQP(object):
queues = (Queue(conf.CELERY_DEFAULT_QUEUE,
exchange=self.default_exchange,
routing_key=conf.CELERY_DEFAULT_ROUTING_KEY), )
return Queues(queues, self.default_exchange, create_missing, ha_policy)
autoexchange = (self.autoexchange if autoexchange is None
else autoexchange)
return Queues(
queues, self.default_exchange, create_missing,
ha_policy, autoexchange,
)
def Router(self, queues=None, create_missing=None):
"""Returns the current task router."""
"""Return the current task router."""
return _routes.Router(self.routes, queues or self.queues,
self.app.either('CELERY_CREATE_MISSING_QUEUES',
create_missing), app=self.app)
@ -365,7 +434,7 @@ class AMQP(object):
@cached_property
def TaskProducer(self):
"""Returns publisher used to send tasks.
"""Return publisher used to send tasks.
You should use `app.send_task` instead.

View File

@ -12,15 +12,14 @@
"""
from __future__ import absolute_import
from celery.utils.functional import firstmethod, mpromise
from celery.five import string_t
from celery.utils.functional import firstmethod, mlazy
from celery.utils.imports import instantiate
_first_match = firstmethod('annotate')
_first_match_any = firstmethod('annotate_any')
def resolve_all(anno, task):
return (r for r in (_first_match(anno, task), _first_match_any(anno)) if r)
__all__ = ['MapAnnotation', 'prepare', 'resolve_all']
class MapAnnotation(dict):
@ -44,8 +43,8 @@ def prepare(annotations):
def expand_annotation(annotation):
if isinstance(annotation, dict):
return MapAnnotation(annotation)
elif isinstance(annotation, basestring):
return mpromise(instantiate, annotation)
elif isinstance(annotation, string_t):
return mlazy(instantiate, annotation)
return annotation
if annotations is None:
@ -53,3 +52,7 @@ def prepare(annotations):
elif not isinstance(annotations, (list, tuple)):
annotations = (annotations, )
return [expand_annotation(anno) for anno in annotations]
def resolve_all(anno, task):
return (x for x in (_first_match(anno, task), _first_match_any(anno)) if x)

View File

@ -7,36 +7,59 @@
"""
from __future__ import absolute_import
from __future__ import with_statement
import os
import threading
import warnings
from collections import deque
from collections import Callable, defaultdict, deque
from contextlib import contextmanager
from copy import deepcopy
from functools import wraps
from operator import attrgetter
from billiard.util import register_after_fork
from kombu.clocks import LamportClock
from kombu.utils import cached_property
from kombu.common import oid_from
from kombu.utils import cached_property, uuid
from celery import platforms
from celery._state import (
_task_stack, _tls, get_current_app, _register_app, get_current_worker_task,
)
from celery.exceptions import AlwaysEagerIgnored, ImproperlyConfigured
from celery.five import items, values
from celery.loaders import get_loader_cls
from celery.local import PromiseProxy, maybe_evaluate
from celery._state import _task_stack, _tls, get_current_app, _register_app
from celery.utils.functional import first, maybe_list
from celery.utils.imports import instantiate, symbol_by_name
from celery.utils.log import ensure_process_aware_logger
from celery.utils.objects import mro_lookup
from .annotations import prepare as prepare_annotations
from .builtins import shared_task, load_shared_tasks
from .defaults import DEFAULTS, find_deprecated_settings
from .registry import TaskRegistry
from .utils import AppPickler, Settings, bugreport, _unpickle_app
from .utils import (
AppPickler, Settings, bugreport, _unpickle_app, _unpickle_app_v2, appstr,
)
__all__ = ['Celery']
_EXECV = os.environ.get('FORKED_BY_MULTIPROCESSING')
BUILTIN_FIXUPS = frozenset([
'celery.fixups.django:fixup',
])
ERR_ENVVAR_NOT_SET = """\
The environment variable {0!r} is not set,
and as such the configuration could not be loaded.
Please set this variable and make it point to
a configuration module."""
def app_has_custom(app, attr):
return mro_lookup(app.__class__, attr, stop=(Celery, object),
monkey_patched=[__name__])
def _unpickle_appattr(reverse_name, args):
@ -46,6 +69,7 @@ def _unpickle_appattr(reverse_name, args):
class Celery(object):
#: This is deprecated, use :meth:`reduce_keys` instead
Pickler = AppPickler
SYSTEM = platforms.SYSTEM
@ -57,6 +81,7 @@ class Celery(object):
loader_cls = 'celery.loaders.app:AppLoader'
log_cls = 'celery.app.log:Logging'
control_cls = 'celery.app.control:Control'
task_cls = 'celery.app.task:Task'
registry_cls = TaskRegistry
_pool = None
@ -64,8 +89,7 @@ class Celery(object):
amqp=None, events=None, log=None, control=None,
set_as_current=True, accept_magic_kwargs=False,
tasks=None, broker=None, include=None, changes=None,
config_source=None,
**kwargs):
config_source=None, fixups=None, task_cls=None, **kwargs):
self.clock = LamportClock()
self.main = main
self.amqp_cls = amqp or self.amqp_cls
@ -74,10 +98,13 @@ class Celery(object):
self.loader_cls = loader or self.loader_cls
self.log_cls = log or self.log_cls
self.control_cls = control or self.control_cls
self.task_cls = task_cls or self.task_cls
self.set_as_current = set_as_current
self.registry_cls = symbol_by_name(self.registry_cls)
self.accept_magic_kwargs = accept_magic_kwargs
self.user_options = defaultdict(set)
self._config_source = config_source
self.steps = defaultdict(set)
self.configured = False
self._pending_defaults = deque()
@ -89,6 +116,11 @@ class Celery(object):
if not isinstance(self._tasks, TaskRegistry):
self._tasks = TaskRegistry(self._tasks or {})
# If the class defins a custom __reduce_args__ we need to use
# the old way of pickling apps, which is pickling a list of
# args instead of the new way that pickles a dict of keywords.
self._using_v1_reduce = app_has_custom(self, '__reduce_args__')
# these options are moved to the config to
# simplify pickling of the app object.
self._preconf = changes or {}
@ -97,6 +129,11 @@ class Celery(object):
if include:
self._preconf['CELERY_IMPORTS'] = include
# Apply fixups.
self.fixups = set(fixups or ())
for fixup in self.fixups | BUILTIN_FIXUPS:
symbol_by_name(fixup)(self)
if self.set_as_current:
self.set_current()
@ -133,7 +170,7 @@ class Celery(object):
def worker_main(self, argv=None):
return instantiate(
'celery.bin.celeryd:WorkerCommand',
'celery.bin.worker:worker',
app=self).execute_from_commandline(argv)
def task(self, *args, **opts):
@ -145,8 +182,7 @@ class Celery(object):
# the task instance from the current app.
# Really need a better solution for this :(
from . import shared_task as proxies_to_curapp
opts['_force_evaluate'] = True # XXX Py2.5
return proxies_to_curapp(*args, **opts)
return proxies_to_curapp(*args, _force_evaluate=True, **opts)
def inner_create_task_cls(shared=True, filter=None, **opts):
_filt = filter # stupid 2to3
@ -162,16 +198,20 @@ class Celery(object):
task = filter(task)
return task
# return a proxy object that is only evaluated when first used
promise = PromiseProxy(self._task_from_fun, (fun, ), opts)
self._pending.append(promise)
if self.finalized or opts.get('_force_evaluate'):
ret = self._task_from_fun(fun, **opts)
else:
# return a proxy object that evaluates on first use
ret = PromiseProxy(self._task_from_fun, (fun, ), opts,
__doc__=fun.__doc__)
self._pending.append(ret)
if _filt:
return _filt(promise)
return promise
return _filt(ret)
return ret
return _create_task_cls
if len(args) == 1 and callable(args[0]):
if len(args) == 1 and isinstance(args[0], Callable):
return inner_create_task_cls(**opts)(*args)
if args:
raise TypeError(
@ -180,15 +220,16 @@ class Celery(object):
def _task_from_fun(self, fun, **options):
base = options.pop('base', None) or self.Task
bind = options.pop('bind', False)
T = type(fun.__name__, (base, ), dict({
'app': self,
'accept_magic_kwargs': False,
'run': staticmethod(fun),
'run': fun if bind else staticmethod(fun),
'_decorated': True,
'__doc__': fun.__doc__,
'__module__': fun.__module__}, **options))()
task = self._tasks[T.name] # return global instance.
task.bind(self)
return task
def finalize(self):
@ -201,11 +242,11 @@ class Celery(object):
while pending:
maybe_evaluate(pending.popleft())
for task in self._tasks.itervalues():
for task in values(self._tasks):
task.bind(self)
def add_defaults(self, fun):
if not callable(fun):
if not isinstance(fun, Callable):
d, fun = fun, lambda: d
if self.configured:
return self.conf.add_defaults(fun())
@ -221,56 +262,83 @@ class Celery(object):
if not module_name:
if silent:
return False
raise ImproperlyConfigured(self.error_envvar_not_set % module_name)
raise ImproperlyConfigured(ERR_ENVVAR_NOT_SET.format(module_name))
return self.config_from_object(module_name, silent=silent)
def config_from_cmdline(self, argv, namespace='celery'):
self.conf.update(self.loader.cmdline_config_parser(argv, namespace))
def setup_security(self, allowed_serializers=None, key=None, cert=None,
store=None, digest='sha1', serializer='json'):
from celery.security import setup_security
return setup_security(allowed_serializers, key, cert,
store, digest, serializer, app=self)
def autodiscover_tasks(self, packages, related_name='tasks'):
if self.conf.CELERY_FORCE_BILLIARD_LOGGING:
# we'll use billiard's processName instead of
# multiprocessing's one in all the loggers
# created after this call
ensure_process_aware_logger()
self.loader.autodiscover_tasks(packages, related_name)
def send_task(self, name, args=None, kwargs=None, countdown=None,
eta=None, task_id=None, producer=None, connection=None,
result_cls=None, expires=None, queues=None, publisher=None,
link=None, link_error=None,
**options):
router=None, result_cls=None, expires=None,
publisher=None, link=None, link_error=None,
add_to_parent=True, reply_to=None, **options):
task_id = task_id or uuid()
producer = producer or publisher # XXX compat
if self.conf.CELERY_ALWAYS_EAGER: # pragma: no cover
router = router or self.amqp.router
conf = self.conf
if conf.CELERY_ALWAYS_EAGER: # pragma: no cover
warnings.warn(AlwaysEagerIgnored(
'CELERY_ALWAYS_EAGER has no effect on send_task'))
result_cls = result_cls or self.AsyncResult
router = self.amqp.Router(queues)
options.setdefault('compression',
self.conf.CELERY_MESSAGE_COMPRESSION)
options = router.route(options, name, args, kwargs)
with self.producer_or_acquire(producer) as producer:
return result_cls(producer.publish_task(
name, args, kwargs,
task_id=task_id,
countdown=countdown, eta=eta,
callbacks=maybe_list(link),
errbacks=maybe_list(link_error),
expires=expires, **options
))
if connection:
producer = self.amqp.TaskProducer(connection)
with self.producer_or_acquire(producer) as P:
self.backend.on_task_call(P, task_id)
task_id = P.publish_task(
name, args, kwargs, countdown=countdown, eta=eta,
task_id=task_id, expires=expires,
callbacks=maybe_list(link), errbacks=maybe_list(link_error),
reply_to=reply_to or self.oid, **options
)
result = (result_cls or self.AsyncResult)(task_id)
if add_to_parent:
parent = get_current_worker_task()
if parent:
parent.add_trail(result)
return result
def connection(self, hostname=None, userid=None,
password=None, virtual_host=None, port=None, ssl=None,
insist=None, connect_timeout=None, transport=None,
transport_options=None, heartbeat=None, **kwargs):
def connection(self, hostname=None, userid=None, password=None,
virtual_host=None, port=None, ssl=None,
connect_timeout=None, transport=None,
transport_options=None, heartbeat=None,
login_method=None, failover_strategy=None, **kwargs):
conf = self.conf
return self.amqp.Connection(
hostname or conf.BROKER_HOST,
hostname or conf.BROKER_URL,
userid or conf.BROKER_USER,
password or conf.BROKER_PASSWORD,
virtual_host or conf.BROKER_VHOST,
port or conf.BROKER_PORT,
transport=transport or conf.BROKER_TRANSPORT,
insist=self.either('BROKER_INSIST', insist),
ssl=self.either('BROKER_USE_SSL', ssl),
connect_timeout=self.either(
'BROKER_CONNECTION_TIMEOUT', connect_timeout),
heartbeat=heartbeat,
transport_options=dict(conf.BROKER_TRANSPORT_OPTIONS,
**transport_options or {}))
login_method=login_method or conf.BROKER_LOGIN_METHOD,
failover_strategy=(
failover_strategy or conf.BROKER_FAILOVER_STRATEGY
),
transport_options=dict(
conf.BROKER_TRANSPORT_OPTIONS, **transport_options or {}
),
connect_timeout=self.either(
'BROKER_CONNECTION_TIMEOUT', connect_timeout
),
)
broker_connection = connection
@contextmanager
@ -296,26 +364,6 @@ class Celery(object):
yield producer
default_producer = producer_or_acquire # XXX compat
def with_default_connection(self, fun):
"""With any function accepting a `connection`
keyword argument, establishes a default connection if one is
not already passed to it.
Any automatically established connection will be closed after
the function returns.
**Deprecated**
Use ``with app.connection_or_acquire(connection)`` instead.
"""
@wraps(fun)
def _inner(*args, **kwargs):
connection = kwargs.pop('connection', None)
with self.connection_or_acquire(connection) as c:
return fun(*args, **dict(kwargs, connection=c))
return _inner
def prepare_config(self, c):
"""Prepare configuration before it is merged with the defaults."""
return find_deprecated_settings(c)
@ -339,7 +387,7 @@ class Celery(object):
)
def select_queues(self, queues=None):
return self.amqp.queues.select_subset(queues)
return self.amqp.queues.select(queues)
def either(self, default_key, *values):
"""Fallback to the value of a configuration key if none of the
@ -356,7 +404,12 @@ class Celery(object):
self.loader)
return backend(app=self, url=url)
def on_configure(self):
"""Callback calld when the app loads configuration"""
pass
def _get_config(self):
self.on_configure()
self.configured = True
s = Settings({}, [self.prepare_config(self.loader.conf),
deepcopy(DEFAULTS)])
@ -364,9 +417,9 @@ class Celery(object):
# load lazy config dict initializers.
pending = self._pending_defaults
while pending:
s.add_defaults(pending.popleft()())
s.add_defaults(maybe_evaluate(pending.popleft()()))
if self._preconf:
for key, value in self._preconf.iteritems():
for key, value in items(self._preconf):
setattr(s, key, value)
return s
@ -382,14 +435,20 @@ class Celery(object):
amqp._producer_pool.force_close_all()
amqp._producer_pool = None
def signature(self, *args, **kwargs):
kwargs['app'] = self
return self.canvas.signature(*args, **kwargs)
def create_task_cls(self):
"""Creates a base task class using default configuration
taken from this app."""
return self.subclass_with_self('celery.app.task:Task', name='Task',
attribute='_app', abstract=True)
return self.subclass_with_self(
self.task_cls, name='Task', attribute='_app',
keep_reduce=True, abstract=True,
)
def subclass_with_self(self, Class, name=None, attribute='app',
reverse=None, **kw):
reverse=None, keep_reduce=False, **kw):
"""Subclass an app-compatible class by setting its app attribute
to be this app instance.
@ -410,18 +469,24 @@ class Celery(object):
return _unpickle_appattr, (reverse, self.__reduce_args__())
attrs = dict({attribute: self}, __module__=Class.__module__,
__doc__=Class.__doc__, __reduce__=__reduce__, **kw)
__doc__=Class.__doc__, **kw)
if not keep_reduce:
attrs['__reduce__'] = __reduce__
return type(name or Class.__name__, (Class, ), attrs)
def _rgetattr(self, path):
return reduce(getattr, [self] + path.split('.'))
return attrgetter(path)(self)
def __repr__(self):
return '<%s %s:0x%x>' % (self.__class__.__name__,
self.main or '__main__', id(self), )
return '<{0} {1}>'.format(type(self).__name__, appstr(self))
def __reduce__(self):
if self._using_v1_reduce:
return self.__reduce_v1__()
return (_unpickle_app_v2, (self.__class__, self.__reduce_keys__()))
def __reduce_v1__(self):
# Reduce only pickles the configuration changes,
# so the default configuration doesn't have to be passed
# between processes.
@ -430,11 +495,30 @@ class Celery(object):
(self.__class__, self.Pickler) + self.__reduce_args__(),
)
def __reduce_keys__(self):
"""Return keyword arguments used to reconstruct the object
when unpickling."""
return {
'main': self.main,
'changes': self.conf.changes,
'loader': self.loader_cls,
'backend': self.backend_cls,
'amqp': self.amqp_cls,
'events': self.events_cls,
'log': self.log_cls,
'control': self.control_cls,
'accept_magic_kwargs': self.accept_magic_kwargs,
'fixups': self.fixups,
'config_source': self._config_source,
'task_cls': self.task_cls,
}
def __reduce_args__(self):
return (self.main, self.conf.changes, self.loader_cls,
self.backend_cls, self.amqp_cls, self.events_cls,
self.log_cls, self.control_cls, self.accept_magic_kwargs,
self._config_source)
"""Deprecated method, please use :meth:`__reduce_keys__` instead."""
return (self.main, self.conf.changes,
self.loader_cls, self.backend_cls, self.amqp_cls,
self.events_cls, self.log_cls, self.control_cls,
self.accept_magic_kwargs, self._config_source)
@cached_property
def Worker(self):
@ -448,10 +532,6 @@ class Celery(object):
def Beat(self, **kwargs):
return self.subclass_with_self('celery.apps.beat:Beat')
@cached_property
def TaskSet(self):
return self.subclass_with_self('celery.task.sets:TaskSet')
@cached_property
def Task(self):
return self.create_task_cls()
@ -464,12 +544,22 @@ class Celery(object):
def AsyncResult(self):
return self.subclass_with_self('celery.result:AsyncResult')
@cached_property
def ResultSet(self):
return self.subclass_with_self('celery.result:ResultSet')
@cached_property
def GroupResult(self):
return self.subclass_with_self('celery.result:GroupResult')
@cached_property
def TaskSet(self): # XXX compat
"""Deprecated! Please use :class:`celery.group` instead."""
return self.subclass_with_self('celery.task.sets:TaskSet')
@cached_property
def TaskSetResult(self): # XXX compat
"""Deprecated! Please use :attr:`GroupResult` instead."""
return self.subclass_with_self('celery.result:TaskSetResult')
@property
@ -484,6 +574,10 @@ class Celery(object):
def current_task(self):
return _task_stack.top
@cached_property
def oid(self):
return oid_from(self)
@cached_property
def amqp(self):
return instantiate(self.amqp_cls, app=self)
@ -512,8 +606,23 @@ class Celery(object):
def log(self):
return instantiate(self.log_cls, app=self)
@cached_property
def canvas(self):
from celery import canvas
return canvas
@cached_property
def tasks(self):
self.finalize()
return self._tasks
@cached_property
def timezone(self):
from celery.utils.timeutils import timezone
conf = self.conf
tz = conf.CELERY_TIMEZONE
if not tz:
return (timezone.get_timezone('UTC') if conf.CELERY_ENABLE_UTC
else timezone.local)
return timezone.get_timezone(self.conf.CELERY_TIMEZONE)
App = Celery # compat

View File

@ -8,32 +8,35 @@
"""
from __future__ import absolute_import
from __future__ import with_statement
from collections import deque
from celery._state import get_current_worker_task
from celery.utils import uuid
__all__ = ['shared_task', 'load_shared_tasks']
#: global list of functions defining tasks that should be
#: added to all apps.
_shared_tasks = []
_shared_tasks = set()
def shared_task(constructor):
"""Decorator that specifies that the decorated function is a function
that generates a built-in task.
"""Decorator that specifies a function that generates a built-in task.
The function will then be called for every new app instance created
(lazily, so more exactly when the task registry for that app is needed).
The function must take a single ``app`` argument.
"""
_shared_tasks.append(constructor)
_shared_tasks.add(constructor)
return constructor
def load_shared_tasks(app):
"""Loads the built-in tasks for an app instance."""
for constructor in _shared_tasks:
"""Create built-in tasks for an app instance."""
constructors = set(_shared_tasks)
for constructor in constructors:
constructor(app)
@ -42,17 +45,13 @@ def add_backend_cleanup_task(app):
"""The backend cleanup task can be used to clean up the default result
backend.
This task is also added do the periodic task schedule so that it is
run every day at midnight, but :program:`celerybeat` must be running
for this to be effective.
Note that not all backends do anything for this, what needs to be
done at cleanup is up to each backend, and some backends
may even clean up in realtime so that a periodic cleanup is not necessary.
If the configured backend requires periodic cleanup this task is also
automatically configured to run every day at midnight (requires
:program:`celery beat` to be running).
"""
@app.task(name='celery.backend_cleanup', _force_evaluate=True)
@app.task(name='celery.backend_cleanup',
shared=False, _force_evaluate=True)
def backend_cleanup():
app.backend.cleanup()
return backend_cleanup
@ -60,58 +59,62 @@ def add_backend_cleanup_task(app):
@shared_task
def add_unlock_chord_task(app):
"""The unlock chord task is used by result backends that doesn't
have native chord support.
"""This task is used by result backends without native chord support.
It creates a task chain polling the header for completion.
It joins chords by creating a task chain polling the header for completion.
"""
from celery.canvas import subtask
from celery.canvas import signature
from celery.exceptions import ChordError
from celery.result import from_serializable
from celery.result import result_from_tuple
default_propagate = app.conf.CELERY_CHORD_PROPAGATES
@app.task(name='celery.chord_unlock', max_retries=None,
@app.task(name='celery.chord_unlock', max_retries=None, shared=False,
default_retry_delay=1, ignore_result=True, _force_evaluate=True)
def unlock_chord(group_id, callback, interval=None, propagate=None,
max_retries=None, result=None,
Result=app.AsyncResult, GroupResult=app.GroupResult,
from_serializable=from_serializable):
result_from_tuple=result_from_tuple):
# if propagate is disabled exceptions raised by chord tasks
# will be sent as part of the result list to the chord callback.
# Since 3.1 propagate will be enabled by default, and instead
# the chord callback changes state to FAILURE with the
# exception set to ChordError.
propagate = default_propagate if propagate is None else propagate
if interval is None:
interval = unlock_chord.default_retry_delay
# check if the task group is ready, and if so apply the callback.
deps = GroupResult(
group_id,
[from_serializable(r, app=app) for r in result],
[result_from_tuple(r, app=app) for r in result],
)
j = deps.join_native if deps.supports_native_join else deps.join
if deps.ready():
callback = subtask(callback)
callback = signature(callback, app=app)
try:
ret = j(propagate=propagate)
except Exception, exc:
except Exception as exc:
try:
culprit = deps._failed_join_report().next()
reason = 'Dependency %s raised %r' % (culprit.id, exc)
culprit = next(deps._failed_join_report())
reason = 'Dependency {0.id} raised {1!r}'.format(
culprit, exc,
)
except StopIteration:
reason = repr(exc)
app._tasks[callback.task].backend.fail_from_current_stack(
callback.id, exc=ChordError(reason),
)
else:
try:
callback.delay(ret)
except Exception, exc:
except Exception as exc:
app._tasks[callback.task].backend.fail_from_current_stack(
callback.id,
exc=ChordError('Callback error: %r' % (exc, )),
exc=ChordError('Callback error: {0!r}'.format(exc)),
)
else:
return unlock_chord.retry(countdown=interval,
@ -121,23 +124,23 @@ def add_unlock_chord_task(app):
@shared_task
def add_map_task(app):
from celery.canvas import subtask
from celery.canvas import signature
@app.task(name='celery.map', _force_evaluate=True)
@app.task(name='celery.map', shared=False, _force_evaluate=True)
def xmap(task, it):
task = subtask(task).type
return [task(value) for value in it]
task = signature(task, app=app).type
return [task(item) for item in it]
return xmap
@shared_task
def add_starmap_task(app):
from celery.canvas import subtask
from celery.canvas import signature
@app.task(name='celery.starmap', _force_evaluate=True)
@app.task(name='celery.starmap', shared=False, _force_evaluate=True)
def xstarmap(task, it):
task = subtask(task).type
return [task(*args) for args in it]
task = signature(task, app=app).type
return [task(*item) for item in it]
return xstarmap
@ -145,7 +148,7 @@ def add_starmap_task(app):
def add_chunk_task(app):
from celery.canvas import chunks as _chunks
@app.task(name='celery.chunks', _force_evaluate=True)
@app.task(name='celery.chunks', shared=False, _force_evaluate=True)
def chunks(task, it, n):
return _chunks.apply_chunks(task, it, n)
return chunks
@ -154,19 +157,20 @@ def add_chunk_task(app):
@shared_task
def add_group_task(app):
_app = app
from celery.canvas import maybe_subtask, subtask
from celery.result import from_serializable
from celery.canvas import maybe_signature, signature
from celery.result import result_from_tuple
class Group(app.Task):
app = _app
name = 'celery.group'
accept_magic_kwargs = False
_decorated = True
def run(self, tasks, result, group_id, partial_args):
app = self.app
result = from_serializable(result, app)
result = result_from_tuple(result, app)
# any partial args are added to all tasks in the group
taskit = (subtask(task).clone(partial_args)
taskit = (signature(task, app=app).clone(partial_args)
for i, task in enumerate(tasks))
if self.request.is_eager or app.conf.CELERY_ALWAYS_EAGER:
return app.GroupResult(
@ -178,30 +182,25 @@ def add_group_task(app):
add_to_parent=False) for stask in taskit]
parent = get_current_worker_task()
if parent:
parent.request.children.append(result)
parent.add_trail(result)
return result
def prepare(self, options, tasks, args, **kwargs):
AsyncResult = self.AsyncResult
options['group_id'] = group_id = (
options.setdefault('task_id', uuid()))
def prepare_member(task):
task = maybe_subtask(task)
opts = task.options
opts['group_id'] = group_id
try:
tid = opts['task_id']
except KeyError:
tid = opts['task_id'] = uuid()
return task, AsyncResult(tid)
task = maybe_signature(task, app=self.app)
task.options['group_id'] = group_id
return task, task.freeze()
try:
tasks, results = zip(*[prepare_member(task) for task in tasks])
tasks, res = list(zip(
*[prepare_member(task) for task in tasks]
))
except ValueError: # tasks empty
tasks, results = [], []
return (tasks, self.app.GroupResult(group_id, results),
group_id, args)
tasks, res = [], []
return (tasks, self.app.GroupResult(group_id, res), group_id, args)
def apply_async(self, partial_args=(), kwargs={}, **options):
if self.app.conf.CELERY_ALWAYS_EAGER:
@ -210,7 +209,7 @@ def add_group_task(app):
options, args=partial_args, **kwargs
)
super(Group, self).apply_async((
list(tasks), result.serializable(), gid, args), **options
list(tasks), result.as_tuple(), gid, args), **options
)
return result
@ -223,50 +222,55 @@ def add_group_task(app):
@shared_task
def add_chain_task(app):
from celery.canvas import Signature, chord, group, maybe_subtask
from celery.canvas import Signature, chord, group, maybe_signature
_app = app
class Chain(app.Task):
app = _app
name = 'celery.chain'
accept_magic_kwargs = False
_decorated = True
def prepare_steps(self, args, tasks):
app = self.app
steps = deque(tasks)
next_step = prev_task = prev_res = None
tasks, results = [], []
i = 0
while steps:
# First task get partial args from chain.
task = maybe_subtask(steps.popleft())
task = maybe_signature(steps.popleft(), app=app)
task = task.clone() if i else task.clone(args)
res = task._freeze()
res = task.freeze()
i += 1
if isinstance(task, group):
if isinstance(task, group) and steps and \
not isinstance(steps[0], group):
# automatically upgrade group(..) | s to chord(group, s)
try:
next_step = steps.popleft()
# for chords we freeze by pretending it's a normal
# task instead of a group.
res = Signature._freeze(task)
res = Signature.freeze(next_step)
task = chord(task, body=next_step, task_id=res.task_id)
except IndexError:
pass
pass # no callback, so keep as group
if prev_task:
# link previous task to this task.
prev_task.link(task)
# set the results parent attribute.
res.parent = prev_res
if not res.parent:
res.parent = prev_res
results.append(res)
tasks.append(task)
if not isinstance(prev_task, chord):
results.append(res)
tasks.append(task)
prev_task, prev_res = task, res
return tasks, results
def apply_async(self, args=(), kwargs={}, group_id=None, chord=None,
task_id=None, **options):
task_id=None, link=None, link_error=None, **options):
if self.app.conf.CELERY_ALWAYS_EAGER:
return self.apply(args, kwargs, **options)
options.pop('publisher', None)
@ -279,13 +283,24 @@ def add_chain_task(app):
if task_id:
tasks[-1].set(task_id=task_id)
result = tasks[-1].type.AsyncResult(task_id)
# make sure we can do a link() and link_error() on a chain object.
if link:
tasks[-1].set(link=link)
# and if any task in the chain fails, call the errbacks
if link_error:
for task in tasks:
task.set(link_error=link_error)
tasks[0].apply_async()
return result
def apply(self, args=(), kwargs={}, subtask=maybe_subtask, **options):
def apply(self, args=(), kwargs={}, signature=maybe_signature,
**options):
app = self.app
last, fargs = None, args # fargs passed to first task only
for task in kwargs['tasks']:
res = subtask(task).clone(fargs).apply(last and (last.get(), ))
res = signature(task, app=app).clone(fargs).apply(
last and (last.get(), ),
)
res.parent, last, fargs = last, res, None
return last
return Chain
@ -294,10 +309,10 @@ def add_chain_task(app):
@shared_task
def add_chord_task(app):
"""Every chord is executed in a dedicated task, so that the chord
can be used as a subtask, and this generates the task
can be used as a signature, and this generates the task
responsible for that."""
from celery import group
from celery.canvas import maybe_subtask
from celery.canvas import maybe_signature
_app = app
default_propagate = app.conf.CELERY_CHORD_PROPAGATES
@ -306,18 +321,22 @@ def add_chord_task(app):
name = 'celery.chord'
accept_magic_kwargs = False
ignore_result = False
_decorated = True
def run(self, header, body, partial_args=(), interval=None,
countdown=1, max_retries=None, propagate=None,
eager=False, **kwargs):
app = self.app
propagate = default_propagate if propagate is None else propagate
group_id = uuid()
AsyncResult = self.app.AsyncResult
AsyncResult = app.AsyncResult
prepare_member = self._prepare_member
# - convert back to group if serialized
tasks = header.tasks if isinstance(header, group) else header
header = group([maybe_subtask(s).clone() for s in tasks])
header = group([
maybe_signature(s, app=app).clone() for s in tasks
])
# - eager applies the group inline
if eager:
return header.apply(args=partial_args, task_id=group_id)
@ -333,8 +352,9 @@ def add_chord_task(app):
propagate=propagate,
result=results)
# - call the header group, returning the GroupResult.
# XXX Python 2.5 doesn't allow kwargs after star-args.
return header(*partial_args, **{'task_id': group_id})
final_res = header(*partial_args, task_id=group_id)
return final_res
def _prepare_member(self, task, body, group_id):
opts = task.options
@ -346,23 +366,25 @@ def add_chord_task(app):
opts.update(chord=body, group_id=group_id)
return task_id
def apply_async(self, args=(), kwargs={}, task_id=None, **options):
if self.app.conf.CELERY_ALWAYS_EAGER:
def apply_async(self, args=(), kwargs={}, task_id=None,
group_id=None, chord=None, **options):
app = self.app
if app.conf.CELERY_ALWAYS_EAGER:
return self.apply(args, kwargs, **options)
group_id = options.pop('group_id', None)
chord = options.pop('chord', None)
header = kwargs.pop('header')
body = kwargs.pop('body')
header, body = (list(maybe_subtask(header)),
maybe_subtask(body))
if group_id:
body.set(group_id=group_id)
if chord:
body.set(chord=chord)
callback_id = body.options.setdefault('task_id', task_id or uuid())
header, body = (list(maybe_signature(header, app=app)),
maybe_signature(body, app=app))
# forward certain options to body
if chord is not None:
body.options['chord'] = chord
if group_id is not None:
body.options['group_id'] = group_id
[body.link(s) for s in options.pop('link', [])]
[body.link_error(s) for s in options.pop('link_error', [])]
body_result = body.freeze(task_id)
parent = super(Chord, self).apply_async((header, body, args),
kwargs, **options)
body_result = self.AsyncResult(callback_id)
body_result.parent = parent
return body_result
@ -370,6 +392,6 @@ def add_chord_task(app):
body = kwargs['body']
res = super(Chord, self).apply(args, dict(kwargs, eager=True),
**options)
return maybe_subtask(body).apply(
return maybe_signature(body, app=self.app).apply(
args=(res.get(propagate=propagate).get(), ))
return Chord

View File

@ -8,17 +8,32 @@
"""
from __future__ import absolute_import
from __future__ import with_statement
import warnings
from kombu.pidbox import Mailbox
from kombu.utils import cached_property
from . import app_or_default
from celery.exceptions import DuplicateNodenameWarning
__all__ = ['Inspect', 'Control', 'flatten_reply']
W_DUPNODE = """\
Received multiple replies from node name {0!r}.
Please make sure you give each node a unique nodename using the `-n` option.\
"""
def flatten_reply(reply):
nodes = {}
seen = set()
for item in reply:
dup = next((nodename in seen for nodename in item), None)
if dup:
warnings.warn(DuplicateNodenameWarning(
W_DUPNODE.format(dup),
))
seen.update(item)
nodes.update(item)
return nodes
@ -58,6 +73,9 @@ class Inspect(object):
def report(self):
return self._request('report')
def clock(self):
return self._request('clock')
def active(self, safe=False):
return self._request('dump_active', safe=safe)
@ -83,15 +101,30 @@ class Inspect(object):
def active_queues(self):
return self._request('active_queues')
def conf(self):
return self._request('dump_conf')
def query_task(self, ids):
return self._request('query_task', ids=ids)
def conf(self, with_defaults=False):
return self._request('dump_conf', with_defaults=with_defaults)
def hello(self, from_node, revoked=None):
return self._request('hello', from_node=from_node, revoked=revoked)
def memsample(self):
return self._request('memsample')
def memdump(self, samples=10):
return self._request('memdump', samples=samples)
def objgraph(self, type='Request', n=200, max_depth=10):
return self._request('objgraph', num=n, max_depth=max_depth, type=type)
class Control(object):
Mailbox = Mailbox
def __init__(self, app=None):
self.app = app_or_default(app)
self.app = app
self.mailbox = self.Mailbox('celery', type='fanout',
accept=self.app.conf.CELERY_ACCEPT_CONTENT)
@ -112,6 +145,11 @@ class Control(object):
return self.app.amqp.TaskConsumer(conn).purge()
discard_all = purge
def election(self, id, topic, action=None, connection=None):
self.broadcast('election', connection=connection, arguments={
'id': id, 'topic': topic, 'action': action,
})
def revoke(self, task_id, destination=None, terminate=False,
signal='SIGTERM', **kwargs):
"""Tell all (or specific) workers to revoke a task by id.
@ -136,7 +174,7 @@ class Control(object):
def ping(self, destination=None, timeout=1, **kwargs):
"""Ping all (or specific) workers.
Returns answer from alive workers.
Will return the list of answers.
See :meth:`broadcast` for supported keyword arguments.
@ -234,7 +272,7 @@ class Control(object):
Supports the same arguments as :meth:`broadcast`.
"""
return self.broadcast('pool_grow', {}, destination, **kwargs)
return self.broadcast('pool_grow', {'n': n}, destination, **kwargs)
def pool_shrink(self, n=1, destination=None, **kwargs):
"""Tell all (or specific) workers to shrink the pool by ``n``.
@ -242,7 +280,7 @@ class Control(object):
Supports the same arguments as :meth:`broadcast`.
"""
return self.broadcast('pool_shrink', {}, destination, **kwargs)
return self.broadcast('pool_shrink', {'n': n}, destination, **kwargs)
def broadcast(self, command, arguments=None, destination=None,
connection=None, reply=False, timeout=1, limit=None,

Some files were not shown because too many files have changed in this diff Show More