Merge pull request #48 from matburt/upgrade_vendored_packages

Upgrade vendored packages
This commit is contained in:
Matthew Jones 2015-01-30 11:12:32 -05:00
commit b73c7ce41b
367 changed files with 26588 additions and 25350 deletions

View File

@ -110,7 +110,7 @@ push:
# locally downloaded packages).
requirements:
@if [ "$(VIRTUAL_ENV)" ]; then \
(cd requirements && pip install --no-index setuptools-2.2.tar.gz); \
(cd requirements && pip install --no-index setuptools-12.0.5.tar.gz); \
(cd requirements && pip install --no-index Django-1.6.7.tar.gz); \
(cd requirements && pip install --no-index -r dev_local.txt); \
$(PYTHON) fix_virtualenv_setuptools.py; \
@ -122,7 +122,7 @@ requirements:
# (downloading from PyPI if necessary).
requirements_pypi:
@if [ "$(VIRTUAL_ENV)" ]; then \
pip install setuptools==2.2; \
pip install setuptools==12.0.5; \
pip install Django\>=1.6.7,\<1.7; \
pip install -r requirements/dev.txt; \
$(PYTHON) fix_virtualenv_setuptools.py; \

View File

@ -5,6 +5,7 @@ amqp==1.4.5 (amqp/*)
ansi2html==1.0.6 (ansi2html/*)
anyjson==0.3.3 (anyjson/*)
argparse==1.2.1 (argparse.py, needed for Python 2.6 support)
azure==0.9.0 (azure/*)
Babel==1.3 (babel/*, excluded bin/pybabel)
billiard==3.3.0.16 (billiard/*, funtests/*, excluded _billiard.so)
boto==2.34.0 (boto/*, excluded bin/asadmin, bin/bundle_image, bin/cfadmin,
@ -30,9 +31,9 @@ gevent-websocket==0.9.3 (geventwebsocket/*)
httplib2==0.9 (httplib2/*)
importlib==1.0.3 (importlib/*, needed for Python 2.6 support)
iso8601==0.1.10 (iso8601/*)
keyring==4.0 (keyring/*, excluded bin/keyring)
keyring==4.1 (keyring/*, excluded bin/keyring)
kombu==3.0.21 (kombu/*)
Markdown==2.4.1 (markdown/*, excluded bin/markdown_py)
Markdown==2.5.2 (markdown/*, excluded bin/markdown_py)
mock==1.0.1 (mock.py)
ordereddict==1.1 (ordereddict.py, needed for Python 2.6 support)
os-diskconfig-python-novaclient-ext==0.1.2 (os_diskconfig_python_novaclient_ext/*)
@ -44,16 +45,16 @@ pexpect==3.1 (pexpect/*, excluded pxssh.py, fdpexpect.py, FSM.py, screen.py,
pip==1.5.4 (pip/*, excluded bin/pip*)
prettytable==0.7.2 (prettytable.py)
pyrax==1.9.0 (pyrax/*)
python-dateutil==2.2 (dateutil/*)
python-dateutil==2.4.0 (dateutil/*)
python-novaclient==2.18.1 (novaclient/*, excluded bin/nova)
python-swiftclient==2.2.0 (swiftclient/*, excluded bin/swift)
pytz==2014.4 (pytz/*)
pytz==2014.10 (pytz/*)
rackspace-auth-openstack==1.3 (rackspace_auth_openstack/*)
rackspace-novaclient==1.4 (no files)
rax-default-network-flags-python-novaclient-ext==0.2.3 (rax_default_network_flags_python_novaclient_ext/*)
rax-scheduled-images-python-novaclient-ext==0.2.1 (rax_scheduled_images_python_novaclient_ext/*)
requests==2.3.0 (requests/*)
setuptools==2.2 (setuptools/*, _markerlib/*, pkg_resources.py, easy_install.py, excluded bin/easy_install*)
requests==2.5.1 (requests/*)
setuptools==12.0.5 (setuptools/*, _markerlib/*, pkg_resources/*, easy_install.py)
simplejson==3.6.0 (simplejson/*, excluded simplejson/_speedups.so)
six==1.7.3 (six.py)
six==1.9.0 (six.py)
South==0.8.4 (south/*)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,81 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003" ToolsVersion="4.0">
<PropertyGroup>
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
<SchemaVersion>2.0</SchemaVersion>
<ProjectGuid>{25b2c65a-0553-4452-8907-8b5b17544e68}</ProjectGuid>
<ProjectHome>
</ProjectHome>
<StartupFile>storage\blobservice.py</StartupFile>
<SearchPath>..</SearchPath>
<WorkingDirectory>.</WorkingDirectory>
<OutputPath>.</OutputPath>
<Name>azure</Name>
<RootNamespace>azure</RootNamespace>
<IsWindowsApplication>False</IsWindowsApplication>
<LaunchProvider>Standard Python launcher</LaunchProvider>
<CommandLineArguments />
<InterpreterPath />
<InterpreterArguments />
<InterpreterId>{2af0f10d-7135-4994-9156-5d01c9c11b7e}</InterpreterId>
<InterpreterVersion>2.7</InterpreterVersion>
<SccProjectName>SAK</SccProjectName>
<SccProvider>SAK</SccProvider>
<SccAuxPath>SAK</SccAuxPath>
<SccLocalPath>SAK</SccLocalPath>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)' == 'Debug' ">
<DebugSymbols>true</DebugSymbols>
<EnableUnmanagedDebugging>false</EnableUnmanagedDebugging>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)' == 'Release' ">
<DebugSymbols>true</DebugSymbols>
<EnableUnmanagedDebugging>false</EnableUnmanagedDebugging>
</PropertyGroup>
<ItemGroup>
<Compile Include="http\batchclient.py" />
<Compile Include="http\httpclient.py" />
<Compile Include="http\requestsclient.py" />
<Compile Include="http\winhttp.py" />
<Compile Include="http\__init__.py" />
<Compile Include="servicemanagement\schedulermanagementservice.py" />
<Compile Include="servicemanagement\servicebusmanagementservice.py" />
<Compile Include="servicemanagement\servicemanagementclient.py" />
<Compile Include="servicemanagement\servicemanagementservice.py" />
<Compile Include="servicemanagement\sqldatabasemanagementservice.py" />
<Compile Include="servicemanagement\websitemanagementservice.py" />
<Compile Include="servicemanagement\__init__.py" />
<Compile Include="servicebus\servicebusservice.py" />
<Compile Include="storage\blobservice.py" />
<Compile Include="storage\queueservice.py" />
<Compile Include="storage\cloudstorageaccount.py" />
<Compile Include="storage\tableservice.py" />
<Compile Include="storage\sharedaccesssignature.py" />
<Compile Include="__init__.py" />
<Compile Include="servicebus\__init__.py" />
<Compile Include="storage\storageclient.py" />
<Compile Include="storage\__init__.py" />
</ItemGroup>
<ItemGroup>
<Folder Include="http" />
<Folder Include="servicemanagement" />
<Folder Include="servicebus" />
<Folder Include="storage" />
</ItemGroup>
<ItemGroup>
<InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\2.6" />
<InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\2.7" />
<InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\3.3" />
<InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\3.4" />
<InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\2.7" />
<InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\3.3" />
<InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\3.4" />
</ItemGroup>
<PropertyGroup>
<VisualStudioVersion Condition="'$(VisualStudioVersion)' == ''">10.0</VisualStudioVersion>
<VSToolsPath Condition="'$(VSToolsPath)' == ''">$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)</VSToolsPath>
<PtvsTargetsFile>$(VSToolsPath)\Python Tools\Microsoft.PythonTools.targets</PtvsTargetsFile>
</PropertyGroup>
<Import Condition="Exists($(PtvsTargetsFile))" Project="$(PtvsTargetsFile)" />
<Import Condition="!Exists($(PtvsTargetsFile))" Project="$(MSBuildToolsPath)\Microsoft.Common.targets" />
</Project>

View File

@ -1,73 +1,73 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
HTTP_RESPONSE_NO_CONTENT = 204
class HTTPError(Exception):
''' HTTP Exception when response status code >= 300 '''
def __init__(self, status, message, respheader, respbody):
'''Creates a new HTTPError with the specified status, message,
response headers and body'''
self.status = status
self.respheader = respheader
self.respbody = respbody
Exception.__init__(self, message)
class HTTPResponse(object):
"""Represents a response from an HTTP request. An HTTPResponse has the
following attributes:
status: the status code of the response
message: the message
headers: the returned headers, as a list of (name, value) pairs
body: the body of the response
"""
def __init__(self, status, message, headers, body):
self.status = status
self.message = message
self.headers = headers
self.body = body
class HTTPRequest(object):
'''Represents an HTTP Request. An HTTP Request consists of the following
attributes:
host: the host name to connect to
method: the method to use to connect (string such as GET, POST, PUT, etc.)
path: the uri fragment
query: query parameters specified as a list of (name, value) pairs
headers: header values specified as (name, value) pairs
body: the body of the request.
protocol_override:
specify to use this protocol instead of the global one stored in
_HTTPClient.
'''
def __init__(self):
self.host = ''
self.method = ''
self.path = ''
self.query = [] # list of (name, value)
self.headers = [] # list of (header name, header value)
self.body = ''
self.protocol_override = None
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
HTTP_RESPONSE_NO_CONTENT = 204
class HTTPError(Exception):
''' HTTP Exception when response status code >= 300 '''
def __init__(self, status, message, respheader, respbody):
'''Creates a new HTTPError with the specified status, message,
response headers and body'''
self.status = status
self.respheader = respheader
self.respbody = respbody
Exception.__init__(self, message)
class HTTPResponse(object):
"""Represents a response from an HTTP request. An HTTPResponse has the
following attributes:
status: the status code of the response
message: the message
headers: the returned headers, as a list of (name, value) pairs
body: the body of the response
"""
def __init__(self, status, message, headers, body):
self.status = status
self.message = message
self.headers = headers
self.body = body
class HTTPRequest(object):
'''Represents an HTTP Request. An HTTP Request consists of the following
attributes:
host: the host name to connect to
method: the method to use to connect (string such as GET, POST, PUT, etc.)
path: the uri fragment
query: query parameters specified as a list of (name, value) pairs
headers: header values specified as (name, value) pairs
body: the body of the request.
protocol_override:
specify to use this protocol instead of the global one stored in
_HTTPClient.
'''
def __init__(self):
self.host = ''
self.method = ''
self.path = ''
self.query = [] # list of (name, value)
self.headers = [] # list of (header name, header value)
self.body = ''
self.protocol_override = None

View File

@ -1,339 +1,339 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
import sys
import uuid
from azure import (
_update_request_uri_query,
WindowsAzureError,
WindowsAzureBatchOperationError,
_get_children_from_path,
url_unquote,
_ERROR_CANNOT_FIND_PARTITION_KEY,
_ERROR_CANNOT_FIND_ROW_KEY,
_ERROR_INCORRECT_TABLE_IN_BATCH,
_ERROR_INCORRECT_PARTITION_KEY_IN_BATCH,
_ERROR_DUPLICATE_ROW_KEY_IN_BATCH,
_ERROR_BATCH_COMMIT_FAIL,
)
from azure.http import HTTPError, HTTPRequest, HTTPResponse
from azure.http.httpclient import _HTTPClient
from azure.storage import (
_update_storage_table_header,
METADATA_NS,
_sign_storage_table_request,
)
from xml.dom import minidom
_DATASERVICES_NS = 'http://schemas.microsoft.com/ado/2007/08/dataservices'
if sys.version_info < (3,):
def _new_boundary():
return str(uuid.uuid1())
else:
def _new_boundary():
return str(uuid.uuid1()).encode('utf-8')
class _BatchClient(_HTTPClient):
'''
This is the class that is used for batch operation for storage table
service. It only supports one changeset.
'''
def __init__(self, service_instance, account_key, account_name,
protocol='http'):
_HTTPClient.__init__(self, service_instance, account_name=account_name,
account_key=account_key, protocol=protocol)
self.is_batch = False
self.batch_requests = []
self.batch_table = ''
self.batch_partition_key = ''
self.batch_row_keys = []
def get_request_table(self, request):
'''
Extracts table name from request.uri. The request.uri has either
"/mytable(...)" or "/mytable" format.
request: the request to insert, update or delete entity
'''
if '(' in request.path:
pos = request.path.find('(')
return request.path[1:pos]
else:
return request.path[1:]
def get_request_partition_key(self, request):
'''
Extracts PartitionKey from request.body if it is a POST request or from
request.path if it is not a POST request. Only insert operation request
is a POST request and the PartitionKey is in the request body.
request: the request to insert, update or delete entity
'''
if request.method == 'POST':
doc = minidom.parseString(request.body)
part_key = _get_children_from_path(
doc, 'entry', 'content', (METADATA_NS, 'properties'),
(_DATASERVICES_NS, 'PartitionKey'))
if not part_key:
raise WindowsAzureError(_ERROR_CANNOT_FIND_PARTITION_KEY)
return part_key[0].firstChild.nodeValue
else:
uri = url_unquote(request.path)
pos1 = uri.find('PartitionKey=\'')
pos2 = uri.find('\',', pos1)
if pos1 == -1 or pos2 == -1:
raise WindowsAzureError(_ERROR_CANNOT_FIND_PARTITION_KEY)
return uri[pos1 + len('PartitionKey=\''):pos2]
def get_request_row_key(self, request):
'''
Extracts RowKey from request.body if it is a POST request or from
request.path if it is not a POST request. Only insert operation request
is a POST request and the Rowkey is in the request body.
request: the request to insert, update or delete entity
'''
if request.method == 'POST':
doc = minidom.parseString(request.body)
row_key = _get_children_from_path(
doc, 'entry', 'content', (METADATA_NS, 'properties'),
(_DATASERVICES_NS, 'RowKey'))
if not row_key:
raise WindowsAzureError(_ERROR_CANNOT_FIND_ROW_KEY)
return row_key[0].firstChild.nodeValue
else:
uri = url_unquote(request.path)
pos1 = uri.find('RowKey=\'')
pos2 = uri.find('\')', pos1)
if pos1 == -1 or pos2 == -1:
raise WindowsAzureError(_ERROR_CANNOT_FIND_ROW_KEY)
row_key = uri[pos1 + len('RowKey=\''):pos2]
return row_key
def validate_request_table(self, request):
'''
Validates that all requests have the same table name. Set the table
name if it is the first request for the batch operation.
request: the request to insert, update or delete entity
'''
if self.batch_table:
if self.get_request_table(request) != self.batch_table:
raise WindowsAzureError(_ERROR_INCORRECT_TABLE_IN_BATCH)
else:
self.batch_table = self.get_request_table(request)
def validate_request_partition_key(self, request):
'''
Validates that all requests have the same PartitiionKey. Set the
PartitionKey if it is the first request for the batch operation.
request: the request to insert, update or delete entity
'''
if self.batch_partition_key:
if self.get_request_partition_key(request) != \
self.batch_partition_key:
raise WindowsAzureError(_ERROR_INCORRECT_PARTITION_KEY_IN_BATCH)
else:
self.batch_partition_key = self.get_request_partition_key(request)
def validate_request_row_key(self, request):
'''
Validates that all requests have the different RowKey and adds RowKey
to existing RowKey list.
request: the request to insert, update or delete entity
'''
if self.batch_row_keys:
if self.get_request_row_key(request) in self.batch_row_keys:
raise WindowsAzureError(_ERROR_DUPLICATE_ROW_KEY_IN_BATCH)
else:
self.batch_row_keys.append(self.get_request_row_key(request))
def begin_batch(self):
'''
Starts the batch operation. Intializes the batch variables
is_batch: batch operation flag.
batch_table: the table name of the batch operation
batch_partition_key: the PartitionKey of the batch requests.
batch_row_keys: the RowKey list of adding requests.
batch_requests: the list of the requests.
'''
self.is_batch = True
self.batch_table = ''
self.batch_partition_key = ''
self.batch_row_keys = []
self.batch_requests = []
def insert_request_to_batch(self, request):
'''
Adds request to batch operation.
request: the request to insert, update or delete entity
'''
self.validate_request_table(request)
self.validate_request_partition_key(request)
self.validate_request_row_key(request)
self.batch_requests.append(request)
def commit_batch(self):
''' Resets batch flag and commits the batch requests. '''
if self.is_batch:
self.is_batch = False
self.commit_batch_requests()
def commit_batch_requests(self):
''' Commits the batch requests. '''
batch_boundary = b'batch_' + _new_boundary()
changeset_boundary = b'changeset_' + _new_boundary()
# Commits batch only the requests list is not empty.
if self.batch_requests:
request = HTTPRequest()
request.method = 'POST'
request.host = self.batch_requests[0].host
request.path = '/$batch'
request.headers = [
('Content-Type', 'multipart/mixed; boundary=' + \
batch_boundary.decode('utf-8')),
('Accept', 'application/atom+xml,application/xml'),
('Accept-Charset', 'UTF-8')]
request.body = b'--' + batch_boundary + b'\n'
request.body += b'Content-Type: multipart/mixed; boundary='
request.body += changeset_boundary + b'\n\n'
content_id = 1
# Adds each request body to the POST data.
for batch_request in self.batch_requests:
request.body += b'--' + changeset_boundary + b'\n'
request.body += b'Content-Type: application/http\n'
request.body += b'Content-Transfer-Encoding: binary\n\n'
request.body += batch_request.method.encode('utf-8')
request.body += b' http://'
request.body += batch_request.host.encode('utf-8')
request.body += batch_request.path.encode('utf-8')
request.body += b' HTTP/1.1\n'
request.body += b'Content-ID: '
request.body += str(content_id).encode('utf-8') + b'\n'
content_id += 1
# Add different headers for different type requests.
if not batch_request.method == 'DELETE':
request.body += \
b'Content-Type: application/atom+xml;type=entry\n'
for name, value in batch_request.headers:
if name == 'If-Match':
request.body += name.encode('utf-8') + b': '
request.body += value.encode('utf-8') + b'\n'
break
request.body += b'Content-Length: '
request.body += str(len(batch_request.body)).encode('utf-8')
request.body += b'\n\n'
request.body += batch_request.body + b'\n'
else:
for name, value in batch_request.headers:
# If-Match should be already included in
# batch_request.headers, but in case it is missing,
# just add it.
if name == 'If-Match':
request.body += name.encode('utf-8') + b': '
request.body += value.encode('utf-8') + b'\n\n'
break
else:
request.body += b'If-Match: *\n\n'
request.body += b'--' + changeset_boundary + b'--' + b'\n'
request.body += b'--' + batch_boundary + b'--'
request.path, request.query = _update_request_uri_query(request)
request.headers = _update_storage_table_header(request)
auth = _sign_storage_table_request(request,
self.account_name,
self.account_key)
request.headers.append(('Authorization', auth))
# Submit the whole request as batch request.
response = self.perform_request(request)
if response.status >= 300:
raise HTTPError(response.status,
_ERROR_BATCH_COMMIT_FAIL,
self.respheader,
response.body)
# http://www.odata.org/documentation/odata-version-2-0/batch-processing/
# The body of a ChangeSet response is either a response for all the
# successfully processed change request within the ChangeSet,
# formatted exactly as it would have appeared outside of a batch,
# or a single response indicating a failure of the entire ChangeSet.
responses = self._parse_batch_response(response.body)
if responses and responses[0].status >= 300:
self._report_batch_error(responses[0])
def cancel_batch(self):
''' Resets the batch flag. '''
self.is_batch = False
def _parse_batch_response(self, body):
parts = body.split(b'--changesetresponse_')
responses = []
for part in parts:
httpLocation = part.find(b'HTTP/')
if httpLocation > 0:
response = self._parse_batch_response_part(part[httpLocation:])
responses.append(response)
return responses
def _parse_batch_response_part(self, part):
lines = part.splitlines();
# First line is the HTTP status/reason
status, _, reason = lines[0].partition(b' ')[2].partition(b' ')
# Followed by headers and body
headers = []
body = b''
isBody = False
for line in lines[1:]:
if line == b'' and not isBody:
isBody = True
elif isBody:
body += line
else:
headerName, _, headerVal = line.partition(b':')
headers.append((headerName.lower(), headerVal))
return HTTPResponse(int(status), reason.strip(), headers, body)
def _report_batch_error(self, response):
xml = response.body.decode('utf-8')
doc = minidom.parseString(xml)
n = _get_children_from_path(doc, (METADATA_NS, 'error'), 'code')
code = n[0].firstChild.nodeValue if n and n[0].firstChild else ''
n = _get_children_from_path(doc, (METADATA_NS, 'error'), 'message')
message = n[0].firstChild.nodeValue if n and n[0].firstChild else xml
raise WindowsAzureBatchOperationError(message, code)
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
import sys
import uuid
from azure import (
_update_request_uri_query,
WindowsAzureError,
WindowsAzureBatchOperationError,
_get_children_from_path,
url_unquote,
_ERROR_CANNOT_FIND_PARTITION_KEY,
_ERROR_CANNOT_FIND_ROW_KEY,
_ERROR_INCORRECT_TABLE_IN_BATCH,
_ERROR_INCORRECT_PARTITION_KEY_IN_BATCH,
_ERROR_DUPLICATE_ROW_KEY_IN_BATCH,
_ERROR_BATCH_COMMIT_FAIL,
)
from azure.http import HTTPError, HTTPRequest, HTTPResponse
from azure.http.httpclient import _HTTPClient
from azure.storage import (
_update_storage_table_header,
METADATA_NS,
_sign_storage_table_request,
)
from xml.dom import minidom
_DATASERVICES_NS = 'http://schemas.microsoft.com/ado/2007/08/dataservices'
if sys.version_info < (3,):
def _new_boundary():
return str(uuid.uuid1())
else:
def _new_boundary():
return str(uuid.uuid1()).encode('utf-8')
class _BatchClient(_HTTPClient):
'''
This is the class that is used for batch operation for storage table
service. It only supports one changeset.
'''
def __init__(self, service_instance, account_key, account_name,
protocol='http'):
_HTTPClient.__init__(self, service_instance, account_name=account_name,
account_key=account_key, protocol=protocol)
self.is_batch = False
self.batch_requests = []
self.batch_table = ''
self.batch_partition_key = ''
self.batch_row_keys = []
def get_request_table(self, request):
'''
Extracts table name from request.uri. The request.uri has either
"/mytable(...)" or "/mytable" format.
request: the request to insert, update or delete entity
'''
if '(' in request.path:
pos = request.path.find('(')
return request.path[1:pos]
else:
return request.path[1:]
def get_request_partition_key(self, request):
'''
Extracts PartitionKey from request.body if it is a POST request or from
request.path if it is not a POST request. Only insert operation request
is a POST request and the PartitionKey is in the request body.
request: the request to insert, update or delete entity
'''
if request.method == 'POST':
doc = minidom.parseString(request.body)
part_key = _get_children_from_path(
doc, 'entry', 'content', (METADATA_NS, 'properties'),
(_DATASERVICES_NS, 'PartitionKey'))
if not part_key:
raise WindowsAzureError(_ERROR_CANNOT_FIND_PARTITION_KEY)
return part_key[0].firstChild.nodeValue
else:
uri = url_unquote(request.path)
pos1 = uri.find('PartitionKey=\'')
pos2 = uri.find('\',', pos1)
if pos1 == -1 or pos2 == -1:
raise WindowsAzureError(_ERROR_CANNOT_FIND_PARTITION_KEY)
return uri[pos1 + len('PartitionKey=\''):pos2]
def get_request_row_key(self, request):
'''
Extracts RowKey from request.body if it is a POST request or from
request.path if it is not a POST request. Only insert operation request
is a POST request and the Rowkey is in the request body.
request: the request to insert, update or delete entity
'''
if request.method == 'POST':
doc = minidom.parseString(request.body)
row_key = _get_children_from_path(
doc, 'entry', 'content', (METADATA_NS, 'properties'),
(_DATASERVICES_NS, 'RowKey'))
if not row_key:
raise WindowsAzureError(_ERROR_CANNOT_FIND_ROW_KEY)
return row_key[0].firstChild.nodeValue
else:
uri = url_unquote(request.path)
pos1 = uri.find('RowKey=\'')
pos2 = uri.find('\')', pos1)
if pos1 == -1 or pos2 == -1:
raise WindowsAzureError(_ERROR_CANNOT_FIND_ROW_KEY)
row_key = uri[pos1 + len('RowKey=\''):pos2]
return row_key
def validate_request_table(self, request):
'''
Validates that all requests have the same table name. Set the table
name if it is the first request for the batch operation.
request: the request to insert, update or delete entity
'''
if self.batch_table:
if self.get_request_table(request) != self.batch_table:
raise WindowsAzureError(_ERROR_INCORRECT_TABLE_IN_BATCH)
else:
self.batch_table = self.get_request_table(request)
def validate_request_partition_key(self, request):
'''
Validates that all requests have the same PartitiionKey. Set the
PartitionKey if it is the first request for the batch operation.
request: the request to insert, update or delete entity
'''
if self.batch_partition_key:
if self.get_request_partition_key(request) != \
self.batch_partition_key:
raise WindowsAzureError(_ERROR_INCORRECT_PARTITION_KEY_IN_BATCH)
else:
self.batch_partition_key = self.get_request_partition_key(request)
def validate_request_row_key(self, request):
'''
Validates that all requests have the different RowKey and adds RowKey
to existing RowKey list.
request: the request to insert, update or delete entity
'''
if self.batch_row_keys:
if self.get_request_row_key(request) in self.batch_row_keys:
raise WindowsAzureError(_ERROR_DUPLICATE_ROW_KEY_IN_BATCH)
else:
self.batch_row_keys.append(self.get_request_row_key(request))
def begin_batch(self):
'''
Starts the batch operation. Intializes the batch variables
is_batch: batch operation flag.
batch_table: the table name of the batch operation
batch_partition_key: the PartitionKey of the batch requests.
batch_row_keys: the RowKey list of adding requests.
batch_requests: the list of the requests.
'''
self.is_batch = True
self.batch_table = ''
self.batch_partition_key = ''
self.batch_row_keys = []
self.batch_requests = []
def insert_request_to_batch(self, request):
'''
Adds request to batch operation.
request: the request to insert, update or delete entity
'''
self.validate_request_table(request)
self.validate_request_partition_key(request)
self.validate_request_row_key(request)
self.batch_requests.append(request)
def commit_batch(self):
''' Resets batch flag and commits the batch requests. '''
if self.is_batch:
self.is_batch = False
self.commit_batch_requests()
def commit_batch_requests(self):
''' Commits the batch requests. '''
batch_boundary = b'batch_' + _new_boundary()
changeset_boundary = b'changeset_' + _new_boundary()
# Commits batch only the requests list is not empty.
if self.batch_requests:
request = HTTPRequest()
request.method = 'POST'
request.host = self.batch_requests[0].host
request.path = '/$batch'
request.headers = [
('Content-Type', 'multipart/mixed; boundary=' + \
batch_boundary.decode('utf-8')),
('Accept', 'application/atom+xml,application/xml'),
('Accept-Charset', 'UTF-8')]
request.body = b'--' + batch_boundary + b'\n'
request.body += b'Content-Type: multipart/mixed; boundary='
request.body += changeset_boundary + b'\n\n'
content_id = 1
# Adds each request body to the POST data.
for batch_request in self.batch_requests:
request.body += b'--' + changeset_boundary + b'\n'
request.body += b'Content-Type: application/http\n'
request.body += b'Content-Transfer-Encoding: binary\n\n'
request.body += batch_request.method.encode('utf-8')
request.body += b' http://'
request.body += batch_request.host.encode('utf-8')
request.body += batch_request.path.encode('utf-8')
request.body += b' HTTP/1.1\n'
request.body += b'Content-ID: '
request.body += str(content_id).encode('utf-8') + b'\n'
content_id += 1
# Add different headers for different type requests.
if not batch_request.method == 'DELETE':
request.body += \
b'Content-Type: application/atom+xml;type=entry\n'
for name, value in batch_request.headers:
if name == 'If-Match':
request.body += name.encode('utf-8') + b': '
request.body += value.encode('utf-8') + b'\n'
break
request.body += b'Content-Length: '
request.body += str(len(batch_request.body)).encode('utf-8')
request.body += b'\n\n'
request.body += batch_request.body + b'\n'
else:
for name, value in batch_request.headers:
# If-Match should be already included in
# batch_request.headers, but in case it is missing,
# just add it.
if name == 'If-Match':
request.body += name.encode('utf-8') + b': '
request.body += value.encode('utf-8') + b'\n\n'
break
else:
request.body += b'If-Match: *\n\n'
request.body += b'--' + changeset_boundary + b'--' + b'\n'
request.body += b'--' + batch_boundary + b'--'
request.path, request.query = _update_request_uri_query(request)
request.headers = _update_storage_table_header(request)
auth = _sign_storage_table_request(request,
self.account_name,
self.account_key)
request.headers.append(('Authorization', auth))
# Submit the whole request as batch request.
response = self.perform_request(request)
if response.status >= 300:
raise HTTPError(response.status,
_ERROR_BATCH_COMMIT_FAIL,
self.respheader,
response.body)
# http://www.odata.org/documentation/odata-version-2-0/batch-processing/
# The body of a ChangeSet response is either a response for all the
# successfully processed change request within the ChangeSet,
# formatted exactly as it would have appeared outside of a batch,
# or a single response indicating a failure of the entire ChangeSet.
responses = self._parse_batch_response(response.body)
if responses and responses[0].status >= 300:
self._report_batch_error(responses[0])
def cancel_batch(self):
''' Resets the batch flag. '''
self.is_batch = False
def _parse_batch_response(self, body):
parts = body.split(b'--changesetresponse_')
responses = []
for part in parts:
httpLocation = part.find(b'HTTP/')
if httpLocation > 0:
response = self._parse_batch_response_part(part[httpLocation:])
responses.append(response)
return responses
def _parse_batch_response_part(self, part):
lines = part.splitlines();
# First line is the HTTP status/reason
status, _, reason = lines[0].partition(b' ')[2].partition(b' ')
# Followed by headers and body
headers = []
body = b''
isBody = False
for line in lines[1:]:
if line == b'' and not isBody:
isBody = True
elif isBody:
body += line
else:
headerName, _, headerVal = line.partition(b':')
headers.append((headerName.lower(), headerVal))
return HTTPResponse(int(status), reason.strip(), headers, body)
def _report_batch_error(self, response):
xml = response.body.decode('utf-8')
doc = minidom.parseString(xml)
n = _get_children_from_path(doc, (METADATA_NS, 'error'), 'code')
code = n[0].firstChild.nodeValue if n and n[0].firstChild else ''
n = _get_children_from_path(doc, (METADATA_NS, 'error'), 'message')
message = n[0].firstChild.nodeValue if n and n[0].firstChild else xml
raise WindowsAzureBatchOperationError(message, code)

View File

@ -1,223 +1,251 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
import base64
import os
import sys
if sys.version_info < (3,):
from httplib import (
HTTPSConnection,
HTTPConnection,
HTTP_PORT,
HTTPS_PORT,
)
from urlparse import urlparse
else:
from http.client import (
HTTPSConnection,
HTTPConnection,
HTTP_PORT,
HTTPS_PORT,
)
from urllib.parse import urlparse
from azure.http import HTTPError, HTTPResponse
from azure import _USER_AGENT_STRING, _update_request_uri_query
class _HTTPClient(object):
'''
Takes the request and sends it to cloud service and returns the response.
'''
def __init__(self, service_instance, cert_file=None, account_name=None,
account_key=None, service_namespace=None, issuer=None,
protocol='https'):
'''
service_instance: service client instance.
cert_file:
certificate file name/location. This is only used in hosted
service management.
account_name: the storage account.
account_key:
the storage account access key for storage services or servicebus
access key for service bus service.
service_namespace: the service namespace for service bus.
issuer: the issuer for service bus service.
'''
self.service_instance = service_instance
self.status = None
self.respheader = None
self.message = None
self.cert_file = cert_file
self.account_name = account_name
self.account_key = account_key
self.service_namespace = service_namespace
self.issuer = issuer
self.protocol = protocol
self.proxy_host = None
self.proxy_port = None
self.proxy_user = None
self.proxy_password = None
self.use_httplib = self.should_use_httplib()
def should_use_httplib(self):
if sys.platform.lower().startswith('win') and self.cert_file:
# On Windows, auto-detect between Windows Store Certificate
# (winhttp) and OpenSSL .pem certificate file (httplib).
#
# We used to only support certificates installed in the Windows
# Certificate Store.
# cert_file example: CURRENT_USER\my\CertificateName
#
# We now support using an OpenSSL .pem certificate file,
# for a consistent experience across all platforms.
# cert_file example: account\certificate.pem
#
# When using OpenSSL .pem certificate file on Windows, make sure
# you are on CPython 2.7.4 or later.
# If it's not an existing file on disk, then treat it as a path in
# the Windows Certificate Store, which means we can't use httplib.
if not os.path.isfile(self.cert_file):
return False
return True
def set_proxy(self, host, port, user, password):
'''
Sets the proxy server host and port for the HTTP CONNECT Tunnelling.
host: Address of the proxy. Ex: '192.168.0.100'
port: Port of the proxy. Ex: 6000
user: User for proxy authorization.
password: Password for proxy authorization.
'''
self.proxy_host = host
self.proxy_port = port
self.proxy_user = user
self.proxy_password = password
def get_connection(self, request):
''' Create connection for the request. '''
protocol = request.protocol_override \
if request.protocol_override else self.protocol
target_host = request.host
target_port = HTTP_PORT if protocol == 'http' else HTTPS_PORT
if not self.use_httplib:
import azure.http.winhttp
connection = azure.http.winhttp._HTTPConnection(
target_host, cert_file=self.cert_file, protocol=protocol)
proxy_host = self.proxy_host
proxy_port = self.proxy_port
else:
if ':' in target_host:
target_host, _, target_port = target_host.rpartition(':')
if self.proxy_host:
proxy_host = target_host
proxy_port = target_port
host = self.proxy_host
port = self.proxy_port
else:
host = target_host
port = target_port
if protocol == 'http':
connection = HTTPConnection(host, int(port))
else:
connection = HTTPSConnection(
host, int(port), cert_file=self.cert_file)
if self.proxy_host:
headers = None
if self.proxy_user and self.proxy_password:
auth = base64.encodestring(
"{0}:{1}".format(self.proxy_user, self.proxy_password))
headers = {'Proxy-Authorization': 'Basic {0}'.format(auth)}
connection.set_tunnel(proxy_host, int(proxy_port), headers)
return connection
def send_request_headers(self, connection, request_headers):
if self.use_httplib:
if self.proxy_host:
for i in connection._buffer:
if i.startswith("Host: "):
connection._buffer.remove(i)
connection.putheader(
'Host', "{0}:{1}".format(connection._tunnel_host,
connection._tunnel_port))
for name, value in request_headers:
if value:
connection.putheader(name, value)
connection.putheader('User-Agent', _USER_AGENT_STRING)
connection.endheaders()
def send_request_body(self, connection, request_body):
if request_body:
assert isinstance(request_body, bytes)
connection.send(request_body)
elif (not isinstance(connection, HTTPSConnection) and
not isinstance(connection, HTTPConnection)):
connection.send(None)
def perform_request(self, request):
''' Sends request to cloud service server and return the response. '''
connection = self.get_connection(request)
try:
connection.putrequest(request.method, request.path)
if not self.use_httplib:
if self.proxy_host and self.proxy_user:
connection.set_proxy_credentials(
self.proxy_user, self.proxy_password)
self.send_request_headers(connection, request.headers)
self.send_request_body(connection, request.body)
resp = connection.getresponse()
self.status = int(resp.status)
self.message = resp.reason
self.respheader = headers = resp.getheaders()
# for consistency across platforms, make header names lowercase
for i, value in enumerate(headers):
headers[i] = (value[0].lower(), value[1])
respbody = None
if resp.length is None:
respbody = resp.read()
elif resp.length > 0:
respbody = resp.read(resp.length)
response = HTTPResponse(
int(resp.status), resp.reason, headers, respbody)
if self.status == 307:
new_url = urlparse(dict(headers)['location'])
request.host = new_url.hostname
request.path = new_url.path
request.path, request.query = _update_request_uri_query(request)
return self.perform_request(request)
if self.status >= 300:
raise HTTPError(self.status, self.message,
self.respheader, respbody)
return response
finally:
connection.close()
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
import base64
import os
import sys
if sys.version_info < (3,):
from httplib import (
HTTPSConnection,
HTTPConnection,
HTTP_PORT,
HTTPS_PORT,
)
from urlparse import urlparse
else:
from http.client import (
HTTPSConnection,
HTTPConnection,
HTTP_PORT,
HTTPS_PORT,
)
from urllib.parse import urlparse
from azure.http import HTTPError, HTTPResponse
from azure import _USER_AGENT_STRING, _update_request_uri_query
DEBUG_REQUESTS = False
DEBUG_RESPONSES = False
class _HTTPClient(object):
'''
Takes the request and sends it to cloud service and returns the response.
'''
def __init__(self, service_instance, cert_file=None, account_name=None,
account_key=None, protocol='https', request_session=None):
'''
service_instance: service client instance.
cert_file:
certificate file name/location. This is only used in hosted
service management.
account_name: the storage account.
account_key:
the storage account access key.
request_session:
session object created with requests library (or compatible).
'''
self.service_instance = service_instance
self.status = None
self.respheader = None
self.message = None
self.cert_file = cert_file
self.account_name = account_name
self.account_key = account_key
self.protocol = protocol
self.proxy_host = None
self.proxy_port = None
self.proxy_user = None
self.proxy_password = None
self.request_session = request_session
if request_session:
self.use_httplib = True
else:
self.use_httplib = self.should_use_httplib()
def should_use_httplib(self):
if sys.platform.lower().startswith('win') and self.cert_file:
# On Windows, auto-detect between Windows Store Certificate
# (winhttp) and OpenSSL .pem certificate file (httplib).
#
# We used to only support certificates installed in the Windows
# Certificate Store.
# cert_file example: CURRENT_USER\my\CertificateName
#
# We now support using an OpenSSL .pem certificate file,
# for a consistent experience across all platforms.
# cert_file example: account\certificate.pem
#
# When using OpenSSL .pem certificate file on Windows, make sure
# you are on CPython 2.7.4 or later.
# If it's not an existing file on disk, then treat it as a path in
# the Windows Certificate Store, which means we can't use httplib.
if not os.path.isfile(self.cert_file):
return False
return True
def set_proxy(self, host, port, user, password):
'''
Sets the proxy server host and port for the HTTP CONNECT Tunnelling.
host: Address of the proxy. Ex: '192.168.0.100'
port: Port of the proxy. Ex: 6000
user: User for proxy authorization.
password: Password for proxy authorization.
'''
self.proxy_host = host
self.proxy_port = port
self.proxy_user = user
self.proxy_password = password
def get_uri(self, request):
''' Return the target uri for the request.'''
protocol = request.protocol_override \
if request.protocol_override else self.protocol
port = HTTP_PORT if protocol == 'http' else HTTPS_PORT
return protocol + '://' + request.host + ':' + str(port) + request.path
def get_connection(self, request):
''' Create connection for the request. '''
protocol = request.protocol_override \
if request.protocol_override else self.protocol
target_host = request.host
target_port = HTTP_PORT if protocol == 'http' else HTTPS_PORT
if self.request_session:
import azure.http.requestsclient
connection = azure.http.requestsclient._RequestsConnection(
target_host, protocol, self.request_session)
#TODO: proxy stuff
elif not self.use_httplib:
import azure.http.winhttp
connection = azure.http.winhttp._HTTPConnection(
target_host, cert_file=self.cert_file, protocol=protocol)
proxy_host = self.proxy_host
proxy_port = self.proxy_port
else:
if ':' in target_host:
target_host, _, target_port = target_host.rpartition(':')
if self.proxy_host:
proxy_host = target_host
proxy_port = target_port
host = self.proxy_host
port = self.proxy_port
else:
host = target_host
port = target_port
if protocol == 'http':
connection = HTTPConnection(host, int(port))
else:
connection = HTTPSConnection(
host, int(port), cert_file=self.cert_file)
if self.proxy_host:
headers = None
if self.proxy_user and self.proxy_password:
auth = base64.encodestring(
"{0}:{1}".format(self.proxy_user, self.proxy_password))
headers = {'Proxy-Authorization': 'Basic {0}'.format(auth)}
connection.set_tunnel(proxy_host, int(proxy_port), headers)
return connection
def send_request_headers(self, connection, request_headers):
if self.use_httplib:
if self.proxy_host:
for i in connection._buffer:
if i.startswith("Host: "):
connection._buffer.remove(i)
connection.putheader(
'Host', "{0}:{1}".format(connection._tunnel_host,
connection._tunnel_port))
for name, value in request_headers:
if value:
connection.putheader(name, value)
connection.putheader('User-Agent', _USER_AGENT_STRING)
connection.endheaders()
def send_request_body(self, connection, request_body):
if request_body:
assert isinstance(request_body, bytes)
connection.send(request_body)
elif (not isinstance(connection, HTTPSConnection) and
not isinstance(connection, HTTPConnection)):
connection.send(None)
def perform_request(self, request):
''' Sends request to cloud service server and return the response. '''
connection = self.get_connection(request)
try:
connection.putrequest(request.method, request.path)
if not self.use_httplib:
if self.proxy_host and self.proxy_user:
connection.set_proxy_credentials(
self.proxy_user, self.proxy_password)
self.send_request_headers(connection, request.headers)
self.send_request_body(connection, request.body)
if DEBUG_REQUESTS and request.body:
print('request:')
try:
print(request.body)
except:
pass
resp = connection.getresponse()
self.status = int(resp.status)
self.message = resp.reason
self.respheader = headers = resp.getheaders()
# for consistency across platforms, make header names lowercase
for i, value in enumerate(headers):
headers[i] = (value[0].lower(), value[1])
respbody = None
if resp.length is None:
respbody = resp.read()
elif resp.length > 0:
respbody = resp.read(resp.length)
if DEBUG_RESPONSES and respbody:
print('response:')
try:
print(respbody)
except:
pass
response = HTTPResponse(
int(resp.status), resp.reason, headers, respbody)
if self.status == 307:
new_url = urlparse(dict(headers)['location'])
request.host = new_url.hostname
request.path = new_url.path
request.path, request.query = _update_request_uri_query(request)
return self.perform_request(request)
if self.status >= 300:
raise HTTPError(self.status, self.message,
self.respheader, respbody)
return response
finally:
connection.close()

View File

@ -0,0 +1,74 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
class _Response(object):
''' Response class corresponding to the response returned from httplib
HTTPConnection. '''
def __init__(self, response):
self.status = response.status_code
self.reason = response.reason
self.respbody = response.content
self.length = len(response.content)
self.headers = []
for key, name in response.headers.items():
self.headers.append((key.lower(), name))
def getheaders(self):
'''Returns response headers.'''
return self.headers
def read(self, _length):
'''Returns response body. '''
return self.respbody[:_length]
class _RequestsConnection(object):
def __init__(self, host, protocol, session):
self.host = host
self.protocol = protocol
self.session = session
self.headers = {}
self.method = None
self.body = None
self.response = None
self.uri = None
def close(self):
pass
def set_tunnel(self, host, port=None, headers=None):
pass
def set_proxy_credentials(self, user, password):
pass
def putrequest(self, method, uri):
self.method = method
self.uri = self.protocol + '://' + self.host + uri
def putheader(self, name, value):
self.headers[name] = value
def endheaders(self):
pass
def send(self, request_body):
self.response = self.session.request(self.method, self.uri, data=request_body, headers=self.headers)
def getresponse(self):
return _Response(self.response)

View File

@ -1,471 +1,471 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from ctypes import (
c_void_p,
c_long,
c_ulong,
c_longlong,
c_ulonglong,
c_short,
c_ushort,
c_wchar_p,
c_byte,
byref,
Structure,
Union,
POINTER,
WINFUNCTYPE,
HRESULT,
oledll,
WinDLL,
)
import ctypes
import sys
if sys.version_info >= (3,):
def unicode(text):
return text
#------------------------------------------------------------------------------
# Constants that are used in COM operations
VT_EMPTY = 0
VT_NULL = 1
VT_I2 = 2
VT_I4 = 3
VT_BSTR = 8
VT_BOOL = 11
VT_I1 = 16
VT_UI1 = 17
VT_UI2 = 18
VT_UI4 = 19
VT_I8 = 20
VT_UI8 = 21
VT_ARRAY = 8192
HTTPREQUEST_PROXYSETTING_PROXY = 2
HTTPREQUEST_SETCREDENTIALS_FOR_PROXY = 1
HTTPREQUEST_PROXY_SETTING = c_long
HTTPREQUEST_SETCREDENTIALS_FLAGS = c_long
#------------------------------------------------------------------------------
# Com related APIs that are used.
_ole32 = oledll.ole32
_oleaut32 = WinDLL('oleaut32')
_CLSIDFromString = _ole32.CLSIDFromString
_CoInitialize = _ole32.CoInitialize
_CoInitialize.argtypes = [c_void_p]
_CoCreateInstance = _ole32.CoCreateInstance
_SysAllocString = _oleaut32.SysAllocString
_SysAllocString.restype = c_void_p
_SysAllocString.argtypes = [c_wchar_p]
_SysFreeString = _oleaut32.SysFreeString
_SysFreeString.argtypes = [c_void_p]
# SAFEARRAY*
# SafeArrayCreateVector(_In_ VARTYPE vt,_In_ LONG lLbound,_In_ ULONG
# cElements);
_SafeArrayCreateVector = _oleaut32.SafeArrayCreateVector
_SafeArrayCreateVector.restype = c_void_p
_SafeArrayCreateVector.argtypes = [c_ushort, c_long, c_ulong]
# HRESULT
# SafeArrayAccessData(_In_ SAFEARRAY *psa, _Out_ void **ppvData);
_SafeArrayAccessData = _oleaut32.SafeArrayAccessData
_SafeArrayAccessData.argtypes = [c_void_p, POINTER(c_void_p)]
# HRESULT
# SafeArrayUnaccessData(_In_ SAFEARRAY *psa);
_SafeArrayUnaccessData = _oleaut32.SafeArrayUnaccessData
_SafeArrayUnaccessData.argtypes = [c_void_p]
# HRESULT
# SafeArrayGetUBound(_In_ SAFEARRAY *psa, _In_ UINT nDim, _Out_ LONG
# *plUbound);
_SafeArrayGetUBound = _oleaut32.SafeArrayGetUBound
_SafeArrayGetUBound.argtypes = [c_void_p, c_ulong, POINTER(c_long)]
#------------------------------------------------------------------------------
class BSTR(c_wchar_p):
''' BSTR class in python. '''
def __init__(self, value):
super(BSTR, self).__init__(_SysAllocString(value))
def __del__(self):
_SysFreeString(self)
class VARIANT(Structure):
'''
VARIANT structure in python. Does not match the definition in
MSDN exactly & it is only mapping the used fields. Field names are also
slighty different.
'''
class _tagData(Union):
class _tagRecord(Structure):
_fields_ = [('pvoid', c_void_p), ('precord', c_void_p)]
_fields_ = [('llval', c_longlong),
('ullval', c_ulonglong),
('lval', c_long),
('ulval', c_ulong),
('ival', c_short),
('boolval', c_ushort),
('bstrval', BSTR),
('parray', c_void_p),
('record', _tagRecord)]
_fields_ = [('vt', c_ushort),
('wReserved1', c_ushort),
('wReserved2', c_ushort),
('wReserved3', c_ushort),
('vdata', _tagData)]
@staticmethod
def create_empty():
variant = VARIANT()
variant.vt = VT_EMPTY
variant.vdata.llval = 0
return variant
@staticmethod
def create_safearray_from_str(text):
variant = VARIANT()
variant.vt = VT_ARRAY | VT_UI1
length = len(text)
variant.vdata.parray = _SafeArrayCreateVector(VT_UI1, 0, length)
pvdata = c_void_p()
_SafeArrayAccessData(variant.vdata.parray, byref(pvdata))
ctypes.memmove(pvdata, text, length)
_SafeArrayUnaccessData(variant.vdata.parray)
return variant
@staticmethod
def create_bstr_from_str(text):
variant = VARIANT()
variant.vt = VT_BSTR
variant.vdata.bstrval = BSTR(text)
return variant
@staticmethod
def create_bool_false():
variant = VARIANT()
variant.vt = VT_BOOL
variant.vdata.boolval = 0
return variant
def is_safearray_of_bytes(self):
return self.vt == VT_ARRAY | VT_UI1
def str_from_safearray(self):
assert self.vt == VT_ARRAY | VT_UI1
pvdata = c_void_p()
count = c_long()
_SafeArrayGetUBound(self.vdata.parray, 1, byref(count))
count = c_long(count.value + 1)
_SafeArrayAccessData(self.vdata.parray, byref(pvdata))
text = ctypes.string_at(pvdata, count)
_SafeArrayUnaccessData(self.vdata.parray)
return text
def __del__(self):
_VariantClear(self)
# HRESULT VariantClear(_Inout_ VARIANTARG *pvarg);
_VariantClear = _oleaut32.VariantClear
_VariantClear.argtypes = [POINTER(VARIANT)]
class GUID(Structure):
''' GUID structure in python. '''
_fields_ = [("data1", c_ulong),
("data2", c_ushort),
("data3", c_ushort),
("data4", c_byte * 8)]
def __init__(self, name=None):
if name is not None:
_CLSIDFromString(unicode(name), byref(self))
class _WinHttpRequest(c_void_p):
'''
Maps the Com API to Python class functions. Not all methods in
IWinHttpWebRequest are mapped - only the methods we use.
'''
_AddRef = WINFUNCTYPE(c_long) \
(1, 'AddRef')
_Release = WINFUNCTYPE(c_long) \
(2, 'Release')
_SetProxy = WINFUNCTYPE(HRESULT,
HTTPREQUEST_PROXY_SETTING,
VARIANT,
VARIANT) \
(7, 'SetProxy')
_SetCredentials = WINFUNCTYPE(HRESULT,
BSTR,
BSTR,
HTTPREQUEST_SETCREDENTIALS_FLAGS) \
(8, 'SetCredentials')
_Open = WINFUNCTYPE(HRESULT, BSTR, BSTR, VARIANT) \
(9, 'Open')
_SetRequestHeader = WINFUNCTYPE(HRESULT, BSTR, BSTR) \
(10, 'SetRequestHeader')
_GetResponseHeader = WINFUNCTYPE(HRESULT, BSTR, POINTER(c_void_p)) \
(11, 'GetResponseHeader')
_GetAllResponseHeaders = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \
(12, 'GetAllResponseHeaders')
_Send = WINFUNCTYPE(HRESULT, VARIANT) \
(13, 'Send')
_Status = WINFUNCTYPE(HRESULT, POINTER(c_long)) \
(14, 'Status')
_StatusText = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \
(15, 'StatusText')
_ResponseText = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \
(16, 'ResponseText')
_ResponseBody = WINFUNCTYPE(HRESULT, POINTER(VARIANT)) \
(17, 'ResponseBody')
_ResponseStream = WINFUNCTYPE(HRESULT, POINTER(VARIANT)) \
(18, 'ResponseStream')
_WaitForResponse = WINFUNCTYPE(HRESULT, VARIANT, POINTER(c_ushort)) \
(21, 'WaitForResponse')
_Abort = WINFUNCTYPE(HRESULT) \
(22, 'Abort')
_SetTimeouts = WINFUNCTYPE(HRESULT, c_long, c_long, c_long, c_long) \
(23, 'SetTimeouts')
_SetClientCertificate = WINFUNCTYPE(HRESULT, BSTR) \
(24, 'SetClientCertificate')
def open(self, method, url):
'''
Opens the request.
method: the request VERB 'GET', 'POST', etc.
url: the url to connect
'''
_WinHttpRequest._SetTimeouts(self, 0, 65000, 65000, 65000)
flag = VARIANT.create_bool_false()
_method = BSTR(method)
_url = BSTR(url)
_WinHttpRequest._Open(self, _method, _url, flag)
def set_request_header(self, name, value):
''' Sets the request header. '''
_name = BSTR(name)
_value = BSTR(value)
_WinHttpRequest._SetRequestHeader(self, _name, _value)
def get_all_response_headers(self):
''' Gets back all response headers. '''
bstr_headers = c_void_p()
_WinHttpRequest._GetAllResponseHeaders(self, byref(bstr_headers))
bstr_headers = ctypes.cast(bstr_headers, c_wchar_p)
headers = bstr_headers.value
_SysFreeString(bstr_headers)
return headers
def send(self, request=None):
''' Sends the request body. '''
# Sends VT_EMPTY if it is GET, HEAD request.
if request is None:
var_empty = VARIANT.create_empty()
_WinHttpRequest._Send(self, var_empty)
else: # Sends request body as SAFEArray.
_request = VARIANT.create_safearray_from_str(request)
_WinHttpRequest._Send(self, _request)
def status(self):
''' Gets status of response. '''
status = c_long()
_WinHttpRequest._Status(self, byref(status))
return int(status.value)
def status_text(self):
''' Gets status text of response. '''
bstr_status_text = c_void_p()
_WinHttpRequest._StatusText(self, byref(bstr_status_text))
bstr_status_text = ctypes.cast(bstr_status_text, c_wchar_p)
status_text = bstr_status_text.value
_SysFreeString(bstr_status_text)
return status_text
def response_body(self):
'''
Gets response body as a SAFEARRAY and converts the SAFEARRAY to str.
If it is an xml file, it always contains 3 characters before <?xml,
so we remove them.
'''
var_respbody = VARIANT()
_WinHttpRequest._ResponseBody(self, byref(var_respbody))
if var_respbody.is_safearray_of_bytes():
respbody = var_respbody.str_from_safearray()
if respbody[3:].startswith(b'<?xml') and\
respbody.startswith(b'\xef\xbb\xbf'):
respbody = respbody[3:]
return respbody
else:
return ''
def set_client_certificate(self, certificate):
'''Sets client certificate for the request. '''
_certificate = BSTR(certificate)
_WinHttpRequest._SetClientCertificate(self, _certificate)
def set_tunnel(self, host, port):
''' Sets up the host and the port for the HTTP CONNECT Tunnelling.'''
url = host
if port:
url = url + u':' + port
var_host = VARIANT.create_bstr_from_str(url)
var_empty = VARIANT.create_empty()
_WinHttpRequest._SetProxy(
self, HTTPREQUEST_PROXYSETTING_PROXY, var_host, var_empty)
def set_proxy_credentials(self, user, password):
_WinHttpRequest._SetCredentials(
self, BSTR(user), BSTR(password),
HTTPREQUEST_SETCREDENTIALS_FOR_PROXY)
def __del__(self):
if self.value is not None:
_WinHttpRequest._Release(self)
class _Response(object):
''' Response class corresponding to the response returned from httplib
HTTPConnection. '''
def __init__(self, _status, _status_text, _length, _headers, _respbody):
self.status = _status
self.reason = _status_text
self.length = _length
self.headers = _headers
self.respbody = _respbody
def getheaders(self):
'''Returns response headers.'''
return self.headers
def read(self, _length):
'''Returns resonse body. '''
return self.respbody[:_length]
class _HTTPConnection(object):
''' Class corresponding to httplib HTTPConnection class. '''
def __init__(self, host, cert_file=None, key_file=None, protocol='http'):
''' initialize the IWinHttpWebRequest Com Object.'''
self.host = unicode(host)
self.cert_file = cert_file
self._httprequest = _WinHttpRequest()
self.protocol = protocol
clsid = GUID('{2087C2F4-2CEF-4953-A8AB-66779B670495}')
iid = GUID('{016FE2EC-B2C8-45F8-B23B-39E53A75396B}')
_CoInitialize(None)
_CoCreateInstance(byref(clsid), 0, 1, byref(iid),
byref(self._httprequest))
def close(self):
pass
def set_tunnel(self, host, port=None, headers=None):
''' Sets up the host and the port for the HTTP CONNECT Tunnelling. '''
self._httprequest.set_tunnel(unicode(host), unicode(str(port)))
def set_proxy_credentials(self, user, password):
self._httprequest.set_proxy_credentials(
unicode(user), unicode(password))
def putrequest(self, method, uri):
''' Connects to host and sends the request. '''
protocol = unicode(self.protocol + '://')
url = protocol + self.host + unicode(uri)
self._httprequest.open(unicode(method), url)
# sets certificate for the connection if cert_file is set.
if self.cert_file is not None:
self._httprequest.set_client_certificate(unicode(self.cert_file))
def putheader(self, name, value):
''' Sends the headers of request. '''
if sys.version_info < (3,):
name = str(name).decode('utf-8')
value = str(value).decode('utf-8')
self._httprequest.set_request_header(name, value)
def endheaders(self):
''' No operation. Exists only to provide the same interface of httplib
HTTPConnection.'''
pass
def send(self, request_body):
''' Sends request body. '''
if not request_body:
self._httprequest.send()
else:
self._httprequest.send(request_body)
def getresponse(self):
''' Gets the response and generates the _Response object'''
status = self._httprequest.status()
status_text = self._httprequest.status_text()
resp_headers = self._httprequest.get_all_response_headers()
fixed_headers = []
for resp_header in resp_headers.split('\n'):
if (resp_header.startswith('\t') or\
resp_header.startswith(' ')) and fixed_headers:
# append to previous header
fixed_headers[-1] += resp_header
else:
fixed_headers.append(resp_header)
headers = []
for resp_header in fixed_headers:
if ':' in resp_header:
pos = resp_header.find(':')
headers.append(
(resp_header[:pos].lower(), resp_header[pos + 1:].strip()))
body = self._httprequest.response_body()
length = len(body)
return _Response(status, status_text, length, headers, body)
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from ctypes import (
c_void_p,
c_long,
c_ulong,
c_longlong,
c_ulonglong,
c_short,
c_ushort,
c_wchar_p,
c_byte,
byref,
Structure,
Union,
POINTER,
WINFUNCTYPE,
HRESULT,
oledll,
WinDLL,
)
import ctypes
import sys
if sys.version_info >= (3,):
def unicode(text):
return text
#------------------------------------------------------------------------------
# Constants that are used in COM operations
VT_EMPTY = 0
VT_NULL = 1
VT_I2 = 2
VT_I4 = 3
VT_BSTR = 8
VT_BOOL = 11
VT_I1 = 16
VT_UI1 = 17
VT_UI2 = 18
VT_UI4 = 19
VT_I8 = 20
VT_UI8 = 21
VT_ARRAY = 8192
HTTPREQUEST_PROXYSETTING_PROXY = 2
HTTPREQUEST_SETCREDENTIALS_FOR_PROXY = 1
HTTPREQUEST_PROXY_SETTING = c_long
HTTPREQUEST_SETCREDENTIALS_FLAGS = c_long
#------------------------------------------------------------------------------
# Com related APIs that are used.
_ole32 = oledll.ole32
_oleaut32 = WinDLL('oleaut32')
_CLSIDFromString = _ole32.CLSIDFromString
_CoInitialize = _ole32.CoInitialize
_CoInitialize.argtypes = [c_void_p]
_CoCreateInstance = _ole32.CoCreateInstance
_SysAllocString = _oleaut32.SysAllocString
_SysAllocString.restype = c_void_p
_SysAllocString.argtypes = [c_wchar_p]
_SysFreeString = _oleaut32.SysFreeString
_SysFreeString.argtypes = [c_void_p]
# SAFEARRAY*
# SafeArrayCreateVector(_In_ VARTYPE vt,_In_ LONG lLbound,_In_ ULONG
# cElements);
_SafeArrayCreateVector = _oleaut32.SafeArrayCreateVector
_SafeArrayCreateVector.restype = c_void_p
_SafeArrayCreateVector.argtypes = [c_ushort, c_long, c_ulong]
# HRESULT
# SafeArrayAccessData(_In_ SAFEARRAY *psa, _Out_ void **ppvData);
_SafeArrayAccessData = _oleaut32.SafeArrayAccessData
_SafeArrayAccessData.argtypes = [c_void_p, POINTER(c_void_p)]
# HRESULT
# SafeArrayUnaccessData(_In_ SAFEARRAY *psa);
_SafeArrayUnaccessData = _oleaut32.SafeArrayUnaccessData
_SafeArrayUnaccessData.argtypes = [c_void_p]
# HRESULT
# SafeArrayGetUBound(_In_ SAFEARRAY *psa, _In_ UINT nDim, _Out_ LONG
# *plUbound);
_SafeArrayGetUBound = _oleaut32.SafeArrayGetUBound
_SafeArrayGetUBound.argtypes = [c_void_p, c_ulong, POINTER(c_long)]
#------------------------------------------------------------------------------
class BSTR(c_wchar_p):
''' BSTR class in python. '''
def __init__(self, value):
super(BSTR, self).__init__(_SysAllocString(value))
def __del__(self):
_SysFreeString(self)
class VARIANT(Structure):
'''
VARIANT structure in python. Does not match the definition in
MSDN exactly & it is only mapping the used fields. Field names are also
slighty different.
'''
class _tagData(Union):
class _tagRecord(Structure):
_fields_ = [('pvoid', c_void_p), ('precord', c_void_p)]
_fields_ = [('llval', c_longlong),
('ullval', c_ulonglong),
('lval', c_long),
('ulval', c_ulong),
('ival', c_short),
('boolval', c_ushort),
('bstrval', BSTR),
('parray', c_void_p),
('record', _tagRecord)]
_fields_ = [('vt', c_ushort),
('wReserved1', c_ushort),
('wReserved2', c_ushort),
('wReserved3', c_ushort),
('vdata', _tagData)]
@staticmethod
def create_empty():
variant = VARIANT()
variant.vt = VT_EMPTY
variant.vdata.llval = 0
return variant
@staticmethod
def create_safearray_from_str(text):
variant = VARIANT()
variant.vt = VT_ARRAY | VT_UI1
length = len(text)
variant.vdata.parray = _SafeArrayCreateVector(VT_UI1, 0, length)
pvdata = c_void_p()
_SafeArrayAccessData(variant.vdata.parray, byref(pvdata))
ctypes.memmove(pvdata, text, length)
_SafeArrayUnaccessData(variant.vdata.parray)
return variant
@staticmethod
def create_bstr_from_str(text):
variant = VARIANT()
variant.vt = VT_BSTR
variant.vdata.bstrval = BSTR(text)
return variant
@staticmethod
def create_bool_false():
variant = VARIANT()
variant.vt = VT_BOOL
variant.vdata.boolval = 0
return variant
def is_safearray_of_bytes(self):
return self.vt == VT_ARRAY | VT_UI1
def str_from_safearray(self):
assert self.vt == VT_ARRAY | VT_UI1
pvdata = c_void_p()
count = c_long()
_SafeArrayGetUBound(self.vdata.parray, 1, byref(count))
count = c_long(count.value + 1)
_SafeArrayAccessData(self.vdata.parray, byref(pvdata))
text = ctypes.string_at(pvdata, count)
_SafeArrayUnaccessData(self.vdata.parray)
return text
def __del__(self):
_VariantClear(self)
# HRESULT VariantClear(_Inout_ VARIANTARG *pvarg);
_VariantClear = _oleaut32.VariantClear
_VariantClear.argtypes = [POINTER(VARIANT)]
class GUID(Structure):
''' GUID structure in python. '''
_fields_ = [("data1", c_ulong),
("data2", c_ushort),
("data3", c_ushort),
("data4", c_byte * 8)]
def __init__(self, name=None):
if name is not None:
_CLSIDFromString(unicode(name), byref(self))
class _WinHttpRequest(c_void_p):
'''
Maps the Com API to Python class functions. Not all methods in
IWinHttpWebRequest are mapped - only the methods we use.
'''
_AddRef = WINFUNCTYPE(c_long) \
(1, 'AddRef')
_Release = WINFUNCTYPE(c_long) \
(2, 'Release')
_SetProxy = WINFUNCTYPE(HRESULT,
HTTPREQUEST_PROXY_SETTING,
VARIANT,
VARIANT) \
(7, 'SetProxy')
_SetCredentials = WINFUNCTYPE(HRESULT,
BSTR,
BSTR,
HTTPREQUEST_SETCREDENTIALS_FLAGS) \
(8, 'SetCredentials')
_Open = WINFUNCTYPE(HRESULT, BSTR, BSTR, VARIANT) \
(9, 'Open')
_SetRequestHeader = WINFUNCTYPE(HRESULT, BSTR, BSTR) \
(10, 'SetRequestHeader')
_GetResponseHeader = WINFUNCTYPE(HRESULT, BSTR, POINTER(c_void_p)) \
(11, 'GetResponseHeader')
_GetAllResponseHeaders = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \
(12, 'GetAllResponseHeaders')
_Send = WINFUNCTYPE(HRESULT, VARIANT) \
(13, 'Send')
_Status = WINFUNCTYPE(HRESULT, POINTER(c_long)) \
(14, 'Status')
_StatusText = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \
(15, 'StatusText')
_ResponseText = WINFUNCTYPE(HRESULT, POINTER(c_void_p)) \
(16, 'ResponseText')
_ResponseBody = WINFUNCTYPE(HRESULT, POINTER(VARIANT)) \
(17, 'ResponseBody')
_ResponseStream = WINFUNCTYPE(HRESULT, POINTER(VARIANT)) \
(18, 'ResponseStream')
_WaitForResponse = WINFUNCTYPE(HRESULT, VARIANT, POINTER(c_ushort)) \
(21, 'WaitForResponse')
_Abort = WINFUNCTYPE(HRESULT) \
(22, 'Abort')
_SetTimeouts = WINFUNCTYPE(HRESULT, c_long, c_long, c_long, c_long) \
(23, 'SetTimeouts')
_SetClientCertificate = WINFUNCTYPE(HRESULT, BSTR) \
(24, 'SetClientCertificate')
def open(self, method, url):
'''
Opens the request.
method: the request VERB 'GET', 'POST', etc.
url: the url to connect
'''
_WinHttpRequest._SetTimeouts(self, 0, 65000, 65000, 65000)
flag = VARIANT.create_bool_false()
_method = BSTR(method)
_url = BSTR(url)
_WinHttpRequest._Open(self, _method, _url, flag)
def set_request_header(self, name, value):
''' Sets the request header. '''
_name = BSTR(name)
_value = BSTR(value)
_WinHttpRequest._SetRequestHeader(self, _name, _value)
def get_all_response_headers(self):
''' Gets back all response headers. '''
bstr_headers = c_void_p()
_WinHttpRequest._GetAllResponseHeaders(self, byref(bstr_headers))
bstr_headers = ctypes.cast(bstr_headers, c_wchar_p)
headers = bstr_headers.value
_SysFreeString(bstr_headers)
return headers
def send(self, request=None):
''' Sends the request body. '''
# Sends VT_EMPTY if it is GET, HEAD request.
if request is None:
var_empty = VARIANT.create_empty()
_WinHttpRequest._Send(self, var_empty)
else: # Sends request body as SAFEArray.
_request = VARIANT.create_safearray_from_str(request)
_WinHttpRequest._Send(self, _request)
def status(self):
''' Gets status of response. '''
status = c_long()
_WinHttpRequest._Status(self, byref(status))
return int(status.value)
def status_text(self):
''' Gets status text of response. '''
bstr_status_text = c_void_p()
_WinHttpRequest._StatusText(self, byref(bstr_status_text))
bstr_status_text = ctypes.cast(bstr_status_text, c_wchar_p)
status_text = bstr_status_text.value
_SysFreeString(bstr_status_text)
return status_text
def response_body(self):
'''
Gets response body as a SAFEARRAY and converts the SAFEARRAY to str.
If it is an xml file, it always contains 3 characters before <?xml,
so we remove them.
'''
var_respbody = VARIANT()
_WinHttpRequest._ResponseBody(self, byref(var_respbody))
if var_respbody.is_safearray_of_bytes():
respbody = var_respbody.str_from_safearray()
if respbody[3:].startswith(b'<?xml') and\
respbody.startswith(b'\xef\xbb\xbf'):
respbody = respbody[3:]
return respbody
else:
return ''
def set_client_certificate(self, certificate):
'''Sets client certificate for the request. '''
_certificate = BSTR(certificate)
_WinHttpRequest._SetClientCertificate(self, _certificate)
def set_tunnel(self, host, port):
''' Sets up the host and the port for the HTTP CONNECT Tunnelling.'''
url = host
if port:
url = url + u':' + port
var_host = VARIANT.create_bstr_from_str(url)
var_empty = VARIANT.create_empty()
_WinHttpRequest._SetProxy(
self, HTTPREQUEST_PROXYSETTING_PROXY, var_host, var_empty)
def set_proxy_credentials(self, user, password):
_WinHttpRequest._SetCredentials(
self, BSTR(user), BSTR(password),
HTTPREQUEST_SETCREDENTIALS_FOR_PROXY)
def __del__(self):
if self.value is not None:
_WinHttpRequest._Release(self)
class _Response(object):
''' Response class corresponding to the response returned from httplib
HTTPConnection. '''
def __init__(self, _status, _status_text, _length, _headers, _respbody):
self.status = _status
self.reason = _status_text
self.length = _length
self.headers = _headers
self.respbody = _respbody
def getheaders(self):
'''Returns response headers.'''
return self.headers
def read(self, _length):
'''Returns resonse body. '''
return self.respbody[:_length]
class _HTTPConnection(object):
''' Class corresponding to httplib HTTPConnection class. '''
def __init__(self, host, cert_file=None, key_file=None, protocol='http'):
''' initialize the IWinHttpWebRequest Com Object.'''
self.host = unicode(host)
self.cert_file = cert_file
self._httprequest = _WinHttpRequest()
self.protocol = protocol
clsid = GUID('{2087C2F4-2CEF-4953-A8AB-66779B670495}')
iid = GUID('{016FE2EC-B2C8-45F8-B23B-39E53A75396B}')
_CoInitialize(None)
_CoCreateInstance(byref(clsid), 0, 1, byref(iid),
byref(self._httprequest))
def close(self):
pass
def set_tunnel(self, host, port=None, headers=None):
''' Sets up the host and the port for the HTTP CONNECT Tunnelling. '''
self._httprequest.set_tunnel(unicode(host), unicode(str(port)))
def set_proxy_credentials(self, user, password):
self._httprequest.set_proxy_credentials(
unicode(user), unicode(password))
def putrequest(self, method, uri):
''' Connects to host and sends the request. '''
protocol = unicode(self.protocol + '://')
url = protocol + self.host + unicode(uri)
self._httprequest.open(unicode(method), url)
# sets certificate for the connection if cert_file is set.
if self.cert_file is not None:
self._httprequest.set_client_certificate(unicode(self.cert_file))
def putheader(self, name, value):
''' Sends the headers of request. '''
if sys.version_info < (3,):
name = str(name).decode('utf-8')
value = str(value).decode('utf-8')
self._httprequest.set_request_header(name, value)
def endheaders(self):
''' No operation. Exists only to provide the same interface of httplib
HTTPConnection.'''
pass
def send(self, request_body):
''' Sends request body. '''
if not request_body:
self._httprequest.send()
else:
self._httprequest.send(request_body)
def getresponse(self):
''' Gets the response and generates the _Response object'''
status = self._httprequest.status()
status_text = self._httprequest.status_text()
resp_headers = self._httprequest.get_all_response_headers()
fixed_headers = []
for resp_header in resp_headers.split('\n'):
if (resp_header.startswith('\t') or\
resp_header.startswith(' ')) and fixed_headers:
# append to previous header
fixed_headers[-1] += resp_header
else:
fixed_headers.append(resp_header)
headers = []
for resp_header in fixed_headers:
if ':' in resp_header:
pos = resp_header.find(':')
headers.append(
(resp_header[:pos].lower(), resp_header[pos + 1:].strip()))
body = self._httprequest.response_body()
length = len(body)
return _Response(status, status_text, length, headers, body)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,70 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import (
MANAGEMENT_HOST,
_str
)
from azure.servicemanagement import (
CloudServices,
)
from azure.servicemanagement.servicemanagementclient import (
_ServiceManagementClient,
)
class SchedulerManagementService(_ServiceManagementClient):
''' Note that this class is a preliminary work on Scheduler
management. Since it lack a lot a features, final version
can be slightly different from the current one.
'''
def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST, request_session=None):
'''
Initializes the scheduler management service.
subscription_id: Subscription to manage.
cert_file:
Path to .pem certificate file (httplib), or location of the
certificate in your Personal certificate store (winhttp) in the
CURRENT_USER\my\CertificateName format.
If a request_session is specified, then this is unused.
host: Live ServiceClient URL. Defaults to Azure public cloud.
request_session:
Session object to use for http requests. If this is specified, it
replaces the default use of httplib or winhttp. Also, the cert_file
parameter is unused when a session is passed in.
The session object handles authentication, and as such can support
multiple types of authentication: .pem certificate, oauth.
For example, you can pass in a Session instance from the requests
library. To use .pem certificate authentication with requests
library, set the path to the .pem file on the session.cert
attribute.
'''
super(SchedulerManagementService, self).__init__(
subscription_id, cert_file, host, request_session)
#--Operations for scheduler ----------------------------------------
def list_cloud_services(self):
'''
List the cloud services for scheduling defined on the account.
'''
return self._perform_get(self._get_list_cloud_services_path(),
CloudServices)
#--Helper functions --------------------------------------------------
def _get_list_cloud_services_path(self):
return self._get_path('cloudservices', None)

View File

@ -1,113 +1,534 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import (
MANAGEMENT_HOST,
_convert_response_to_feeds,
_str,
_validate_not_none,
)
from azure.servicemanagement import (
_ServiceBusManagementXmlSerializer,
)
from azure.servicemanagement.servicemanagementclient import (
_ServiceManagementClient,
)
class ServiceBusManagementService(_ServiceManagementClient):
def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST):
super(ServiceBusManagementService, self).__init__(
subscription_id, cert_file, host)
#--Operations for service bus ----------------------------------------
def get_regions(self):
'''
Get list of available service bus regions.
'''
response = self._perform_get(
self._get_path('services/serviceBus/Regions/', None),
None)
return _convert_response_to_feeds(
response,
_ServiceBusManagementXmlSerializer.xml_to_region)
def list_namespaces(self):
'''
List the service bus namespaces defined on the account.
'''
response = self._perform_get(
self._get_path('services/serviceBus/Namespaces/', None),
None)
return _convert_response_to_feeds(
response,
_ServiceBusManagementXmlSerializer.xml_to_namespace)
def get_namespace(self, name):
'''
Get details about a specific namespace.
name: Name of the service bus namespace.
'''
response = self._perform_get(
self._get_path('services/serviceBus/Namespaces', name),
None)
return _ServiceBusManagementXmlSerializer.xml_to_namespace(
response.body)
def create_namespace(self, name, region):
'''
Create a new service bus namespace.
name: Name of the service bus namespace to create.
region: Region to create the namespace in.
'''
_validate_not_none('name', name)
return self._perform_put(
self._get_path('services/serviceBus/Namespaces', name),
_ServiceBusManagementXmlSerializer.namespace_to_xml(region))
def delete_namespace(self, name):
'''
Delete a service bus namespace.
name: Name of the service bus namespace to delete.
'''
_validate_not_none('name', name)
return self._perform_delete(
self._get_path('services/serviceBus/Namespaces', name),
None)
def check_namespace_availability(self, name):
'''
Checks to see if the specified service bus namespace is available, or
if it has already been taken.
name: Name of the service bus namespace to validate.
'''
_validate_not_none('name', name)
response = self._perform_get(
self._get_path('services/serviceBus/CheckNamespaceAvailability',
None) + '/?namespace=' + _str(name), None)
return _ServiceBusManagementXmlSerializer.xml_to_namespace_availability(
response.body)
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import (
MANAGEMENT_HOST,
_convert_response_to_feeds,
_str,
_validate_not_none,
_convert_xml_to_windows_azure_object,
)
from azure.servicemanagement import (
_ServiceBusManagementXmlSerializer,
QueueDescription,
TopicDescription,
NotificationHubDescription,
RelayDescription,
MetricProperties,
MetricValues,
MetricRollups,
)
from azure.servicemanagement.servicemanagementclient import (
_ServiceManagementClient,
)
from functools import partial
X_MS_VERSION = '2012-03-01'
class ServiceBusManagementService(_ServiceManagementClient):
def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST, request_session=None):
'''
Initializes the service bus management service.
subscription_id: Subscription to manage.
cert_file:
Path to .pem certificate file (httplib), or location of the
certificate in your Personal certificate store (winhttp) in the
CURRENT_USER\my\CertificateName format.
If a request_session is specified, then this is unused.
host: Live ServiceClient URL. Defaults to Azure public cloud.
request_session:
Session object to use for http requests. If this is specified, it
replaces the default use of httplib or winhttp. Also, the cert_file
parameter is unused when a session is passed in.
The session object handles authentication, and as such can support
multiple types of authentication: .pem certificate, oauth.
For example, you can pass in a Session instance from the requests
library. To use .pem certificate authentication with requests
library, set the path to the .pem file on the session.cert
attribute.
'''
super(ServiceBusManagementService, self).__init__(
subscription_id, cert_file, host, request_session)
self.x_ms_version = X_MS_VERSION
# Operations for service bus ----------------------------------------
def get_regions(self):
'''
Get list of available service bus regions.
'''
response = self._perform_get(
self._get_path('services/serviceBus/Regions/', None),
None)
return _convert_response_to_feeds(
response,
_ServiceBusManagementXmlSerializer.xml_to_region)
def list_namespaces(self):
'''
List the service bus namespaces defined on the account.
'''
response = self._perform_get(
self._get_path('services/serviceBus/Namespaces/', None),
None)
return _convert_response_to_feeds(
response,
_ServiceBusManagementXmlSerializer.xml_to_namespace)
def get_namespace(self, name):
'''
Get details about a specific namespace.
name: Name of the service bus namespace.
'''
response = self._perform_get(
self._get_path('services/serviceBus/Namespaces', name),
None)
return _ServiceBusManagementXmlSerializer.xml_to_namespace(
response.body)
def create_namespace(self, name, region):
'''
Create a new service bus namespace.
name: Name of the service bus namespace to create.
region: Region to create the namespace in.
'''
_validate_not_none('name', name)
return self._perform_put(
self._get_path('services/serviceBus/Namespaces', name),
_ServiceBusManagementXmlSerializer.namespace_to_xml(region))
def delete_namespace(self, name):
'''
Delete a service bus namespace.
name: Name of the service bus namespace to delete.
'''
_validate_not_none('name', name)
return self._perform_delete(
self._get_path('services/serviceBus/Namespaces', name),
None)
def check_namespace_availability(self, name):
'''
Checks to see if the specified service bus namespace is available, or
if it has already been taken.
name: Name of the service bus namespace to validate.
'''
_validate_not_none('name', name)
response = self._perform_get(
self._get_path('services/serviceBus/CheckNamespaceAvailability',
None) + '/?namespace=' + _str(name), None)
return _ServiceBusManagementXmlSerializer.xml_to_namespace_availability(
response.body)
def list_queues(self, name):
'''
Enumerates the queues in the service namespace.
name: Name of the service bus namespace.
'''
_validate_not_none('name', name)
response = self._perform_get(
self._get_list_queues_path(name),
None)
return _convert_response_to_feeds(response,
partial(_convert_xml_to_windows_azure_object,
azure_type=QueueDescription))
def list_topics(self, name):
'''
Retrieves the topics in the service namespace.
name: Name of the service bus namespace.
'''
response = self._perform_get(
self._get_list_topics_path(name),
None)
return _convert_response_to_feeds(response,
partial(_convert_xml_to_windows_azure_object,
azure_type=TopicDescription))
def list_notification_hubs(self, name):
'''
Retrieves the notification hubs in the service namespace.
name: Name of the service bus namespace.
'''
response = self._perform_get(
self._get_list_notification_hubs_path(name),
None)
return _convert_response_to_feeds(response,
partial(_convert_xml_to_windows_azure_object,
azure_type=NotificationHubDescription))
def list_relays(self, name):
'''
Retrieves the relays in the service namespace.
name: Name of the service bus namespace.
'''
response = self._perform_get(
self._get_list_relays_path(name),
None)
return _convert_response_to_feeds(response,
partial(_convert_xml_to_windows_azure_object,
azure_type=RelayDescription))
def get_supported_metrics_queue(self, name, queue_name):
'''
Retrieves the list of supported metrics for this namespace and queue
name: Name of the service bus namespace.
queue_name: Name of the service bus queue in this namespace.
'''
response = self._perform_get(
self._get_get_supported_metrics_queue_path(name, queue_name),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricProperties))
def get_supported_metrics_topic(self, name, topic_name):
'''
Retrieves the list of supported metrics for this namespace and topic
name: Name of the service bus namespace.
topic_name: Name of the service bus queue in this namespace.
'''
response = self._perform_get(
self._get_get_supported_metrics_topic_path(name, topic_name),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricProperties))
def get_supported_metrics_notification_hub(self, name, hub_name):
'''
Retrieves the list of supported metrics for this namespace and topic
name: Name of the service bus namespace.
hub_name: Name of the service bus notification hub in this namespace.
'''
response = self._perform_get(
self._get_get_supported_metrics_hub_path(name, hub_name),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricProperties))
def get_supported_metrics_relay(self, name, relay_name):
'''
Retrieves the list of supported metrics for this namespace and relay
name: Name of the service bus namespace.
relay_name: Name of the service bus relay in this namespace.
'''
response = self._perform_get(
self._get_get_supported_metrics_relay_path(name, relay_name),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricProperties))
def get_metrics_data_queue(self, name, queue_name, metric, rollup, filter_expresssion):
'''
Retrieves the list of supported metrics for this namespace and queue
name: Name of the service bus namespace.
queue_name: Name of the service bus queue in this namespace.
metric: name of a supported metric
rollup: name of a supported rollup
filter_expression: filter, for instance "$filter=Timestamp gt datetime'2014-10-01T00:00:00Z'"
'''
response = self._perform_get(
self._get_get_metrics_data_queue_path(name, queue_name, metric, rollup, filter_expresssion),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricValues))
def get_metrics_data_topic(self, name, topic_name, metric, rollup, filter_expresssion):
'''
Retrieves the list of supported metrics for this namespace and topic
name: Name of the service bus namespace.
topic_name: Name of the service bus queue in this namespace.
metric: name of a supported metric
rollup: name of a supported rollup
filter_expression: filter, for instance "$filter=Timestamp gt datetime'2014-10-01T00:00:00Z'"
'''
response = self._perform_get(
self._get_get_metrics_data_topic_path(name, topic_name, metric, rollup, filter_expresssion),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricValues))
def get_metrics_data_notification_hub(self, name, hub_name, metric, rollup, filter_expresssion):
'''
Retrieves the list of supported metrics for this namespace and topic
name: Name of the service bus namespace.
hub_name: Name of the service bus notification hub in this namespace.
metric: name of a supported metric
rollup: name of a supported rollup
filter_expression: filter, for instance "$filter=Timestamp gt datetime'2014-10-01T00:00:00Z'"
'''
response = self._perform_get(
self._get_get_metrics_data_hub_path(name, hub_name, metric, rollup, filter_expresssion),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricValues))
def get_metrics_data_relay(self, name, relay_name, metric, rollup, filter_expresssion):
'''
Retrieves the list of supported metrics for this namespace and relay
name: Name of the service bus namespace.
relay_name: Name of the service bus relay in this namespace.
metric: name of a supported metric
rollup: name of a supported rollup
filter_expression: filter, for instance "$filter=Timestamp gt datetime'2014-10-01T00:00:00Z'"
'''
response = self._perform_get(
self._get_get_metrics_data_relay_path(name, relay_name, metric, rollup, filter_expresssion),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricValues))
def get_metrics_rollups_queue(self, name, queue_name, metric):
'''
This operation gets rollup data for Service Bus metrics queue.
Rollup data includes the time granularity for the telemetry aggregation as well as
the retention settings for each time granularity.
name: Name of the service bus namespace.
queue_name: Name of the service bus queue in this namespace.
metric: name of a supported metric
'''
response = self._perform_get(
self._get_get_metrics_rollup_queue_path(name, queue_name, metric),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricRollups))
def get_metrics_rollups_topic(self, name, topic_name, metric):
'''
This operation gets rollup data for Service Bus metrics topic.
Rollup data includes the time granularity for the telemetry aggregation as well as
the retention settings for each time granularity.
name: Name of the service bus namespace.
topic_name: Name of the service bus queue in this namespace.
metric: name of a supported metric
'''
response = self._perform_get(
self._get_get_metrics_rollup_topic_path(name, topic_name, metric),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricRollups))
def get_metrics_rollups_notification_hub(self, name, hub_name, metric):
'''
This operation gets rollup data for Service Bus metrics notification hub.
Rollup data includes the time granularity for the telemetry aggregation as well as
the retention settings for each time granularity.
name: Name of the service bus namespace.
hub_name: Name of the service bus notification hub in this namespace.
metric: name of a supported metric
'''
response = self._perform_get(
self._get_get_metrics_rollup_hub_path(name, hub_name, metric),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricRollups))
def get_metrics_rollups_relay(self, name, relay_name, metric):
'''
This operation gets rollup data for Service Bus metrics relay.
Rollup data includes the time granularity for the telemetry aggregation as well as
the retention settings for each time granularity.
name: Name of the service bus namespace.
relay_name: Name of the service bus relay in this namespace.
metric: name of a supported metric
'''
response = self._perform_get(
self._get_get_metrics_rollup_relay_path(name, relay_name, metric),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricRollups))
# Helper functions --------------------------------------------------
def _get_list_queues_path(self, namespace_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/Queues'
def _get_list_topics_path(self, namespace_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/Topics'
def _get_list_notification_hubs_path(self, namespace_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/NotificationHubs'
def _get_list_relays_path(self, namespace_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/Relays'
def _get_get_supported_metrics_queue_path(self, namespace_name, queue_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/Queues/' + _str(queue_name) + '/Metrics'
def _get_get_supported_metrics_topic_path(self, namespace_name, topic_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/Topics/' + _str(topic_name) + '/Metrics'
def _get_get_supported_metrics_hub_path(self, namespace_name, hub_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/NotificationHubs/' + _str(hub_name) + '/Metrics'
def _get_get_supported_metrics_relay_path(self, namespace_name, queue_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/Relays/' + _str(queue_name) + '/Metrics'
def _get_get_metrics_data_queue_path(self, namespace_name, queue_name, metric, rollup, filter_expr):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/Queues/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups/',
_str(rollup),
'/Values?',
filter_expr
])
def _get_get_metrics_data_topic_path(self, namespace_name, queue_name, metric, rollup, filter_expr):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/Topics/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups/',
_str(rollup),
'/Values?',
filter_expr
])
def _get_get_metrics_data_hub_path(self, namespace_name, queue_name, metric, rollup, filter_expr):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/NotificationHubs/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups/',
_str(rollup),
'/Values?',
filter_expr
])
def _get_get_metrics_data_relay_path(self, namespace_name, queue_name, metric, rollup, filter_expr):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/Relays/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups/',
_str(rollup),
'/Values?',
filter_expr
])
def _get_get_metrics_rollup_queue_path(self, namespace_name, queue_name, metric):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/Queues/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups',
])
def _get_get_metrics_rollup_topic_path(self, namespace_name, queue_name, metric):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/Topics/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups',
])
def _get_get_metrics_rollup_hub_path(self, namespace_name, queue_name, metric):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/NotificationHubs/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups',
])
def _get_get_metrics_rollup_relay_path(self, namespace_name, queue_name, metric):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/Relays/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups',
])

View File

@ -1,166 +1,258 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
import os
from azure import (
WindowsAzureError,
MANAGEMENT_HOST,
_get_request_body,
_parse_response,
_str,
_update_request_uri_query,
)
from azure.http import (
HTTPError,
HTTPRequest,
)
from azure.http.httpclient import _HTTPClient
from azure.servicemanagement import (
AZURE_MANAGEMENT_CERTFILE,
AZURE_MANAGEMENT_SUBSCRIPTIONID,
_management_error_handler,
_parse_response_for_async_op,
_update_management_header,
)
class _ServiceManagementClient(object):
def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST):
self.requestid = None
self.subscription_id = subscription_id
self.cert_file = cert_file
self.host = host
if not self.cert_file:
if AZURE_MANAGEMENT_CERTFILE in os.environ:
self.cert_file = os.environ[AZURE_MANAGEMENT_CERTFILE]
if not self.subscription_id:
if AZURE_MANAGEMENT_SUBSCRIPTIONID in os.environ:
self.subscription_id = os.environ[
AZURE_MANAGEMENT_SUBSCRIPTIONID]
if not self.cert_file or not self.subscription_id:
raise WindowsAzureError(
'You need to provide subscription id and certificate file')
self._httpclient = _HTTPClient(
service_instance=self, cert_file=self.cert_file)
self._filter = self._httpclient.perform_request
def with_filter(self, filter):
'''Returns a new service which will process requests with the
specified filter. Filtering operations can include logging, automatic
retrying, etc... The filter is a lambda which receives the HTTPRequest
and another lambda. The filter can perform any pre-processing on the
request, pass it off to the next lambda, and then perform any
post-processing on the response.'''
res = type(self)(self.subscription_id, self.cert_file, self.host)
old_filter = self._filter
def new_filter(request):
return filter(request, old_filter)
res._filter = new_filter
return res
def set_proxy(self, host, port, user=None, password=None):
'''
Sets the proxy server host and port for the HTTP CONNECT Tunnelling.
host: Address of the proxy. Ex: '192.168.0.100'
port: Port of the proxy. Ex: 6000
user: User for proxy authorization.
password: Password for proxy authorization.
'''
self._httpclient.set_proxy(host, port, user, password)
#--Helper functions --------------------------------------------------
def _perform_request(self, request):
try:
resp = self._filter(request)
except HTTPError as ex:
return _management_error_handler(ex)
return resp
def _perform_get(self, path, response_type):
request = HTTPRequest()
request.method = 'GET'
request.host = self.host
request.path = path
request.path, request.query = _update_request_uri_query(request)
request.headers = _update_management_header(request)
response = self._perform_request(request)
if response_type is not None:
return _parse_response(response, response_type)
return response
def _perform_put(self, path, body, async=False):
request = HTTPRequest()
request.method = 'PUT'
request.host = self.host
request.path = path
request.body = _get_request_body(body)
request.path, request.query = _update_request_uri_query(request)
request.headers = _update_management_header(request)
response = self._perform_request(request)
if async:
return _parse_response_for_async_op(response)
return None
def _perform_post(self, path, body, response_type=None, async=False):
request = HTTPRequest()
request.method = 'POST'
request.host = self.host
request.path = path
request.body = _get_request_body(body)
request.path, request.query = _update_request_uri_query(request)
request.headers = _update_management_header(request)
response = self._perform_request(request)
if response_type is not None:
return _parse_response(response, response_type)
if async:
return _parse_response_for_async_op(response)
return None
def _perform_delete(self, path, async=False):
request = HTTPRequest()
request.method = 'DELETE'
request.host = self.host
request.path = path
request.path, request.query = _update_request_uri_query(request)
request.headers = _update_management_header(request)
response = self._perform_request(request)
if async:
return _parse_response_for_async_op(response)
return None
def _get_path(self, resource, name):
path = '/' + self.subscription_id + '/' + resource
if name is not None:
path += '/' + _str(name)
return path
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
import os
from azure import (
WindowsAzureError,
MANAGEMENT_HOST,
_get_request_body,
_parse_response,
_str,
_update_request_uri_query,
)
from azure.http import (
HTTPError,
HTTPRequest,
)
from azure.http.httpclient import _HTTPClient
from azure.servicemanagement import (
AZURE_MANAGEMENT_CERTFILE,
AZURE_MANAGEMENT_SUBSCRIPTIONID,
_management_error_handler,
parse_response_for_async_op,
X_MS_VERSION,
)
class _ServiceManagementClient(object):
def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST, request_session=None):
self.requestid = None
self.subscription_id = subscription_id
self.cert_file = cert_file
self.host = host
self.request_session = request_session
self.x_ms_version = X_MS_VERSION
self.content_type = 'application/atom+xml;type=entry;charset=utf-8'
if not self.cert_file and not request_session:
if AZURE_MANAGEMENT_CERTFILE in os.environ:
self.cert_file = os.environ[AZURE_MANAGEMENT_CERTFILE]
if not self.subscription_id:
if AZURE_MANAGEMENT_SUBSCRIPTIONID in os.environ:
self.subscription_id = os.environ[
AZURE_MANAGEMENT_SUBSCRIPTIONID]
if not self.request_session:
if not self.cert_file or not self.subscription_id:
raise WindowsAzureError(
'You need to provide subscription id and certificate file')
self._httpclient = _HTTPClient(
service_instance=self, cert_file=self.cert_file,
request_session=self.request_session)
self._filter = self._httpclient.perform_request
def with_filter(self, filter):
'''Returns a new service which will process requests with the
specified filter. Filtering operations can include logging, automatic
retrying, etc... The filter is a lambda which receives the HTTPRequest
and another lambda. The filter can perform any pre-processing on the
request, pass it off to the next lambda, and then perform any
post-processing on the response.'''
res = type(self)(self.subscription_id, self.cert_file, self.host,
self.request_session)
old_filter = self._filter
def new_filter(request):
return filter(request, old_filter)
res._filter = new_filter
return res
def set_proxy(self, host, port, user=None, password=None):
'''
Sets the proxy server host and port for the HTTP CONNECT Tunnelling.
host: Address of the proxy. Ex: '192.168.0.100'
port: Port of the proxy. Ex: 6000
user: User for proxy authorization.
password: Password for proxy authorization.
'''
self._httpclient.set_proxy(host, port, user, password)
def perform_get(self, path, x_ms_version=None):
'''
Performs a GET request and returns the response.
path:
Path to the resource.
Ex: '/<subscription-id>/services/hostedservices/<service-name>'
x_ms_version:
If specified, this is used for the x-ms-version header.
Otherwise, self.x_ms_version is used.
'''
request = HTTPRequest()
request.method = 'GET'
request.host = self.host
request.path = path
request.path, request.query = _update_request_uri_query(request)
request.headers = self._update_management_header(request, x_ms_version)
response = self._perform_request(request)
return response
def perform_put(self, path, body, x_ms_version=None):
'''
Performs a PUT request and returns the response.
path:
Path to the resource.
Ex: '/<subscription-id>/services/hostedservices/<service-name>'
body:
Body for the PUT request.
x_ms_version:
If specified, this is used for the x-ms-version header.
Otherwise, self.x_ms_version is used.
'''
request = HTTPRequest()
request.method = 'PUT'
request.host = self.host
request.path = path
request.body = _get_request_body(body)
request.path, request.query = _update_request_uri_query(request)
request.headers = self._update_management_header(request, x_ms_version)
response = self._perform_request(request)
return response
def perform_post(self, path, body, x_ms_version=None):
'''
Performs a POST request and returns the response.
path:
Path to the resource.
Ex: '/<subscription-id>/services/hostedservices/<service-name>'
body:
Body for the POST request.
x_ms_version:
If specified, this is used for the x-ms-version header.
Otherwise, self.x_ms_version is used.
'''
request = HTTPRequest()
request.method = 'POST'
request.host = self.host
request.path = path
request.body = _get_request_body(body)
request.path, request.query = _update_request_uri_query(request)
request.headers = self._update_management_header(request, x_ms_version)
response = self._perform_request(request)
return response
def perform_delete(self, path, x_ms_version=None):
'''
Performs a DELETE request and returns the response.
path:
Path to the resource.
Ex: '/<subscription-id>/services/hostedservices/<service-name>'
x_ms_version:
If specified, this is used for the x-ms-version header.
Otherwise, self.x_ms_version is used.
'''
request = HTTPRequest()
request.method = 'DELETE'
request.host = self.host
request.path = path
request.path, request.query = _update_request_uri_query(request)
request.headers = self._update_management_header(request, x_ms_version)
response = self._perform_request(request)
return response
#--Helper functions --------------------------------------------------
def _perform_request(self, request):
try:
resp = self._filter(request)
except HTTPError as ex:
return _management_error_handler(ex)
return resp
def _update_management_header(self, request, x_ms_version):
''' Add additional headers for management. '''
if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']:
request.headers.append(('Content-Length', str(len(request.body))))
# append additional headers base on the service
request.headers.append(('x-ms-version', x_ms_version or self.x_ms_version))
# if it is not GET or HEAD request, must set content-type.
if not request.method in ['GET', 'HEAD']:
for name, _ in request.headers:
if 'content-type' == name.lower():
break
else:
request.headers.append(
('Content-Type',
self.content_type))
return request.headers
def _perform_get(self, path, response_type, x_ms_version=None):
response = self.perform_get(path, x_ms_version)
if response_type is not None:
return _parse_response(response, response_type)
return response
def _perform_put(self, path, body, async=False, x_ms_version=None):
response = self.perform_put(path, body, x_ms_version)
if async:
return parse_response_for_async_op(response)
return None
def _perform_post(self, path, body, response_type=None, async=False,
x_ms_version=None):
response = self.perform_post(path, body, x_ms_version)
if response_type is not None:
return _parse_response(response, response_type)
if async:
return parse_response_for_async_op(response)
return None
def _perform_delete(self, path, async=False, x_ms_version=None):
response = self.perform_delete(path, x_ms_version)
if async:
return parse_response_for_async_op(response)
return None
def _get_path(self, resource, name):
path = '/' + self.subscription_id + '/' + resource
if name is not None:
path += '/' + _str(name)
return path

View File

@ -0,0 +1,390 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import (
MANAGEMENT_HOST,
_parse_service_resources_response,
_validate_not_none,
)
from azure.servicemanagement import (
EventLog,
ServerQuota,
Servers,
ServiceObjective,
Database,
FirewallRule,
_SqlManagementXmlSerializer,
)
from azure.servicemanagement.servicemanagementclient import (
_ServiceManagementClient,
)
class SqlDatabaseManagementService(_ServiceManagementClient):
''' Note that this class is a preliminary work on SQL Database
management. Since it lack a lot a features, final version
can be slightly different from the current one.
'''
def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST, request_session=None):
'''
Initializes the sql database management service.
subscription_id: Subscription to manage.
cert_file:
Path to .pem certificate file (httplib), or location of the
certificate in your Personal certificate store (winhttp) in the
CURRENT_USER\my\CertificateName format.
If a request_session is specified, then this is unused.
host: Live ServiceClient URL. Defaults to Azure public cloud.
request_session:
Session object to use for http requests. If this is specified, it
replaces the default use of httplib or winhttp. Also, the cert_file
parameter is unused when a session is passed in.
The session object handles authentication, and as such can support
multiple types of authentication: .pem certificate, oauth.
For example, you can pass in a Session instance from the requests
library. To use .pem certificate authentication with requests
library, set the path to the .pem file on the session.cert
attribute.
'''
super(SqlDatabaseManagementService, self).__init__(
subscription_id, cert_file, host, request_session)
self.content_type = 'application/xml'
#--Operations for sql servers ----------------------------------------
def create_server(self, admin_login, admin_password, location):
'''
Create a new Azure SQL Database server.
admin_login: The administrator login name for the new server.
admin_password: The administrator login password for the new server.
location: The region to deploy the new server.
'''
_validate_not_none('admin_login', admin_login)
_validate_not_none('admin_password', admin_password)
_validate_not_none('location', location)
response = self.perform_post(
self._get_servers_path(),
_SqlManagementXmlSerializer.create_server_to_xml(
admin_login,
admin_password,
location
)
)
return _SqlManagementXmlSerializer.xml_to_create_server_response(
response.body)
def set_server_admin_password(self, server_name, admin_password):
'''
Reset the administrator password for a server.
server_name: Name of the server to change the password.
admin_password: The new administrator password for the server.
'''
_validate_not_none('server_name', server_name)
_validate_not_none('admin_password', admin_password)
return self._perform_post(
self._get_servers_path(server_name) + '?op=ResetPassword',
_SqlManagementXmlSerializer.set_server_admin_password_to_xml(
admin_password
)
)
def delete_server(self, server_name):
'''
Deletes an Azure SQL Database server (including all its databases).
server_name: Name of the server you want to delete.
'''
_validate_not_none('server_name', server_name)
return self._perform_delete(
self._get_servers_path(server_name))
def list_servers(self):
'''
List the SQL servers defined on the account.
'''
return self._perform_get(self._get_servers_path(),
Servers)
def list_quotas(self, server_name):
'''
Gets quotas for an Azure SQL Database Server.
server_name: Name of the server.
'''
_validate_not_none('server_name', server_name)
response = self._perform_get(self._get_quotas_path(server_name),
None)
return _parse_service_resources_response(response, ServerQuota)
def get_server_event_logs(self, server_name, start_date,
interval_size_in_minutes, event_types=''):
'''
Gets the event logs for an Azure SQL Database Server.
server_name: Name of the server to retrieve the event logs from.
start_date:
The starting date and time of the events to retrieve in UTC format,
for example '2011-09-28 16:05:00'.
interval_size_in_minutes:
Size of the event logs to retrieve (in minutes).
Valid values are: 5, 60, or 1440.
event_types:
The event type of the log entries you want to retrieve.
Valid values are:
- connection_successful
- connection_failed
- connection_terminated
- deadlock
- throttling
- throttling_long_transaction
To return all event types pass in an empty string.
'''
_validate_not_none('server_name', server_name)
_validate_not_none('start_date', start_date)
_validate_not_none('interval_size_in_minutes', interval_size_in_minutes)
_validate_not_none('event_types', event_types)
path = self._get_server_event_logs_path(server_name) + \
'?startDate={0}&intervalSizeInMinutes={1}&eventTypes={2}'.format(
start_date, interval_size_in_minutes, event_types)
response = self._perform_get(path, None)
return _parse_service_resources_response(response, EventLog)
#--Operations for firewall rules ------------------------------------------
def create_firewall_rule(self, server_name, name, start_ip_address,
end_ip_address):
'''
Creates an Azure SQL Database server firewall rule.
server_name: Name of the server to set the firewall rule on.
name: The name of the new firewall rule.
start_ip_address:
The lowest IP address in the range of the server-level firewall
setting. IP addresses equal to or greater than this can attempt to
connect to the server. The lowest possible IP address is 0.0.0.0.
end_ip_address:
The highest IP address in the range of the server-level firewall
setting. IP addresses equal to or less than this can attempt to
connect to the server. The highest possible IP address is
255.255.255.255.
'''
_validate_not_none('server_name', server_name)
_validate_not_none('name', name)
_validate_not_none('start_ip_address', start_ip_address)
_validate_not_none('end_ip_address', end_ip_address)
return self._perform_post(
self._get_firewall_rules_path(server_name),
_SqlManagementXmlSerializer.create_firewall_rule_to_xml(
name, start_ip_address, end_ip_address
)
)
def update_firewall_rule(self, server_name, name, start_ip_address,
end_ip_address):
'''
Update a firewall rule for an Azure SQL Database server.
server_name: Name of the server to set the firewall rule on.
name: The name of the firewall rule to update.
start_ip_address:
The lowest IP address in the range of the server-level firewall
setting. IP addresses equal to or greater than this can attempt to
connect to the server. The lowest possible IP address is 0.0.0.0.
end_ip_address:
The highest IP address in the range of the server-level firewall
setting. IP addresses equal to or less than this can attempt to
connect to the server. The highest possible IP address is
255.255.255.255.
'''
_validate_not_none('server_name', server_name)
_validate_not_none('name', name)
_validate_not_none('start_ip_address', start_ip_address)
_validate_not_none('end_ip_address', end_ip_address)
return self._perform_put(
self._get_firewall_rules_path(server_name, name),
_SqlManagementXmlSerializer.update_firewall_rule_to_xml(
name, start_ip_address, end_ip_address
)
)
def delete_firewall_rule(self, server_name, name):
'''
Deletes an Azure SQL Database server firewall rule.
server_name:
Name of the server with the firewall rule you want to delete.
name:
Name of the firewall rule you want to delete.
'''
_validate_not_none('server_name', server_name)
_validate_not_none('name', name)
return self._perform_delete(
self._get_firewall_rules_path(server_name, name))
def list_firewall_rules(self, server_name):
'''
Retrieves the set of firewall rules for an Azure SQL Database Server.
server_name: Name of the server.
'''
_validate_not_none('server_name', server_name)
response = self._perform_get(self._get_firewall_rules_path(server_name),
None)
return _parse_service_resources_response(response, FirewallRule)
def list_service_level_objectives(self, server_name):
'''
Gets the service level objectives for an Azure SQL Database server.
server_name: Name of the server.
'''
_validate_not_none('server_name', server_name)
response = self._perform_get(
self._get_service_objectives_path(server_name), None)
return _parse_service_resources_response(response, ServiceObjective)
#--Operations for sql databases ----------------------------------------
def create_database(self, server_name, name, service_objective_id,
edition=None, collation_name=None,
max_size_bytes=None):
'''
Creates a new Azure SQL Database.
server_name: Name of the server to contain the new database.
name:
Required. The name for the new database. See Naming Requirements
in Azure SQL Database General Guidelines and Limitations and
Database Identifiers for more information.
service_objective_id:
Required. The GUID corresponding to the performance level for
Edition. See List Service Level Objectives for current values.
edition:
Optional. The Service Tier (Edition) for the new database. If
omitted, the default is Web. Valid values are Web, Business,
Basic, Standard, and Premium. See Azure SQL Database Service Tiers
(Editions) and Web and Business Edition Sunset FAQ for more
information.
collation_name:
Optional. The database collation. This can be any collation
supported by SQL. If omitted, the default collation is used. See
SQL Server Collation Support in Azure SQL Database General
Guidelines and Limitations for more information.
max_size_bytes:
Optional. Sets the maximum size, in bytes, for the database. This
value must be within the range of allowed values for Edition. If
omitted, the default value for the edition is used. See Azure SQL
Database Service Tiers (Editions) for current maximum databases
sizes. Convert MB or GB values to bytes.
1 MB = 1048576 bytes. 1 GB = 1073741824 bytes.
'''
_validate_not_none('server_name', server_name)
_validate_not_none('name', name)
_validate_not_none('service_objective_id', service_objective_id)
return self._perform_post(
self._get_databases_path(server_name),
_SqlManagementXmlSerializer.create_database_to_xml(
name, service_objective_id, edition, collation_name,
max_size_bytes
)
)
def update_database(self, server_name, name, new_database_name=None,
service_objective_id=None, edition=None,
max_size_bytes=None):
'''
Updates existing database details.
server_name: Name of the server to contain the new database.
name:
Required. The name for the new database. See Naming Requirements
in Azure SQL Database General Guidelines and Limitations and
Database Identifiers for more information.
new_database_name:
Optional. The new name for the new database.
service_objective_id:
Optional. The new service level to apply to the database. For more
information about service levels, see Azure SQL Database Service
Tiers and Performance Levels. Use List Service Level Objectives to
get the correct ID for the desired service objective.
edition:
Optional. The new edition for the new database.
max_size_bytes:
Optional. The new size of the database in bytes. For information on
available sizes for each edition, see Azure SQL Database Service
Tiers (Editions).
'''
_validate_not_none('server_name', server_name)
_validate_not_none('name', name)
return self._perform_put(
self._get_databases_path(server_name, name),
_SqlManagementXmlSerializer.update_database_to_xml(
new_database_name, service_objective_id, edition,
max_size_bytes
)
)
def delete_database(self, server_name, name):
'''
Deletes an Azure SQL Database.
server_name: Name of the server where the database is located.
name: Name of the database to delete.
'''
return self._perform_delete(self._get_databases_path(server_name, name))
def list_databases(self, name):
'''
List the SQL databases defined on the specified server name
'''
response = self._perform_get(self._get_list_databases_path(name),
None)
return _parse_service_resources_response(response, Database)
#--Helper functions --------------------------------------------------
def _get_servers_path(self, server_name=None):
return self._get_path('services/sqlservers/servers', server_name)
def _get_firewall_rules_path(self, server_name, name=None):
path = self._get_servers_path(server_name) + '/firewallrules'
if name:
path = path + '/' + name
return path
def _get_databases_path(self, server_name, name=None):
path = self._get_servers_path(server_name) + '/databases'
if name:
path = path + '/' + name
return path
def _get_server_event_logs_path(self, server_name):
return self._get_servers_path(server_name) + '/events'
def _get_service_objectives_path(self, server_name):
return self._get_servers_path(server_name) + '/serviceobjectives'
def _get_quotas_path(self, server_name, name=None):
path = self._get_servers_path(server_name) + '/serverquotas'
if name:
path = path + '/' + name
return path
def _get_list_databases_path(self, name):
# *contentview=generic is mandatory*
return self._get_path('services/sqlservers/servers/',
name) + '/databases?contentview=generic'

View File

@ -0,0 +1,256 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import (
MANAGEMENT_HOST,
_str,
)
from azure.servicemanagement import (
WebSpaces,
WebSpace,
Sites,
Site,
MetricResponses,
MetricDefinitions,
PublishData,
_XmlSerializer,
)
from azure.servicemanagement.servicemanagementclient import (
_ServiceManagementClient,
)
class WebsiteManagementService(_ServiceManagementClient):
''' Note that this class is a preliminary work on WebSite
management. Since it lack a lot a features, final version
can be slightly different from the current one.
'''
def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST, request_session=None):
'''
Initializes the website management service.
subscription_id: Subscription to manage.
cert_file:
Path to .pem certificate file (httplib), or location of the
certificate in your Personal certificate store (winhttp) in the
CURRENT_USER\my\CertificateName format.
If a request_session is specified, then this is unused.
host: Live ServiceClient URL. Defaults to Azure public cloud.
request_session:
Session object to use for http requests. If this is specified, it
replaces the default use of httplib or winhttp. Also, the cert_file
parameter is unused when a session is passed in.
The session object handles authentication, and as such can support
multiple types of authentication: .pem certificate, oauth.
For example, you can pass in a Session instance from the requests
library. To use .pem certificate authentication with requests
library, set the path to the .pem file on the session.cert
attribute.
'''
super(WebsiteManagementService, self).__init__(
subscription_id, cert_file, host, request_session)
#--Operations for web sites ----------------------------------------
def list_webspaces(self):
'''
List the webspaces defined on the account.
'''
return self._perform_get(self._get_list_webspaces_path(),
WebSpaces)
def get_webspace(self, webspace_name):
'''
Get details of a specific webspace.
webspace_name: The name of the webspace.
'''
return self._perform_get(self._get_webspace_details_path(webspace_name),
WebSpace)
def list_sites(self, webspace_name):
'''
List the web sites defined on this webspace.
webspace_name: The name of the webspace.
'''
return self._perform_get(self._get_sites_path(webspace_name),
Sites)
def get_site(self, webspace_name, website_name):
'''
List the web sites defined on this webspace.
webspace_name: The name of the webspace.
website_name: The name of the website.
'''
return self._perform_get(self._get_sites_details_path(webspace_name,
website_name),
Site)
def create_site(self, webspace_name, website_name, geo_region, host_names,
plan='VirtualDedicatedPlan', compute_mode='Shared',
server_farm=None, site_mode=None):
'''
Create a website.
webspace_name: The name of the webspace.
website_name: The name of the website.
geo_region:
The geographical region of the webspace that will be created.
host_names:
An array of fully qualified domain names for website. Only one
hostname can be specified in the azurewebsites.net domain.
The hostname should match the name of the website. Custom domains
can only be specified for Shared or Standard websites.
plan:
This value must be 'VirtualDedicatedPlan'.
compute_mode:
This value should be 'Shared' for the Free or Paid Shared
offerings, or 'Dedicated' for the Standard offering. The default
value is 'Shared'. If you set it to 'Dedicated', you must specify
a value for the server_farm parameter.
server_farm:
The name of the Server Farm associated with this website. This is
a required value for Standard mode.
site_mode:
Can be None, 'Limited' or 'Basic'. This value is 'Limited' for the
Free offering, and 'Basic' for the Paid Shared offering. Standard
mode does not use the site_mode parameter; it uses the compute_mode
parameter.
'''
xml = _XmlSerializer.create_website_to_xml(webspace_name, website_name, geo_region, plan, host_names, compute_mode, server_farm, site_mode)
return self._perform_post(
self._get_sites_path(webspace_name),
xml,
Site)
def delete_site(self, webspace_name, website_name,
delete_empty_server_farm=False, delete_metrics=False):
'''
Delete a website.
webspace_name: The name of the webspace.
website_name: The name of the website.
delete_empty_server_farm:
If the site being deleted is the last web site in a server farm,
you can delete the server farm by setting this to True.
delete_metrics:
To also delete the metrics for the site that you are deleting, you
can set this to True.
'''
path = self._get_sites_details_path(webspace_name, website_name)
query = ''
if delete_empty_server_farm:
query += '&deleteEmptyServerFarm=true'
if delete_metrics:
query += '&deleteMetrics=true'
if query:
path = path + '?' + query.lstrip('&')
return self._perform_delete(path)
def restart_site(self, webspace_name, website_name):
'''
Restart a web site.
webspace_name: The name of the webspace.
website_name: The name of the website.
'''
return self._perform_post(
self._get_restart_path(webspace_name, website_name),
'')
def get_historical_usage_metrics(self, webspace_name, website_name,
metrics = None, start_time=None, end_time=None, time_grain=None):
'''
Get historical usage metrics.
webspace_name: The name of the webspace.
website_name: The name of the website.
metrics: Optional. List of metrics name. Otherwise, all metrics returned.
start_time: Optional. An ISO8601 date. Otherwise, current hour is used.
end_time: Optional. An ISO8601 date. Otherwise, current time is used.
time_grain: Optional. A rollup name, as P1D. OTherwise, default rollup for the metrics is used.
More information and metrics name at:
http://msdn.microsoft.com/en-us/library/azure/dn166964.aspx
'''
metrics = ('names='+','.join(metrics)) if metrics else ''
start_time = ('StartTime='+start_time) if start_time else ''
end_time = ('EndTime='+end_time) if end_time else ''
time_grain = ('TimeGrain='+time_grain) if time_grain else ''
parameters = ('&'.join(v for v in (metrics, start_time, end_time, time_grain) if v))
parameters = '?'+parameters if parameters else ''
return self._perform_get(self._get_historical_usage_metrics_path(webspace_name, website_name) + parameters,
MetricResponses)
def get_metric_definitions(self, webspace_name, website_name):
'''
Get metric definitions of metrics available of this web site.
webspace_name: The name of the webspace.
website_name: The name of the website.
'''
return self._perform_get(self._get_metric_definitions_path(webspace_name, website_name),
MetricDefinitions)
def get_publish_profile_xml(self, webspace_name, website_name):
'''
Get a site's publish profile as a string
webspace_name: The name of the webspace.
website_name: The name of the website.
'''
return self._perform_get(self._get_publishxml_path(webspace_name, website_name),
None).body.decode("utf-8")
def get_publish_profile(self, webspace_name, website_name):
'''
Get a site's publish profile as an object
webspace_name: The name of the webspace.
website_name: The name of the website.
'''
return self._perform_get(self._get_publishxml_path(webspace_name, website_name),
PublishData)
#--Helper functions --------------------------------------------------
def _get_list_webspaces_path(self):
return self._get_path('services/webspaces', None)
def _get_webspace_details_path(self, webspace_name):
return self._get_path('services/webspaces/', webspace_name)
def _get_sites_path(self, webspace_name):
return self._get_path('services/webspaces/',
webspace_name) + '/sites'
def _get_sites_details_path(self, webspace_name, website_name):
return self._get_path('services/webspaces/',
webspace_name) + '/sites/' + _str(website_name)
def _get_restart_path(self, webspace_name, website_name):
return self._get_path('services/webspaces/',
webspace_name) + '/sites/' + _str(website_name) + '/restart/'
def _get_historical_usage_metrics_path(self, webspace_name, website_name):
return self._get_path('services/webspaces/',
webspace_name) + '/sites/' + _str(website_name) + '/metrics/'
def _get_metric_definitions_path(self, webspace_name, website_name):
return self._get_path('services/webspaces/',
webspace_name) + '/sites/' + _str(website_name) + '/metricdefinitions/'
def _get_publishxml_path(self, webspace_name, website_name):
return self._get_path('services/webspaces/',
webspace_name) + '/sites/' + _str(website_name) + '/publishxml/'

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,39 +1,39 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure.storage.blobservice import BlobService
from azure.storage.tableservice import TableService
from azure.storage.queueservice import QueueService
class CloudStorageAccount(object):
"""
Provides a factory for creating the blob, queue, and table services
with a common account name and account key. Users can either use the
factory or can construct the appropriate service directly.
"""
def __init__(self, account_name=None, account_key=None):
self.account_name = account_name
self.account_key = account_key
def create_blob_service(self):
return BlobService(self.account_name, self.account_key)
def create_table_service(self):
return TableService(self.account_name, self.account_key)
def create_queue_service(self):
return QueueService(self.account_name, self.account_key)
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure.storage.blobservice import BlobService
from azure.storage.tableservice import TableService
from azure.storage.queueservice import QueueService
class CloudStorageAccount(object):
"""
Provides a factory for creating the blob, queue, and table services
with a common account name and account key. Users can either use the
factory or can construct the appropriate service directly.
"""
def __init__(self, account_name=None, account_key=None):
self.account_name = account_name
self.account_key = account_key
def create_blob_service(self):
return BlobService(self.account_name, self.account_key)
def create_table_service(self):
return TableService(self.account_name, self.account_key)
def create_queue_service(self):
return QueueService(self.account_name, self.account_key)

View File

@ -1,458 +1,458 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import (
WindowsAzureConflictError,
WindowsAzureError,
DEV_QUEUE_HOST,
QUEUE_SERVICE_HOST_BASE,
xml_escape,
_convert_class_to_xml,
_dont_fail_not_exist,
_dont_fail_on_exist,
_get_request_body,
_int_or_none,
_parse_enum_results_list,
_parse_response,
_parse_response_for_dict_filter,
_parse_response_for_dict_prefix,
_str,
_str_or_none,
_update_request_uri_query_local_storage,
_validate_not_none,
_ERROR_CONFLICT,
)
from azure.http import (
HTTPRequest,
HTTP_RESPONSE_NO_CONTENT,
)
from azure.storage import (
Queue,
QueueEnumResults,
QueueMessagesList,
StorageServiceProperties,
_update_storage_queue_header,
)
from azure.storage.storageclient import _StorageClient
class QueueService(_StorageClient):
'''
This is the main class managing queue resources.
'''
def __init__(self, account_name=None, account_key=None, protocol='https',
host_base=QUEUE_SERVICE_HOST_BASE, dev_host=DEV_QUEUE_HOST):
'''
account_name: your storage account name, required for all operations.
account_key: your storage account key, required for all operations.
protocol: Optional. Protocol. Defaults to http.
host_base:
Optional. Live host base url. Defaults to Azure url. Override this
for on-premise.
dev_host: Optional. Dev host url. Defaults to localhost.
'''
super(QueueService, self).__init__(
account_name, account_key, protocol, host_base, dev_host)
def get_queue_service_properties(self, timeout=None):
'''
Gets the properties of a storage account's Queue Service, including
Windows Azure Storage Analytics.
timeout: Optional. The timeout parameter is expressed in seconds.
'''
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/?restype=service&comp=properties'
request.query = [('timeout', _int_or_none(timeout))]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
response = self._perform_request(request)
return _parse_response(response, StorageServiceProperties)
def list_queues(self, prefix=None, marker=None, maxresults=None,
include=None):
'''
Lists all of the queues in a given storage account.
prefix:
Filters the results to return only queues with names that begin
with the specified prefix.
marker:
A string value that identifies the portion of the list to be
returned with the next list operation. The operation returns a
NextMarker element within the response body if the list returned
was not complete. This value may then be used as a query parameter
in a subsequent call to request the next portion of the list of
queues. The marker value is opaque to the client.
maxresults:
Specifies the maximum number of queues to return. If maxresults is
not specified, the server will return up to 5,000 items.
include:
Optional. Include this parameter to specify that the container's
metadata be returned as part of the response body.
'''
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/?comp=list'
request.query = [
('prefix', _str_or_none(prefix)),
('marker', _str_or_none(marker)),
('maxresults', _int_or_none(maxresults)),
('include', _str_or_none(include))
]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
response = self._perform_request(request)
return _parse_enum_results_list(
response, QueueEnumResults, "Queues", Queue)
def create_queue(self, queue_name, x_ms_meta_name_values=None,
fail_on_exist=False):
'''
Creates a queue under the given account.
queue_name: name of the queue.
x_ms_meta_name_values:
Optional. A dict containing name-value pairs to associate with the
queue as metadata.
fail_on_exist: Specify whether throw exception when queue exists.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + ''
request.headers = [('x-ms-meta-name-values', x_ms_meta_name_values)]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
if not fail_on_exist:
try:
response = self._perform_request(request)
if response.status == HTTP_RESPONSE_NO_CONTENT:
return False
return True
except WindowsAzureError as ex:
_dont_fail_on_exist(ex)
return False
else:
response = self._perform_request(request)
if response.status == HTTP_RESPONSE_NO_CONTENT:
raise WindowsAzureConflictError(
_ERROR_CONFLICT.format(response.message))
return True
def delete_queue(self, queue_name, fail_not_exist=False):
'''
Permanently deletes the specified queue.
queue_name: Name of the queue.
fail_not_exist:
Specify whether throw exception when queue doesn't exist.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'DELETE'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + ''
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
if not fail_not_exist:
try:
self._perform_request(request)
return True
except WindowsAzureError as ex:
_dont_fail_not_exist(ex)
return False
else:
self._perform_request(request)
return True
def get_queue_metadata(self, queue_name):
'''
Retrieves user-defined metadata and queue properties on the specified
queue. Metadata is associated with the queue as name-values pairs.
queue_name: Name of the queue.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + '?comp=metadata'
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
response = self._perform_request(request)
return _parse_response_for_dict_prefix(
response,
prefixes=['x-ms-meta', 'x-ms-approximate-messages-count'])
def set_queue_metadata(self, queue_name, x_ms_meta_name_values=None):
'''
Sets user-defined metadata on the specified queue. Metadata is
associated with the queue as name-value pairs.
queue_name: Name of the queue.
x_ms_meta_name_values:
Optional. A dict containing name-value pairs to associate with the
queue as metadata.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + '?comp=metadata'
request.headers = [('x-ms-meta-name-values', x_ms_meta_name_values)]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
self._perform_request(request)
def put_message(self, queue_name, message_text, visibilitytimeout=None,
messagettl=None):
'''
Adds a new message to the back of the message queue. A visibility
timeout can also be specified to make the message invisible until the
visibility timeout expires. A message must be in a format that can be
included in an XML request with UTF-8 encoding. The encoded message can
be up to 64KB in size for versions 2011-08-18 and newer, or 8KB in size
for previous versions.
queue_name: Name of the queue.
message_text: Message content.
visibilitytimeout:
Optional. If not specified, the default value is 0. Specifies the
new visibility timeout value, in seconds, relative to server time.
The new value must be larger than or equal to 0, and cannot be
larger than 7 days. The visibility timeout of a message cannot be
set to a value later than the expiry time. visibilitytimeout
should be set to a value smaller than the time-to-live value.
messagettl:
Optional. Specifies the time-to-live interval for the message, in
seconds. The maximum time-to-live allowed is 7 days. If this
parameter is omitted, the default time-to-live is 7 days.
'''
_validate_not_none('queue_name', queue_name)
_validate_not_none('message_text', message_text)
request = HTTPRequest()
request.method = 'POST'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + '/messages'
request.query = [
('visibilitytimeout', _str_or_none(visibilitytimeout)),
('messagettl', _str_or_none(messagettl))
]
request.body = _get_request_body(
'<?xml version="1.0" encoding="utf-8"?> \
<QueueMessage> \
<MessageText>' + xml_escape(_str(message_text)) + '</MessageText> \
</QueueMessage>')
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
self._perform_request(request)
def get_messages(self, queue_name, numofmessages=None,
visibilitytimeout=None):
'''
Retrieves one or more messages from the front of the queue.
queue_name: Name of the queue.
numofmessages:
Optional. A nonzero integer value that specifies the number of
messages to retrieve from the queue, up to a maximum of 32. If
fewer are visible, the visible messages are returned. By default,
a single message is retrieved from the queue with this operation.
visibilitytimeout:
Specifies the new visibility timeout value, in seconds, relative
to server time. The new value must be larger than or equal to 1
second, and cannot be larger than 7 days, or larger than 2 hours
on REST protocol versions prior to version 2011-08-18. The
visibility timeout of a message can be set to a value later than
the expiry time.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + '/messages'
request.query = [
('numofmessages', _str_or_none(numofmessages)),
('visibilitytimeout', _str_or_none(visibilitytimeout))
]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
response = self._perform_request(request)
return _parse_response(response, QueueMessagesList)
def peek_messages(self, queue_name, numofmessages=None):
'''
Retrieves one or more messages from the front of the queue, but does
not alter the visibility of the message.
queue_name: Name of the queue.
numofmessages:
Optional. A nonzero integer value that specifies the number of
messages to peek from the queue, up to a maximum of 32. By default,
a single message is peeked from the queue with this operation.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + '/messages?peekonly=true'
request.query = [('numofmessages', _str_or_none(numofmessages))]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
response = self._perform_request(request)
return _parse_response(response, QueueMessagesList)
def delete_message(self, queue_name, message_id, popreceipt):
'''
Deletes the specified message.
queue_name: Name of the queue.
message_id: Message to delete.
popreceipt:
Required. A valid pop receipt value returned from an earlier call
to the Get Messages or Update Message operation.
'''
_validate_not_none('queue_name', queue_name)
_validate_not_none('message_id', message_id)
_validate_not_none('popreceipt', popreceipt)
request = HTTPRequest()
request.method = 'DELETE'
request.host = self._get_host()
request.path = '/' + \
_str(queue_name) + '/messages/' + _str(message_id) + ''
request.query = [('popreceipt', _str_or_none(popreceipt))]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
self._perform_request(request)
def clear_messages(self, queue_name):
'''
Deletes all messages from the specified queue.
queue_name: Name of the queue.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'DELETE'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + '/messages'
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
self._perform_request(request)
def update_message(self, queue_name, message_id, message_text, popreceipt,
visibilitytimeout):
'''
Updates the visibility timeout of a message. You can also use this
operation to update the contents of a message.
queue_name: Name of the queue.
message_id: Message to update.
message_text: Content of message.
popreceipt:
Required. A valid pop receipt value returned from an earlier call
to the Get Messages or Update Message operation.
visibilitytimeout:
Required. Specifies the new visibility timeout value, in seconds,
relative to server time. The new value must be larger than or equal
to 0, and cannot be larger than 7 days. The visibility timeout of a
message cannot be set to a value later than the expiry time. A
message can be updated until it has been deleted or has expired.
'''
_validate_not_none('queue_name', queue_name)
_validate_not_none('message_id', message_id)
_validate_not_none('message_text', message_text)
_validate_not_none('popreceipt', popreceipt)
_validate_not_none('visibilitytimeout', visibilitytimeout)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + \
_str(queue_name) + '/messages/' + _str(message_id) + ''
request.query = [
('popreceipt', _str_or_none(popreceipt)),
('visibilitytimeout', _str_or_none(visibilitytimeout))
]
request.body = _get_request_body(
'<?xml version="1.0" encoding="utf-8"?> \
<QueueMessage> \
<MessageText>' + xml_escape(_str(message_text)) + '</MessageText> \
</QueueMessage>')
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
response = self._perform_request(request)
return _parse_response_for_dict_filter(
response,
filter=['x-ms-popreceipt', 'x-ms-time-next-visible'])
def set_queue_service_properties(self, storage_service_properties,
timeout=None):
'''
Sets the properties of a storage account's Queue service, including
Windows Azure Storage Analytics.
storage_service_properties: StorageServiceProperties object.
timeout: Optional. The timeout parameter is expressed in seconds.
'''
_validate_not_none('storage_service_properties',
storage_service_properties)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/?restype=service&comp=properties'
request.query = [('timeout', _int_or_none(timeout))]
request.body = _get_request_body(
_convert_class_to_xml(storage_service_properties))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
self._perform_request(request)
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import (
WindowsAzureConflictError,
WindowsAzureError,
DEV_QUEUE_HOST,
QUEUE_SERVICE_HOST_BASE,
xml_escape,
_convert_class_to_xml,
_dont_fail_not_exist,
_dont_fail_on_exist,
_get_request_body,
_int_or_none,
_parse_enum_results_list,
_parse_response,
_parse_response_for_dict_filter,
_parse_response_for_dict_prefix,
_str,
_str_or_none,
_update_request_uri_query_local_storage,
_validate_not_none,
_ERROR_CONFLICT,
)
from azure.http import (
HTTPRequest,
HTTP_RESPONSE_NO_CONTENT,
)
from azure.storage import (
Queue,
QueueEnumResults,
QueueMessagesList,
StorageServiceProperties,
_update_storage_queue_header,
)
from azure.storage.storageclient import _StorageClient
class QueueService(_StorageClient):
'''
This is the main class managing queue resources.
'''
def __init__(self, account_name=None, account_key=None, protocol='https',
host_base=QUEUE_SERVICE_HOST_BASE, dev_host=DEV_QUEUE_HOST):
'''
account_name: your storage account name, required for all operations.
account_key: your storage account key, required for all operations.
protocol: Optional. Protocol. Defaults to http.
host_base:
Optional. Live host base url. Defaults to Azure url. Override this
for on-premise.
dev_host: Optional. Dev host url. Defaults to localhost.
'''
super(QueueService, self).__init__(
account_name, account_key, protocol, host_base, dev_host)
def get_queue_service_properties(self, timeout=None):
'''
Gets the properties of a storage account's Queue Service, including
Windows Azure Storage Analytics.
timeout: Optional. The timeout parameter is expressed in seconds.
'''
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/?restype=service&comp=properties'
request.query = [('timeout', _int_or_none(timeout))]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
response = self._perform_request(request)
return _parse_response(response, StorageServiceProperties)
def list_queues(self, prefix=None, marker=None, maxresults=None,
include=None):
'''
Lists all of the queues in a given storage account.
prefix:
Filters the results to return only queues with names that begin
with the specified prefix.
marker:
A string value that identifies the portion of the list to be
returned with the next list operation. The operation returns a
NextMarker element within the response body if the list returned
was not complete. This value may then be used as a query parameter
in a subsequent call to request the next portion of the list of
queues. The marker value is opaque to the client.
maxresults:
Specifies the maximum number of queues to return. If maxresults is
not specified, the server will return up to 5,000 items.
include:
Optional. Include this parameter to specify that the container's
metadata be returned as part of the response body.
'''
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/?comp=list'
request.query = [
('prefix', _str_or_none(prefix)),
('marker', _str_or_none(marker)),
('maxresults', _int_or_none(maxresults)),
('include', _str_or_none(include))
]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
response = self._perform_request(request)
return _parse_enum_results_list(
response, QueueEnumResults, "Queues", Queue)
def create_queue(self, queue_name, x_ms_meta_name_values=None,
fail_on_exist=False):
'''
Creates a queue under the given account.
queue_name: name of the queue.
x_ms_meta_name_values:
Optional. A dict containing name-value pairs to associate with the
queue as metadata.
fail_on_exist: Specify whether throw exception when queue exists.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + ''
request.headers = [('x-ms-meta-name-values', x_ms_meta_name_values)]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
if not fail_on_exist:
try:
response = self._perform_request(request)
if response.status == HTTP_RESPONSE_NO_CONTENT:
return False
return True
except WindowsAzureError as ex:
_dont_fail_on_exist(ex)
return False
else:
response = self._perform_request(request)
if response.status == HTTP_RESPONSE_NO_CONTENT:
raise WindowsAzureConflictError(
_ERROR_CONFLICT.format(response.message))
return True
def delete_queue(self, queue_name, fail_not_exist=False):
'''
Permanently deletes the specified queue.
queue_name: Name of the queue.
fail_not_exist:
Specify whether throw exception when queue doesn't exist.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'DELETE'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + ''
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
if not fail_not_exist:
try:
self._perform_request(request)
return True
except WindowsAzureError as ex:
_dont_fail_not_exist(ex)
return False
else:
self._perform_request(request)
return True
def get_queue_metadata(self, queue_name):
'''
Retrieves user-defined metadata and queue properties on the specified
queue. Metadata is associated with the queue as name-values pairs.
queue_name: Name of the queue.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + '?comp=metadata'
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
response = self._perform_request(request)
return _parse_response_for_dict_prefix(
response,
prefixes=['x-ms-meta', 'x-ms-approximate-messages-count'])
def set_queue_metadata(self, queue_name, x_ms_meta_name_values=None):
'''
Sets user-defined metadata on the specified queue. Metadata is
associated with the queue as name-value pairs.
queue_name: Name of the queue.
x_ms_meta_name_values:
Optional. A dict containing name-value pairs to associate with the
queue as metadata.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + '?comp=metadata'
request.headers = [('x-ms-meta-name-values', x_ms_meta_name_values)]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
self._perform_request(request)
def put_message(self, queue_name, message_text, visibilitytimeout=None,
messagettl=None):
'''
Adds a new message to the back of the message queue. A visibility
timeout can also be specified to make the message invisible until the
visibility timeout expires. A message must be in a format that can be
included in an XML request with UTF-8 encoding. The encoded message can
be up to 64KB in size for versions 2011-08-18 and newer, or 8KB in size
for previous versions.
queue_name: Name of the queue.
message_text: Message content.
visibilitytimeout:
Optional. If not specified, the default value is 0. Specifies the
new visibility timeout value, in seconds, relative to server time.
The new value must be larger than or equal to 0, and cannot be
larger than 7 days. The visibility timeout of a message cannot be
set to a value later than the expiry time. visibilitytimeout
should be set to a value smaller than the time-to-live value.
messagettl:
Optional. Specifies the time-to-live interval for the message, in
seconds. The maximum time-to-live allowed is 7 days. If this
parameter is omitted, the default time-to-live is 7 days.
'''
_validate_not_none('queue_name', queue_name)
_validate_not_none('message_text', message_text)
request = HTTPRequest()
request.method = 'POST'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + '/messages'
request.query = [
('visibilitytimeout', _str_or_none(visibilitytimeout)),
('messagettl', _str_or_none(messagettl))
]
request.body = _get_request_body(
'<?xml version="1.0" encoding="utf-8"?> \
<QueueMessage> \
<MessageText>' + xml_escape(_str(message_text)) + '</MessageText> \
</QueueMessage>')
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
self._perform_request(request)
def get_messages(self, queue_name, numofmessages=None,
visibilitytimeout=None):
'''
Retrieves one or more messages from the front of the queue.
queue_name: Name of the queue.
numofmessages:
Optional. A nonzero integer value that specifies the number of
messages to retrieve from the queue, up to a maximum of 32. If
fewer are visible, the visible messages are returned. By default,
a single message is retrieved from the queue with this operation.
visibilitytimeout:
Specifies the new visibility timeout value, in seconds, relative
to server time. The new value must be larger than or equal to 1
second, and cannot be larger than 7 days, or larger than 2 hours
on REST protocol versions prior to version 2011-08-18. The
visibility timeout of a message can be set to a value later than
the expiry time.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + '/messages'
request.query = [
('numofmessages', _str_or_none(numofmessages)),
('visibilitytimeout', _str_or_none(visibilitytimeout))
]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
response = self._perform_request(request)
return _parse_response(response, QueueMessagesList)
def peek_messages(self, queue_name, numofmessages=None):
'''
Retrieves one or more messages from the front of the queue, but does
not alter the visibility of the message.
queue_name: Name of the queue.
numofmessages:
Optional. A nonzero integer value that specifies the number of
messages to peek from the queue, up to a maximum of 32. By default,
a single message is peeked from the queue with this operation.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + '/messages?peekonly=true'
request.query = [('numofmessages', _str_or_none(numofmessages))]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
response = self._perform_request(request)
return _parse_response(response, QueueMessagesList)
def delete_message(self, queue_name, message_id, popreceipt):
'''
Deletes the specified message.
queue_name: Name of the queue.
message_id: Message to delete.
popreceipt:
Required. A valid pop receipt value returned from an earlier call
to the Get Messages or Update Message operation.
'''
_validate_not_none('queue_name', queue_name)
_validate_not_none('message_id', message_id)
_validate_not_none('popreceipt', popreceipt)
request = HTTPRequest()
request.method = 'DELETE'
request.host = self._get_host()
request.path = '/' + \
_str(queue_name) + '/messages/' + _str(message_id) + ''
request.query = [('popreceipt', _str_or_none(popreceipt))]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
self._perform_request(request)
def clear_messages(self, queue_name):
'''
Deletes all messages from the specified queue.
queue_name: Name of the queue.
'''
_validate_not_none('queue_name', queue_name)
request = HTTPRequest()
request.method = 'DELETE'
request.host = self._get_host()
request.path = '/' + _str(queue_name) + '/messages'
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
self._perform_request(request)
def update_message(self, queue_name, message_id, message_text, popreceipt,
visibilitytimeout):
'''
Updates the visibility timeout of a message. You can also use this
operation to update the contents of a message.
queue_name: Name of the queue.
message_id: Message to update.
message_text: Content of message.
popreceipt:
Required. A valid pop receipt value returned from an earlier call
to the Get Messages or Update Message operation.
visibilitytimeout:
Required. Specifies the new visibility timeout value, in seconds,
relative to server time. The new value must be larger than or equal
to 0, and cannot be larger than 7 days. The visibility timeout of a
message cannot be set to a value later than the expiry time. A
message can be updated until it has been deleted or has expired.
'''
_validate_not_none('queue_name', queue_name)
_validate_not_none('message_id', message_id)
_validate_not_none('message_text', message_text)
_validate_not_none('popreceipt', popreceipt)
_validate_not_none('visibilitytimeout', visibilitytimeout)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + \
_str(queue_name) + '/messages/' + _str(message_id) + ''
request.query = [
('popreceipt', _str_or_none(popreceipt)),
('visibilitytimeout', _str_or_none(visibilitytimeout))
]
request.body = _get_request_body(
'<?xml version="1.0" encoding="utf-8"?> \
<QueueMessage> \
<MessageText>' + xml_escape(_str(message_text)) + '</MessageText> \
</QueueMessage>')
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
response = self._perform_request(request)
return _parse_response_for_dict_filter(
response,
filter=['x-ms-popreceipt', 'x-ms-time-next-visible'])
def set_queue_service_properties(self, storage_service_properties,
timeout=None):
'''
Sets the properties of a storage account's Queue service, including
Windows Azure Storage Analytics.
storage_service_properties: StorageServiceProperties object.
timeout: Optional. The timeout parameter is expressed in seconds.
'''
_validate_not_none('storage_service_properties',
storage_service_properties)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/?restype=service&comp=properties'
request.query = [('timeout', _int_or_none(timeout))]
request.body = _get_request_body(
_convert_class_to_xml(storage_service_properties))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_queue_header(
request, self.account_name, self.account_key)
self._perform_request(request)

View File

@ -1,230 +1,231 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import url_quote
from azure.storage import _sign_string, X_MS_VERSION
#-------------------------------------------------------------------------
# Constants for the share access signature
SIGNED_START = 'st'
SIGNED_EXPIRY = 'se'
SIGNED_RESOURCE = 'sr'
SIGNED_PERMISSION = 'sp'
SIGNED_IDENTIFIER = 'si'
SIGNED_SIGNATURE = 'sig'
SIGNED_VERSION = 'sv'
RESOURCE_BLOB = 'b'
RESOURCE_CONTAINER = 'c'
SIGNED_RESOURCE_TYPE = 'resource'
SHARED_ACCESS_PERMISSION = 'permission'
#--------------------------------------------------------------------------
class WebResource(object):
'''
Class that stands for the resource to get the share access signature
path: the resource path.
properties: dict of name and values. Contains 2 item: resource type and
permission
request_url: the url of the webresource include all the queries.
'''
def __init__(self, path=None, request_url=None, properties=None):
self.path = path
self.properties = properties or {}
self.request_url = request_url
class Permission(object):
'''
Permission class. Contains the path and query_string for the path.
path: the resource path
query_string: dict of name, values. Contains SIGNED_START, SIGNED_EXPIRY
SIGNED_RESOURCE, SIGNED_PERMISSION, SIGNED_IDENTIFIER,
SIGNED_SIGNATURE name values.
'''
def __init__(self, path=None, query_string=None):
self.path = path
self.query_string = query_string
class SharedAccessPolicy(object):
''' SharedAccessPolicy class. '''
def __init__(self, access_policy, signed_identifier=None):
self.id = signed_identifier
self.access_policy = access_policy
class SharedAccessSignature(object):
'''
The main class used to do the signing and generating the signature.
account_name:
the storage account name used to generate shared access signature
account_key: the access key to genenerate share access signature
permission_set: the permission cache used to signed the request url.
'''
def __init__(self, account_name, account_key, permission_set=None):
self.account_name = account_name
self.account_key = account_key
self.permission_set = permission_set
def generate_signed_query_string(self, path, resource_type,
shared_access_policy,
version=X_MS_VERSION):
'''
Generates the query string for path, resource type and shared access
policy.
path: the resource
resource_type: could be blob or container
shared_access_policy: shared access policy
version:
x-ms-version for storage service, or None to get a signed query
string compatible with pre 2012-02-12 clients, where the version
is not included in the query string.
'''
query_string = {}
if shared_access_policy.access_policy.start:
query_string[
SIGNED_START] = shared_access_policy.access_policy.start
if version:
query_string[SIGNED_VERSION] = version
query_string[SIGNED_EXPIRY] = shared_access_policy.access_policy.expiry
query_string[SIGNED_RESOURCE] = resource_type
query_string[
SIGNED_PERMISSION] = shared_access_policy.access_policy.permission
if shared_access_policy.id:
query_string[SIGNED_IDENTIFIER] = shared_access_policy.id
query_string[SIGNED_SIGNATURE] = self._generate_signature(
path, shared_access_policy, version)
return query_string
def sign_request(self, web_resource):
''' sign request to generate request_url with sharedaccesssignature
info for web_resource.'''
if self.permission_set:
for shared_access_signature in self.permission_set:
if self._permission_matches_request(
shared_access_signature, web_resource,
web_resource.properties[
SIGNED_RESOURCE_TYPE],
web_resource.properties[SHARED_ACCESS_PERMISSION]):
if web_resource.request_url.find('?') == -1:
web_resource.request_url += '?'
else:
web_resource.request_url += '&'
web_resource.request_url += self._convert_query_string(
shared_access_signature.query_string)
break
return web_resource
def _convert_query_string(self, query_string):
''' Converts query string to str. The order of name, values is very
important and can't be wrong.'''
convert_str = ''
if SIGNED_START in query_string:
convert_str += SIGNED_START + '=' + \
url_quote(query_string[SIGNED_START]) + '&'
convert_str += SIGNED_EXPIRY + '=' + \
url_quote(query_string[SIGNED_EXPIRY]) + '&'
convert_str += SIGNED_PERMISSION + '=' + \
query_string[SIGNED_PERMISSION] + '&'
convert_str += SIGNED_RESOURCE + '=' + \
query_string[SIGNED_RESOURCE] + '&'
if SIGNED_IDENTIFIER in query_string:
convert_str += SIGNED_IDENTIFIER + '=' + \
query_string[SIGNED_IDENTIFIER] + '&'
if SIGNED_VERSION in query_string:
convert_str += SIGNED_VERSION + '=' + \
query_string[SIGNED_VERSION] + '&'
convert_str += SIGNED_SIGNATURE + '=' + \
url_quote(query_string[SIGNED_SIGNATURE]) + '&'
return convert_str
def _generate_signature(self, path, shared_access_policy, version):
''' Generates signature for a given path and shared access policy. '''
def get_value_to_append(value, no_new_line=False):
return_value = ''
if value:
return_value = value
if not no_new_line:
return_value += '\n'
return return_value
if path[0] != '/':
path = '/' + path
canonicalized_resource = '/' + self.account_name + path
# Form the string to sign from shared_access_policy and canonicalized
# resource. The order of values is important.
string_to_sign = \
(get_value_to_append(shared_access_policy.access_policy.permission) +
get_value_to_append(shared_access_policy.access_policy.start) +
get_value_to_append(shared_access_policy.access_policy.expiry) +
get_value_to_append(canonicalized_resource))
if version:
string_to_sign += get_value_to_append(shared_access_policy.id)
string_to_sign += get_value_to_append(version, True)
else:
string_to_sign += get_value_to_append(shared_access_policy.id, True)
return self._sign(string_to_sign)
def _permission_matches_request(self, shared_access_signature,
web_resource, resource_type,
required_permission):
''' Check whether requested permission matches given
shared_access_signature, web_resource and resource type. '''
required_resource_type = resource_type
if required_resource_type == RESOURCE_BLOB:
required_resource_type += RESOURCE_CONTAINER
for name, value in shared_access_signature.query_string.items():
if name == SIGNED_RESOURCE and \
required_resource_type.find(value) == -1:
return False
elif name == SIGNED_PERMISSION and \
required_permission.find(value) == -1:
return False
return web_resource.path.find(shared_access_signature.path) != -1
def _sign(self, string_to_sign):
''' use HMAC-SHA256 to sign the string and convert it as base64
encoded string. '''
return _sign_string(self.account_key, string_to_sign)
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import _sign_string, url_quote
from azure.storage import X_MS_VERSION
#-------------------------------------------------------------------------
# Constants for the share access signature
SIGNED_VERSION = 'sv'
SIGNED_START = 'st'
SIGNED_EXPIRY = 'se'
SIGNED_RESOURCE = 'sr'
SIGNED_PERMISSION = 'sp'
SIGNED_IDENTIFIER = 'si'
SIGNED_SIGNATURE = 'sig'
SIGNED_VERSION = 'sv'
RESOURCE_BLOB = 'b'
RESOURCE_CONTAINER = 'c'
SIGNED_RESOURCE_TYPE = 'resource'
SHARED_ACCESS_PERMISSION = 'permission'
#--------------------------------------------------------------------------
class WebResource(object):
'''
Class that stands for the resource to get the share access signature
path: the resource path.
properties: dict of name and values. Contains 2 item: resource type and
permission
request_url: the url of the webresource include all the queries.
'''
def __init__(self, path=None, request_url=None, properties=None):
self.path = path
self.properties = properties or {}
self.request_url = request_url
class Permission(object):
'''
Permission class. Contains the path and query_string for the path.
path: the resource path
query_string: dict of name, values. Contains SIGNED_START, SIGNED_EXPIRY
SIGNED_RESOURCE, SIGNED_PERMISSION, SIGNED_IDENTIFIER,
SIGNED_SIGNATURE name values.
'''
def __init__(self, path=None, query_string=None):
self.path = path
self.query_string = query_string
class SharedAccessPolicy(object):
''' SharedAccessPolicy class. '''
def __init__(self, access_policy, signed_identifier=None):
self.id = signed_identifier
self.access_policy = access_policy
class SharedAccessSignature(object):
'''
The main class used to do the signing and generating the signature.
account_name:
the storage account name used to generate shared access signature
account_key: the access key to genenerate share access signature
permission_set: the permission cache used to signed the request url.
'''
def __init__(self, account_name, account_key, permission_set=None):
self.account_name = account_name
self.account_key = account_key
self.permission_set = permission_set
def generate_signed_query_string(self, path, resource_type,
shared_access_policy,
version=X_MS_VERSION):
'''
Generates the query string for path, resource type and shared access
policy.
path: the resource
resource_type: could be blob or container
shared_access_policy: shared access policy
version:
x-ms-version for storage service, or None to get a signed query
string compatible with pre 2012-02-12 clients, where the version
is not included in the query string.
'''
query_string = {}
if shared_access_policy.access_policy.start:
query_string[
SIGNED_START] = shared_access_policy.access_policy.start
if version:
query_string[SIGNED_VERSION] = version
query_string[SIGNED_EXPIRY] = shared_access_policy.access_policy.expiry
query_string[SIGNED_RESOURCE] = resource_type
query_string[
SIGNED_PERMISSION] = shared_access_policy.access_policy.permission
if shared_access_policy.id:
query_string[SIGNED_IDENTIFIER] = shared_access_policy.id
query_string[SIGNED_SIGNATURE] = self._generate_signature(
path, shared_access_policy, version)
return query_string
def sign_request(self, web_resource):
''' sign request to generate request_url with sharedaccesssignature
info for web_resource.'''
if self.permission_set:
for shared_access_signature in self.permission_set:
if self._permission_matches_request(
shared_access_signature, web_resource,
web_resource.properties[
SIGNED_RESOURCE_TYPE],
web_resource.properties[SHARED_ACCESS_PERMISSION]):
if web_resource.request_url.find('?') == -1:
web_resource.request_url += '?'
else:
web_resource.request_url += '&'
web_resource.request_url += self._convert_query_string(
shared_access_signature.query_string)
break
return web_resource
def _convert_query_string(self, query_string):
''' Converts query string to str. The order of name, values is very
important and can't be wrong.'''
convert_str = ''
if SIGNED_START in query_string:
convert_str += SIGNED_START + '=' + \
url_quote(query_string[SIGNED_START]) + '&'
convert_str += SIGNED_EXPIRY + '=' + \
url_quote(query_string[SIGNED_EXPIRY]) + '&'
convert_str += SIGNED_PERMISSION + '=' + \
query_string[SIGNED_PERMISSION] + '&'
convert_str += SIGNED_RESOURCE + '=' + \
query_string[SIGNED_RESOURCE] + '&'
if SIGNED_IDENTIFIER in query_string:
convert_str += SIGNED_IDENTIFIER + '=' + \
query_string[SIGNED_IDENTIFIER] + '&'
if SIGNED_VERSION in query_string:
convert_str += SIGNED_VERSION + '=' + \
query_string[SIGNED_VERSION] + '&'
convert_str += SIGNED_SIGNATURE + '=' + \
url_quote(query_string[SIGNED_SIGNATURE]) + '&'
return convert_str
def _generate_signature(self, path, shared_access_policy, version):
''' Generates signature for a given path and shared access policy. '''
def get_value_to_append(value, no_new_line=False):
return_value = ''
if value:
return_value = value
if not no_new_line:
return_value += '\n'
return return_value
if path[0] != '/':
path = '/' + path
canonicalized_resource = '/' + self.account_name + path
# Form the string to sign from shared_access_policy and canonicalized
# resource. The order of values is important.
string_to_sign = \
(get_value_to_append(shared_access_policy.access_policy.permission) +
get_value_to_append(shared_access_policy.access_policy.start) +
get_value_to_append(shared_access_policy.access_policy.expiry) +
get_value_to_append(canonicalized_resource))
if version:
string_to_sign += get_value_to_append(shared_access_policy.id)
string_to_sign += get_value_to_append(version, True)
else:
string_to_sign += get_value_to_append(shared_access_policy.id, True)
return self._sign(string_to_sign)
def _permission_matches_request(self, shared_access_signature,
web_resource, resource_type,
required_permission):
''' Check whether requested permission matches given
shared_access_signature, web_resource and resource type. '''
required_resource_type = resource_type
if required_resource_type == RESOURCE_BLOB:
required_resource_type += RESOURCE_CONTAINER
for name, value in shared_access_signature.query_string.items():
if name == SIGNED_RESOURCE and \
required_resource_type.find(value) == -1:
return False
elif name == SIGNED_PERMISSION and \
required_permission.find(value) == -1:
return False
return web_resource.path.find(shared_access_signature.path) != -1
def _sign(self, string_to_sign):
''' use HMAC-SHA256 to sign the string and convert it as base64
encoded string. '''
return _sign_string(self.account_key, string_to_sign)

View File

@ -1,152 +1,152 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
import os
import sys
from azure import (
WindowsAzureError,
DEV_ACCOUNT_NAME,
DEV_ACCOUNT_KEY,
_ERROR_STORAGE_MISSING_INFO,
)
from azure.http import HTTPError
from azure.http.httpclient import _HTTPClient
from azure.storage import _storage_error_handler
#--------------------------------------------------------------------------
# constants for azure app setting environment variables
AZURE_STORAGE_ACCOUNT = 'AZURE_STORAGE_ACCOUNT'
AZURE_STORAGE_ACCESS_KEY = 'AZURE_STORAGE_ACCESS_KEY'
EMULATED = 'EMULATED'
#--------------------------------------------------------------------------
class _StorageClient(object):
'''
This is the base class for BlobManager, TableManager and QueueManager.
'''
def __init__(self, account_name=None, account_key=None, protocol='https',
host_base='', dev_host=''):
'''
account_name: your storage account name, required for all operations.
account_key: your storage account key, required for all operations.
protocol: Optional. Protocol. Defaults to http.
host_base:
Optional. Live host base url. Defaults to Azure url. Override this
for on-premise.
dev_host: Optional. Dev host url. Defaults to localhost.
'''
self.account_name = account_name
self.account_key = account_key
self.requestid = None
self.protocol = protocol
self.host_base = host_base
self.dev_host = dev_host
# the app is not run in azure emulator or use default development
# storage account and key if app is run in emulator.
self.use_local_storage = False
# check whether it is run in emulator.
if EMULATED in os.environ:
self.is_emulated = os.environ[EMULATED].lower() != 'false'
else:
self.is_emulated = False
# get account_name and account key. If they are not set when
# constructing, get the account and key from environment variables if
# the app is not run in azure emulator or use default development
# storage account and key if app is run in emulator.
if not self.account_name or not self.account_key:
if self.is_emulated:
self.account_name = DEV_ACCOUNT_NAME
self.account_key = DEV_ACCOUNT_KEY
self.protocol = 'http'
self.use_local_storage = True
else:
self.account_name = os.environ.get(AZURE_STORAGE_ACCOUNT)
self.account_key = os.environ.get(AZURE_STORAGE_ACCESS_KEY)
if not self.account_name or not self.account_key:
raise WindowsAzureError(_ERROR_STORAGE_MISSING_INFO)
self._httpclient = _HTTPClient(
service_instance=self,
account_key=self.account_key,
account_name=self.account_name,
protocol=self.protocol)
self._batchclient = None
self._filter = self._perform_request_worker
def with_filter(self, filter):
'''
Returns a new service which will process requests with the specified
filter. Filtering operations can include logging, automatic retrying,
etc... The filter is a lambda which receives the HTTPRequest and
another lambda. The filter can perform any pre-processing on the
request, pass it off to the next lambda, and then perform any
post-processing on the response.
'''
res = type(self)(self.account_name, self.account_key, self.protocol)
old_filter = self._filter
def new_filter(request):
return filter(request, old_filter)
res._filter = new_filter
return res
def set_proxy(self, host, port, user=None, password=None):
'''
Sets the proxy server host and port for the HTTP CONNECT Tunnelling.
host: Address of the proxy. Ex: '192.168.0.100'
port: Port of the proxy. Ex: 6000
user: User for proxy authorization.
password: Password for proxy authorization.
'''
self._httpclient.set_proxy(host, port, user, password)
def _get_host(self):
if self.use_local_storage:
return self.dev_host
else:
return self.account_name + self.host_base
def _perform_request_worker(self, request):
return self._httpclient.perform_request(request)
def _perform_request(self, request, text_encoding='utf-8'):
'''
Sends the request and return response. Catches HTTPError and hand it
to error handler
'''
try:
if self._batchclient is not None:
return self._batchclient.insert_request_to_batch(request)
else:
resp = self._filter(request)
if sys.version_info >= (3,) and isinstance(resp, bytes) and \
text_encoding:
resp = resp.decode(text_encoding)
except HTTPError as ex:
_storage_error_handler(ex)
return resp
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
import os
import sys
from azure import (
WindowsAzureError,
DEV_ACCOUNT_NAME,
DEV_ACCOUNT_KEY,
_ERROR_STORAGE_MISSING_INFO,
)
from azure.http import HTTPError
from azure.http.httpclient import _HTTPClient
from azure.storage import _storage_error_handler
#--------------------------------------------------------------------------
# constants for azure app setting environment variables
AZURE_STORAGE_ACCOUNT = 'AZURE_STORAGE_ACCOUNT'
AZURE_STORAGE_ACCESS_KEY = 'AZURE_STORAGE_ACCESS_KEY'
EMULATED = 'EMULATED'
#--------------------------------------------------------------------------
class _StorageClient(object):
'''
This is the base class for BlobManager, TableManager and QueueManager.
'''
def __init__(self, account_name=None, account_key=None, protocol='https',
host_base='', dev_host=''):
'''
account_name: your storage account name, required for all operations.
account_key: your storage account key, required for all operations.
protocol: Optional. Protocol. Defaults to http.
host_base:
Optional. Live host base url. Defaults to Azure url. Override this
for on-premise.
dev_host: Optional. Dev host url. Defaults to localhost.
'''
self.account_name = account_name
self.account_key = account_key
self.requestid = None
self.protocol = protocol
self.host_base = host_base
self.dev_host = dev_host
# the app is not run in azure emulator or use default development
# storage account and key if app is run in emulator.
self.use_local_storage = False
# check whether it is run in emulator.
if EMULATED in os.environ:
self.is_emulated = os.environ[EMULATED].lower() != 'false'
else:
self.is_emulated = False
# get account_name and account key. If they are not set when
# constructing, get the account and key from environment variables if
# the app is not run in azure emulator or use default development
# storage account and key if app is run in emulator.
if not self.account_name or not self.account_key:
if self.is_emulated:
self.account_name = DEV_ACCOUNT_NAME
self.account_key = DEV_ACCOUNT_KEY
self.protocol = 'http'
self.use_local_storage = True
else:
self.account_name = os.environ.get(AZURE_STORAGE_ACCOUNT)
self.account_key = os.environ.get(AZURE_STORAGE_ACCESS_KEY)
if not self.account_name or not self.account_key:
raise WindowsAzureError(_ERROR_STORAGE_MISSING_INFO)
self._httpclient = _HTTPClient(
service_instance=self,
account_key=self.account_key,
account_name=self.account_name,
protocol=self.protocol)
self._batchclient = None
self._filter = self._perform_request_worker
def with_filter(self, filter):
'''
Returns a new service which will process requests with the specified
filter. Filtering operations can include logging, automatic retrying,
etc... The filter is a lambda which receives the HTTPRequest and
another lambda. The filter can perform any pre-processing on the
request, pass it off to the next lambda, and then perform any
post-processing on the response.
'''
res = type(self)(self.account_name, self.account_key, self.protocol)
old_filter = self._filter
def new_filter(request):
return filter(request, old_filter)
res._filter = new_filter
return res
def set_proxy(self, host, port, user=None, password=None):
'''
Sets the proxy server host and port for the HTTP CONNECT Tunnelling.
host: Address of the proxy. Ex: '192.168.0.100'
port: Port of the proxy. Ex: 6000
user: User for proxy authorization.
password: Password for proxy authorization.
'''
self._httpclient.set_proxy(host, port, user, password)
def _get_host(self):
if self.use_local_storage:
return self.dev_host
else:
return self.account_name + self.host_base
def _perform_request_worker(self, request):
return self._httpclient.perform_request(request)
def _perform_request(self, request, text_encoding='utf-8'):
'''
Sends the request and return response. Catches HTTPError and hand it
to error handler
'''
try:
if self._batchclient is not None:
return self._batchclient.insert_request_to_batch(request)
else:
resp = self._filter(request)
if sys.version_info >= (3,) and isinstance(resp, bytes) and \
text_encoding:
resp = resp.decode(text_encoding)
except HTTPError as ex:
_storage_error_handler(ex)
return resp

View File

@ -1,491 +1,491 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import (
WindowsAzureError,
TABLE_SERVICE_HOST_BASE,
DEV_TABLE_HOST,
_convert_class_to_xml,
_convert_response_to_feeds,
_dont_fail_not_exist,
_dont_fail_on_exist,
_get_request_body,
_int_or_none,
_parse_response,
_parse_response_for_dict,
_parse_response_for_dict_filter,
_str,
_str_or_none,
_update_request_uri_query_local_storage,
_validate_not_none,
)
from azure.http import HTTPRequest
from azure.http.batchclient import _BatchClient
from azure.storage import (
StorageServiceProperties,
_convert_entity_to_xml,
_convert_response_to_entity,
_convert_table_to_xml,
_convert_xml_to_entity,
_convert_xml_to_table,
_sign_storage_table_request,
_update_storage_table_header,
)
from azure.storage.storageclient import _StorageClient
class TableService(_StorageClient):
'''
This is the main class managing Table resources.
'''
def __init__(self, account_name=None, account_key=None, protocol='https',
host_base=TABLE_SERVICE_HOST_BASE, dev_host=DEV_TABLE_HOST):
'''
account_name: your storage account name, required for all operations.
account_key: your storage account key, required for all operations.
protocol: Optional. Protocol. Defaults to http.
host_base:
Optional. Live host base url. Defaults to Azure url. Override this
for on-premise.
dev_host: Optional. Dev host url. Defaults to localhost.
'''
super(TableService, self).__init__(
account_name, account_key, protocol, host_base, dev_host)
def begin_batch(self):
if self._batchclient is None:
self._batchclient = _BatchClient(
service_instance=self,
account_key=self.account_key,
account_name=self.account_name)
return self._batchclient.begin_batch()
def commit_batch(self):
try:
ret = self._batchclient.commit_batch()
finally:
self._batchclient = None
return ret
def cancel_batch(self):
self._batchclient = None
def get_table_service_properties(self):
'''
Gets the properties of a storage account's Table service, including
Windows Azure Storage Analytics.
'''
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/?restype=service&comp=properties'
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _parse_response(response, StorageServiceProperties)
def set_table_service_properties(self, storage_service_properties):
'''
Sets the properties of a storage account's Table Service, including
Windows Azure Storage Analytics.
storage_service_properties: StorageServiceProperties object.
'''
_validate_not_none('storage_service_properties',
storage_service_properties)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/?restype=service&comp=properties'
request.body = _get_request_body(
_convert_class_to_xml(storage_service_properties))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _parse_response_for_dict(response)
def query_tables(self, table_name=None, top=None, next_table_name=None):
'''
Returns a list of tables under the specified account.
table_name: Optional. The specific table to query.
top: Optional. Maximum number of tables to return.
next_table_name:
Optional. When top is used, the next table name is stored in
result.x_ms_continuation['NextTableName']
'''
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
if table_name is not None:
uri_part_table_name = "('" + table_name + "')"
else:
uri_part_table_name = ""
request.path = '/Tables' + uri_part_table_name + ''
request.query = [
('$top', _int_or_none(top)),
('NextTableName', _str_or_none(next_table_name))
]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _convert_response_to_feeds(response, _convert_xml_to_table)
def create_table(self, table, fail_on_exist=False):
'''
Creates a new table in the storage account.
table:
Name of the table to create. Table name may contain only
alphanumeric characters and cannot begin with a numeric character.
It is case-insensitive and must be from 3 to 63 characters long.
fail_on_exist: Specify whether throw exception when table exists.
'''
_validate_not_none('table', table)
request = HTTPRequest()
request.method = 'POST'
request.host = self._get_host()
request.path = '/Tables'
request.body = _get_request_body(_convert_table_to_xml(table))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
if not fail_on_exist:
try:
self._perform_request(request)
return True
except WindowsAzureError as ex:
_dont_fail_on_exist(ex)
return False
else:
self._perform_request(request)
return True
def delete_table(self, table_name, fail_not_exist=False):
'''
table_name: Name of the table to delete.
fail_not_exist:
Specify whether throw exception when table doesn't exist.
'''
_validate_not_none('table_name', table_name)
request = HTTPRequest()
request.method = 'DELETE'
request.host = self._get_host()
request.path = '/Tables(\'' + _str(table_name) + '\')'
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
if not fail_not_exist:
try:
self._perform_request(request)
return True
except WindowsAzureError as ex:
_dont_fail_not_exist(ex)
return False
else:
self._perform_request(request)
return True
def get_entity(self, table_name, partition_key, row_key, select=''):
'''
Get an entity in a table; includes the $select options.
partition_key: PartitionKey of the entity.
row_key: RowKey of the entity.
select: Property names to select.
'''
_validate_not_none('table_name', table_name)
_validate_not_none('partition_key', partition_key)
_validate_not_none('row_key', row_key)
_validate_not_none('select', select)
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/' + _str(table_name) + \
'(PartitionKey=\'' + _str(partition_key) + \
'\',RowKey=\'' + \
_str(row_key) + '\')?$select=' + \
_str(select) + ''
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _convert_response_to_entity(response)
def query_entities(self, table_name, filter=None, select=None, top=None,
next_partition_key=None, next_row_key=None):
'''
Get entities in a table; includes the $filter and $select options.
table_name: Table to query.
filter:
Optional. Filter as described at
http://msdn.microsoft.com/en-us/library/windowsazure/dd894031.aspx
select: Optional. Property names to select from the entities.
top: Optional. Maximum number of entities to return.
next_partition_key:
Optional. When top is used, the next partition key is stored in
result.x_ms_continuation['NextPartitionKey']
next_row_key:
Optional. When top is used, the next partition key is stored in
result.x_ms_continuation['NextRowKey']
'''
_validate_not_none('table_name', table_name)
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/' + _str(table_name) + '()'
request.query = [
('$filter', _str_or_none(filter)),
('$select', _str_or_none(select)),
('$top', _int_or_none(top)),
('NextPartitionKey', _str_or_none(next_partition_key)),
('NextRowKey', _str_or_none(next_row_key))
]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _convert_response_to_feeds(response, _convert_xml_to_entity)
def insert_entity(self, table_name, entity,
content_type='application/atom+xml'):
'''
Inserts a new entity into a table.
table_name: Table name.
entity:
Required. The entity object to insert. Could be a dict format or
entity object.
content_type: Required. Must be set to application/atom+xml
'''
_validate_not_none('table_name', table_name)
_validate_not_none('entity', entity)
_validate_not_none('content_type', content_type)
request = HTTPRequest()
request.method = 'POST'
request.host = self._get_host()
request.path = '/' + _str(table_name) + ''
request.headers = [('Content-Type', _str_or_none(content_type))]
request.body = _get_request_body(_convert_entity_to_xml(entity))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _convert_response_to_entity(response)
def update_entity(self, table_name, partition_key, row_key, entity,
content_type='application/atom+xml', if_match='*'):
'''
Updates an existing entity in a table. The Update Entity operation
replaces the entire entity and can be used to remove properties.
table_name: Table name.
partition_key: PartitionKey of the entity.
row_key: RowKey of the entity.
entity:
Required. The entity object to insert. Could be a dict format or
entity object.
content_type: Required. Must be set to application/atom+xml
if_match:
Optional. Specifies the condition for which the merge should be
performed. To force an unconditional merge, set to the wildcard
character (*).
'''
_validate_not_none('table_name', table_name)
_validate_not_none('partition_key', partition_key)
_validate_not_none('row_key', row_key)
_validate_not_none('entity', entity)
_validate_not_none('content_type', content_type)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + \
_str(table_name) + '(PartitionKey=\'' + \
_str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')'
request.headers = [
('Content-Type', _str_or_none(content_type)),
('If-Match', _str_or_none(if_match))
]
request.body = _get_request_body(_convert_entity_to_xml(entity))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _parse_response_for_dict_filter(response, filter=['etag'])
def merge_entity(self, table_name, partition_key, row_key, entity,
content_type='application/atom+xml', if_match='*'):
'''
Updates an existing entity by updating the entity's properties. This
operation does not replace the existing entity as the Update Entity
operation does.
table_name: Table name.
partition_key: PartitionKey of the entity.
row_key: RowKey of the entity.
entity:
Required. The entity object to insert. Can be a dict format or
entity object.
content_type: Required. Must be set to application/atom+xml
if_match:
Optional. Specifies the condition for which the merge should be
performed. To force an unconditional merge, set to the wildcard
character (*).
'''
_validate_not_none('table_name', table_name)
_validate_not_none('partition_key', partition_key)
_validate_not_none('row_key', row_key)
_validate_not_none('entity', entity)
_validate_not_none('content_type', content_type)
request = HTTPRequest()
request.method = 'MERGE'
request.host = self._get_host()
request.path = '/' + \
_str(table_name) + '(PartitionKey=\'' + \
_str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')'
request.headers = [
('Content-Type', _str_or_none(content_type)),
('If-Match', _str_or_none(if_match))
]
request.body = _get_request_body(_convert_entity_to_xml(entity))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _parse_response_for_dict_filter(response, filter=['etag'])
def delete_entity(self, table_name, partition_key, row_key,
content_type='application/atom+xml', if_match='*'):
'''
Deletes an existing entity in a table.
table_name: Table name.
partition_key: PartitionKey of the entity.
row_key: RowKey of the entity.
content_type: Required. Must be set to application/atom+xml
if_match:
Optional. Specifies the condition for which the delete should be
performed. To force an unconditional delete, set to the wildcard
character (*).
'''
_validate_not_none('table_name', table_name)
_validate_not_none('partition_key', partition_key)
_validate_not_none('row_key', row_key)
_validate_not_none('content_type', content_type)
_validate_not_none('if_match', if_match)
request = HTTPRequest()
request.method = 'DELETE'
request.host = self._get_host()
request.path = '/' + \
_str(table_name) + '(PartitionKey=\'' + \
_str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')'
request.headers = [
('Content-Type', _str_or_none(content_type)),
('If-Match', _str_or_none(if_match))
]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
self._perform_request(request)
def insert_or_replace_entity(self, table_name, partition_key, row_key,
entity, content_type='application/atom+xml'):
'''
Replaces an existing entity or inserts a new entity if it does not
exist in the table. Because this operation can insert or update an
entity, it is also known as an "upsert" operation.
table_name: Table name.
partition_key: PartitionKey of the entity.
row_key: RowKey of the entity.
entity:
Required. The entity object to insert. Could be a dict format or
entity object.
content_type: Required. Must be set to application/atom+xml
'''
_validate_not_none('table_name', table_name)
_validate_not_none('partition_key', partition_key)
_validate_not_none('row_key', row_key)
_validate_not_none('entity', entity)
_validate_not_none('content_type', content_type)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + \
_str(table_name) + '(PartitionKey=\'' + \
_str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')'
request.headers = [('Content-Type', _str_or_none(content_type))]
request.body = _get_request_body(_convert_entity_to_xml(entity))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _parse_response_for_dict_filter(response, filter=['etag'])
def insert_or_merge_entity(self, table_name, partition_key, row_key,
entity, content_type='application/atom+xml'):
'''
Merges an existing entity or inserts a new entity if it does not exist
in the table. Because this operation can insert or update an entity,
it is also known as an "upsert" operation.
table_name: Table name.
partition_key: PartitionKey of the entity.
row_key: RowKey of the entity.
entity:
Required. The entity object to insert. Could be a dict format or
entity object.
content_type: Required. Must be set to application/atom+xml
'''
_validate_not_none('table_name', table_name)
_validate_not_none('partition_key', partition_key)
_validate_not_none('row_key', row_key)
_validate_not_none('entity', entity)
_validate_not_none('content_type', content_type)
request = HTTPRequest()
request.method = 'MERGE'
request.host = self._get_host()
request.path = '/' + \
_str(table_name) + '(PartitionKey=\'' + \
_str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')'
request.headers = [('Content-Type', _str_or_none(content_type))]
request.body = _get_request_body(_convert_entity_to_xml(entity))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _parse_response_for_dict_filter(response, filter=['etag'])
def _perform_request_worker(self, request):
auth = _sign_storage_table_request(request,
self.account_name,
self.account_key)
request.headers.append(('Authorization', auth))
return self._httpclient.perform_request(request)
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import (
WindowsAzureError,
TABLE_SERVICE_HOST_BASE,
DEV_TABLE_HOST,
_convert_class_to_xml,
_convert_response_to_feeds,
_dont_fail_not_exist,
_dont_fail_on_exist,
_get_request_body,
_int_or_none,
_parse_response,
_parse_response_for_dict,
_parse_response_for_dict_filter,
_str,
_str_or_none,
_update_request_uri_query_local_storage,
_validate_not_none,
)
from azure.http import HTTPRequest
from azure.http.batchclient import _BatchClient
from azure.storage import (
StorageServiceProperties,
_convert_entity_to_xml,
_convert_response_to_entity,
_convert_table_to_xml,
_convert_xml_to_entity,
_convert_xml_to_table,
_sign_storage_table_request,
_update_storage_table_header,
)
from azure.storage.storageclient import _StorageClient
class TableService(_StorageClient):
'''
This is the main class managing Table resources.
'''
def __init__(self, account_name=None, account_key=None, protocol='https',
host_base=TABLE_SERVICE_HOST_BASE, dev_host=DEV_TABLE_HOST):
'''
account_name: your storage account name, required for all operations.
account_key: your storage account key, required for all operations.
protocol: Optional. Protocol. Defaults to http.
host_base:
Optional. Live host base url. Defaults to Azure url. Override this
for on-premise.
dev_host: Optional. Dev host url. Defaults to localhost.
'''
super(TableService, self).__init__(
account_name, account_key, protocol, host_base, dev_host)
def begin_batch(self):
if self._batchclient is None:
self._batchclient = _BatchClient(
service_instance=self,
account_key=self.account_key,
account_name=self.account_name)
return self._batchclient.begin_batch()
def commit_batch(self):
try:
ret = self._batchclient.commit_batch()
finally:
self._batchclient = None
return ret
def cancel_batch(self):
self._batchclient = None
def get_table_service_properties(self):
'''
Gets the properties of a storage account's Table service, including
Windows Azure Storage Analytics.
'''
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/?restype=service&comp=properties'
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _parse_response(response, StorageServiceProperties)
def set_table_service_properties(self, storage_service_properties):
'''
Sets the properties of a storage account's Table Service, including
Windows Azure Storage Analytics.
storage_service_properties: StorageServiceProperties object.
'''
_validate_not_none('storage_service_properties',
storage_service_properties)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/?restype=service&comp=properties'
request.body = _get_request_body(
_convert_class_to_xml(storage_service_properties))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _parse_response_for_dict(response)
def query_tables(self, table_name=None, top=None, next_table_name=None):
'''
Returns a list of tables under the specified account.
table_name: Optional. The specific table to query.
top: Optional. Maximum number of tables to return.
next_table_name:
Optional. When top is used, the next table name is stored in
result.x_ms_continuation['NextTableName']
'''
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
if table_name is not None:
uri_part_table_name = "('" + table_name + "')"
else:
uri_part_table_name = ""
request.path = '/Tables' + uri_part_table_name + ''
request.query = [
('$top', _int_or_none(top)),
('NextTableName', _str_or_none(next_table_name))
]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _convert_response_to_feeds(response, _convert_xml_to_table)
def create_table(self, table, fail_on_exist=False):
'''
Creates a new table in the storage account.
table:
Name of the table to create. Table name may contain only
alphanumeric characters and cannot begin with a numeric character.
It is case-insensitive and must be from 3 to 63 characters long.
fail_on_exist: Specify whether throw exception when table exists.
'''
_validate_not_none('table', table)
request = HTTPRequest()
request.method = 'POST'
request.host = self._get_host()
request.path = '/Tables'
request.body = _get_request_body(_convert_table_to_xml(table))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
if not fail_on_exist:
try:
self._perform_request(request)
return True
except WindowsAzureError as ex:
_dont_fail_on_exist(ex)
return False
else:
self._perform_request(request)
return True
def delete_table(self, table_name, fail_not_exist=False):
'''
table_name: Name of the table to delete.
fail_not_exist:
Specify whether throw exception when table doesn't exist.
'''
_validate_not_none('table_name', table_name)
request = HTTPRequest()
request.method = 'DELETE'
request.host = self._get_host()
request.path = '/Tables(\'' + _str(table_name) + '\')'
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
if not fail_not_exist:
try:
self._perform_request(request)
return True
except WindowsAzureError as ex:
_dont_fail_not_exist(ex)
return False
else:
self._perform_request(request)
return True
def get_entity(self, table_name, partition_key, row_key, select=''):
'''
Get an entity in a table; includes the $select options.
partition_key: PartitionKey of the entity.
row_key: RowKey of the entity.
select: Property names to select.
'''
_validate_not_none('table_name', table_name)
_validate_not_none('partition_key', partition_key)
_validate_not_none('row_key', row_key)
_validate_not_none('select', select)
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/' + _str(table_name) + \
'(PartitionKey=\'' + _str(partition_key) + \
'\',RowKey=\'' + \
_str(row_key) + '\')?$select=' + \
_str(select) + ''
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _convert_response_to_entity(response)
def query_entities(self, table_name, filter=None, select=None, top=None,
next_partition_key=None, next_row_key=None):
'''
Get entities in a table; includes the $filter and $select options.
table_name: Table to query.
filter:
Optional. Filter as described at
http://msdn.microsoft.com/en-us/library/windowsazure/dd894031.aspx
select: Optional. Property names to select from the entities.
top: Optional. Maximum number of entities to return.
next_partition_key:
Optional. When top is used, the next partition key is stored in
result.x_ms_continuation['NextPartitionKey']
next_row_key:
Optional. When top is used, the next partition key is stored in
result.x_ms_continuation['NextRowKey']
'''
_validate_not_none('table_name', table_name)
request = HTTPRequest()
request.method = 'GET'
request.host = self._get_host()
request.path = '/' + _str(table_name) + '()'
request.query = [
('$filter', _str_or_none(filter)),
('$select', _str_or_none(select)),
('$top', _int_or_none(top)),
('NextPartitionKey', _str_or_none(next_partition_key)),
('NextRowKey', _str_or_none(next_row_key))
]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _convert_response_to_feeds(response, _convert_xml_to_entity)
def insert_entity(self, table_name, entity,
content_type='application/atom+xml'):
'''
Inserts a new entity into a table.
table_name: Table name.
entity:
Required. The entity object to insert. Could be a dict format or
entity object.
content_type: Required. Must be set to application/atom+xml
'''
_validate_not_none('table_name', table_name)
_validate_not_none('entity', entity)
_validate_not_none('content_type', content_type)
request = HTTPRequest()
request.method = 'POST'
request.host = self._get_host()
request.path = '/' + _str(table_name) + ''
request.headers = [('Content-Type', _str_or_none(content_type))]
request.body = _get_request_body(_convert_entity_to_xml(entity))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _convert_response_to_entity(response)
def update_entity(self, table_name, partition_key, row_key, entity,
content_type='application/atom+xml', if_match='*'):
'''
Updates an existing entity in a table. The Update Entity operation
replaces the entire entity and can be used to remove properties.
table_name: Table name.
partition_key: PartitionKey of the entity.
row_key: RowKey of the entity.
entity:
Required. The entity object to insert. Could be a dict format or
entity object.
content_type: Required. Must be set to application/atom+xml
if_match:
Optional. Specifies the condition for which the merge should be
performed. To force an unconditional merge, set to the wildcard
character (*).
'''
_validate_not_none('table_name', table_name)
_validate_not_none('partition_key', partition_key)
_validate_not_none('row_key', row_key)
_validate_not_none('entity', entity)
_validate_not_none('content_type', content_type)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + \
_str(table_name) + '(PartitionKey=\'' + \
_str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')'
request.headers = [
('Content-Type', _str_or_none(content_type)),
('If-Match', _str_or_none(if_match))
]
request.body = _get_request_body(_convert_entity_to_xml(entity))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _parse_response_for_dict_filter(response, filter=['etag'])
def merge_entity(self, table_name, partition_key, row_key, entity,
content_type='application/atom+xml', if_match='*'):
'''
Updates an existing entity by updating the entity's properties. This
operation does not replace the existing entity as the Update Entity
operation does.
table_name: Table name.
partition_key: PartitionKey of the entity.
row_key: RowKey of the entity.
entity:
Required. The entity object to insert. Can be a dict format or
entity object.
content_type: Required. Must be set to application/atom+xml
if_match:
Optional. Specifies the condition for which the merge should be
performed. To force an unconditional merge, set to the wildcard
character (*).
'''
_validate_not_none('table_name', table_name)
_validate_not_none('partition_key', partition_key)
_validate_not_none('row_key', row_key)
_validate_not_none('entity', entity)
_validate_not_none('content_type', content_type)
request = HTTPRequest()
request.method = 'MERGE'
request.host = self._get_host()
request.path = '/' + \
_str(table_name) + '(PartitionKey=\'' + \
_str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')'
request.headers = [
('Content-Type', _str_or_none(content_type)),
('If-Match', _str_or_none(if_match))
]
request.body = _get_request_body(_convert_entity_to_xml(entity))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _parse_response_for_dict_filter(response, filter=['etag'])
def delete_entity(self, table_name, partition_key, row_key,
content_type='application/atom+xml', if_match='*'):
'''
Deletes an existing entity in a table.
table_name: Table name.
partition_key: PartitionKey of the entity.
row_key: RowKey of the entity.
content_type: Required. Must be set to application/atom+xml
if_match:
Optional. Specifies the condition for which the delete should be
performed. To force an unconditional delete, set to the wildcard
character (*).
'''
_validate_not_none('table_name', table_name)
_validate_not_none('partition_key', partition_key)
_validate_not_none('row_key', row_key)
_validate_not_none('content_type', content_type)
_validate_not_none('if_match', if_match)
request = HTTPRequest()
request.method = 'DELETE'
request.host = self._get_host()
request.path = '/' + \
_str(table_name) + '(PartitionKey=\'' + \
_str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')'
request.headers = [
('Content-Type', _str_or_none(content_type)),
('If-Match', _str_or_none(if_match))
]
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
self._perform_request(request)
def insert_or_replace_entity(self, table_name, partition_key, row_key,
entity, content_type='application/atom+xml'):
'''
Replaces an existing entity or inserts a new entity if it does not
exist in the table. Because this operation can insert or update an
entity, it is also known as an "upsert" operation.
table_name: Table name.
partition_key: PartitionKey of the entity.
row_key: RowKey of the entity.
entity:
Required. The entity object to insert. Could be a dict format or
entity object.
content_type: Required. Must be set to application/atom+xml
'''
_validate_not_none('table_name', table_name)
_validate_not_none('partition_key', partition_key)
_validate_not_none('row_key', row_key)
_validate_not_none('entity', entity)
_validate_not_none('content_type', content_type)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + \
_str(table_name) + '(PartitionKey=\'' + \
_str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')'
request.headers = [('Content-Type', _str_or_none(content_type))]
request.body = _get_request_body(_convert_entity_to_xml(entity))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _parse_response_for_dict_filter(response, filter=['etag'])
def insert_or_merge_entity(self, table_name, partition_key, row_key,
entity, content_type='application/atom+xml'):
'''
Merges an existing entity or inserts a new entity if it does not exist
in the table. Because this operation can insert or update an entity,
it is also known as an "upsert" operation.
table_name: Table name.
partition_key: PartitionKey of the entity.
row_key: RowKey of the entity.
entity:
Required. The entity object to insert. Could be a dict format or
entity object.
content_type: Required. Must be set to application/atom+xml
'''
_validate_not_none('table_name', table_name)
_validate_not_none('partition_key', partition_key)
_validate_not_none('row_key', row_key)
_validate_not_none('entity', entity)
_validate_not_none('content_type', content_type)
request = HTTPRequest()
request.method = 'MERGE'
request.host = self._get_host()
request.path = '/' + \
_str(table_name) + '(PartitionKey=\'' + \
_str(partition_key) + '\',RowKey=\'' + _str(row_key) + '\')'
request.headers = [('Content-Type', _str_or_none(content_type))]
request.body = _get_request_body(_convert_entity_to_xml(entity))
request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage)
request.headers = _update_storage_table_header(request)
response = self._perform_request(request)
return _parse_response_for_dict_filter(response, filter=['etag'])
def _perform_request_worker(self, request):
auth = _sign_storage_table_request(request,
self.account_name,
self.account_key)
request.headers.append(('Authorization', auth))
return self._httpclient.perform_request(request)

View File

@ -1,10 +1,2 @@
# -*- coding: utf-8 -*-
"""
Copyright (c) 2003-2010 Gustavo Niemeyer <gustavo@niemeyer.net>
This module offers extensions to the standard Python
datetime module.
"""
__author__ = "Tomi Pieviläinen <tomi.pievilainen@iki.fi>"
__license__ = "Simplified BSD"
__version__ = "2.2"
__version__ = "2.4.0"

View File

@ -1,18 +1,17 @@
# -*- coding: utf-8 -*-
"""
Copyright (c) 2003-2007 Gustavo Niemeyer <gustavo@niemeyer.net>
This module offers extensions to the standard Python
datetime module.
This module offers a generic easter computing method for any given year, using
Western, Orthodox or Julian algorithms.
"""
__license__ = "Simplified BSD"
import datetime
__all__ = ["easter", "EASTER_JULIAN", "EASTER_ORTHODOX", "EASTER_WESTERN"]
EASTER_JULIAN = 1
EASTER_JULIAN = 1
EASTER_ORTHODOX = 2
EASTER_WESTERN = 3
EASTER_WESTERN = 3
def easter(year, method=EASTER_WESTERN):
"""
@ -24,7 +23,7 @@ def easter(year, method=EASTER_WESTERN):
This algorithm implements three different easter
calculation methods:
1 - Original calculation in Julian calendar, valid in
dates after 326 AD
2 - Original method, with date converted to Gregorian
@ -39,7 +38,7 @@ def easter(year, method=EASTER_WESTERN):
EASTER_WESTERN = 3
The default method is method 3.
More about the algorithm may be found at:
http://users.chariot.net.au/~gmarts/eastalg.htm
@ -68,24 +67,23 @@ def easter(year, method=EASTER_WESTERN):
e = 0
if method < 3:
# Old method
i = (19*g+15)%30
j = (y+y//4+i)%7
i = (19*g + 15) % 30
j = (y + y//4 + i) % 7
if method == 2:
# Extra dates to convert Julian to Gregorian date
e = 10
if y > 1600:
e = e+y//100-16-(y//100-16)//4
e = e + y//100 - 16 - (y//100 - 16)//4
else:
# New method
c = y//100
h = (c-c//4-(8*c+13)//25+19*g+15)%30
i = h-(h//28)*(1-(h//28)*(29//(h+1))*((21-g)//11))
j = (y+y//4+i+2-c+c//4)%7
h = (c - c//4 - (8*c + 13)//25 + 19*g + 15) % 30
i = h - (h//28)*(1 - (h//28)*(29//(h + 1))*((21 - g)//11))
j = (y + y//4 + i + 2 - c + c//4) % 7
# p can be from -6 to 56 corresponding to dates 22 March to 23 May
# (later dates apply to method 2, although 23 May never actually occurs)
p = i-j+e
d = 1+(p+27+(p+6)//40)%31
m = 3+(p+26)//30
p = i - j + e
d = 1 + (p + 27 + (p + 6)//40) % 31
m = 3 + (p + 26)//30
return datetime.date(int(y), int(m), int(d))

View File

@ -1,32 +1,21 @@
# -*- coding:iso-8859-1 -*-
"""
Copyright (c) 2003-2007 Gustavo Niemeyer <gustavo@niemeyer.net>
This module offers extensions to the standard Python
datetime module.
This module offers a generic date/time string parser which is able to parse
most known formats to represent a date and/or time.
"""
from __future__ import unicode_literals
__license__ = "Simplified BSD"
import datetime
import string
import time
import sys
import os
import collections
try:
from io import StringIO
except ImportError:
from io import StringIO
from io import StringIO
from six import text_type, binary_type, integer_types
from . import relativedelta
from . import tz
__all__ = ["parse", "parserinfo"]
@ -83,9 +72,9 @@ class _timelex(object):
state = '0'
elif nextchar in whitespace:
token = ' '
break # emit token
break # emit token
else:
break # emit token
break # emit token
elif state == 'a':
seenletters = True
if nextchar in wordchars:
@ -95,7 +84,7 @@ class _timelex(object):
state = 'a.'
else:
self.charstack.append(nextchar)
break # emit token
break # emit token
elif state == '0':
if nextchar in numchars:
token += nextchar
@ -104,7 +93,7 @@ class _timelex(object):
state = '0.'
else:
self.charstack.append(nextchar)
break # emit token
break # emit token
elif state == 'a.':
seenletters = True
if nextchar == '.' or nextchar in wordchars:
@ -114,7 +103,7 @@ class _timelex(object):
state = '0.'
else:
self.charstack.append(nextchar)
break # emit token
break # emit token
elif state == '0.':
if nextchar == '.' or nextchar in numchars:
token += nextchar
@ -123,9 +112,9 @@ class _timelex(object):
state = 'a.'
else:
self.charstack.append(nextchar)
break # emit token
if (state in ('a.', '0.') and
(seenletters or token.count('.') > 1 or token[-1] == '.')):
break # emit token
if (state in ('a.', '0.') and (seenletters or token.count('.') > 1 or
token[-1] == '.')):
l = token.split('.')
token = l[0]
for tok in l[1:]:
@ -183,18 +172,18 @@ class parserinfo(object):
("Fri", "Friday"),
("Sat", "Saturday"),
("Sun", "Sunday")]
MONTHS = [("Jan", "January"),
("Feb", "February"),
("Mar", "March"),
("Apr", "April"),
("May", "May"),
("Jun", "June"),
("Jul", "July"),
("Aug", "August"),
("Sep", "Sept", "September"),
("Oct", "October"),
("Nov", "November"),
("Dec", "December")]
MONTHS = [("Jan", "January"),
("Feb", "February"),
("Mar", "March"),
("Apr", "April"),
("May", "May"),
("Jun", "June"),
("Jul", "July"),
("Aug", "August"),
("Sep", "Sept", "September"),
("Oct", "October"),
("Nov", "November"),
("Dec", "December")]
HMS = [("h", "hour", "hours"),
("m", "minute", "minutes"),
("s", "second", "seconds")]
@ -299,15 +288,16 @@ class parser(object):
def __init__(self, info=None):
self.info = info or parserinfo()
def parse(self, timestr, default=None,
ignoretz=False, tzinfos=None,
**kwargs):
def parse(self, timestr, default=None, ignoretz=False, tzinfos=None,
**kwargs):
if not default:
default = datetime.datetime.now().replace(hour=0, minute=0,
second=0, microsecond=0)
res, skipped_tokens = self._parse(timestr, **kwargs)
if kwargs.get('fuzzy_with_tokens', False):
res, skipped_tokens = self._parse(timestr, **kwargs)
else:
res = self._parse(timestr, **kwargs)
if res is None:
raise ValueError("unknown string format")
@ -321,7 +311,8 @@ class parser(object):
if res.weekday is not None and not res.day:
ret = ret+relativedelta.relativedelta(weekday=res.weekday)
if not ignoretz:
if isinstance(tzinfos, collections.Callable) or tzinfos and res.tzname in tzinfos:
if (isinstance(tzinfos, collections.Callable) or
tzinfos and res.tzname in tzinfos):
if isinstance(tzinfos, collections.Callable):
tzdata = tzinfos(res.tzname, res.tzoffset)
else:
@ -333,8 +324,8 @@ class parser(object):
elif isinstance(tzdata, integer_types):
tzinfo = tz.tzoffset(res.tzname, tzdata)
else:
raise ValueError("offset must be tzinfo subclass, " \
"tz string, or int offset")
raise ValueError("offset must be tzinfo subclass, "
"tz string, or int offset")
ret = ret.replace(tzinfo=tzinfo)
elif res.tzname and res.tzname in time.tzname:
ret = ret.replace(tzinfo=tz.tzlocal())
@ -343,17 +334,18 @@ class parser(object):
elif res.tzoffset:
ret = ret.replace(tzinfo=tz.tzoffset(res.tzname, res.tzoffset))
if skipped_tokens:
if kwargs.get('fuzzy_with_tokens', False):
return ret, skipped_tokens
return ret
else:
return ret
class _result(_resultbase):
__slots__ = ["year", "month", "day", "weekday",
"hour", "minute", "second", "microsecond",
"tzname", "tzoffset"]
def _parse(self, timestr, dayfirst=None, yearfirst=None, fuzzy=False, fuzzy_with_tokens=False):
def _parse(self, timestr, dayfirst=None, yearfirst=None, fuzzy=False,
fuzzy_with_tokens=False):
if fuzzy_with_tokens:
fuzzy = True
@ -365,7 +357,6 @@ class parser(object):
res = self._result()
l = _timelex.split(timestr)
# keep up with the last token skipped so we can recombine
# consecutively skipped tokens (-2 for when i begins at 0).
last_skipped_token_i = -2
@ -440,12 +431,12 @@ class parser(object):
while True:
if idx == 0:
res.hour = int(value)
if value%1:
res.minute = int(60*(value%1))
if value % 1:
res.minute = int(60*(value % 1))
elif idx == 1:
res.minute = int(value)
if value%1:
res.second = int(60*(value%1))
if value % 1:
res.second = int(60*(value % 1))
elif idx == 2:
res.second, res.microsecond = \
_parsems(value_repr)
@ -465,16 +456,17 @@ class parser(object):
newidx = info.hms(l[i])
if newidx is not None:
idx = newidx
elif i == len_l and l[i-2] == ' ' and info.hms(l[i-3]) is not None:
elif (i == len_l and l[i-2] == ' ' and
info.hms(l[i-3]) is not None):
# X h MM or X m SS
idx = info.hms(l[i-3]) + 1
if idx == 1:
res.minute = int(value)
if value%1:
res.second = int(60*(value%1))
if value % 1:
res.second = int(60*(value % 1))
elif idx == 2:
res.second, res.microsecond = \
_parsems(value_repr)
_parsems(value_repr)
i += 1
elif i+1 < len_l and l[i] == ':':
# HH:MM[:SS[.ss]]
@ -482,8 +474,8 @@ class parser(object):
i += 1
value = float(l[i])
res.minute = int(value)
if value%1:
res.second = int(60*(value%1))
if value % 1:
res.second = int(60*(value % 1))
i += 1
if i < len_l and l[i] == ':':
res.second, res.microsecond = _parsems(l[i+1])
@ -597,8 +589,9 @@ class parser(object):
# Check for a timezone name
if (res.hour is not None and len(l[i]) <= 5 and
res.tzname is None and res.tzoffset is None and
not [x for x in l[i] if x not in string.ascii_uppercase]):
res.tzname is None and res.tzoffset is None and
not [x for x in l[i] if x not in
string.ascii_uppercase]):
res.tzname = l[i]
res.tzoffset = info.tzoffset(res.tzname)
i += 1
@ -643,7 +636,7 @@ class parser(object):
info.jump(l[i]) and l[i+1] == '(' and l[i+3] == ')' and
3 <= len(l[i+2]) <= 5 and
not [x for x in l[i+2]
if x not in string.ascii_uppercase]):
if x not in string.ascii_uppercase]):
# -0300 (BRST)
res.tzname = l[i+2]
i += 4
@ -732,10 +725,12 @@ class parser(object):
if fuzzy_with_tokens:
return res, tuple(skipped_tokens)
return res, None
else:
return res
DEFAULTPARSER = parser()
def parse(timestr, parserinfo=None, **kwargs):
# Python 2.x support: datetimes return their string presentation as
# bytes in 2.x and unicode in 3.x, so it's reasonable to expect that
@ -779,7 +774,7 @@ class _tzparser(object):
# BRST+3[BRDT[+2]]
j = i
while j < len_l and not [x for x in l[j]
if x in "0123456789:,-+"]:
if x in "0123456789:,-+"]:
j += 1
if j != i:
if not res.stdabbr:
@ -789,8 +784,8 @@ class _tzparser(object):
offattr = "dstoffset"
res.dstabbr = "".join(l[i:j])
i = j
if (i < len_l and
(l[i] in ('+', '-') or l[i][0] in "0123456789")):
if (i < len_l and (l[i] in ('+', '-') or l[i][0] in
"0123456789")):
if l[i] in ('+', '-'):
# Yes, that's right. See the TZ variable
# documentation.
@ -801,8 +796,8 @@ class _tzparser(object):
len_li = len(l[i])
if len_li == 4:
# -0300
setattr(res, offattr,
(int(l[i][:2])*3600+int(l[i][2:])*60)*signal)
setattr(res, offattr, (int(l[i][:2])*3600 +
int(l[i][2:])*60)*signal)
elif i+1 < len_l and l[i+1] == ':':
# -03:00
setattr(res, offattr,
@ -822,7 +817,8 @@ class _tzparser(object):
if i < len_l:
for j in range(i, len_l):
if l[j] == ';': l[j] = ','
if l[j] == ';':
l[j] = ','
assert l[i] == ','
@ -831,7 +827,7 @@ class _tzparser(object):
if i >= len_l:
pass
elif (8 <= l.count(',') <= 9 and
not [y for x in l[i:] if x != ','
not [y for x in l[i:] if x != ','
for y in x if y not in "0123456789"]):
# GMT0BST,3,0,30,3600,10,0,26,7200[,3600]
for x in (res.start, res.end):
@ -845,7 +841,7 @@ class _tzparser(object):
i += 2
if value:
x.week = value
x.weekday = (int(l[i])-1)%7
x.weekday = (int(l[i])-1) % 7
else:
x.day = int(l[i])
i += 2
@ -861,7 +857,7 @@ class _tzparser(object):
elif (l.count(',') == 2 and l[i:].count('/') <= 2 and
not [y for x in l[i:] if x not in (',', '/', 'J', 'M',
'.', '-', ':')
for y in x if y not in "0123456789"]):
for y in x if y not in "0123456789"]):
for x in (res.start, res.end):
if l[i] == 'J':
# non-leap year day (1 based)
@ -880,7 +876,7 @@ class _tzparser(object):
i += 1
assert l[i] in ('-', '.')
i += 1
x.weekday = (int(l[i])-1)%7
x.weekday = (int(l[i])-1) % 7
else:
# year day (zero based)
x.yday = int(l[i])+1
@ -921,6 +917,8 @@ class _tzparser(object):
DEFAULTTZPARSER = _tzparser()
def _parsetz(tzstr):
return DEFAULTTZPARSER.parse(tzstr)

View File

@ -1,11 +1,4 @@
"""
Copyright (c) 2003-2010 Gustavo Niemeyer <gustavo@niemeyer.net>
This module offers extensions to the standard Python
datetime module.
"""
__license__ = "Simplified BSD"
# -*- coding: utf-8 -*-
import datetime
import calendar
@ -13,6 +6,7 @@ from six import integer_types
__all__ = ["relativedelta", "MO", "TU", "WE", "TH", "FR", "SA", "SU"]
class weekday(object):
__slots__ = ["weekday", "n"]
@ -43,25 +37,35 @@ class weekday(object):
MO, TU, WE, TH, FR, SA, SU = weekdays = tuple([weekday(x) for x in range(7)])
class relativedelta(object):
"""
The relativedelta type is based on the specification of the excelent
work done by M.-A. Lemburg in his mx.DateTime extension. However,
notice that this type does *NOT* implement the same algorithm as
The relativedelta type is based on the specification of the excellent
work done by M.-A. Lemburg in his
`mx.DateTime <http://www.egenix.com/files/python/mxDateTime.html>`_ extension.
However, notice that this type does *NOT* implement the same algorithm as
his work. Do *NOT* expect it to behave like mx.DateTime's counterpart.
There's two different ways to build a relativedelta instance. The
first one is passing it two date/datetime classes:
There are two different ways to build a relativedelta instance. The
first one is passing it two date/datetime classes::
relativedelta(datetime1, datetime2)
And the other way is to use the following keyword arguments:
The second one is passing it any number of the following keyword arguments::
relativedelta(arg1=x,arg2=y,arg3=z...)
year, month, day, hour, minute, second, microsecond:
Absolute information.
Absolute information (argument is singular); adding or subtracting a
relativedelta with absolute information does not perform an aritmetic
operation, but rather REPLACES the corresponding value in the
original datetime with the value(s) in relativedelta.
years, months, weeks, days, hours, minutes, seconds, microseconds:
Relative information, may be negative.
Relative information, may be negative (argument is plural); adding
or subtracting a relativedelta with relative information performs
the corresponding aritmetic operation on the original datetime value
with the information in the relativedelta.
weekday:
One of the weekday instances (MO, TU, etc). These instances may
@ -80,26 +84,26 @@ And the other way is to use the following keyword arguments:
Here is the behavior of operations with relativedelta:
1) Calculate the absolute year, using the 'year' argument, or the
1. Calculate the absolute year, using the 'year' argument, or the
original datetime year, if the argument is not present.
2) Add the relative 'years' argument to the absolute year.
2. Add the relative 'years' argument to the absolute year.
3) Do steps 1 and 2 for month/months.
3. Do steps 1 and 2 for month/months.
4) Calculate the absolute day, using the 'day' argument, or the
4. Calculate the absolute day, using the 'day' argument, or the
original datetime day, if the argument is not present. Then,
subtract from the day until it fits in the year and month
found after their operations.
5) Add the relative 'days' argument to the absolute day. Notice
5. Add the relative 'days' argument to the absolute day. Notice
that the 'weeks' argument is multiplied by 7 and added to
'days'.
6) Do steps 1 and 2 for hour/hours, minute/minutes, second/seconds,
6. Do steps 1 and 2 for hour/hours, minute/minutes, second/seconds,
microsecond/microseconds.
7) If the 'weekday' argument is present, calculate the weekday,
7. If the 'weekday' argument is present, calculate the weekday,
with the given (wday, nth) tuple. wday is the index of the
weekday (0-6, 0=Mon), and nth is the number of weeks to add
forward or backward, depending on its signal. Notice that if
@ -114,9 +118,14 @@ Here is the behavior of operations with relativedelta:
yearday=None, nlyearday=None,
hour=None, minute=None, second=None, microsecond=None):
if dt1 and dt2:
if (not isinstance(dt1, datetime.date)) or (not isinstance(dt2, datetime.date)):
# datetime is a subclass of date. So both must be date
if not (isinstance(dt1, datetime.date) and
isinstance(dt2, datetime.date)):
raise TypeError("relativedelta only diffs datetime/date")
if not type(dt1) == type(dt2): #isinstance(dt1, type(dt2)):
# We allow two dates, or two datetimes, so we coerce them to be
# of the same type
if (isinstance(dt1, datetime.datetime) !=
isinstance(dt2, datetime.datetime)):
if not isinstance(dt1, datetime.datetime):
dt1 = datetime.datetime.fromordinal(dt1.toordinal())
elif not isinstance(dt2, datetime.datetime):
@ -185,7 +194,8 @@ Here is the behavior of operations with relativedelta:
if yearday > 59:
self.leapdays = -1
if yday:
ydayidx = [31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 366]
ydayidx = [31, 59, 90, 120, 151, 181, 212,
243, 273, 304, 334, 366]
for idx, ydays in enumerate(ydayidx):
if yday <= ydays:
self.month = idx+1
@ -225,9 +235,9 @@ Here is the behavior of operations with relativedelta:
div, mod = divmod(self.months*s, 12)
self.months = mod*s
self.years += div*s
if (self.hours or self.minutes or self.seconds or self.microseconds or
self.hour is not None or self.minute is not None or
self.second is not None or self.microsecond is not None):
if (self.hours or self.minutes or self.seconds or self.microseconds
or self.hour is not None or self.minute is not None or
self.second is not None or self.microsecond is not None):
self._has_time = 1
else:
self._has_time = 0
@ -245,21 +255,23 @@ Here is the behavior of operations with relativedelta:
def __add__(self, other):
if isinstance(other, relativedelta):
return relativedelta(years=other.years+self.years,
months=other.months+self.months,
days=other.days+self.days,
hours=other.hours+self.hours,
minutes=other.minutes+self.minutes,
seconds=other.seconds+self.seconds,
microseconds=other.microseconds+self.microseconds,
leapdays=other.leapdays or self.leapdays,
year=other.year or self.year,
month=other.month or self.month,
day=other.day or self.day,
weekday=other.weekday or self.weekday,
hour=other.hour or self.hour,
minute=other.minute or self.minute,
second=other.second or self.second,
microsecond=other.microsecond or self.microsecond)
months=other.months+self.months,
days=other.days+self.days,
hours=other.hours+self.hours,
minutes=other.minutes+self.minutes,
seconds=other.seconds+self.seconds,
microseconds=(other.microseconds +
self.microseconds),
leapdays=other.leapdays or self.leapdays,
year=other.year or self.year,
month=other.month or self.month,
day=other.day or self.day,
weekday=other.weekday or self.weekday,
hour=other.hour or self.hour,
minute=other.minute or self.minute,
second=other.second or self.second,
microsecond=(other.microsecond or
self.microsecond))
if not isinstance(other, datetime.date):
raise TypeError("unsupported type for add operation")
elif self._has_time and not isinstance(other, datetime.datetime):
@ -295,9 +307,9 @@ Here is the behavior of operations with relativedelta:
weekday, nth = self.weekday.weekday, self.weekday.n or 1
jumpdays = (abs(nth)-1)*7
if nth > 0:
jumpdays += (7-ret.weekday()+weekday)%7
jumpdays += (7-ret.weekday()+weekday) % 7
else:
jumpdays += (ret.weekday()-weekday)%7
jumpdays += (ret.weekday()-weekday) % 7
jumpdays *= -1
ret += datetime.timedelta(days=jumpdays)
return ret

View File

@ -1,21 +1,19 @@
# -*- coding: utf-8 -*-
"""
Copyright (c) 2003-2010 Gustavo Niemeyer <gustavo@niemeyer.net>
This module offers extensions to the standard Python
datetime module.
The rrule module offers a small, complete, and very fast, implementation of
the recurrence rules documented in the
`iCalendar RFC <http://www.ietf.org/rfc/rfc2445.txt>`_,
including support for caching of results.
"""
__license__ = "Simplified BSD"
import itertools
import datetime
import calendar
try:
import _thread
except ImportError:
import thread as _thread
import sys
from fractions import gcd
from six import advance_iterator, integer_types
from six.moves import _thread
__all__ = ["rrule", "rruleset", "rrulestr",
"YEARLY", "MONTHLY", "WEEKLY", "DAILY",
@ -23,7 +21,7 @@ __all__ = ["rrule", "rruleset", "rrulestr",
"MO", "TU", "WE", "TH", "FR", "SA", "SU"]
# Every mask is 7 days longer to handle cross-year weekly periods.
M366MASK = tuple([1]*31+[2]*29+[3]*31+[4]*30+[5]*31+[6]*30+
M366MASK = tuple([1]*31+[2]*29+[3]*31+[4]*30+[5]*31+[6]*30 +
[7]*31+[8]*31+[9]*30+[10]*31+[11]*30+[12]*31+[1]*7)
M365MASK = list(M366MASK)
M29, M30, M31 = list(range(1, 30)), list(range(1, 31)), list(range(1, 32))
@ -51,6 +49,7 @@ M365MASK = tuple(M365MASK)
easter = None
parser = None
class weekday(object):
__slots__ = ["weekday", "n"]
@ -83,12 +82,13 @@ class weekday(object):
MO, TU, WE, TH, FR, SA, SU = weekdays = tuple([weekday(x) for x in range(7)])
class rrulebase(object):
def __init__(self, cache=False):
if cache:
self._cache = []
self._cache_lock = _thread.allocate_lock()
self._cache_gen = self._iter()
self._cache_gen = self._iter()
self._cache_complete = False
else:
self._cache = None
@ -163,11 +163,17 @@ class rrulebase(object):
# __len__() introduces a large performance penality.
def count(self):
""" Returns the number of recurrences in this set. It will have go
trough the whole recurrence, if this hasn't been done before. """
if self._len is None:
for x in self: pass
for x in self:
pass
return self._len
def before(self, dt, inc=False):
""" Returns the last recurrence before the given datetime instance. The
inc keyword defines what happens if dt is an occurrence. With
inc=True, if dt itself is an occurrence, it will be returned. """
if self._cache_complete:
gen = self._cache
else:
@ -186,6 +192,9 @@ class rrulebase(object):
return last
def after(self, dt, inc=False):
""" Returns the first recurrence after the given datetime instance. The
inc keyword defines what happens if dt is an occurrence. With
inc=True, if dt itself is an occurrence, it will be returned. """
if self._cache_complete:
gen = self._cache
else:
@ -201,6 +210,10 @@ class rrulebase(object):
return None
def between(self, after, before, inc=False):
""" Returns all the occurrences of the rrule between after and before.
The inc keyword defines what happens if after and/or before are
themselves occurrences. With inc=True, they will be included in the
list, if they are found in the recurrence set. """
if self._cache_complete:
gen = self._cache
else:
@ -229,7 +242,93 @@ class rrulebase(object):
l.append(i)
return l
class rrule(rrulebase):
"""
That's the base of the rrule operation. It accepts all the keywords
defined in the RFC as its constructor parameters (except byday,
which was renamed to byweekday) and more. The constructor prototype is::
rrule(freq)
Where freq must be one of YEARLY, MONTHLY, WEEKLY, DAILY, HOURLY, MINUTELY,
or SECONDLY.
Additionally, it supports the following keyword arguments:
:param cache:
If given, it must be a boolean value specifying to enable or disable
caching of results. If you will use the same rrule instance multiple
times, enabling caching will improve the performance considerably.
:param dtstart:
The recurrence start. Besides being the base for the recurrence,
missing parameters in the final recurrence instances will also be
extracted from this date. If not given, datetime.now() will be used
instead.
:param interval:
The interval between each freq iteration. For example, when using
YEARLY, an interval of 2 means once every two years, but with HOURLY,
it means once every two hours. The default interval is 1.
:param wkst:
The week start day. Must be one of the MO, TU, WE constants, or an
integer, specifying the first day of the week. This will affect
recurrences based on weekly periods. The default week start is got
from calendar.firstweekday(), and may be modified by
calendar.setfirstweekday().
:param count:
How many occurrences will be generated.
:param until:
If given, this must be a datetime instance, that will specify the
limit of the recurrence. If a recurrence instance happens to be the
same as the datetime instance given in the until keyword, this will
be the last occurrence.
:param bysetpos:
If given, it must be either an integer, or a sequence of integers,
positive or negative. Each given integer will specify an occurrence
number, corresponding to the nth occurrence of the rule inside the
frequency period. For example, a bysetpos of -1 if combined with a
MONTHLY frequency, and a byweekday of (MO, TU, WE, TH, FR), will
result in the last work day of every month.
:param bymonth:
If given, it must be either an integer, or a sequence of integers,
meaning the months to apply the recurrence to.
:param bymonthday:
If given, it must be either an integer, or a sequence of integers,
meaning the month days to apply the recurrence to.
:param byyearday:
If given, it must be either an integer, or a sequence of integers,
meaning the year days to apply the recurrence to.
:param byweekno:
If given, it must be either an integer, or a sequence of integers,
meaning the week numbers to apply the recurrence to. Week numbers
have the meaning described in ISO8601, that is, the first week of
the year is that containing at least four days of the new year.
:param byweekday:
If given, it must be either an integer (0 == MO), a sequence of
integers, one of the weekday constants (MO, TU, etc), or a sequence
of these constants. When given, these variables will define the
weekdays where the recurrence will be applied. It's also possible to
use an argument n for the weekday instances, which will mean the nth
occurrence of this weekday in the period. For example, with MONTHLY,
or with YEARLY and BYMONTH, using FR(+1) in byweekday will specify the
first friday of the month where the recurrence happens. Notice that in
the RFC documentation, this is specified as BYDAY, but was renamed to
avoid the ambiguity of that keyword.
:param byhour:
If given, it must be either an integer, or a sequence of integers,
meaning the hours to apply the recurrence to.
:param byminute:
If given, it must be either an integer, or a sequence of integers,
meaning the minutes to apply the recurrence to.
:param bysecond:
If given, it must be either an integer, or a sequence of integers,
meaning the seconds to apply the recurrence to.
:param byeaster:
If given, it must be either an integer, or a sequence of integers,
positive or negative. Each integer will define an offset from the
Easter Sunday. Passing the offset 0 to byeaster will yield the Easter
Sunday itself. This is an extension to the RFC specification.
"""
def __init__(self, freq, dtstart=None,
interval=1, wkst=None, count=None, until=None, bysetpos=None,
bymonth=None, bymonthday=None, byyearday=None, byeaster=None,
@ -249,15 +348,18 @@ class rrule(rrulebase):
self._freq = freq
self._interval = interval
self._count = count
if until and not isinstance(until, datetime.datetime):
until = datetime.datetime.fromordinal(until.toordinal())
self._until = until
if wkst is None:
self._wkst = calendar.firstweekday()
elif isinstance(wkst, integer_types):
self._wkst = wkst
else:
self._wkst = wkst.weekday
if bysetpos is None:
self._bysetpos = None
elif isinstance(bysetpos, integer_types):
@ -271,30 +373,36 @@ class rrule(rrulebase):
if pos == 0 or not (-366 <= pos <= 366):
raise ValueError("bysetpos must be between 1 and 366, "
"or between -366 and -1")
if not (byweekno or byyearday or bymonthday or
byweekday is not None or byeaster is not None):
if (byweekno is None and byyearday is None and bymonthday is None and
byweekday is None and byeaster is None):
if freq == YEARLY:
if not bymonth:
if bymonth is None:
bymonth = dtstart.month
bymonthday = dtstart.day
elif freq == MONTHLY:
bymonthday = dtstart.day
elif freq == WEEKLY:
byweekday = dtstart.weekday()
# bymonth
if not bymonth:
if bymonth is None:
self._bymonth = None
elif isinstance(bymonth, integer_types):
self._bymonth = (bymonth,)
else:
self._bymonth = tuple(bymonth)
if isinstance(bymonth, integer_types):
bymonth = (bymonth,)
self._bymonth = set(bymonth)
# byyearday
if not byyearday:
if byyearday is None:
self._byyearday = None
elif isinstance(byyearday, integer_types):
self._byyearday = (byyearday,)
else:
self._byyearday = tuple(byyearday)
if isinstance(byyearday, integer_types):
byyearday = (byyearday,)
self._byyearday = set(byyearday)
# byeaster
if byeaster is not None:
if not easter:
@ -305,87 +413,104 @@ class rrule(rrulebase):
self._byeaster = tuple(byeaster)
else:
self._byeaster = None
# bymonthay
if not bymonthday:
if bymonthday is None:
self._bymonthday = ()
self._bynmonthday = ()
elif isinstance(bymonthday, integer_types):
if bymonthday < 0:
self._bynmonthday = (bymonthday,)
self._bymonthday = ()
else:
self._bymonthday = (bymonthday,)
self._bynmonthday = ()
else:
self._bymonthday = tuple([x for x in bymonthday if x > 0])
self._bynmonthday = tuple([x for x in bymonthday if x < 0])
if isinstance(bymonthday, integer_types):
bymonthday = (bymonthday,)
self._bymonthday = set([x for x in bymonthday if x > 0])
self._bynmonthday = set([x for x in bymonthday if x < 0])
# byweekno
if byweekno is None:
self._byweekno = None
elif isinstance(byweekno, integer_types):
self._byweekno = (byweekno,)
else:
self._byweekno = tuple(byweekno)
if isinstance(byweekno, integer_types):
byweekno = (byweekno,)
self._byweekno = set(byweekno)
# byweekday / bynweekday
if byweekday is None:
self._byweekday = None
self._bynweekday = None
elif isinstance(byweekday, integer_types):
self._byweekday = (byweekday,)
self._bynweekday = None
elif hasattr(byweekday, "n"):
if not byweekday.n or freq > MONTHLY:
self._byweekday = (byweekday.weekday,)
self._bynweekday = None
else:
self._bynweekday = ((byweekday.weekday, byweekday.n),)
self._byweekday = None
else:
self._byweekday = []
self._bynweekday = []
if isinstance(byweekday, integer_types):
byweekday = (byweekday,)
elif hasattr(byweekday, "n"):
byweekday = (byweekday.weekday,)
self._byweekday = set()
self._bynweekday = set()
for wday in byweekday:
if isinstance(wday, integer_types):
self._byweekday.append(wday)
self._byweekday.add(wday)
elif not wday.n or freq > MONTHLY:
self._byweekday.append(wday.weekday)
self._byweekday.add(wday.weekday)
else:
self._bynweekday.append((wday.weekday, wday.n))
self._byweekday = tuple(self._byweekday)
self._bynweekday = tuple(self._bynweekday)
self._bynweekday.add((wday.weekday, wday.n))
if not self._byweekday:
self._byweekday = None
elif not self._bynweekday:
self._bynweekday = None
# byhour
if byhour is None:
if freq < HOURLY:
self._byhour = (dtstart.hour,)
self._byhour = set((dtstart.hour,))
else:
self._byhour = None
elif isinstance(byhour, integer_types):
self._byhour = (byhour,)
else:
self._byhour = tuple(byhour)
if isinstance(byhour, integer_types):
byhour = (byhour,)
if freq == HOURLY:
self._byhour = self.__construct_byset(start=dtstart.hour,
byxxx=byhour,
base=24)
else:
self._byhour = set(byhour)
# byminute
if byminute is None:
if freq < MINUTELY:
self._byminute = (dtstart.minute,)
self._byminute = set((dtstart.minute,))
else:
self._byminute = None
elif isinstance(byminute, integer_types):
self._byminute = (byminute,)
else:
self._byminute = tuple(byminute)
if isinstance(byminute, integer_types):
byminute = (byminute,)
if freq == MINUTELY:
self._byminute = self.__construct_byset(start=dtstart.minute,
byxxx=byminute,
base=60)
else:
self._byminute = set(byminute)
# bysecond
if bysecond is None:
if freq < SECONDLY:
self._bysecond = (dtstart.second,)
self._bysecond = ((dtstart.second,))
else:
self._bysecond = None
elif isinstance(bysecond, integer_types):
self._bysecond = (bysecond,)
else:
self._bysecond = tuple(bysecond)
if isinstance(bysecond, integer_types):
bysecond = (bysecond,)
self._bysecond = set(bysecond)
if freq == SECONDLY:
self._bysecond = self.__construct_byset(start=dtstart.second,
byxxx=bysecond,
base=60)
else:
self._bysecond = set(bysecond)
if self._freq >= HOURLY:
self._timeset = None
@ -395,8 +520,8 @@ class rrule(rrulebase):
for minute in self._byminute:
for second in self._bysecond:
self._timeset.append(
datetime.time(hour, minute, second,
tzinfo=self._tzinfo))
datetime.time(hour, minute, second,
tzinfo=self._tzinfo))
self._timeset.sort()
self._timeset = tuple(self._timeset)
@ -424,20 +549,20 @@ class rrule(rrulebase):
ii = _iterinfo(self)
ii.rebuild(year, month)
getdayset = {YEARLY:ii.ydayset,
MONTHLY:ii.mdayset,
WEEKLY:ii.wdayset,
DAILY:ii.ddayset,
HOURLY:ii.ddayset,
MINUTELY:ii.ddayset,
SECONDLY:ii.ddayset}[freq]
getdayset = {YEARLY: ii.ydayset,
MONTHLY: ii.mdayset,
WEEKLY: ii.wdayset,
DAILY: ii.ddayset,
HOURLY: ii.ddayset,
MINUTELY: ii.ddayset,
SECONDLY: ii.ddayset}[freq]
if freq < HOURLY:
timeset = self._timeset
else:
gettimeset = {HOURLY:ii.htimeset,
MINUTELY:ii.mtimeset,
SECONDLY:ii.stimeset}[freq]
gettimeset = {HOURLY: ii.htimeset,
MINUTELY: ii.mtimeset,
SECONDLY: ii.stimeset}[freq]
if ((freq >= HOURLY and
self._byhour and hour not in self._byhour) or
(freq >= MINUTELY and
@ -466,11 +591,10 @@ class rrule(rrulebase):
ii.mdaymask[i] not in bymonthday and
ii.nmdaymask[i] not in bynmonthday) or
(byyearday and
((i < ii.yearlen and i+1 not in byyearday
and -ii.yearlen+i not in byyearday) or
(i >= ii.yearlen and i+1-ii.yearlen not in byyearday
and -ii.nextyearlen+i-ii.yearlen
not in byyearday)))):
((i < ii.yearlen and i+1 not in byyearday and
-ii.yearlen+i not in byyearday) or
(i >= ii.yearlen and i+1-ii.yearlen not in byyearday and
-ii.nextyearlen+i-ii.yearlen not in byyearday)))):
dayset[i] = None
filtered = True
@ -484,7 +608,7 @@ class rrule(rrulebase):
daypos, timepos = divmod(pos-1, len(timeset))
try:
i = [x for x in dayset[start:end]
if x is not None][daypos]
if x is not None][daypos]
time = timeset[timepos]
except IndexError:
pass
@ -559,60 +683,86 @@ class rrule(rrulebase):
if filtered:
# Jump to one iteration before next day
hour += ((23-hour)//interval)*interval
while True:
hour += interval
div, mod = divmod(hour, 24)
if div:
hour = mod
day += div
fixday = True
if not byhour or hour in byhour:
break
if byhour:
ndays, hour = self.__mod_distance(value=hour,
byxxx=self._byhour,
base=24)
else:
ndays, hour = divmod(hour+interval, 24)
if ndays:
day += ndays
fixday = True
timeset = gettimeset(hour, minute, second)
elif freq == MINUTELY:
if filtered:
# Jump to one iteration before next day
minute += ((1439-(hour*60+minute))//interval)*interval
while True:
minute += interval
div, mod = divmod(minute, 60)
valid = False
rep_rate = (24*60)
for j in range(rep_rate // gcd(interval, rep_rate)):
if byminute:
nhours, minute = \
self.__mod_distance(value=minute,
byxxx=self._byminute,
base=60)
else:
nhours, minute = divmod(minute+interval, 60)
div, hour = divmod(hour+nhours, 24)
if div:
minute = mod
hour += div
div, mod = divmod(hour, 24)
if div:
hour = mod
day += div
fixday = True
filtered = False
if ((not byhour or hour in byhour) and
(not byminute or minute in byminute)):
day += div
fixday = True
filtered = False
if not byhour or hour in byhour:
valid = True
break
if not valid:
raise ValueError('Invalid combination of interval and ' +
'byhour resulting in empty rule.')
timeset = gettimeset(hour, minute, second)
elif freq == SECONDLY:
if filtered:
# Jump to one iteration before next day
second += (((86399-(hour*3600+minute*60+second))
//interval)*interval)
while True:
second += self._interval
div, mod = divmod(second, 60)
// interval)*interval)
rep_rate = (24*3600)
valid = False
for j in range(0, rep_rate // gcd(interval, rep_rate)):
if bysecond:
nminutes, second = \
self.__mod_distance(value=second,
byxxx=self._bysecond,
base=60)
else:
nminutes, second = divmod(second+interval, 60)
div, minute = divmod(minute+nminutes, 60)
if div:
second = mod
minute += div
div, mod = divmod(minute, 60)
hour += div
div, hour = divmod(hour, 24)
if div:
minute = mod
hour += div
div, mod = divmod(hour, 24)
if div:
hour = mod
day += div
fixday = True
day += div
fixday = True
if ((not byhour or hour in byhour) and
(not byminute or minute in byminute) and
(not bysecond or second in bysecond)):
(not byminute or minute in byminute) and
(not bysecond or second in bysecond)):
valid = True
break
if not valid:
raise ValueError('Invalid combination of interval, ' +
'byhour and byminute resulting in empty' +
' rule.')
timeset = gettimeset(hour, minute, second)
if fixday and day > 28:
@ -630,6 +780,80 @@ class rrule(rrulebase):
daysinmonth = calendar.monthrange(year, month)[1]
ii.rebuild(year, month)
def __construct_byset(self, start, byxxx, base):
"""
If a `BYXXX` sequence is passed to the constructor at the same level as
`FREQ` (e.g. `FREQ=HOURLY,BYHOUR={2,4,7},INTERVAL=3`), there are some
specifications which cannot be reached given some starting conditions.
This occurs whenever the interval is not coprime with the base of a
given unit and the difference between the starting position and the
ending position is not coprime with the greatest common denominator
between the interval and the base. For example, with a FREQ of hourly
starting at 17:00 and an interval of 4, the only valid values for
BYHOUR would be {21, 1, 5, 9, 13, 17}, because 4 and 24 are not
coprime.
:param:`start` specifies the starting position.
:param:`byxxx` is an iterable containing the list of allowed values.
:param:`base` is the largest allowable value for the specified
frequency (e.g. 24 hours, 60 minutes).
This does not preserve the type of the iterable, returning a set, since
the values should be unique and the order is irrelevant, this will
speed up later lookups.
In the event of an empty set, raises a :exception:`ValueError`, as this
results in an empty rrule.
"""
cset = set()
# Support a single byxxx value.
if isinstance(byxxx, integer_types):
byxxx = (byxxx)
for num in byxxx:
i_gcd = gcd(self._interval, base)
# Use divmod rather than % because we need to wrap negative nums.
if i_gcd == 1 or divmod(num - start, i_gcd)[1] == 0:
cset.add(num)
if len(cset) == 0:
raise ValueError("Invalid rrule byxxx generates an empty set.")
return cset
def __mod_distance(self, value, byxxx, base):
"""
Calculates the next value in a sequence where the `FREQ` parameter is
specified along with a `BYXXX` parameter at the same "level"
(e.g. `HOURLY` specified with `BYHOUR`).
:param:`value` is the old value of the component.
:param:`byxxx` is the `BYXXX` set, which should have been generated
by `rrule._construct_byset`, or something else which
checks that a valid rule is present.
:param:`base` is the largest allowable value for the specified
frequency (e.g. 24 hours, 60 minutes).
If a valid value is not found after `base` iterations (the maximum
number before the sequence would start to repeat), this raises a
:exception:`ValueError`, as no valid values were found.
This returns a tuple of `divmod(n*interval, base)`, where `n` is the
smallest number of `interval` repetitions until the next specified
value in `byxxx` is found.
"""
accumulator = 0
for ii in range(1, base + 1):
# Using divmod() over % to account for negative intervals
div, value = divmod(value + self._interval, base)
accumulator += div
if value in byxxx:
return (accumulator, value)
class _iterinfo(object):
__slots__ = ["rrule", "lastyear", "lastmonth",
"yearlen", "nextyearlen", "yearordinal", "yearweekday",
@ -669,13 +893,13 @@ class _iterinfo(object):
self.wnomask = None
else:
self.wnomask = [0]*(self.yearlen+7)
#no1wkst = firstwkst = self.wdaymask.index(rr._wkst)
no1wkst = firstwkst = (7-self.yearweekday+rr._wkst)%7
# no1wkst = firstwkst = self.wdaymask.index(rr._wkst)
no1wkst = firstwkst = (7-self.yearweekday+rr._wkst) % 7
if no1wkst >= 4:
no1wkst = 0
# Number of days in the year, plus the days we got
# from last year.
wyearlen = self.yearlen+(self.yearweekday-rr._wkst)%7
wyearlen = self.yearlen+(self.yearweekday-rr._wkst) % 7
else:
# Number of days in the year, minus the days we
# left in last year.
@ -721,22 +945,22 @@ class _iterinfo(object):
# this year.
if -1 not in rr._byweekno:
lyearweekday = datetime.date(year-1, 1, 1).weekday()
lno1wkst = (7-lyearweekday+rr._wkst)%7
lno1wkst = (7-lyearweekday+rr._wkst) % 7
lyearlen = 365+calendar.isleap(year-1)
if lno1wkst >= 4:
lno1wkst = 0
lnumweeks = 52+(lyearlen+
(lyearweekday-rr._wkst)%7)%7//4
lnumweeks = 52+(lyearlen +
(lyearweekday-rr._wkst) % 7) % 7//4
else:
lnumweeks = 52+(self.yearlen-no1wkst)%7//4
lnumweeks = 52+(self.yearlen-no1wkst) % 7//4
else:
lnumweeks = -1
if lnumweeks in rr._byweekno:
for i in range(no1wkst):
self.wnomask[i] = 1
if (rr._bynweekday and
(month != self.lastmonth or year != self.lastyear)):
if (rr._bynweekday and (month != self.lastmonth or
year != self.lastyear)):
ranges = []
if rr._freq == YEARLY:
if rr._bymonth:
@ -755,10 +979,10 @@ class _iterinfo(object):
for wday, n in rr._bynweekday:
if n < 0:
i = last+(n+1)*7
i -= (self.wdaymask[i]-wday)%7
i -= (self.wdaymask[i]-wday) % 7
else:
i = first+(n-1)*7
i += (7-self.wdaymask[i]+wday)%7
i += (7-self.wdaymask[i]+wday) % 7
if first <= i <= last:
self.nwdaymask[i] = 1
@ -775,50 +999,50 @@ class _iterinfo(object):
return list(range(self.yearlen)), 0, self.yearlen
def mdayset(self, year, month, day):
set = [None]*self.yearlen
dset = [None]*self.yearlen
start, end = self.mrange[month-1:month+1]
for i in range(start, end):
set[i] = i
return set, start, end
dset[i] = i
return dset, start, end
def wdayset(self, year, month, day):
# We need to handle cross-year weeks here.
set = [None]*(self.yearlen+7)
dset = [None]*(self.yearlen+7)
i = datetime.date(year, month, day).toordinal()-self.yearordinal
start = i
for j in range(7):
set[i] = i
dset[i] = i
i += 1
#if (not (0 <= i < self.yearlen) or
# if (not (0 <= i < self.yearlen) or
# self.wdaymask[i] == self.rrule._wkst):
# This will cross the year boundary, if necessary.
if self.wdaymask[i] == self.rrule._wkst:
break
return set, start, i
return dset, start, i
def ddayset(self, year, month, day):
set = [None]*self.yearlen
dset = [None]*self.yearlen
i = datetime.date(year, month, day).toordinal()-self.yearordinal
set[i] = i
return set, i, i+1
dset[i] = i
return dset, i, i+1
def htimeset(self, hour, minute, second):
set = []
tset = []
rr = self.rrule
for minute in rr._byminute:
for second in rr._bysecond:
set.append(datetime.time(hour, minute, second,
tset.append(datetime.time(hour, minute, second,
tzinfo=rr._tzinfo))
set.sort()
return set
tset.sort()
return tset
def mtimeset(self, hour, minute, second):
set = []
tset = []
rr = self.rrule
for second in rr._bysecond:
set.append(datetime.time(hour, minute, second, tzinfo=rr._tzinfo))
set.sort()
return set
tset.append(datetime.time(hour, minute, second, tzinfo=rr._tzinfo))
tset.sort()
return tset
def stimeset(self, hour, minute, second):
return (datetime.time(hour, minute, second,
@ -826,6 +1050,12 @@ class _iterinfo(object):
class rruleset(rrulebase):
""" The rruleset type allows more complex recurrence setups, mixing
multiple rules, dates, exclusion rules, and exclusion dates. The type
constructor takes the following keyword arguments:
:param cache: If True, caching of results will be enabled, improving
performance of multiple queries considerably. """
class _genitem(object):
def __init__(self, genlist, gen):
@ -865,15 +1095,26 @@ class rruleset(rrulebase):
self._exdate = []
def rrule(self, rrule):
""" Include the given :py:class:`rrule` instance in the recurrence set
generation. """
self._rrule.append(rrule)
def rdate(self, rdate):
""" Include the given :py:class:`datetime` instance in the recurrence
set generation. """
self._rdate.append(rdate)
def exrule(self, exrule):
""" Include the given rrule instance in the recurrence set exclusion
list. Dates which are part of the given recurrence rules will not
be generated, even if some inclusive rrule or rdate matches them.
"""
self._exrule.append(exrule)
def exdate(self, exdate):
""" Include the given datetime instance in the recurrence set
exclusion list. Dates included that way will not be generated,
even if some inclusive rrule or rdate matches them. """
self._exdate.append(exdate)
def _iter(self):
@ -905,6 +1146,7 @@ class rruleset(rrulebase):
rlist.sort()
self._len = total
class _rrulestr(object):
_freq_map = {"YEARLY": YEARLY,
@ -915,7 +1157,8 @@ class _rrulestr(object):
"MINUTELY": MINUTELY,
"SECONDLY": SECONDLY}
_weekday_map = {"MO":0,"TU":1,"WE":2,"TH":3,"FR":4,"SA":5,"SU":6}
_weekday_map = {"MO": 0, "TU": 1, "WE": 2, "TH": 3,
"FR": 4, "SA": 5, "SU": 6}
def _handle_int(self, rrkwargs, name, value, **kwargs):
rrkwargs[name.lower()] = int(value)
@ -923,17 +1166,17 @@ class _rrulestr(object):
def _handle_int_list(self, rrkwargs, name, value, **kwargs):
rrkwargs[name.lower()] = [int(x) for x in value.split(',')]
_handle_INTERVAL = _handle_int
_handle_COUNT = _handle_int
_handle_BYSETPOS = _handle_int_list
_handle_BYMONTH = _handle_int_list
_handle_INTERVAL = _handle_int
_handle_COUNT = _handle_int
_handle_BYSETPOS = _handle_int_list
_handle_BYMONTH = _handle_int_list
_handle_BYMONTHDAY = _handle_int_list
_handle_BYYEARDAY = _handle_int_list
_handle_BYEASTER = _handle_int_list
_handle_BYWEEKNO = _handle_int_list
_handle_BYHOUR = _handle_int_list
_handle_BYMINUTE = _handle_int_list
_handle_BYSECOND = _handle_int_list
_handle_BYYEARDAY = _handle_int_list
_handle_BYEASTER = _handle_int_list
_handle_BYWEEKNO = _handle_int_list
_handle_BYHOUR = _handle_int_list
_handle_BYMINUTE = _handle_int_list
_handle_BYSECOND = _handle_int_list
def _handle_FREQ(self, rrkwargs, name, value, **kwargs):
rrkwargs["freq"] = self._freq_map[value]
@ -944,8 +1187,8 @@ class _rrulestr(object):
from dateutil import parser
try:
rrkwargs["until"] = parser.parse(value,
ignoretz=kwargs.get("ignoretz"),
tzinfos=kwargs.get("tzinfos"))
ignoretz=kwargs.get("ignoretz"),
tzinfos=kwargs.get("tzinfos"))
except ValueError:
raise ValueError("invalid until date")
@ -960,7 +1203,8 @@ class _rrulestr(object):
break
n = wday[:i] or None
w = wday[i:]
if n: n = int(n)
if n:
n = int(n)
l.append(weekdays[self._weekday_map[w]](n))
rrkwargs["byweekday"] = l
@ -1021,8 +1265,8 @@ class _rrulestr(object):
i += 1
else:
lines = s.split()
if (not forceset and len(lines) == 1 and
(s.find(':') == -1 or s.startswith('RRULE:'))):
if (not forceset and len(lines) == 1 and (s.find(':') == -1 or
s.startswith('RRULE:'))):
return self._parse_rfc_rrule(lines[0], cache=cache,
dtstart=dtstart, ignoretz=ignoretz,
tzinfos=tzinfos)
@ -1071,32 +1315,32 @@ class _rrulestr(object):
tzinfos=tzinfos)
else:
raise ValueError("unsupported property: "+name)
if (forceset or len(rrulevals) > 1 or
rdatevals or exrulevals or exdatevals):
if (forceset or len(rrulevals) > 1 or rdatevals
or exrulevals or exdatevals):
if not parser and (rdatevals or exdatevals):
from dateutil import parser
set = rruleset(cache=cache)
rset = rruleset(cache=cache)
for value in rrulevals:
set.rrule(self._parse_rfc_rrule(value, dtstart=dtstart,
ignoretz=ignoretz,
tzinfos=tzinfos))
for value in rdatevals:
for datestr in value.split(','):
set.rdate(parser.parse(datestr,
ignoretz=ignoretz,
tzinfos=tzinfos))
for value in exrulevals:
set.exrule(self._parse_rfc_rrule(value, dtstart=dtstart,
rset.rrule(self._parse_rfc_rrule(value, dtstart=dtstart,
ignoretz=ignoretz,
tzinfos=tzinfos))
for value in exdatevals:
for value in rdatevals:
for datestr in value.split(','):
set.exdate(parser.parse(datestr,
rset.rdate(parser.parse(datestr,
ignoretz=ignoretz,
tzinfos=tzinfos))
for value in exrulevals:
rset.exrule(self._parse_rfc_rrule(value, dtstart=dtstart,
ignoretz=ignoretz,
tzinfos=tzinfos))
for value in exdatevals:
for datestr in value.split(','):
rset.exdate(parser.parse(datestr,
ignoretz=ignoretz,
tzinfos=tzinfos))
if compatible and dtstart:
set.rdate(dtstart)
return set
rset.rdate(dtstart)
return rset
else:
return self._parse_rfc_rrule(rrulevals[0],
dtstart=dtstart,

View File

@ -1,19 +1,25 @@
# -*- coding: utf-8 -*-
"""
Copyright (c) 2003-2007 Gustavo Niemeyer <gustavo@niemeyer.net>
This module offers extensions to the standard Python
datetime module.
This module offers timezone implementations subclassing the abstract
:py:`datetime.tzinfo` type. There are classes to handle tzfile format files
(usually are in :file:`/etc/localtime`, :file:`/usr/share/zoneinfo`, etc), TZ
environment string (in all known formats), given ranges (with help from
relative deltas), local machine timezone, fixed offset timezone, and UTC
timezone.
"""
__license__ = "Simplified BSD"
from six import string_types, PY3
import datetime
import struct
import time
import sys
import os
from six import string_types, PY3
try:
from dateutil.tzwin import tzwin, tzwinlocal
except ImportError:
tzwin = tzwinlocal = None
relativedelta = None
parser = None
rrule = None
@ -21,10 +27,6 @@ rrule = None
__all__ = ["tzutc", "tzoffset", "tzlocal", "tzfile", "tzrange",
"tzstr", "tzical", "tzwin", "tzwinlocal", "gettz"]
try:
from dateutil.tzwin import tzwin, tzwinlocal
except (ImportError, OSError):
tzwin, tzwinlocal = None, None
def tzname_in_python2(myfunc):
"""Change unicode output into bytestrings in Python 2
@ -42,11 +44,12 @@ def tzname_in_python2(myfunc):
ZERO = datetime.timedelta(0)
EPOCHORDINAL = datetime.datetime.utcfromtimestamp(0).toordinal()
class tzutc(datetime.tzinfo):
def utcoffset(self, dt):
return ZERO
def dst(self, dt):
return ZERO
@ -66,6 +69,7 @@ class tzutc(datetime.tzinfo):
__reduce__ = object.__reduce__
class tzoffset(datetime.tzinfo):
def __init__(self, name, offset):
@ -96,6 +100,7 @@ class tzoffset(datetime.tzinfo):
__reduce__ = object.__reduce__
class tzlocal(datetime.tzinfo):
_std_offset = datetime.timedelta(seconds=-time.timezone)
@ -123,25 +128,25 @@ class tzlocal(datetime.tzinfo):
def _isdst(self, dt):
# We can't use mktime here. It is unstable when deciding if
# the hour near to a change is DST or not.
#
#
# timestamp = time.mktime((dt.year, dt.month, dt.day, dt.hour,
# dt.minute, dt.second, dt.weekday(), 0, -1))
# return time.localtime(timestamp).tm_isdst
#
# The code above yields the following result:
#
#>>> import tz, datetime
#>>> t = tz.tzlocal()
#>>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname()
#'BRDT'
#>>> datetime.datetime(2003,2,16,0,tzinfo=t).tzname()
#'BRST'
#>>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname()
#'BRST'
#>>> datetime.datetime(2003,2,15,22,tzinfo=t).tzname()
#'BRDT'
#>>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname()
#'BRDT'
# >>> import tz, datetime
# >>> t = tz.tzlocal()
# >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname()
# 'BRDT'
# >>> datetime.datetime(2003,2,16,0,tzinfo=t).tzname()
# 'BRST'
# >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname()
# 'BRST'
# >>> datetime.datetime(2003,2,15,22,tzinfo=t).tzname()
# 'BRDT'
# >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname()
# 'BRDT'
#
# Here is a more stable implementation:
#
@ -166,6 +171,7 @@ class tzlocal(datetime.tzinfo):
__reduce__ = object.__reduce__
class _ttinfo(object):
__slots__ = ["offset", "delta", "isdst", "abbr", "isstd", "isgmt"]
@ -205,15 +211,20 @@ class _ttinfo(object):
if name in state:
setattr(self, name, state[name])
class tzfile(datetime.tzinfo):
# http://www.twinsun.com/tz/tz-link.htm
# ftp://ftp.iana.org/tz/tz*.tar.gz
def __init__(self, fileobj):
def __init__(self, fileobj, filename=None):
file_opened_here = False
if isinstance(fileobj, string_types):
self._filename = fileobj
fileobj = open(fileobj, 'rb')
file_opened_here = True
elif filename is not None:
self._filename = filename
elif hasattr(fileobj, "name"):
self._filename = fileobj.name
else:
@ -228,125 +239,128 @@ class tzfile(datetime.tzinfo):
# six four-byte values of type long, written in a
# ``standard'' byte order (the high-order byte
# of the value is written first).
try:
if fileobj.read(4).decode() != "TZif":
raise ValueError("magic not found")
if fileobj.read(4).decode() != "TZif":
raise ValueError("magic not found")
fileobj.read(16)
fileobj.read(16)
(
# The number of UTC/local indicators stored in the file.
ttisgmtcnt,
(
# The number of UTC/local indicators stored in the file.
ttisgmtcnt,
# The number of standard/wall indicators stored in the file.
ttisstdcnt,
# The number of standard/wall indicators stored in the file.
ttisstdcnt,
# The number of leap seconds for which data is
# stored in the file.
leapcnt,
# The number of leap seconds for which data is
# stored in the file.
leapcnt,
# The number of "transition times" for which data
# is stored in the file.
timecnt,
# The number of "transition times" for which data
# is stored in the file.
timecnt,
# The number of "local time types" for which data
# is stored in the file (must not be zero).
typecnt,
# The number of "local time types" for which data
# is stored in the file (must not be zero).
typecnt,
# The number of characters of "time zone
# abbreviation strings" stored in the file.
charcnt,
# The number of characters of "time zone
# abbreviation strings" stored in the file.
charcnt,
) = struct.unpack(">6l", fileobj.read(24))
) = struct.unpack(">6l", fileobj.read(24))
# The above header is followed by tzh_timecnt four-byte
# values of type long, sorted in ascending order.
# These values are written in ``standard'' byte order.
# Each is used as a transition time (as returned by
# time(2)) at which the rules for computing local time
# change.
# The above header is followed by tzh_timecnt four-byte
# values of type long, sorted in ascending order.
# These values are written in ``standard'' byte order.
# Each is used as a transition time (as returned by
# time(2)) at which the rules for computing local time
# change.
if timecnt:
self._trans_list = struct.unpack(">%dl" % timecnt,
fileobj.read(timecnt*4))
else:
self._trans_list = []
if timecnt:
self._trans_list = struct.unpack(">%dl" % timecnt,
fileobj.read(timecnt*4))
else:
self._trans_list = []
# Next come tzh_timecnt one-byte values of type unsigned
# char; each one tells which of the different types of
# ``local time'' types described in the file is associated
# with the same-indexed transition time. These values
# serve as indices into an array of ttinfo structures that
# appears next in the file.
if timecnt:
self._trans_idx = struct.unpack(">%dB" % timecnt,
fileobj.read(timecnt))
else:
self._trans_idx = []
# Each ttinfo structure is written as a four-byte value
# for tt_gmtoff of type long, in a standard byte
# order, followed by a one-byte value for tt_isdst
# and a one-byte value for tt_abbrind. In each
# structure, tt_gmtoff gives the number of
# seconds to be added to UTC, tt_isdst tells whether
# tm_isdst should be set by localtime(3), and
# tt_abbrind serves as an index into the array of
# time zone abbreviation characters that follow the
# ttinfo structure(s) in the file.
# Next come tzh_timecnt one-byte values of type unsigned
# char; each one tells which of the different types of
# ``local time'' types described in the file is associated
# with the same-indexed transition time. These values
# serve as indices into an array of ttinfo structures that
# appears next in the file.
ttinfo = []
if timecnt:
self._trans_idx = struct.unpack(">%dB" % timecnt,
fileobj.read(timecnt))
else:
self._trans_idx = []
for i in range(typecnt):
ttinfo.append(struct.unpack(">lbb", fileobj.read(6)))
# Each ttinfo structure is written as a four-byte value
# for tt_gmtoff of type long, in a standard byte
# order, followed by a one-byte value for tt_isdst
# and a one-byte value for tt_abbrind. In each
# structure, tt_gmtoff gives the number of
# seconds to be added to UTC, tt_isdst tells whether
# tm_isdst should be set by localtime(3), and
# tt_abbrind serves as an index into the array of
# time zone abbreviation characters that follow the
# ttinfo structure(s) in the file.
abbr = fileobj.read(charcnt).decode()
ttinfo = []
# Then there are tzh_leapcnt pairs of four-byte
# values, written in standard byte order; the
# first value of each pair gives the time (as
# returned by time(2)) at which a leap second
# occurs; the second gives the total number of
# leap seconds to be applied after the given time.
# The pairs of values are sorted in ascending order
# by time.
for i in range(typecnt):
ttinfo.append(struct.unpack(">lbb", fileobj.read(6)))
# Not used, for now
if leapcnt:
leap = struct.unpack(">%dl" % (leapcnt*2),
fileobj.read(leapcnt*8))
abbr = fileobj.read(charcnt).decode()
# Then there are tzh_ttisstdcnt standard/wall
# indicators, each stored as a one-byte value;
# they tell whether the transition times associated
# with local time types were specified as standard
# time or wall clock time, and are used when
# a time zone file is used in handling POSIX-style
# time zone environment variables.
# Then there are tzh_leapcnt pairs of four-byte
# values, written in standard byte order; the
# first value of each pair gives the time (as
# returned by time(2)) at which a leap second
# occurs; the second gives the total number of
# leap seconds to be applied after the given time.
# The pairs of values are sorted in ascending order
# by time.
if ttisstdcnt:
isstd = struct.unpack(">%db" % ttisstdcnt,
fileobj.read(ttisstdcnt))
# Not used, for now
# if leapcnt:
# leap = struct.unpack(">%dl" % (leapcnt*2),
# fileobj.read(leapcnt*8))
# Finally, there are tzh_ttisgmtcnt UTC/local
# indicators, each stored as a one-byte value;
# they tell whether the transition times associated
# with local time types were specified as UTC or
# local time, and are used when a time zone file
# is used in handling POSIX-style time zone envi-
# ronment variables.
# Then there are tzh_ttisstdcnt standard/wall
# indicators, each stored as a one-byte value;
# they tell whether the transition times associated
# with local time types were specified as standard
# time or wall clock time, and are used when
# a time zone file is used in handling POSIX-style
# time zone environment variables.
if ttisgmtcnt:
isgmt = struct.unpack(">%db" % ttisgmtcnt,
fileobj.read(ttisgmtcnt))
if ttisstdcnt:
isstd = struct.unpack(">%db" % ttisstdcnt,
fileobj.read(ttisstdcnt))
# ** Everything has been read **
# Finally, there are tzh_ttisgmtcnt UTC/local
# indicators, each stored as a one-byte value;
# they tell whether the transition times associated
# with local time types were specified as UTC or
# local time, and are used when a time zone file
# is used in handling POSIX-style time zone envi-
# ronment variables.
if ttisgmtcnt:
isgmt = struct.unpack(">%db" % ttisgmtcnt,
fileobj.read(ttisgmtcnt))
# ** Everything has been read **
finally:
if file_opened_here:
fileobj.close()
# Build ttinfo list
self._ttinfo_list = []
for i in range(typecnt):
gmtoff, isdst, abbrind = ttinfo[i]
gmtoff, isdst, abbrind = ttinfo[i]
# Round to full-minutes if that's not the case. Python's
# datetime doesn't accept sub-minute timezones. Check
# http://python.org/sf/1447945 for some information.
@ -464,7 +478,7 @@ class tzfile(datetime.tzinfo):
# However, this class stores historical changes in the
# dst offset, so I belive that this wouldn't be the right
# way to implement this.
@tzname_in_python2
def tzname(self, dt):
if not self._ttinfo_std:
@ -481,7 +495,6 @@ class tzfile(datetime.tzinfo):
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, repr(self._filename))
@ -490,8 +503,8 @@ class tzfile(datetime.tzinfo):
raise ValueError("Unpickable %s class" % self.__class__.__name__)
return (self.__class__, (self._filename,))
class tzrange(datetime.tzinfo):
class tzrange(datetime.tzinfo):
def __init__(self, stdabbr, stdoffset=None,
dstabbr=None, dstoffset=None,
start=None, end=None):
@ -512,12 +525,12 @@ class tzrange(datetime.tzinfo):
self._dst_offset = ZERO
if dstabbr and start is None:
self._start_delta = relativedelta.relativedelta(
hours=+2, month=4, day=1, weekday=relativedelta.SU(+1))
hours=+2, month=4, day=1, weekday=relativedelta.SU(+1))
else:
self._start_delta = start
if dstabbr and end is None:
self._end_delta = relativedelta.relativedelta(
hours=+1, month=10, day=31, weekday=relativedelta.SU(-1))
hours=+1, month=10, day=31, weekday=relativedelta.SU(-1))
else:
self._end_delta = end
@ -570,8 +583,9 @@ class tzrange(datetime.tzinfo):
__reduce__ = object.__reduce__
class tzstr(tzrange):
def __init__(self, s):
global parser
if not parser:
@ -645,9 +659,10 @@ class tzstr(tzrange):
def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, repr(self._s))
class _tzicalvtzcomp(object):
def __init__(self, tzoffsetfrom, tzoffsetto, isdst,
tzname=None, rrule=None):
tzname=None, rrule=None):
self.tzoffsetfrom = datetime.timedelta(seconds=tzoffsetfrom)
self.tzoffsetto = datetime.timedelta(seconds=tzoffsetto)
self.tzoffsetdiff = self.tzoffsetto-self.tzoffsetfrom
@ -655,6 +670,7 @@ class _tzicalvtzcomp(object):
self.tzname = tzname
self.rrule = rrule
class _tzicalvtz(datetime.tzinfo):
def __init__(self, tzid, comps=[]):
self._tzid = tzid
@ -718,6 +734,7 @@ class _tzicalvtz(datetime.tzinfo):
__reduce__ = object.__reduce__
class tzical(object):
def __init__(self, fileobj):
global rrule
@ -726,7 +743,8 @@ class tzical(object):
if isinstance(fileobj, string_types):
self._s = fileobj
fileobj = open(fileobj, 'r') # ical should be encoded in UTF-8 with CRLF
# ical should be encoded in UTF-8 with CRLF
fileobj = open(fileobj, 'r')
elif hasattr(fileobj, "name"):
self._s = fileobj.name
else:
@ -754,7 +772,7 @@ class tzical(object):
if not s:
raise ValueError("empty offset")
if s[0] in ('+', '-'):
signal = (-1, +1)[s[0]=='+']
signal = (-1, +1)[s[0] == '+']
s = s[1:]
else:
signal = +1
@ -815,7 +833,8 @@ class tzical(object):
if not tzid:
raise ValueError("mandatory TZID not found")
if not comps:
raise ValueError("at least one component is needed")
raise ValueError(
"at least one component is needed")
# Process vtimezone
self._vtz[tzid] = _tzicalvtz(tzid, comps)
invtz = False
@ -823,9 +842,11 @@ class tzical(object):
if not founddtstart:
raise ValueError("mandatory DTSTART not found")
if tzoffsetfrom is None:
raise ValueError("mandatory TZOFFSETFROM not found")
raise ValueError(
"mandatory TZOFFSETFROM not found")
if tzoffsetto is None:
raise ValueError("mandatory TZOFFSETFROM not found")
raise ValueError(
"mandatory TZOFFSETFROM not found")
# Process component
rr = None
if rrulelines:
@ -848,15 +869,18 @@ class tzical(object):
rrulelines.append(line)
elif name == "TZOFFSETFROM":
if parms:
raise ValueError("unsupported %s parm: %s "%(name, parms[0]))
raise ValueError(
"unsupported %s parm: %s " % (name, parms[0]))
tzoffsetfrom = self._parse_offset(value)
elif name == "TZOFFSETTO":
if parms:
raise ValueError("unsupported TZOFFSETTO parm: "+parms[0])
raise ValueError(
"unsupported TZOFFSETTO parm: "+parms[0])
tzoffsetto = self._parse_offset(value)
elif name == "TZNAME":
if parms:
raise ValueError("unsupported TZNAME parm: "+parms[0])
raise ValueError(
"unsupported TZNAME parm: "+parms[0])
tzname = value
elif name == "COMMENT":
pass
@ -865,7 +889,8 @@ class tzical(object):
else:
if name == "TZID":
if parms:
raise ValueError("unsupported TZID parm: "+parms[0])
raise ValueError(
"unsupported TZID parm: "+parms[0])
tzid = value
elif name in ("TZURL", "LAST-MODIFIED", "COMMENT"):
pass
@ -886,6 +911,7 @@ else:
TZFILES = []
TZPATHS = []
def gettz(name=None):
tz = None
if not name:
@ -933,11 +959,11 @@ def gettz(name=None):
pass
else:
tz = None
if tzwin:
if tzwin is not None:
try:
tz = tzwin(name)
except OSError:
pass
except WindowsError:
tz = None
if not tz:
from dateutil.zoneinfo import gettz
tz = gettz(name)

View File

@ -1,8 +1,8 @@
# This code was originally contributed by Jeffrey Harris.
import datetime
import struct
import winreg
from six.moves import winreg
__all__ = ["tzwin", "tzwinlocal"]
@ -12,8 +12,8 @@ TZKEYNAMENT = r"SOFTWARE\Microsoft\Windows NT\CurrentVersion\Time Zones"
TZKEYNAME9X = r"SOFTWARE\Microsoft\Windows\CurrentVersion\Time Zones"
TZLOCALKEYNAME = r"SYSTEM\CurrentControlSet\Control\TimeZoneInformation"
def _settzkeyname():
global TZKEYNAME
handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
try:
winreg.OpenKey(handle, TZKEYNAMENT).Close()
@ -21,8 +21,10 @@ def _settzkeyname():
except WindowsError:
TZKEYNAME = TZKEYNAME9X
handle.Close()
return TZKEYNAME
TZKEYNAME = _settzkeyname()
_settzkeyname()
class tzwinbase(datetime.tzinfo):
"""tzinfo class based on win32's timezones available in the registry."""
@ -39,7 +41,7 @@ class tzwinbase(datetime.tzinfo):
return datetime.timedelta(minutes=minutes)
else:
return datetime.timedelta(0)
def tzname(self, dt):
if self._isdst(dt):
return self._dstname
@ -59,8 +61,11 @@ class tzwinbase(datetime.tzinfo):
def display(self):
return self._display
def _isdst(self, dt):
if not self._dstmonth:
# dstmonth == 0 signals the zone has no daylight saving time
return False
dston = picknthweekday(dt.year, self._dstmonth, self._dstdayofweek,
self._dsthour, self._dstminute,
self._dstweeknumber)
@ -78,31 +83,33 @@ class tzwin(tzwinbase):
def __init__(self, name):
self._name = name
handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
tzkey = winreg.OpenKey(handle, "%s\%s" % (TZKEYNAME, name))
keydict = valuestodict(tzkey)
tzkey.Close()
handle.Close()
# multiple contexts only possible in 2.7 and 3.1, we still support 2.6
with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle:
with winreg.OpenKey(handle,
"%s\%s" % (TZKEYNAME, name)) as tzkey:
keydict = valuestodict(tzkey)
self._stdname = keydict["Std"].encode("iso-8859-1")
self._dstname = keydict["Dlt"].encode("iso-8859-1")
self._display = keydict["Display"]
# See http://ww_winreg.jsiinc.com/SUBA/tip0300/rh0398.htm
tup = struct.unpack("=3l16h", keydict["TZI"])
self._stdoffset = -tup[0]-tup[1] # Bias + StandardBias * -1
self._dstoffset = self._stdoffset-tup[2] # + DaylightBias * -1
self._stdoffset = -tup[0]-tup[1] # Bias + StandardBias * -1
self._dstoffset = self._stdoffset-tup[2] # + DaylightBias * -1
# for the meaning see the win32 TIME_ZONE_INFORMATION structure docs
# http://msdn.microsoft.com/en-us/library/windows/desktop/ms725481(v=vs.85).aspx
(self._stdmonth,
self._stddayofweek, # Sunday = 0
self._stdweeknumber, # Last = 5
self._stddayofweek, # Sunday = 0
self._stdweeknumber, # Last = 5
self._stdhour,
self._stdminute) = tup[4:9]
(self._dstmonth,
self._dstdayofweek, # Sunday = 0
self._dstweeknumber, # Last = 5
self._dstdayofweek, # Sunday = 0
self._dstweeknumber, # Last = 5
self._dsthour,
self._dstminute) = tup[12:17]
@ -114,61 +121,59 @@ class tzwin(tzwinbase):
class tzwinlocal(tzwinbase):
def __init__(self):
handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle:
tzlocalkey = winreg.OpenKey(handle, TZLOCALKEYNAME)
keydict = valuestodict(tzlocalkey)
tzlocalkey.Close()
with winreg.OpenKey(handle, TZLOCALKEYNAME) as tzlocalkey:
keydict = valuestodict(tzlocalkey)
self._stdname = keydict["StandardName"].encode("iso-8859-1")
self._dstname = keydict["DaylightName"].encode("iso-8859-1")
self._stdname = keydict["StandardName"].encode("iso-8859-1")
self._dstname = keydict["DaylightName"].encode("iso-8859-1")
try:
tzkey = winreg.OpenKey(handle, "%s\%s"%(TZKEYNAME, self._stdname))
_keydict = valuestodict(tzkey)
self._display = _keydict["Display"]
tzkey.Close()
except OSError:
self._display = None
try:
with winreg.OpenKey(
handle, "%s\%s" % (TZKEYNAME, self._stdname)) as tzkey:
_keydict = valuestodict(tzkey)
self._display = _keydict["Display"]
except OSError:
self._display = None
handle.Close()
self._stdoffset = -keydict["Bias"]-keydict["StandardBias"]
self._dstoffset = self._stdoffset-keydict["DaylightBias"]
# See http://ww_winreg.jsiinc.com/SUBA/tip0300/rh0398.htm
tup = struct.unpack("=8h", keydict["StandardStart"])
(self._stdmonth,
self._stddayofweek, # Sunday = 0
self._stdweeknumber, # Last = 5
self._stddayofweek, # Sunday = 0
self._stdweeknumber, # Last = 5
self._stdhour,
self._stdminute) = tup[1:6]
tup = struct.unpack("=8h", keydict["DaylightStart"])
(self._dstmonth,
self._dstdayofweek, # Sunday = 0
self._dstweeknumber, # Last = 5
self._dstdayofweek, # Sunday = 0
self._dstweeknumber, # Last = 5
self._dsthour,
self._dstminute) = tup[1:6]
def __reduce__(self):
return (self.__class__, ())
def picknthweekday(year, month, dayofweek, hour, minute, whichweek):
"""dayofweek == 0 means Sunday, whichweek 5 means last instance"""
first = datetime.datetime(year, month, 1, hour, minute)
weekdayone = first.replace(day=((dayofweek-first.isoweekday())%7+1))
weekdayone = first.replace(day=((dayofweek-first.isoweekday()) % 7+1))
for n in range(whichweek):
dt = weekdayone+(whichweek-n)*ONEWEEK
if dt.month == month:
return dt
def valuestodict(key):
"""Convert a registry key's values to a dictionary."""
dict = {}

View File

@ -1,109 +1,108 @@
# -*- coding: utf-8 -*-
"""
Copyright (c) 2003-2005 Gustavo Niemeyer <gustavo@niemeyer.net>
This module offers extensions to the standard Python
datetime module.
"""
import logging
import os
from subprocess import call
import warnings
import tempfile
import shutil
from subprocess import check_call
from tarfile import TarFile
from pkgutil import get_data
from io import BytesIO
from contextlib import closing
from dateutil.tz import tzfile
__author__ = "Tomi Pieviläinen <tomi.pievilainen@iki.fi>"
__license__ = "Simplified BSD"
__all__ = ["setcachesize", "gettz", "rebuild"]
CACHE = []
CACHESIZE = 10
_ZONEFILENAME = "dateutil-zoneinfo.tar.gz"
# python2.6 compatability. Note that TarFile.__exit__ != TarFile.close, but
# it's close enough for python2.6
_tar_open = TarFile.open
if not hasattr(TarFile, '__exit__'):
def _tar_open(*args, **kwargs):
return closing(TarFile.open(*args, **kwargs))
class tzfile(tzfile):
def __reduce__(self):
return (gettz, (self._filename,))
def getzoneinfofile():
filenames = sorted(os.listdir(os.path.join(os.path.dirname(__file__))))
filenames.reverse()
for entry in filenames:
if entry.startswith("zoneinfo") and ".tar." in entry:
return os.path.join(os.path.dirname(__file__), entry)
return None
ZONEINFOFILE = getzoneinfofile()
def getzoneinfofile_stream():
try:
return BytesIO(get_data(__name__, _ZONEFILENAME))
except IOError as e: # TODO switch to FileNotFoundError?
warnings.warn("I/O error({0}): {1}".format(e.errno, e.strerror))
return None
del getzoneinfofile
def setcachesize(size):
global CACHESIZE, CACHE
CACHESIZE = size
del CACHE[size:]
class ZoneInfoFile(object):
def __init__(self, zonefile_stream=None):
if zonefile_stream is not None:
with _tar_open(fileobj=zonefile_stream, mode='r') as tf:
# dict comprehension does not work on python2.6
# TODO: get back to the nicer syntax when we ditch python2.6
# self.zones = {zf.name: tzfile(tf.extractfile(zf),
# filename = zf.name)
# for zf in tf.getmembers() if zf.isfile()}
self.zones = dict((zf.name, tzfile(tf.extractfile(zf),
filename=zf.name))
for zf in tf.getmembers() if zf.isfile())
# deal with links: They'll point to their parent object. Less
# waste of memory
# links = {zl.name: self.zones[zl.linkname]
# for zl in tf.getmembers() if zl.islnk() or zl.issym()}
links = dict((zl.name, self.zones[zl.linkname])
for zl in tf.getmembers() if
zl.islnk() or zl.issym())
self.zones.update(links)
else:
self.zones = dict()
# The current API has gettz as a module function, although in fact it taps into
# a stateful class. So as a workaround for now, without changing the API, we
# will create a new "global" class instance the first time a user requests a
# timezone. Ugly, but adheres to the api.
#
# TODO: deprecate this.
_CLASS_ZONE_INSTANCE = list()
def gettz(name):
tzinfo = None
if ZONEINFOFILE:
for cachedname, tzinfo in CACHE:
if cachedname == name:
break
else:
tf = TarFile.open(ZONEINFOFILE)
try:
zonefile = tf.extractfile(name)
except KeyError:
tzinfo = None
else:
tzinfo = tzfile(zonefile)
tf.close()
CACHE.insert(0, (name, tzinfo))
del CACHE[CACHESIZE:]
return tzinfo
if len(_CLASS_ZONE_INSTANCE) == 0:
_CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream()))
return _CLASS_ZONE_INSTANCE[0].zones.get(name)
def rebuild(filename, tag=None, format="gz"):
def rebuild(filename, tag=None, format="gz", zonegroups=[]):
"""Rebuild the internal timezone info in dateutil/zoneinfo/zoneinfo*tar*
filename is the timezone tarball from ftp.iana.org/tz.
"""
import tempfile, shutil
tmpdir = tempfile.mkdtemp()
zonedir = os.path.join(tmpdir, "zoneinfo")
moduledir = os.path.dirname(__file__)
if tag: tag = "-"+tag
targetname = "zoneinfo%s.tar.%s" % (tag, format)
try:
tf = TarFile.open(filename)
# The "backwards" zone file contains links to other files, so must be
# processed as last
for name in sorted(tf.getnames(),
key=lambda k: k != "backward" and k or "z"):
if not (name.endswith(".sh") or
name.endswith(".tab") or
name == "leapseconds"):
with _tar_open(filename) as tf:
for name in zonegroups:
tf.extract(name, tmpdir)
filepath = os.path.join(tmpdir, name)
try:
# zic will return errors for nontz files in the package
# such as the Makefile or README, so check_call cannot
# be used (or at least extra checks would be needed)
call(["zic", "-d", zonedir, filepath])
except OSError as e:
if e.errno == 2:
logging.error(
"Could not find zic. Perhaps you need to install "
"libc-bin or some other package that provides it, "
"or it's not in your PATH?")
filepaths = [os.path.join(tmpdir, n) for n in zonegroups]
try:
check_call(["zic", "-d", zonedir] + filepaths)
except OSError as e:
if e.errno == 2:
logging.error(
"Could not find zic. Perhaps you need to install "
"libc-bin or some other package that provides it, "
"or it's not in your PATH?")
raise
tf.close()
target = os.path.join(moduledir, targetname)
for entry in os.listdir(moduledir):
if entry.startswith("zoneinfo") and ".tar." in entry:
os.unlink(os.path.join(moduledir, entry))
tf = TarFile.open(target, "w:%s" % format)
for entry in os.listdir(zonedir):
entrypath = os.path.join(zonedir, entry)
tf.add(entrypath, entry)
tf.close()
target = os.path.join(moduledir, _ZONEFILENAME)
with _tar_open(target, "w:%s" % format) as tf:
for entry in os.listdir(zonedir):
entrypath = os.path.join(zonedir, entry)
tf.add(entrypath, entry)
finally:
shutil.rmtree(tmpdir)

View File

@ -5,17 +5,27 @@ Keyring implementation support
from __future__ import absolute_import
import abc
import logging
try:
import importlib
except ImportError:
pass
try:
import pkg_resources
except ImportError:
pass
from . import errors, util
from . import backends
from .util import properties
from .py27compat import add_metaclass, filter
log = logging.getLogger(__name__)
class KeyringBackendMeta(abc.ABCMeta):
"""
A metaclass that's both an ABCMeta and a type that keeps a registry of
@ -127,6 +137,38 @@ def _load_backends():
backends = ('file', 'Gnome', 'Google', 'keyczar', 'kwallet', 'multi',
'OS_X', 'pyfs', 'SecretService', 'Windows')
list(map(_load_backend, backends))
_load_plugins()
def _load_plugins():
"""
Locate all setuptools entry points by the name 'keyring backends'
and initialize them.
Any third-party library may register an entry point by adding the
following to their setup.py::
entry_points = {
'keyring backends': [
'plugin_name = mylib.mymodule:initialize_func',
],
},
`plugin_name` can be anything, and is only used to display the name
of the plugin at initialization time.
`initialize_func` is optional, but will be invoked if callable.
"""
if 'pkg_resources' not in globals():
return
group = 'keyring backends'
entry_points = pkg_resources.iter_entry_points(group=group)
for ep in entry_points:
try:
log.info('Loading %s', ep.name)
init_func = ep.load()
if callable(init_func):
init_func()
except Exception:
log.exception("Error initializing plugin %s." % ep)
@util.once
def get_all_keyring():

View File

@ -10,7 +10,7 @@ called from the command line.
import markdown
html = markdown.markdown(your_text_string)
See <http://packages.python.org/Markdown/> for more
See <https://pythonhosted.org/Markdown/> for more
information and instructions on how to extend the functionality of
Python Markdown. Read that before you try modifying this file.
@ -36,6 +36,8 @@ from .__version__ import version, version_info
import codecs
import sys
import logging
import warnings
import importlib
from . import util
from .preprocessors import build_preprocessors
from .blockprocessors import build_block_parser
@ -48,6 +50,7 @@ from .serializers import to_html_string, to_xhtml_string
__all__ = ['Markdown', 'markdown', 'markdownFromFile']
logger = logging.getLogger('MARKDOWN')
logging.captureWarnings(True)
class Markdown(object):
@ -96,8 +99,8 @@ class Markdown(object):
Note that it is suggested that the more specific formats ("xhtml1"
and "html4") be used as "xhtml" or "html" may change in the future
if it makes sense at that time.
* safe_mode: Disallow raw html. One of "remove", "replace" or "escape".
* html_replacement_text: Text used when safe_mode is set to "replace".
* safe_mode: Deprecated! Disallow raw html. One of "remove", "replace" or "escape".
* html_replacement_text: Deprecated! Text used when safe_mode is set to "replace".
* tab_length: Length of tabs in the source. Default: 4
* enable_attributes: Enable the conversion of attributes. Default: True
* smart_emphasis: Treat `_connected_words_` intelligently Default: True
@ -107,14 +110,16 @@ class Markdown(object):
# For backward compatibility, loop through old positional args
pos = ['extensions', 'extension_configs', 'safe_mode', 'output_format']
c = 0
for arg in args:
for c, arg in enumerate(args):
if pos[c] not in kwargs:
kwargs[pos[c]] = arg
c += 1
if c == len(pos):
if c+1 == len(pos): #pragma: no cover
# ignore any additional args
break
if len(args):
warnings.warn('Positional arguments are pending depreacted in Markdown '
'and will be deprecated in version 2.6. Use keyword '
'arguments only.', PendingDeprecationWarning)
# Loop through kwargs and assign defaults
for option, default in self.option_defaults.items():
@ -125,6 +130,18 @@ class Markdown(object):
# Disable attributes in safeMode when not explicitly set
self.enable_attributes = False
if 'safe_mode' in kwargs:
warnings.warn('"safe_mode" is pending deprecation in Python-Markdown '
'and will be deprecated in version 2.6. Use an HTML '
'sanitizer (like Bleach http://bleach.readthedocs.org/) '
'if you are parsing untrusted markdown text. See the '
'2.5 release notes for more info', PendingDeprecationWarning)
if 'html_replacement_text' in kwargs:
warnings.warn('The "html_replacement_text" keyword is pending deprecation '
'in Python-Markdown and will be deprecated in version 2.6 '
'along with "safe_mode".', PendingDeprecationWarning)
self.registeredExtensions = []
self.docType = ""
self.stripTopLevelTags = True
@ -160,9 +177,11 @@ class Markdown(object):
"""
for ext in extensions:
if isinstance(ext, util.string_type):
ext = self.build_extension(ext, configs.get(ext, []))
ext = self.build_extension(ext, configs.get(ext, {}))
if isinstance(ext, Extension):
ext.extendMarkdown(self, globals())
logger.debug('Successfully loaded extension "%s.%s".'
% (ext.__class__.__module__, ext.__class__.__name__))
elif ext is not None:
raise TypeError(
'Extension "%s.%s" must be of type: "markdown.Extension"'
@ -170,52 +189,87 @@ class Markdown(object):
return self
def build_extension(self, ext_name, configs = []):
def build_extension(self, ext_name, configs):
"""Build extension by name, then return the module.
The extension name may contain arguments as part of the string in the
following format: "extname(key1=value1,key2=value2)"
"""
# Parse extensions config params (ignore the order)
configs = dict(configs)
# Parse extensions config params (ignore the order)
pos = ext_name.find("(") # find the first "("
if pos > 0:
ext_args = ext_name[pos+1:-1]
ext_name = ext_name[:pos]
pairs = [x.split("=") for x in ext_args.split(",")]
configs.update([(x.strip(), y.strip()) for (x, y) in pairs])
warnings.warn('Setting configs in the Named Extension string is pending deprecation. '
'It is recommended that you pass an instance of the extension class to '
'Markdown or use the "extension_configs" keyword. The current behavior '
'will be deprecated in version 2.6 and raise an error in version 2.7. '
'See the Release Notes for Python-Markdown version 2.5 for more info.',
PendingDeprecationWarning)
# Setup the module name
module_name = ext_name
if '.' not in ext_name:
module_name = '.'.join(['markdown.extensions', ext_name])
# Get class name (if provided): `path.to.module:ClassName`
ext_name, class_name = ext_name.split(':', 1) if ':' in ext_name else (ext_name, '')
# Try loading the extension first from one place, then another
try: # New style (markdown.extensions.<extension>)
module = __import__(module_name, {}, {}, [module_name.rpartition('.')[0]])
try:
# Assume string uses dot syntax (`path.to.some.module`)
module = importlib.import_module(ext_name)
logger.debug('Successfuly imported extension module "%s".' % ext_name)
# For backward compat (until deprecation) check that this is an extension
if '.' not in ext_name and not (hasattr(module, 'extendMarkdown') or (class_name and hasattr(module, class_name))):
# We have a name conflict (eg: extensions=['tables'] and PyTables is installed)
raise ImportError
except ImportError:
module_name_old_style = '_'.join(['mdx', ext_name])
try: # Old style (mdx_<extension>)
module = __import__(module_name_old_style)
except ImportError as e:
message = "Failed loading extension '%s' from '%s' or '%s'" \
% (ext_name, module_name, module_name_old_style)
# Preppend `markdown.extensions.` to name
module_name = '.'.join(['markdown.extensions', ext_name])
try:
module = importlib.import_module(module_name)
logger.debug('Successfuly imported extension module "%s".' % module_name)
warnings.warn('Using short names for Markdown\'s builtin extensions is pending deprecation. '
'Use the full path to the extension with Python\'s dot notation '
'(eg: "%s" instead of "%s"). The current behavior will be deprecated in '
'version 2.6 and raise an error in version 2.7. See the Release Notes for '
'Python-Markdown version 2.5 for more info.' % (module_name, ext_name),
PendingDeprecationWarning)
except ImportError:
# Preppend `mdx_` to name
module_name_old_style = '_'.join(['mdx', ext_name])
try:
module = importlib.import_module(module_name_old_style)
logger.debug('Successfuly imported extension module "%s".' % module_name_old_style)
warnings.warn('Markdown\'s behavuor of appending "mdx_" to an extension name '
'is pending deprecation. Use the full path to the extension with '
'Python\'s dot notation (eg: "%s" instead of "%s"). The '
'current behavior will be deprecated in version 2.6 and raise an '
'error in version 2.7. See the Release Notes for Python-Markdown '
'version 2.5 for more info.' % (module_name_old_style, ext_name),
PendingDeprecationWarning)
except ImportError as e:
message = "Failed loading extension '%s' from '%s', '%s' or '%s'" \
% (ext_name, ext_name, module_name, module_name_old_style)
e.args = (message,) + e.args[1:]
raise
if class_name:
# Load given class name from module.
return getattr(module, class_name)(**configs)
else:
# Expect makeExtension() function to return a class.
try:
return module.makeExtension(**configs)
except AttributeError as e:
message = e.args[0]
message = "Failed to initiate extension " \
"'%s': %s" % (ext_name, message)
e.args = (message,) + e.args[1:]
raise
# If the module is loaded successfully, we expect it to define a
# function called makeExtension()
try:
return module.makeExtension(configs.items())
except AttributeError as e:
message = e.args[0]
message = "Failed to initiate extension " \
"'%s': %s" % (ext_name, message)
e.args = (message,) + e.args[1:]
raise
def registerExtension(self, extension):
""" This gets called by the extension """
self.registeredExtensions.append(extension)
@ -303,7 +357,7 @@ class Markdown(object):
start = output.index('<%s>'%self.doc_tag)+len(self.doc_tag)+2
end = output.rindex('</%s>'%self.doc_tag)
output = output[start:end].strip()
except ValueError:
except ValueError: #pragma: no cover
if output.strip().endswith('<%s />'%self.doc_tag):
# We have an empty document
output = ''
@ -434,6 +488,10 @@ def markdownFromFile(*args, **kwargs):
c += 1
if c == len(pos):
break
if len(args):
warnings.warn('Positional arguments are pending depreacted in Markdown '
'and will be deprecated in version 2.6. Use keyword '
'arguments only.', PendingDeprecationWarning)
md = Markdown(**kwargs)
md.convertFile(kwargs.get('input', None),

View File

@ -7,20 +7,25 @@ COMMAND-LINE SPECIFIC STUFF
import markdown
import sys
import optparse
import codecs
try:
import yaml
except ImportError: #pragma: no cover
import json as yaml
import logging
from logging import DEBUG, INFO, CRITICAL
logger = logging.getLogger('MARKDOWN')
def parse_options():
def parse_options(args=None, values=None):
"""
Define and parse `optparse` options for command-line usage.
"""
usage = """%prog [options] [INPUTFILE]
(STDIN is assumed if no INPUTFILE is given)"""
desc = "A Python implementation of John Gruber's Markdown. " \
"http://packages.python.org/Markdown/"
"https://pythonhosted.org/Markdown/"
ver = "%%prog %s" % markdown.version
parser = optparse.OptionParser(usage=usage, description=desc, version=ver)
@ -29,28 +34,36 @@ def parse_options():
metavar="OUTPUT_FILE")
parser.add_option("-e", "--encoding", dest="encoding",
help="Encoding for input and output files.",)
parser.add_option("-s", "--safe", dest="safe", default=False,
metavar="SAFE_MODE",
help="Deprecated! 'replace', 'remove' or 'escape' HTML tags in input")
parser.add_option("-o", "--output_format", dest="output_format",
default='xhtml1', metavar="OUTPUT_FORMAT",
help="'xhtml1' (default), 'html4' or 'html5'.")
parser.add_option("-n", "--no_lazy_ol", dest="lazy_ol",
action='store_false', default=True,
help="Observe number of first item of ordered lists.")
parser.add_option("-x", "--extension", action="append", dest="extensions",
help = "Load extension EXTENSION.", metavar="EXTENSION")
parser.add_option("-c", "--extension_configs", dest="configfile", default=None,
help="Read extension configurations from CONFIG_FILE. "
"CONFIG_FILE must be of JSON or YAML format. YAML format requires "
"that a python YAML library be installed. The parsed JSON or YAML "
"must result in a python dictionary which would be accepted by the "
"'extension_configs' keyword on the markdown.Markdown class. "
"The extensions must also be loaded with the `--extension` option.",
metavar="CONFIG_FILE")
parser.add_option("-q", "--quiet", default = CRITICAL,
action="store_const", const=CRITICAL+10, dest="verbose",
help="Suppress all warnings.")
parser.add_option("-v", "--verbose",
action="store_const", const=INFO, dest="verbose",
help="Print all warnings.")
parser.add_option("-s", "--safe", dest="safe", default=False,
metavar="SAFE_MODE",
help="'replace', 'remove' or 'escape' HTML tags in input")
parser.add_option("-o", "--output_format", dest="output_format",
default='xhtml1', metavar="OUTPUT_FORMAT",
help="'xhtml1' (default), 'html4' or 'html5'.")
parser.add_option("--noisy",
action="store_const", const=DEBUG, dest="verbose",
help="Print debug messages.")
parser.add_option("-x", "--extension", action="append", dest="extensions",
help = "Load extension EXTENSION.", metavar="EXTENSION")
parser.add_option("-n", "--no_lazy_ol", dest="lazy_ol",
action='store_false', default=True,
help="Observe number of first item of ordered lists.")
(options, args) = parser.parse_args()
(options, args) = parser.parse_args(args, values)
if len(args) == 0:
input_file = None
@ -60,15 +73,26 @@ def parse_options():
if not options.extensions:
options.extensions = []
extension_configs = {}
if options.configfile:
with codecs.open(options.configfile, mode="r", encoding=options.encoding) as fp:
try:
extension_configs = yaml.load(fp)
except Exception as e:
message = "Failed parsing extension config file: %s" % options.configfile
e.args = (message,) + e.args[1:]
raise
return {'input': input_file,
'output': options.filename,
'safe_mode': options.safe,
'extensions': options.extensions,
'extension_configs': extension_configs,
'encoding': options.encoding,
'output_format': options.output_format,
'lazy_ol': options.lazy_ol}, options.verbose
def run():
def run(): #pragma: no cover
"""Run Markdown from the command line."""
# Parse options and adjust logging level if necessary
@ -80,7 +104,7 @@ def run():
# Run
markdown.markdownFromFile(**options)
if __name__ == '__main__':
if __name__ == '__main__': #pragma: no cover
# Support running module as a commandline command.
# Python 2.5 & 2.6 do: `python -m markdown.__main__ [options] [args]`.
# Python 2.7 & 3.x do: `python -m markdown [options] [args]`.

View File

@ -5,7 +5,7 @@
# (major, minor, micro, alpha/beta/rc/final, #)
# (1, 1, 2, 'alpha', 0) => "1.1.2.dev"
# (1, 2, 0, 'beta', 2) => "1.2b2"
version_info = (2, 4, 1, 'final', 0)
version_info = (2, 5, 2, 'final', 0)
def _get_version():
" Returns a PEP 386-compliant version number from version_info. "

View File

@ -99,7 +99,7 @@ class BlockProcessor:
* ``block``: A block of text from the source which has been split at
blank lines.
"""
pass
pass #pragma: no cover
def run(self, parent, blocks):
""" Run processor. Must be overridden by subclasses.
@ -123,7 +123,7 @@ class BlockProcessor:
* ``parent``: A etree element which is the parent of the current block.
* ``blocks``: A list of all remaining blocks of the document.
"""
pass
pass #pragma: no cover
class ListIndentProcessor(BlockProcessor):
@ -433,7 +433,7 @@ class HashHeaderProcessor(BlockProcessor):
if after:
# Insert remaining lines as first block for future parsing.
blocks.insert(0, after)
else:
else: #pragma: no cover
# This should never happen, but just in case...
logger.warn("We've got a problem header: %r" % block)

View File

@ -4,17 +4,45 @@ Extensions
"""
from __future__ import unicode_literals
from ..util import parseBoolValue
import warnings
class Extension(object):
""" Base class for extensions to subclass. """
def __init__(self, configs = {}):
"""Create an instance of an Extention.
# Default config -- to be overriden by a subclass
# Must be of the following format:
# {
# 'key': ['value', 'description']
# }
# Note that Extension.setConfig will raise a KeyError
# if a default is not set here.
config = {}
def __init__(self, *args, **kwargs):
""" Initiate Extension and set up configs. """
Keyword arguments:
* configs: A dict of configuration setting used by an Extension.
"""
self.config = configs
# check for configs arg for backward compat.
# (there only ever used to be one so we use arg[0])
if len(args):
self.setConfigs(args[0])
warnings.warn('Extension classes accepting positional args is pending Deprecation. '
'Each setting should be passed into the Class as a keyword. Positional '
'args will be deprecated in version 2.6 and raise an error in version '
'2.7. See the Release Notes for Python-Markdown version 2.5 for more info.',
PendingDeprecationWarning)
# check for configs kwarg for backward compat.
if 'configs' in kwargs.keys():
self.setConfigs(kwargs.pop('configs', {}))
warnings.warn('Extension classes accepting a dict on the single keyword "config" is '
'pending Deprecation. Each setting should be passed into the Class as '
'a keyword directly. The "config" keyword will be deprecated in version '
'2.6 and raise an error in version 2.7. See the Release Notes for '
'Python-Markdown version 2.5 for more info.',
PendingDeprecationWarning)
# finally, use kwargs
self.setConfigs(kwargs)
def getConfig(self, key, default=''):
""" Return a setting for the given key or an empty string. """
@ -33,8 +61,20 @@ class Extension(object):
def setConfig(self, key, value):
""" Set a config setting for `key` with the given `value`. """
if isinstance(self.config[key][0], bool):
value = parseBoolValue(value)
if self.config[key][0] is None:
value = parseBoolValue(value, preserve_none=True)
self.config[key][0] = value
def setConfigs(self, items):
""" Set multiple config settings given a dict or list of tuples. """
if hasattr(items, 'items'):
# it's a dict
items = items.items()
for key, value in items:
self.setConfig(key, value)
def extendMarkdown(self, md, md_globals):
"""
Add the various proccesors and patterns to the Markdown Instance.

View File

@ -4,22 +4,15 @@ Abbreviation Extension for Python-Markdown
This extension adds abbreviation handling to Python-Markdown.
Simple Usage:
See <https://pythonhosted.org/Markdown/extensions/abbreviations.html>
for documentation.
>>> import markdown
>>> text = """
... Some text with an ABBR and a REF. Ignore REFERENCE and ref.
...
... *[ABBR]: Abbreviation
... *[REF]: Abbreviation Reference
... """
>>> print markdown.markdown(text, ['abbr'])
<p>Some text with an <abbr title="Abbreviation">ABBR</abbr> and a <abbr title="Abbreviation Reference">REF</abbr>. Ignore REFERENCE and ref.</p>
Oringinal code Copyright 2007-2008 [Waylan Limberg](http://achinghead.com/) and
[Seemant Kulleen](http://www.kulleen.org/)
Copyright 2007-2008
* [Waylan Limberg](http://achinghead.com/)
* [Seemant Kulleen](http://www.kulleen.org/)
All changes Copyright 2008-2014 The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
'''
@ -92,5 +85,5 @@ class AbbrPattern(Pattern):
abbr.set('title', self.title)
return abbr
def makeExtension(configs=None):
return AbbrExtension(configs=configs)
def makeExtension(*args, **kwargs):
return AbbrExtension(*args, **kwargs)

View File

@ -4,39 +4,16 @@ Admonition extension for Python-Markdown
Adds rST-style admonitions. Inspired by [rST][] feature with the same name.
The syntax is (followed by an indented block with the contents):
!!! [type] [optional explicit title]
Where `type` is used as a CSS class name of the div. If not present, `title`
defaults to the capitalized `type`, so "note" -> "Note".
rST suggests the following `types`, but you're free to use whatever you want:
attention, caution, danger, error, hint, important, note, tip, warning
A simple example:
!!! note
This is the first line inside the box.
Outputs:
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>This is the first line inside the box</p>
</div>
You can also specify the title and CSS class of the admonition:
!!! custom "Did you know?"
Another line here.
Outputs:
<div class="admonition custom">
<p class="admonition-title">Did you know?</p>
<p>Another line here.</p>
</div>
[rST]: http://docutils.sourceforge.net/docs/ref/rst/directives.html#specific-admonitions
By [Tiago Serafim](http://www.tiagoserafim.com/).
See <https://pythonhosted.org/Markdown/extensions/admonition.html>
for documentation.
Original code Copyright [Tiago Serafim](http://www.tiagoserafim.com/).
All changes Copyright The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
"""
@ -114,5 +91,6 @@ class AdmonitionProcessor(BlockProcessor):
return klass, title
def makeExtension(configs={}):
return AdmonitionExtension(configs=configs)
def makeExtension(*args, **kwargs):
return AdmonitionExtension(*args, **kwargs)

View File

@ -6,15 +6,14 @@ Adds attribute list syntax. Inspired by
[maruku](http://maruku.rubyforge.org/proposal.html#attribute_lists)'s
feature of the same name.
Copyright 2011 [Waylan Limberg](http://achinghead.com/).
See <https://pythonhosted.org/Markdown/extensions/attr_list.html>
for documentation.
Contact: markdown@freewisdom.org
Original code Copyright 2011 [Waylan Limberg](http://achinghead.com/).
License: BSD (see ../LICENSE.md for details)
All changes Copyright 2011-2014 The Python Markdown Project
Dependencies:
* [Python 2.4+](http://python.org)
* [Markdown 2.1+](http://packages.python.org/Markdown/)
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
"""
@ -27,7 +26,7 @@ import re
try:
Scanner = re.Scanner
except AttributeError:
except AttributeError: #pragma: no cover
# must be on Python 2.4
from sre import Scanner
@ -164,5 +163,5 @@ class AttrListExtension(Extension):
md.treeprocessors.add('attr_list', AttrListTreeprocessor(md), '>prettify')
def makeExtension(configs={}):
return AttrListExtension(configs=configs)
def makeExtension(*args, **kwargs):
return AttrListExtension(*args, **kwargs)

View File

@ -4,17 +4,14 @@ CodeHilite Extension for Python-Markdown
Adds code/syntax highlighting to standard Python-Markdown code blocks.
Copyright 2006-2008 [Waylan Limberg](http://achinghead.com/).
See <https://pythonhosted.org/Markdown/extensions/code_hilite.html>
for documentation.
Project website: <http://packages.python.org/Markdown/extensions/code_hilite.html>
Contact: markdown@freewisdom.org
Original code Copyright 2006-2008 [Waylan Limberg](http://achinghead.com/).
License: BSD (see ../LICENSE.md for details)
All changes Copyright 2008-2014 The Python Markdown Project
Dependencies:
* [Python 2.3+](http://python.org/)
* [Markdown 2.0+](http://packages.python.org/Markdown/)
* [Pygments](http://pygments.org/)
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
"""
@ -25,8 +22,8 @@ from ..treeprocessors import Treeprocessor
import warnings
try:
from pygments import highlight
from pygments.lexers import get_lexer_by_name, guess_lexer, TextLexer
from pygments.formatters import HtmlFormatter
from pygments.lexers import get_lexer_by_name, guess_lexer
from pygments.formatters import get_formatter_by_name
pygments = True
except ImportError:
pygments = False
@ -112,14 +109,15 @@ class CodeHilite(object):
if self.guess_lang:
lexer = guess_lexer(self.src)
else:
lexer = TextLexer()
lexer = get_lexer_by_name('text')
except ValueError:
lexer = TextLexer()
formatter = HtmlFormatter(linenos=self.linenums,
cssclass=self.css_class,
style=self.style,
noclasses=self.noclasses,
hl_lines=self.hl_lines)
lexer = get_lexer_by_name('text')
formatter = get_formatter_by_name('html',
linenos=self.linenums,
cssclass=self.css_class,
style=self.style,
noclasses=self.noclasses,
hl_lines=self.hl_lines)
return highlight(self.src, lexer, formatter)
else:
# just escape and build markup usable by JS highlighting libs
@ -225,7 +223,7 @@ class HiliteTreeprocessor(Treeprocessor):
class CodeHiliteExtension(Extension):
""" Add source code hilighting to markdown codeblocks. """
def __init__(self, configs):
def __init__(self, *args, **kwargs):
# define default configs
self.config = {
'linenums': [None, "Use lines numbers. True=yes, False=no, None=auto"],
@ -237,22 +235,7 @@ class CodeHiliteExtension(Extension):
'noclasses': [False, 'Use inline styles instead of CSS classes - Default false']
}
# Override defaults with user settings
for key, value in configs:
# convert strings to booleans
if value == 'True': value = True
if value == 'False': value = False
if value == 'None': value = None
if key == 'force_linenos':
warnings.warn('The "force_linenos" config setting'
' to the CodeHilite extension is deprecrecated.'
' Use "linenums" instead.', DeprecationWarning)
if value:
# Carry 'force_linenos' over to new 'linenos'.
self.setConfig('linenums', True)
self.setConfig(key, value)
super(CodeHiliteExtension, self).__init__(*args, **kwargs)
def extendMarkdown(self, md, md_globals):
""" Add HilitePostprocessor to Markdown instance. """
@ -263,6 +246,5 @@ class CodeHiliteExtension(Extension):
md.registerExtension(self)
def makeExtension(configs={}):
return CodeHiliteExtension(configs=configs)
def makeExtension(*args, **kwargs):
return CodeHiliteExtension(*args, **kwargs)

View File

@ -2,19 +2,16 @@
Definition List Extension for Python-Markdown
=============================================
Added parsing of Definition Lists to Python-Markdown.
Adds parsing of Definition Lists to Python-Markdown.
A simple example:
See <https://pythonhosted.org/Markdown/extensions/definition_lists.html>
for documentation.
Apple
: Pomaceous fruit of plants of the genus Malus in
the family Rosaceae.
: An american computer company.
Original code Copyright 2008 [Waylan Limberg](http://achinghead.com)
Orange
: The fruit of an evergreen tree of the genus Citrus.
All changes Copyright 2008-2014 The Python Markdown Project
Copyright 2008 - [Waylan Limberg](http://achinghead.com)
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
"""
@ -113,6 +110,6 @@ class DefListExtension(Extension):
'>ulist')
def makeExtension(configs={}):
return DefListExtension(configs=configs)
def makeExtension(*args, **kwargs):
return DefListExtension(*args, **kwargs)

View File

@ -11,10 +11,6 @@ convenience so that only one extension needs to be listed when
initiating Markdown. See the documentation for each individual
extension for specifics about that extension.
In the event that one or more of the supported extensions are not
available for import, Markdown will issue a warning and simply continue
without that extension.
There may be additional extensions that are distributed with
Python-Markdown that are not included here in Extra. Those extensions
are not part of PHP Markdown Extra, and therefore, not part of
@ -24,6 +20,13 @@ under a differant name. You could also edit the `extensions` global
variable defined below, but be aware that such changes may be lost
when you upgrade to any future version of Python-Markdown.
See <https://pythonhosted.org/Markdown/extensions/extra.html>
for documentation.
Copyright The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
"""
from __future__ import absolute_import
@ -33,19 +36,25 @@ from ..blockprocessors import BlockProcessor
from .. import util
import re
extensions = ['smart_strong',
'fenced_code',
'footnotes',
'attr_list',
'def_list',
'tables',
'abbr',
]
extensions = [
'markdown.extensions.smart_strong',
'markdown.extensions.fenced_code',
'markdown.extensions.footnotes',
'markdown.extensions.attr_list',
'markdown.extensions.def_list',
'markdown.extensions.tables',
'markdown.extensions.abbr'
]
class ExtraExtension(Extension):
""" Add various extensions to Markdown class."""
def __init__(self, *args, **kwargs):
""" config is just a dumb holder which gets passed to actual ext later. """
self.config = kwargs.pop('configs', {})
self.config.update(kwargs)
def extendMarkdown(self, md, md_globals):
""" Register extension instances. """
md.registerExtensions(extensions, self.config)
@ -60,8 +69,8 @@ class ExtraExtension(Extension):
r'^(p|h[1-6]|li|dd|dt|td|th|legend|address)$', re.IGNORECASE)
def makeExtension(configs={}):
return ExtraExtension(configs=dict(configs))
def makeExtension(*args, **kwargs):
return ExtraExtension(*args, **kwargs)
class MarkdownInHtmlProcessor(BlockProcessor):

View File

@ -4,87 +4,15 @@ Fenced Code Extension for Python Markdown
This extension adds Fenced Code Blocks to Python-Markdown.
>>> import markdown
>>> text = '''
... A paragraph before a fenced code block:
...
... ~~~
... Fenced code block
... ~~~
... '''
>>> html = markdown.markdown(text, extensions=['fenced_code'])
>>> print html
<p>A paragraph before a fenced code block:</p>
<pre><code>Fenced code block
</code></pre>
See <https://pythonhosted.org/Markdown/extensions/fenced_code_blocks.html>
for documentation.
Works with safe_mode also (we check this because we are using the HtmlStash):
Original code Copyright 2007-2008 [Waylan Limberg](http://achinghead.com/).
>>> print markdown.markdown(text, extensions=['fenced_code'], safe_mode='replace')
<p>A paragraph before a fenced code block:</p>
<pre><code>Fenced code block
</code></pre>
Include tilde's in a code block and wrap with blank lines:
>>> text = '''
... ~~~~~~~~
...
... ~~~~
... ~~~~~~~~'''
>>> print markdown.markdown(text, extensions=['fenced_code'])
<pre><code>
~~~~
</code></pre>
Language tags:
>>> text = '''
... ~~~~{.python}
... # Some python code
... ~~~~'''
>>> print markdown.markdown(text, extensions=['fenced_code'])
<pre><code class="python"># Some python code
</code></pre>
Optionally backticks instead of tildes as per how github's code block markdown is identified:
>>> text = '''
... `````
... # Arbitrary code
... ~~~~~ # these tildes will not close the block
... `````'''
>>> print markdown.markdown(text, extensions=['fenced_code'])
<pre><code># Arbitrary code
~~~~~ # these tildes will not close the block
</code></pre>
If the codehighlite extension and Pygments are installed, lines can be highlighted:
>>> text = '''
... ```hl_lines="1 3"
... line 1
... line 2
... line 3
... ```'''
>>> print markdown.markdown(text, extensions=['codehilite', 'fenced_code'])
<pre><code><span class="hilight">line 1</span>
line 2
<span class="hilight">line 3</span>
</code></pre>
Copyright 2007-2008 [Waylan Limberg](http://achinghead.com/).
Project website: <http://packages.python.org/Markdown/extensions/fenced_code_blocks.html>
Contact: markdown@freewisdom.org
License: BSD (see ../docs/LICENSE for details)
Dependencies:
* [Python 2.4+](http://python.org)
* [Markdown 2.0+](http://packages.python.org/Markdown/)
* [Pygments (optional)](http://pygments.org)
All changes Copyright 2008-2014 The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
"""
from __future__ import absolute_import
@ -175,5 +103,6 @@ class FencedBlockPreprocessor(Preprocessor):
return txt
def makeExtension(configs=None):
return FencedCodeExtension(configs=configs)
def makeExtension(*args, **kwargs):
return FencedCodeExtension(*args, **kwargs)

View File

@ -1,25 +1,15 @@
"""
========================= FOOTNOTES =================================
Footnotes Extension for Python-Markdown
=======================================
This section adds footnote handling to markdown. It can be used as
an example for extending python-markdown with relatively complex
functionality. While in this case the extension is included inside
the module itself, it could just as easily be added from outside the
module. Not that all markdown classes above are ignorant about
footnotes. All footnote functionality is provided separately and
then added to the markdown instance at the run time.
Adds footnote handling to Python-Markdown.
Footnote functionality is attached by calling extendMarkdown()
method of FootnoteExtension. The method also registers the
extension to allow it's state to be reset by a call to reset()
method.
See <https://pythonhosted.org/Markdown/extensions/footnotes.html>
for documentation.
Example:
Footnotes[^1] have a label[^label] and a definition[^!DEF].
Copyright The Python Markdown Project
[^1]: This is a footnote
[^label]: A footnote on "label"
[^!DEF]: The footnote for definition
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
"""
@ -42,23 +32,23 @@ TABBED_RE = re.compile(r'((\t)|( ))(.*)')
class FootnoteExtension(Extension):
""" Footnote Extension. """
def __init__ (self, configs):
def __init__ (self, *args, **kwargs):
""" Setup configs. """
self.config = {'PLACE_MARKER':
["///Footnotes Go Here///",
"The text string that marks where the footnotes go"],
'UNIQUE_IDS':
[False,
"Avoid name collisions across "
"multiple calls to reset()."],
"BACKLINK_TEXT":
["&#8617;",
"The text string that links from the footnote to the reader's place."]
}
for key, value in configs:
self.config[key][0] = value
self.config = {
'PLACE_MARKER':
["///Footnotes Go Here///",
"The text string that marks where the footnotes go"],
'UNIQUE_IDS':
[False,
"Avoid name collisions across "
"multiple calls to reset()."],
"BACKLINK_TEXT":
["&#8617;",
"The text string that links from the footnote to the reader's place."]
}
super(FootnoteExtension, self).__init__(*args, **kwargs)
# In multiple invocations, emit links that don't get tangled.
self.unique_prefix = 0
@ -309,7 +299,7 @@ class FootnotePostprocessor(Postprocessor):
text = text.replace(FN_BACKLINK_TEXT, self.footnotes.getConfig("BACKLINK_TEXT"))
return text.replace(NBSP_PLACEHOLDER, "&#160;")
def makeExtension(configs=[]):
def makeExtension(*args, **kwargs):
""" Return an instance of the FootnoteExtension """
return FootnoteExtension(configs=configs)
return FootnoteExtension(*args, **kwargs)

View File

@ -4,73 +4,14 @@ HeaderID Extension for Python-Markdown
Auto-generate id attributes for HTML headers.
Basic usage:
See <https://pythonhosted.org/Markdown/extensions/header_id.html>
for documentation.
>>> import markdown
>>> text = "# Some Header #"
>>> md = markdown.markdown(text, ['headerid'])
>>> print md
<h1 id="some-header">Some Header</h1>
Original code Copyright 2007-2011 [Waylan Limberg](http://achinghead.com/).
All header IDs are unique:
All changes Copyright 2011-2014 The Python Markdown Project
>>> text = '''
... #Header
... #Header
... #Header'''
>>> md = markdown.markdown(text, ['headerid'])
>>> print md
<h1 id="header">Header</h1>
<h1 id="header_1">Header</h1>
<h1 id="header_2">Header</h1>
To fit within a html template's hierarchy, set the header base level:
>>> text = '''
... #Some Header
... ## Next Level'''
>>> md = markdown.markdown(text, ['headerid(level=3)'])
>>> print md
<h3 id="some-header">Some Header</h3>
<h4 id="next-level">Next Level</h4>
Works with inline markup.
>>> text = '#Some *Header* with [markup](http://example.com).'
>>> md = markdown.markdown(text, ['headerid'])
>>> print md
<h1 id="some-header-with-markup">Some <em>Header</em> with <a href="http://example.com">markup</a>.</h1>
Turn off auto generated IDs:
>>> text = '''
... # Some Header
... # Another Header'''
>>> md = markdown.markdown(text, ['headerid(forceid=False)'])
>>> print md
<h1>Some Header</h1>
<h1>Another Header</h1>
Use with MetaData extension:
>>> text = '''header_level: 2
... header_forceid: Off
...
... # A Header'''
>>> md = markdown.markdown(text, ['headerid', 'meta'])
>>> print md
<h2>A Header</h2>
Copyright 2007-2011 [Waylan Limberg](http://achinghead.com/).
Project website: <http://packages.python.org/Markdown/extensions/header_id.html>
Contact: markdown@freewisdom.org
License: BSD (see ../docs/LICENSE for details)
Dependencies:
* [Python 2.3+](http://python.org)
* [Markdown 2.0+](http://packages.python.org/Markdown/)
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
"""
@ -127,7 +68,7 @@ def stashedHTML2text(text, md):
def _html_sub(m):
""" Substitute raw html with plain text. """
try:
raw, safe = md.htmlStash.rawHtmlBlocks[int(m.group(1))]
raw, safe = md.htmlStash.rawHtmlBlocks[int(m.group(1))]
except (IndexError, TypeError):
return m.group(0)
if md.safeMode and not safe:
@ -176,7 +117,7 @@ class HeaderIdTreeprocessor(Treeprocessor):
class HeaderIdExtension(Extension):
def __init__(self, configs):
def __init__(self, *args, **kwargs):
# set defaults
self.config = {
'level' : ['1', 'Base level for headers.'],
@ -185,8 +126,7 @@ class HeaderIdExtension(Extension):
'slugify' : [slugify, 'Callable to generate anchors'],
}
for key, value in configs:
self.setConfig(key, value)
super(HeaderIdExtension, self).__init__(*args, **kwargs)
def extendMarkdown(self, md, md_globals):
md.registerExtension(self)
@ -204,5 +144,6 @@ class HeaderIdExtension(Extension):
self.processor.IDs = set()
def makeExtension(configs=None):
return HeaderIdExtension(configs=configs)
def makeExtension(*args, **kwargs):
return HeaderIdExtension(*args, **kwargs)

View File

@ -4,38 +4,14 @@ Meta Data Extension for Python-Markdown
This extension adds Meta Data handling to markdown.
Basic Usage:
See <https://pythonhosted.org/Markdown/extensions/meta_data.html>
for documentation.
>>> import markdown
>>> text = '''Title: A Test Doc.
... Author: Waylan Limberg
... John Doe
... Blank_Data:
...
... The body. This is paragraph one.
... '''
>>> md = markdown.Markdown(['meta'])
>>> print md.convert(text)
<p>The body. This is paragraph one.</p>
>>> print md.Meta
{u'blank_data': [u''], u'author': [u'Waylan Limberg', u'John Doe'], u'title': [u'A Test Doc.']}
Original code Copyright 2007-2008 [Waylan Limberg](http://achinghead.com).
Make sure text without Meta Data still works (markdown < 1.6b returns a <p>).
All changes Copyright 2008-2014 The Python Markdown Project
>>> text = ' Some Code - not extra lines of meta data.'
>>> md = markdown.Markdown(['meta'])
>>> print md.convert(text)
<pre><code>Some Code - not extra lines of meta data.
</code></pre>
>>> md.Meta
{}
Copyright 2007-2008 [Waylan Limberg](http://achinghead.com).
Project website: <http://packages.python.org/Markdown/meta_data.html>
Contact: markdown@freewisdom.org
License: BSD (see ../LICENSE.md for details)
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
"""
@ -55,7 +31,7 @@ class MetaExtension (Extension):
def extendMarkdown(self, md, md_globals):
""" Add MetaPreprocessor to Markdown instance. """
md.preprocessors.add("meta", MetaPreprocessor(md), "_begin")
md.preprocessors.add("meta", MetaPreprocessor(md), ">normalize_whitespace")
class MetaPreprocessor(Preprocessor):
@ -89,5 +65,6 @@ class MetaPreprocessor(Preprocessor):
return lines
def makeExtension(configs={}):
return MetaExtension(configs=configs)
def makeExtension(*args, **kwargs):
return MetaExtension(*args, **kwargs)

View File

@ -5,18 +5,14 @@ NL2BR Extension
A Python-Markdown extension to treat newlines as hard breaks; like
GitHub-flavored Markdown does.
Usage:
See <https://pythonhosted.org/Markdown/extensions/nl2br.html>
for documentation.
>>> import markdown
>>> print markdown.markdown('line 1\\nline 2', extensions=['nl2br'])
<p>line 1<br />
line 2</p>
Oringinal code Copyright 2011 [Brian Neal](http://deathofagremmie.com/)
Copyright 2011 [Brian Neal](http://deathofagremmie.com/)
All changes Copyright 2011-2014 The Python Markdown Project
Dependencies:
* [Python 2.4+](http://python.org)
* [Markdown 2.1+](http://packages.python.org/Markdown/)
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
"""
@ -34,5 +30,6 @@ class Nl2BrExtension(Extension):
md.inlinePatterns.add('nl', br_tag, '_end')
def makeExtension(configs=None):
return Nl2BrExtension(configs)
def makeExtension(*args, **kwargs):
return Nl2BrExtension(*args, **kwargs)

View File

@ -2,19 +2,16 @@
Sane List Extension for Python-Markdown
=======================================
Modify the behavior of Lists in Python-Markdown t act in a sane manor.
Modify the behavior of Lists in Python-Markdown to act in a sane manor.
In standard Markdown syntax, the following would constitute a single
ordered list. However, with this extension, the output would include
two lists, the first an ordered list and the second and unordered list.
See <https://pythonhosted.org/Markdown/extensions/sane_lists.html>
for documentation.
1. ordered
2. list
Original code Copyright 2011 [Waylan Limberg](http://achinghead.com)
* unordered
* list
All changes Copyright 2011-2014 The Python Markdown Project
Copyright 2011 - [Waylan Limberg](http://achinghead.com)
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
"""
@ -46,6 +43,6 @@ class SaneListExtension(Extension):
md.parser.blockprocessors['ulist'] = SaneUListProcessor(md.parser)
def makeExtension(configs={}):
return SaneListExtension(configs=configs)
def makeExtension(*args, **kwargs):
return SaneListExtension(*args, **kwargs)

View File

@ -4,21 +4,14 @@ Smart_Strong Extension for Python-Markdown
This extention adds smarter handling of double underscores within words.
Simple Usage:
See <https://pythonhosted.org/Markdown/extensions/smart_strong.html>
for documentation.
>>> import markdown
>>> print markdown.markdown('Text with double__underscore__words.',
... extensions=['smart_strong'])
<p>Text with double__underscore__words.</p>
>>> print markdown.markdown('__Strong__ still works.',
... extensions=['smart_strong'])
<p><strong>Strong</strong> still works.</p>
>>> print markdown.markdown('__this__works__too__.',
... extensions=['smart_strong'])
<p><strong>this__works__too</strong>.</p>
Original code Copyright 2011 [Waylan Limberg](http://achinghead.com)
Copyright 2011
[Waylan Limberg](http://achinghead.com)
All changes Copyright 2011-2014 The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
'''
@ -38,5 +31,5 @@ class SmartEmphasisExtension(Extension):
md.inlinePatterns['strong'] = SimpleTagPattern(STRONG_RE, 'strong')
md.inlinePatterns.add('strong2', SimpleTagPattern(SMART_STRONG_RE, 'strong'), '>emphasis2')
def makeExtension(configs={}):
return SmartEmphasisExtension(configs=dict(configs))
def makeExtension(*args, **kwargs):
return SmartEmphasisExtension(*args, **kwargs)

View File

@ -1,73 +1,91 @@
# -*- coding: utf-8 -*-
# Smarty extension for Python-Markdown
# Author: 2013, Dmitry Shachnev <mitya57@gmail.com>
'''
Smarty extension for Python-Markdown
====================================
Adds conversion of ASCII dashes, quotes and ellipses to their HTML
entity equivalents.
See <https://pythonhosted.org/Markdown/extensions/smarty.html>
for documentation.
Author: 2013, Dmitry Shachnev <mitya57@gmail.com>
All changes Copyright 2013-2014 The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
SmartyPants license:
Copyright (c) 2003 John Gruber <http://daringfireball.net/>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
* Neither the name "SmartyPants" nor the names of its contributors
may be used to endorse or promote products derived from this
software without specific prior written permission.
This software is provided by the copyright holders and contributors "as
is" and any express or implied warranties, including, but not limited
to, the implied warranties of merchantability and fitness for a
particular purpose are disclaimed. In no event shall the copyright
owner or contributors be liable for any direct, indirect, incidental,
special, exemplary, or consequential damages (including, but not
limited to, procurement of substitute goods or services; loss of use,
data, or profits; or business interruption) however caused and on any
theory of liability, whether in contract, strict liability, or tort
(including negligence or otherwise) arising in any way out of the use
of this software, even if advised of the possibility of such damage.
smartypants.py license:
smartypants.py is a derivative work of SmartyPants.
Copyright (c) 2004, 2007 Chad Miller <http://web.chad.org/>
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
This software is provided by the copyright holders and contributors "as
is" and any express or implied warranties, including, but not limited
to, the implied warranties of merchantability and fitness for a
particular purpose are disclaimed. In no event shall the copyright
owner or contributors be liable for any direct, indirect, incidental,
special, exemplary, or consequential damages (including, but not
limited to, procurement of substitute goods or services; loss of use,
data, or profits; or business interruption) however caused and on any
theory of liability, whether in contract, strict liability, or tort
(including negligence or otherwise) arising in any way out of the use
of this software, even if advised of the possibility of such damage.
'''
# SmartyPants license:
#
# Copyright (c) 2003 John Gruber <http://daringfireball.net/>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the
# distribution.
#
# * Neither the name "SmartyPants" nor the names of its contributors
# may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# This software is provided by the copyright holders and contributors "as
# is" and any express or implied warranties, including, but not limited
# to, the implied warranties of merchantability and fitness for a
# particular purpose are disclaimed. In no event shall the copyright
# owner or contributors be liable for any direct, indirect, incidental,
# special, exemplary, or consequential damages (including, but not
# limited to, procurement of substitute goods or services; loss of use,
# data, or profits; or business interruption) however caused and on any
# theory of liability, whether in contract, strict liability, or tort
# (including negligence or otherwise) arising in any way out of the use
# of this software, even if advised of the possibility of such damage.
#
#
# smartypants.py license:
#
# smartypants.py is a derivative work of SmartyPants.
# Copyright (c) 2004, 2007 Chad Miller <http://web.chad.org/>
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the
# distribution.
#
# This software is provided by the copyright holders and contributors "as
# is" and any express or implied warranties, including, but not limited
# to, the implied warranties of merchantability and fitness for a
# particular purpose are disclaimed. In no event shall the copyright
# owner or contributors be liable for any direct, indirect, incidental,
# special, exemplary, or consequential damages (including, but not
# limited to, procurement of substitute goods or services; loss of use,
# data, or profits; or business interruption) however caused and on any
# theory of liability, whether in contract, strict liability, or tort
# (including negligence or otherwise) arising in any way out of the use
# of this software, even if advised of the possibility of such damage.
from __future__ import unicode_literals
from . import Extension
from ..inlinepatterns import HtmlPattern
from ..odict import OrderedDict
from ..treeprocessors import InlineProcessor
from ..util import parseBoolValue
# Constants for quote education.
@ -83,12 +101,25 @@ openingQuotesBase = (
'|&[mn]dash;' # or named dash entities
'|&#8211;|&#8212;' # or decimal entities
')'
)
)
substitutions = {
'mdash': '&mdash;',
'ndash': '&ndash;',
'ellipsis': '&hellip;',
'left-angle-quote': '&laquo;',
'right-angle-quote': '&raquo;',
'left-single-quote': '&lsquo;',
'right-single-quote': '&rsquo;',
'left-double-quote': '&ldquo;',
'right-double-quote': '&rdquo;',
}
# Special case if the very first character is a quote
# followed by punctuation at a non-word-break. Close the quotes by brute force:
singleQuoteStartRe = r"^'(?=%s\\B)" % punctClass
doubleQuoteStartRe = r'^"(?=%s\\B)' % punctClass
singleQuoteStartRe = r"^'(?=%s\B)" % punctClass
doubleQuoteStartRe = r'^"(?=%s\B)' % punctClass
# Special case for double sets of quotes, e.g.:
# <p>He said, "'Quoted' words in a larger quote."</p>
@ -113,8 +144,6 @@ closingSingleQuotesRegex2 = r"(?<=%s)'(\s|s\b)" % closeClass
remainingSingleQuotesRegex = "'"
remainingDoubleQuotesRegex = '"'
lsquo, rsquo, ldquo, rdquo = '&lsquo;', '&rsquo;', '&ldquo;', '&rdquo;'
class SubstituteTextPattern(HtmlPattern):
def __init__(self, pattern, replace, markdown_instance):
""" Replaces matches with some text. """
@ -132,35 +161,56 @@ class SubstituteTextPattern(HtmlPattern):
return result
class SmartyExtension(Extension):
def __init__(self, configs):
def __init__(self, *args, **kwargs):
self.config = {
'smart_quotes': [True, 'Educate quotes'],
'smart_angled_quotes': [False, 'Educate angled quotes'],
'smart_dashes': [True, 'Educate dashes'],
'smart_ellipses': [True, 'Educate ellipses']
'smart_ellipses': [True, 'Educate ellipses'],
'substitutions' : [{}, 'Overwrite default substitutions'],
}
for key, value in configs:
self.setConfig(key, parseBoolValue(value))
super(SmartyExtension, self).__init__(*args, **kwargs)
self.substitutions = dict(substitutions)
self.substitutions.update(self.getConfig('substitutions', default={}))
def _addPatterns(self, md, patterns, serie):
for ind, pattern in enumerate(patterns):
pattern += (md,)
pattern = SubstituteTextPattern(*pattern)
after = ('>smarty-%s-%d' % (serie, ind - 1) if ind else '>entity')
after = ('>smarty-%s-%d' % (serie, ind - 1) if ind else '_begin')
name = 'smarty-%s-%d' % (serie, ind)
md.inlinePatterns.add(name, pattern, after)
self.inlinePatterns.add(name, pattern, after)
def educateDashes(self, md):
emDashesPattern = SubstituteTextPattern(r'(?<!-)---(?!-)', ('&mdash;',), md)
enDashesPattern = SubstituteTextPattern(r'(?<!-)--(?!-)', ('&ndash;',), md)
md.inlinePatterns.add('smarty-em-dashes', emDashesPattern, '>entity')
md.inlinePatterns.add('smarty-en-dashes', enDashesPattern,
emDashesPattern = SubstituteTextPattern(r'(?<!-)---(?!-)',
(self.substitutions['mdash'],), md)
enDashesPattern = SubstituteTextPattern(r'(?<!-)--(?!-)',
(self.substitutions['ndash'],), md)
self.inlinePatterns.add('smarty-em-dashes', emDashesPattern, '_begin')
self.inlinePatterns.add('smarty-en-dashes', enDashesPattern,
'>smarty-em-dashes')
def educateEllipses(self, md):
ellipsesPattern = SubstituteTextPattern(r'(?<!\.)\.{3}(?!\.)', ('&hellip;',), md)
md.inlinePatterns.add('smarty-ellipses', ellipsesPattern, '>entity')
ellipsesPattern = SubstituteTextPattern(r'(?<!\.)\.{3}(?!\.)',
(self.substitutions['ellipsis'],), md)
self.inlinePatterns.add('smarty-ellipses', ellipsesPattern, '_begin')
def educateAngledQuotes(self, md):
leftAngledQuotePattern = SubstituteTextPattern(r'\<\<',
(self.substitutions['left-angle-quote'],), md)
rightAngledQuotePattern = SubstituteTextPattern(r'\>\>',
(self.substitutions['right-angle-quote'],), md)
self.inlinePatterns.add('smarty-left-angle-quotes',
leftAngledQuotePattern, '_begin')
self.inlinePatterns.add('smarty-right-angle-quotes',
rightAngledQuotePattern, '>smarty-left-angle-quotes')
def educateQuotes(self, md):
configs = self.getConfigs()
lsquo = self.substitutions['left-single-quote']
rsquo = self.substitutions['right-single-quote']
ldquo = self.substitutions['left-double-quote']
rdquo = self.substitutions['right-double-quote']
patterns = (
(singleQuoteStartRe, (rsquo,)),
(doubleQuoteStartRe, (rdquo,)),
@ -179,13 +229,19 @@ class SmartyExtension(Extension):
def extendMarkdown(self, md, md_globals):
configs = self.getConfigs()
if configs['smart_quotes']:
self.educateQuotes(md)
if configs['smart_dashes']:
self.educateDashes(md)
self.inlinePatterns = OrderedDict()
if configs['smart_ellipses']:
self.educateEllipses(md)
if configs['smart_quotes']:
self.educateQuotes(md)
if configs['smart_angled_quotes']:
self.educateAngledQuotes(md)
if configs['smart_dashes']:
self.educateDashes(md)
inlineProcessor = InlineProcessor(md)
inlineProcessor.inlinePatterns = self.inlinePatterns
md.treeprocessors.add('smarty', inlineProcessor, '_end')
md.ESCAPED_CHARS.extend(['"', "'"])
def makeExtension(configs=None):
return SmartyExtension(configs)
def makeExtension(*args, **kwargs):
return SmartyExtension(*args, **kwargs)

View File

@ -4,14 +4,15 @@ Tables Extension for Python-Markdown
Added parsing of tables to Python-Markdown.
A simple example:
See <https://pythonhosted.org/Markdown/extensions/tables.html>
for documentation.
First Header | Second Header
------------- | -------------
Content Cell | Content Cell
Content Cell | Content Cell
Original code Copyright 2009 [Waylan Limberg](http://achinghead.com)
All changes Copyright 2008-2014 The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
Copyright 2009 - [Waylan Limberg](http://achinghead.com)
"""
from __future__ import absolute_import
@ -71,7 +72,7 @@ class TableProcessor(BlockProcessor):
c = etree.SubElement(tr, tag)
try:
c.text = cells[i].strip()
except IndexError:
except IndexError: #pragma: no cover
c.text = ""
if a:
c.set('align', a)
@ -96,5 +97,6 @@ class TableExtension(Extension):
'<hashheader')
def makeExtension(configs={}):
return TableExtension(configs=configs)
def makeExtension(*args, **kwargs):
return TableExtension(*args, **kwargs)

View File

@ -1,11 +1,15 @@
"""
Table of Contents Extension for Python-Markdown
* * *
===============================================
(c) 2008 [Jack Miller](http://codezen.org)
See <https://pythonhosted.org/Markdown/extensions/toc.html>
for documentation.
Dependencies:
* [Markdown 2.1+](http://packages.python.org/Markdown/)
Oringinal code Copyright 2008 [Jack Miller](http://codezen.org)
All changes Copyright 2008-2014 The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
"""
@ -23,60 +27,59 @@ def order_toc_list(toc_list):
[{'level': 1}, {'level': 2}]
=>
[{'level': 1, 'children': [{'level': 2, 'children': []}]}]
A wrong list is also converted:
[{'level': 2}, {'level': 1}]
=>
[{'level': 2, 'children': []}, {'level': 1, 'children': []}]
"""
def build_correct(remaining_list, prev_elements=[{'level': 1000}]):
if not remaining_list:
return [], []
current = remaining_list.pop(0)
if not 'children' in current.keys():
current['children'] = []
if not prev_elements:
# This happens for instance with [8, 1, 1], ie. when some
# header level is outside a scope. We treat it as a
# top-level
next_elements, children = build_correct(remaining_list, [current])
current['children'].append(children)
return [current] + next_elements, []
prev_element = prev_elements.pop()
children = []
next_elements = []
# Is current part of the child list or next list?
if current['level'] > prev_element['level']:
#print "%d is a child of %d" % (current['level'], prev_element['level'])
prev_elements.append(prev_element)
prev_elements.append(current)
prev_element['children'].append(current)
next_elements2, children2 = build_correct(remaining_list, prev_elements)
children += children2
next_elements += next_elements2
else:
#print "%d is ancestor of %d" % (current['level'], prev_element['level'])
if not prev_elements:
#print "No previous elements, so appending to the next set"
next_elements.append(current)
prev_elements = [current]
next_elements2, children2 = build_correct(remaining_list, prev_elements)
current['children'].extend(children2)
ordered_list = []
if len(toc_list):
# Initialize everything by processing the first entry
last = toc_list.pop(0)
last['children'] = []
levels = [last['level']]
ordered_list.append(last)
parents = []
# Walk the rest nesting the entries properly
while toc_list:
t = toc_list.pop(0)
current_level = t['level']
t['children'] = []
# Reduce depth if current level < last item's level
if current_level < levels[-1]:
# Pop last level since we know we are less than it
levels.pop()
# Pop parents and levels we are less than or equal to
to_pop = 0
for p in reversed(parents):
if current_level <= p['level']:
to_pop += 1
else:
break
if to_pop:
levels = levels[:-to_pop]
parents = parents[:-to_pop]
# Note current level as last
levels.append(current_level)
# Level is the same, so append to the current parent (if available)
if current_level == levels[-1]:
(parents[-1]['children'] if parents else ordered_list).append(t)
# Current level is > last item's level,
# So make last item a parent and append current as child
else:
#print "Previous elements, comparing to those first"
remaining_list.insert(0, current)
next_elements2, children2 = build_correct(remaining_list, prev_elements)
children.extend(children2)
next_elements += next_elements2
return next_elements, children
ordered_list, __ = build_correct(toc_list)
last['children'].append(t)
parents.append(last)
levels.append(current_level)
last = t
return ordered_list
@ -204,26 +207,26 @@ class TocExtension(Extension):
TreeProcessorClass = TocTreeprocessor
def __init__(self, configs=[]):
self.config = { "marker" : ["[TOC]",
"Text to find and replace with Table of Contents -"
"Defaults to \"[TOC]\""],
"slugify" : [slugify,
"Function to generate anchors based on header text-"
"Defaults to the headerid ext's slugify function."],
"title" : [None,
"Title to insert into TOC <div> - "
"Defaults to None"],
"anchorlink" : [0,
"1 if header should be a self link"
"Defaults to 0"],
"permalink" : [0,
"1 or link text if a Sphinx-style permalink should be added",
"Defaults to 0"]
}
def __init__(self, *args, **kwargs):
self.config = {
"marker" : ["[TOC]",
"Text to find and replace with Table of Contents - "
"Defaults to \"[TOC]\""],
"slugify" : [slugify,
"Function to generate anchors based on header text - "
"Defaults to the headerid ext's slugify function."],
"title" : ["",
"Title to insert into TOC <div> - "
"Defaults to an empty string"],
"anchorlink" : [0,
"1 if header should be a self link - "
"Defaults to 0"],
"permalink" : [0,
"1 or link text if a Sphinx-style permalink should be added - "
"Defaults to 0"]
}
for key, value in configs:
self.setConfig(key, value)
super(TocExtension, self).__init__(*args, **kwargs)
def extendMarkdown(self, md, md_globals):
tocext = self.TreeProcessorClass(md)
@ -236,5 +239,5 @@ class TocExtension(Extension):
md.treeprocessors.add("toc", tocext, "_end")
def makeExtension(configs={}):
return TocExtension(configs=configs)
def makeExtension(*args, **kwargs):
return TocExtension(*args, **kwargs)

View File

@ -2,78 +2,17 @@
WikiLinks Extension for Python-Markdown
======================================
Converts [[WikiLinks]] to relative links. Requires Python-Markdown 2.0+
Converts [[WikiLinks]] to relative links.
Basic usage:
See <https://pythonhosted.org/Markdown/extensions/wikilinks.html>
for documentation.
>>> import markdown
>>> text = "Some text with a [[WikiLink]]."
>>> html = markdown.markdown(text, ['wikilinks'])
>>> print html
<p>Some text with a <a class="wikilink" href="/WikiLink/">WikiLink</a>.</p>
Original code Copyright [Waylan Limberg](http://achinghead.com/).
Whitespace behavior:
>>> print markdown.markdown('[[ foo bar_baz ]]', ['wikilinks'])
<p><a class="wikilink" href="/foo_bar_baz/">foo bar_baz</a></p>
>>> print markdown.markdown('foo [[ ]] bar', ['wikilinks'])
<p>foo bar</p>
To define custom settings the simple way:
>>> print markdown.markdown(text,
... ['wikilinks(base_url=/wiki/,end_url=.html,html_class=foo)']
... )
<p>Some text with a <a class="foo" href="/wiki/WikiLink.html">WikiLink</a>.</p>
Custom settings the complex way:
>>> md = markdown.Markdown(
... extensions = ['wikilinks'],
... extension_configs = {'wikilinks': [
... ('base_url', 'http://example.com/'),
... ('end_url', '.html'),
... ('html_class', '') ]},
... safe_mode = True)
>>> print md.convert(text)
<p>Some text with a <a href="http://example.com/WikiLink.html">WikiLink</a>.</p>
Use MetaData with mdx_meta.py (Note the blank html_class in MetaData):
>>> text = """wiki_base_url: http://example.com/
... wiki_end_url: .html
... wiki_html_class:
...
... Some text with a [[WikiLink]]."""
>>> md = markdown.Markdown(extensions=['meta', 'wikilinks'])
>>> print md.convert(text)
<p>Some text with a <a href="http://example.com/WikiLink.html">WikiLink</a>.</p>
MetaData should not carry over to next document:
>>> print md.convert("No [[MetaData]] here.")
<p>No <a class="wikilink" href="/MetaData/">MetaData</a> here.</p>
Define a custom URL builder:
>>> def my_url_builder(label, base, end):
... return '/bar/'
>>> md = markdown.Markdown(extensions=['wikilinks'],
... extension_configs={'wikilinks' : [('build_url', my_url_builder)]})
>>> print md.convert('[[foo]]')
<p><a class="wikilink" href="/bar/">foo</a></p>
From the command line:
python markdown.py -x wikilinks(base_url=http://example.com/,end_url=.html,html_class=foo) src.txt
By [Waylan Limberg](http://achinghead.com/).
All changes Copyright The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
Dependencies:
* [Python 2.3+](http://python.org)
* [Markdown 2.0+](http://packages.python.org/Markdown/)
'''
from __future__ import absolute_import
@ -90,19 +29,17 @@ def build_url(label, base, end):
class WikiLinkExtension(Extension):
def __init__(self, configs):
# set extension defaults
def __init__ (self, *args, **kwargs):
self.config = {
'base_url' : ['/', 'String to append to beginning or URL.'],
'end_url' : ['/', 'String to append to end of URL.'],
'html_class' : ['wikilink', 'CSS hook. Leave blank for none.'],
'build_url' : [build_url, 'Callable formats URL from label.'],
'base_url' : ['/', 'String to append to beginning or URL.'],
'end_url' : ['/', 'String to append to end of URL.'],
'html_class' : ['wikilink', 'CSS hook. Leave blank for none.'],
'build_url' : [build_url, 'Callable formats URL from label.'],
}
configs = dict(configs) or {}
# Override defaults with user settings
for key, value in configs.items():
self.setConfig(key, value)
super(WikiLinkExtension, self).__init__(*args, **kwargs)
def extendMarkdown(self, md, md_globals):
self.md = md
@ -147,5 +84,5 @@ class WikiLinks(Pattern):
return base_url, end_url, html_class
def makeExtension(configs=None) :
return WikiLinkExtension(configs=configs)
def makeExtension(*args, **kwargs) :
return WikiLinkExtension(*args, **kwargs)

View File

@ -46,13 +46,13 @@ from __future__ import unicode_literals
from . import util
from . import odict
import re
try:
try: #pragma: no cover
from urllib.parse import urlparse, urlunparse
except ImportError:
except ImportError: #pragma: no cover
from urlparse import urlparse, urlunparse
try:
try: #pragma: no cover
from html import entities
except ImportError:
except ImportError: #pragma: no cover
import htmlentitydefs as entities
@ -75,7 +75,8 @@ def build_inlinepatterns(md_instance, **kwargs):
inlinePatterns["html"] = HtmlPattern(HTML_RE, md_instance)
inlinePatterns["entity"] = HtmlPattern(ENTITY_RE, md_instance)
inlinePatterns["not_strong"] = SimpleTextPattern(NOT_STRONG_RE)
inlinePatterns["strong_em"] = DoubleTagPattern(STRONG_EM_RE, 'strong,em')
inlinePatterns["em_strong"] = DoubleTagPattern(EM_STRONG_RE, 'strong,em')
inlinePatterns["strong_em"] = DoubleTagPattern(STRONG_EM_RE, 'em,strong')
inlinePatterns["strong"] = SimpleTagPattern(STRONG_RE, 'strong')
inlinePatterns["emphasis"] = SimpleTagPattern(EMPHASIS_RE, 'em')
if md_instance.smart_emphasis:
@ -100,7 +101,8 @@ BACKTICK_RE = r'(?<!\\)(`+)(.+?)(?<!`)\2(?!`)' # `e=f()` or ``e=f("`")``
ESCAPE_RE = r'\\(.)' # \<
EMPHASIS_RE = r'(\*)([^\*]+)\2' # *emphasis*
STRONG_RE = r'(\*{2}|_{2})(.+?)\2' # **strong**
STRONG_EM_RE = r'(\*{3}|_{3})(.+?)\2' # ***strong***
EM_STRONG_RE = r'(\*|_)\2{2}(.+?)\2(.*?)\2{2}' # ***strongem*** or ***em*strong**
STRONG_EM_RE = r'(\*|_)\2{2}(.+?)\2{2}(.*?)\2' # ***strong**em*
SMART_EMPHASIS_RE = r'(?<!\w)(_)(?!_)(.+?)(?<!_)\2(?!\w)' # _smart_emphasis_
EMPHASIS_2_RE = r'(_)(.+?)\2' # _emphasis_
LINK_RE = NOIMG + BRK + \
@ -156,7 +158,7 @@ class Pattern(object):
"""
self.pattern = pattern
self.compiled_re = re.compile("^(.*?)%s(.*?)$" % pattern,
self.compiled_re = re.compile("^(.*?)%s(.*?)$" % pattern,
re.DOTALL | re.UNICODE)
# Api for Markdown to pass safe_mode into instance
@ -178,7 +180,7 @@ class Pattern(object):
* m: A re match object containing a match of the pattern.
"""
pass
pass #pragma: no cover
def type(self):
""" Return class name, to define pattern type """
@ -188,9 +190,9 @@ class Pattern(object):
""" Return unescaped text given text with an inline placeholder. """
try:
stash = self.markdown.treeprocessors['inline'].stashed_nodes
except KeyError:
except KeyError: #pragma: no cover
return text
def itertext(el):
def itertext(el): #pragma: no cover
' Reimplement Element.itertext for older python versions '
tag = el.tag
if not isinstance(tag, util.string_type) and tag is not None:
@ -210,17 +212,14 @@ class Pattern(object):
return value
else:
# An etree Element - return text content only
return ''.join(itertext(value))
return ''.join(itertext(value))
return util.INLINE_PLACEHOLDER_RE.sub(get_stash, text)
class SimpleTextPattern(Pattern):
""" Return a simple text of group(2) of a Pattern. """
def handleMatch(self, m):
text = m.group(2)
if text == util.INLINE_PLACEHOLDER_PREFIX:
return None
return text
return m.group(2)
class EscapePattern(Pattern):
@ -231,7 +230,7 @@ class EscapePattern(Pattern):
if char in self.markdown.ESCAPED_CHARS:
return '%s%s%s' % (util.STX, ord(char), util.ETX)
else:
return None
return None
class SimpleTagPattern(Pattern):
@ -279,6 +278,8 @@ class DoubleTagPattern(SimpleTagPattern):
el1 = util.etree.Element(tag1)
el2 = util.etree.SubElement(el1, tag2)
el2.text = m.group(3)
if len(m.groups())==5:
el2.tail = m.group(4)
return el1
@ -293,7 +294,7 @@ class HtmlPattern(Pattern):
""" Return unescaped text given text with an inline placeholder. """
try:
stash = self.markdown.treeprocessors['inline'].stashed_nodes
except KeyError:
except KeyError: #pragma: no cover
return text
def get_stash(m):
id = m.group(1)
@ -303,7 +304,7 @@ class HtmlPattern(Pattern):
return self.markdown.serializer(value)
except:
return '\%s' % value
return util.INLINE_PLACEHOLDER_RE.sub(get_stash, text)
@ -323,7 +324,7 @@ class LinkPattern(Pattern):
el.set("href", "")
if title:
title = dequote(self.unescape(title))
title = dequote(self.unescape(title))
el.set("title", title)
return el
@ -347,20 +348,20 @@ class LinkPattern(Pattern):
if not self.markdown.safeMode:
# Return immediately bipassing parsing.
return url
try:
scheme, netloc, path, params, query, fragment = url = urlparse(url)
except ValueError:
except ValueError: #pragma: no cover
# Bad url - so bad it couldn't be parsed.
return ''
locless_schemes = ['', 'mailto', 'news']
allowed_schemes = locless_schemes + ['http', 'https', 'ftp', 'ftps']
if scheme not in allowed_schemes:
# Not a known (allowed) scheme. Not safe.
return ''
if netloc == '' and scheme not in locless_schemes:
if netloc == '' and scheme not in locless_schemes: #pragma: no cover
# This should not happen. Treat as suspect.
return ''

View File

@ -82,11 +82,11 @@ class OrderedDict(dict):
for key in self.keyOrder:
yield self[key]
if util.PY3:
if util.PY3: #pragma: no cover
items = _iteritems
keys = _iterkeys
values = _itervalues
else:
else: #pragma: no cover
iteritems = _iteritems
iterkeys = _iterkeys
itervalues = _itervalues

View File

@ -42,7 +42,7 @@ class Postprocessor(util.Processor):
(possibly modified) string.
"""
pass
pass #pragma: no cover
class RawHtmlPostprocessor(Postprocessor):

View File

@ -41,7 +41,7 @@ class Preprocessor(util.Processor):
the (possibly modified) list of lines.
"""
pass
pass #pragma: no cover
class NormalizeWhitespace(Preprocessor):
@ -174,9 +174,10 @@ class HtmlBlockPreprocessor(Preprocessor):
else: # raw html
if len(items) - right_listindex <= 1: # last element
right_listindex -= 1
offset = 1 if i == right_listindex else 0
placeholder = self.markdown.htmlStash.store('\n\n'.join(
items[i:right_listindex + 1]))
del items[i:right_listindex + 1]
items[i:right_listindex + offset]))
del items[i:right_listindex + offset]
items.insert(i, placeholder)
return items

View File

@ -42,9 +42,9 @@ from __future__ import unicode_literals
from . import util
ElementTree = util.etree.ElementTree
QName = util.etree.QName
if hasattr(util.etree, 'test_comment'):
if hasattr(util.etree, 'test_comment'): #pragma: no cover
Comment = util.etree.test_comment
else:
else: #pragma: no cover
Comment = util.etree.Comment
PI = util.etree.PI
ProcessingInstruction = util.etree.ProcessingInstruction
@ -56,7 +56,7 @@ HTML_EMPTY = ("area", "base", "basefont", "br", "col", "frame", "hr",
try:
HTML_EMPTY = set(HTML_EMPTY)
except NameError:
except NameError: #pragma: no cover
pass
_namespace_map = {
@ -73,7 +73,7 @@ _namespace_map = {
}
def _raise_serialization_error(text):
def _raise_serialization_error(text): #pragma: no cover
raise TypeError(
"cannot serialize %r (type %s)" % (text, type(text).__name__)
)
@ -81,7 +81,7 @@ def _raise_serialization_error(text):
def _encode(text, encoding):
try:
return text.encode(encoding, "xmlcharrefreplace")
except (TypeError, AttributeError):
except (TypeError, AttributeError): #pragma: no cover
_raise_serialization_error(text)
def _escape_cdata(text):
@ -97,7 +97,7 @@ def _escape_cdata(text):
if ">" in text:
text = text.replace(">", "&gt;")
return text
except (TypeError, AttributeError):
except (TypeError, AttributeError): #pragma: no cover
_raise_serialization_error(text)
@ -115,7 +115,7 @@ def _escape_attrib(text):
if "\n" in text:
text = text.replace("\n", "&#10;")
return text
except (TypeError, AttributeError):
except (TypeError, AttributeError): #pragma: no cover
_raise_serialization_error(text)
def _escape_attrib_html(text):
@ -130,7 +130,7 @@ def _escape_attrib_html(text):
if "\"" in text:
text = text.replace("\"", "&quot;")
return text
except (TypeError, AttributeError):
except (TypeError, AttributeError): #pragma: no cover
_raise_serialization_error(text)
@ -240,7 +240,7 @@ def _namespaces(elem, default_namespace=None):
"default_namespace option"
)
qnames[qname] = qname
except TypeError:
except TypeError: #pragma: no cover
_raise_serialization_error(qname)
# populate qname and namespaces table

View File

@ -34,11 +34,11 @@ class Treeprocessor(util.Processor):
def run(self, root):
"""
Subclasses of Treeprocessor should implement a `run` method, which
takes a root ElementTree. This method can return another ElementTree
object, and the existing root ElementTree will be replaced, or it can
takes a root ElementTree. This method can return another ElementTree
object, and the existing root ElementTree will be replaced, or it can
modify the current tree and return None.
"""
pass
pass #pragma: no cover
class InlineProcessor(Treeprocessor):
@ -53,6 +53,7 @@ class InlineProcessor(Treeprocessor):
+ len(self.__placeholder_suffix)
self.__placeholder_re = util.INLINE_PLACEHOLDER_RE
self.markdown = md
self.inlinePatterns = md.inlinePatterns
def __makePlaceholder(self, type):
""" Generate a placeholder """
@ -70,7 +71,7 @@ class InlineProcessor(Treeprocessor):
* index: index, from which we start search
Returns: placeholder id and string index, after the found placeholder.
"""
m = self.__placeholder_re.search(data, index)
if m:
@ -99,9 +100,9 @@ class InlineProcessor(Treeprocessor):
"""
if not isinstance(data, util.AtomicString):
startIndex = 0
while patternIndex < len(self.markdown.inlinePatterns):
while patternIndex < len(self.inlinePatterns):
data, matched, startIndex = self.__applyPattern(
self.markdown.inlinePatterns.value_for_index(patternIndex),
self.inlinePatterns.value_for_index(patternIndex),
data, patternIndex, startIndex)
if not matched:
patternIndex += 1
@ -128,11 +129,10 @@ class InlineProcessor(Treeprocessor):
text = subnode.tail
subnode.tail = None
childResult = self.__processPlaceholders(text, subnode)
childResult = self.__processPlaceholders(text, subnode, isText)
if not isText and node is not subnode:
pos = list(node).index(subnode)
node.remove(subnode)
pos = list(node).index(subnode) + 1
else:
pos = 0
@ -140,7 +140,7 @@ class InlineProcessor(Treeprocessor):
for newChild in childResult:
node.insert(pos, newChild)
def __processPlaceholders(self, data, parent):
def __processPlaceholders(self, data, parent, isText=True):
"""
Process string with placeholders and generate ElementTree tree.
@ -150,7 +150,7 @@ class InlineProcessor(Treeprocessor):
* parent: Element, which contains processing inline data
Returns: list with ElementTree elements with applied inline patterns.
"""
def linkText(text):
if text:
@ -159,6 +159,11 @@ class InlineProcessor(Treeprocessor):
result[-1].tail += text
else:
result[-1].tail = text
elif not isText:
if parent.tail:
parent.tail += text
else:
parent.tail = text
else:
if parent.text:
parent.text += text
@ -182,7 +187,7 @@ class InlineProcessor(Treeprocessor):
for child in [node] + list(node):
if child.tail:
if child.tail.strip():
self.__processElementText(node, child,False)
self.__processElementText(node, child, False)
if child.text:
if child.text.strip():
self.__processElementText(child, child)
@ -239,7 +244,7 @@ class InlineProcessor(Treeprocessor):
# We need to process current node too
for child in [node] + list(node):
if not isString(node):
if child.text:
if child.text:
child.text = self.__handleInline(child.text,
patternIndex + 1)
if child.tail:
@ -287,11 +292,10 @@ class InlineProcessor(Treeprocessor):
if child.tail:
tail = self.__handleInline(child.tail)
dumby = util.etree.Element('d')
tailResult = self.__processPlaceholders(tail, dumby)
if dumby.text:
child.tail = dumby.text
else:
child.tail = None
child.tail = None
tailResult = self.__processPlaceholders(tail, dumby, False)
if dumby.tail:
child.tail = dumby.tail
pos = list(currElement).index(child) + 1
tailResult.reverse()
for newChild in tailResult:
@ -303,7 +307,7 @@ class InlineProcessor(Treeprocessor):
if self.markdown.enable_attributes:
if element.text and isString(element.text):
element.text = \
inlinepatterns.handleAttributes(element.text,
inlinepatterns.handleAttributes(element.text,
element)
i = 0
for newChild in lst:
@ -357,4 +361,4 @@ class PrettifyTreeprocessor(Treeprocessor):
pres = root.getiterator('pre')
for pre in pres:
if len(pre) and pre[0].tag == 'code':
pre[0].text = pre[0].text.rstrip() + '\n'
pre[0].text = util.AtomicString(pre[0].text.rstrip() + '\n')

View File

@ -10,11 +10,11 @@ Python 3 Stuff
"""
PY3 = sys.version_info[0] == 3
if PY3:
if PY3: #pragma: no cover
string_type = str
text_type = str
int2str = chr
else:
else: #pragma: no cover
string_type = basestring
text_type = unicode
int2str = unichr
@ -58,14 +58,15 @@ RTL_BIDI_RANGES = ( ('\u0590', '\u07FF'),
# Extensions should use "markdown.util.etree" instead of "etree" (or do `from
# markdown.util import etree`). Do not import it by yourself.
try: # Is the C implementation of ElementTree available?
try: #pragma: no cover
# Is the C implementation of ElementTree available?
import xml.etree.cElementTree as etree
from xml.etree.ElementTree import Comment
# Serializers (including ours) test with non-c Comment
etree.test_comment = Comment
if etree.VERSION < "1.0.5":
raise RuntimeError("cElementTree version 1.0.5 or higher is required.")
except (ImportError, RuntimeError):
except (ImportError, RuntimeError): #pragma: no cover
# Use the Python implementation of ElementTree?
import xml.etree.ElementTree as etree
if etree.VERSION < "1.1":
@ -85,15 +86,20 @@ def isBlockLevel(tag):
# Some ElementTree tags are not strings, so return False.
return False
def parseBoolValue(value, fail_on_errors=True):
def parseBoolValue(value, fail_on_errors=True, preserve_none=False):
"""Parses a string representing bool value. If parsing was successful,
returns True or False. If parsing was not successful, raises
ValueError, or, if fail_on_errors=False, returns None."""
returns True or False. If preserve_none=True, returns True, False,
or None. If parsing was not successful, raises ValueError, or, if
fail_on_errors=False, returns None."""
if not isinstance(value, string_type):
if preserve_none and value is None:
return value
return bool(value)
elif preserve_none and value.lower() == 'none':
return None
elif value.lower() in ('true', 'yes', 'y', 'on', '1'):
return True
elif value.lower() in ('false', 'no', 'n', 'off', '0'):
elif value.lower() in ('false', 'no', 'n', 'off', '0', 'none'):
return False
elif fail_on_errors:
raise ValueError('Cannot parse bool value: %r' % value)

View File

@ -0,0 +1,31 @@
# Copyright 2014 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
__all__ = [
"__title__", "__summary__", "__uri__", "__version__", "__author__",
"__email__", "__license__", "__copyright__",
]
__title__ = "packaging"
__summary__ = "Core utilities for Python packages"
__uri__ = "https://github.com/pypa/packaging"
__version__ = "15.0"
__author__ = "Donald Stufft"
__email__ = "donald@stufft.io"
__license__ = "Apache License, Version 2.0"
__copyright__ = "Copyright 2014 %s" % __author__

View File

@ -0,0 +1,24 @@
# Copyright 2014 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
from .__about__ import (
__author__, __copyright__, __email__, __license__, __summary__, __title__,
__uri__, __version__
)
__all__ = [
"__title__", "__summary__", "__uri__", "__version__", "__author__",
"__email__", "__license__", "__copyright__",
]

View File

@ -0,0 +1,40 @@
# Copyright 2014 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import sys
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
# flake8: noqa
if PY3:
string_types = str,
else:
string_types = basestring,
def with_metaclass(meta, *bases):
"""
Create a base class with a metaclass.
"""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(meta):
def __new__(cls, name, this_bases, d):
return meta(name, bases, d)
return type.__new__(metaclass, 'temporary_class', (), {})

View File

@ -0,0 +1,78 @@
# Copyright 2014 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
class Infinity(object):
def __repr__(self):
return "Infinity"
def __hash__(self):
return hash(repr(self))
def __lt__(self, other):
return False
def __le__(self, other):
return False
def __eq__(self, other):
return isinstance(other, self.__class__)
def __ne__(self, other):
return not isinstance(other, self.__class__)
def __gt__(self, other):
return True
def __ge__(self, other):
return True
def __neg__(self):
return NegativeInfinity
Infinity = Infinity()
class NegativeInfinity(object):
def __repr__(self):
return "-Infinity"
def __hash__(self):
return hash(repr(self))
def __lt__(self, other):
return True
def __le__(self, other):
return True
def __eq__(self, other):
return isinstance(other, self.__class__)
def __ne__(self, other):
return not isinstance(other, self.__class__)
def __gt__(self, other):
return False
def __ge__(self, other):
return False
def __neg__(self):
return Infinity
NegativeInfinity = NegativeInfinity()

View File

@ -0,0 +1,772 @@
# Copyright 2014 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import abc
import functools
import itertools
import re
from ._compat import string_types, with_metaclass
from .version import Version, LegacyVersion, parse
class InvalidSpecifier(ValueError):
"""
An invalid specifier was found, users should refer to PEP 440.
"""
class BaseSpecifier(with_metaclass(abc.ABCMeta, object)):
@abc.abstractmethod
def __str__(self):
"""
Returns the str representation of this Specifier like object. This
should be representative of the Specifier itself.
"""
@abc.abstractmethod
def __hash__(self):
"""
Returns a hash value for this Specifier like object.
"""
@abc.abstractmethod
def __eq__(self, other):
"""
Returns a boolean representing whether or not the two Specifier like
objects are equal.
"""
@abc.abstractmethod
def __ne__(self, other):
"""
Returns a boolean representing whether or not the two Specifier like
objects are not equal.
"""
@abc.abstractproperty
def prereleases(self):
"""
Returns whether or not pre-releases as a whole are allowed by this
specifier.
"""
@prereleases.setter
def prereleases(self, value):
"""
Sets whether or not pre-releases as a whole are allowed by this
specifier.
"""
@abc.abstractmethod
def contains(self, item, prereleases=None):
"""
Determines if the given item is contained within this specifier.
"""
@abc.abstractmethod
def filter(self, iterable, prereleases=None):
"""
Takes an iterable of items and filters them so that only items which
are contained within this specifier are allowed in it.
"""
class _IndividualSpecifier(BaseSpecifier):
_operators = {}
def __init__(self, spec="", prereleases=None):
match = self._regex.search(spec)
if not match:
raise InvalidSpecifier("Invalid specifier: '{0}'".format(spec))
self._spec = (
match.group("operator").strip(),
match.group("version").strip(),
)
# Store whether or not this Specifier should accept prereleases
self._prereleases = prereleases
def __repr__(self):
pre = (
", prereleases={0!r}".format(self.prereleases)
if self._prereleases is not None
else ""
)
return "<{0}({1!r}{2})>".format(
self.__class__.__name__,
str(self),
pre,
)
def __str__(self):
return "{0}{1}".format(*self._spec)
def __hash__(self):
return hash(self._spec)
def __eq__(self, other):
if isinstance(other, string_types):
try:
other = self.__class__(other)
except InvalidSpecifier:
return NotImplemented
elif not isinstance(other, self.__class__):
return NotImplemented
return self._spec == other._spec
def __ne__(self, other):
if isinstance(other, string_types):
try:
other = self.__class__(other)
except InvalidSpecifier:
return NotImplemented
elif not isinstance(other, self.__class__):
return NotImplemented
return self._spec != other._spec
def _get_operator(self, op):
return getattr(self, "_compare_{0}".format(self._operators[op]))
def _coerce_version(self, version):
if not isinstance(version, (LegacyVersion, Version)):
version = parse(version)
return version
@property
def prereleases(self):
return self._prereleases
@prereleases.setter
def prereleases(self, value):
self._prereleases = value
def contains(self, item, prereleases=None):
# Determine if prereleases are to be allowed or not.
if prereleases is None:
prereleases = self.prereleases
# Normalize item to a Version or LegacyVersion, this allows us to have
# a shortcut for ``"2.0" in Specifier(">=2")
item = self._coerce_version(item)
# Determine if we should be supporting prereleases in this specifier
# or not, if we do not support prereleases than we can short circuit
# logic if this version is a prereleases.
if item.is_prerelease and not prereleases:
return False
# Actually do the comparison to determine if this item is contained
# within this Specifier or not.
return self._get_operator(self._spec[0])(item, self._spec[1])
def filter(self, iterable, prereleases=None):
yielded = False
found_prereleases = []
kw = {"prereleases": prereleases if prereleases is not None else True}
# Attempt to iterate over all the values in the iterable and if any of
# them match, yield them.
for version in iterable:
parsed_version = self._coerce_version(version)
if self.contains(parsed_version, **kw):
# If our version is a prerelease, and we were not set to allow
# prereleases, then we'll store it for later incase nothing
# else matches this specifier.
if (parsed_version.is_prerelease
and not (prereleases or self.prereleases)):
found_prereleases.append(version)
# Either this is not a prerelease, or we should have been
# accepting prereleases from the begining.
else:
yielded = True
yield version
# Now that we've iterated over everything, determine if we've yielded
# any values, and if we have not and we have any prereleases stored up
# then we will go ahead and yield the prereleases.
if not yielded and found_prereleases:
for version in found_prereleases:
yield version
class LegacySpecifier(_IndividualSpecifier):
_regex = re.compile(
r"""
^
\s*
(?P<operator>(==|!=|<=|>=|<|>))
\s*
(?P<version>
[^\s]* # We just match everything, except for whitespace since this
# is a "legacy" specifier and the version string can be just
# about anything.
)
\s*
$
""",
re.VERBOSE | re.IGNORECASE,
)
_operators = {
"==": "equal",
"!=": "not_equal",
"<=": "less_than_equal",
">=": "greater_than_equal",
"<": "less_than",
">": "greater_than",
}
def _coerce_version(self, version):
if not isinstance(version, LegacyVersion):
version = LegacyVersion(str(version))
return version
def _compare_equal(self, prospective, spec):
return prospective == self._coerce_version(spec)
def _compare_not_equal(self, prospective, spec):
return prospective != self._coerce_version(spec)
def _compare_less_than_equal(self, prospective, spec):
return prospective <= self._coerce_version(spec)
def _compare_greater_than_equal(self, prospective, spec):
return prospective >= self._coerce_version(spec)
def _compare_less_than(self, prospective, spec):
return prospective < self._coerce_version(spec)
def _compare_greater_than(self, prospective, spec):
return prospective > self._coerce_version(spec)
def _require_version_compare(fn):
@functools.wraps(fn)
def wrapped(self, prospective, spec):
if not isinstance(prospective, Version):
return False
return fn(self, prospective, spec)
return wrapped
class Specifier(_IndividualSpecifier):
_regex = re.compile(
r"""
^
\s*
(?P<operator>(~=|==|!=|<=|>=|<|>|===))
(?P<version>
(?:
# The identity operators allow for an escape hatch that will
# do an exact string match of the version you wish to install.
# This will not be parsed by PEP 440 and we cannot determine
# any semantic meaning from it. This operator is discouraged
# but included entirely as an escape hatch.
(?<====) # Only match for the identity operator
\s*
[^\s]* # We just match everything, except for whitespace
# since we are only testing for strict identity.
)
|
(?:
# The (non)equality operators allow for wild card and local
# versions to be specified so we have to define these two
# operators separately to enable that.
(?<===|!=) # Only match for equals and not equals
\s*
v?
(?:[0-9]+!)? # epoch
[0-9]+(?:\.[0-9]+)* # release
(?: # pre release
[-_\.]?
(a|b|c|rc|alpha|beta|pre|preview)
[-_\.]?
[0-9]*
)?
(?: # post release
(?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*)
)?
# You cannot use a wild card and a dev or local version
# together so group them with a | and make them optional.
(?:
(?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release
(?:\+[a-z0-9]+(?:[-_\.][a-z0-9]+)*)? # local
|
\.\* # Wild card syntax of .*
)?
)
|
(?:
# The compatible operator requires at least two digits in the
# release segment.
(?<=~=) # Only match for the compatible operator
\s*
v?
(?:[0-9]+!)? # epoch
[0-9]+(?:\.[0-9]+)+ # release (We have a + instead of a *)
(?: # pre release
[-_\.]?
(a|b|c|rc|alpha|beta|pre|preview)
[-_\.]?
[0-9]*
)?
(?: # post release
(?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*)
)?
(?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release
)
|
(?:
# All other operators only allow a sub set of what the
# (non)equality operators do. Specifically they do not allow
# local versions to be specified nor do they allow the prefix
# matching wild cards.
(?<!==|!=|~=) # We have special cases for these
# operators so we want to make sure they
# don't match here.
\s*
v?
(?:[0-9]+!)? # epoch
[0-9]+(?:\.[0-9]+)* # release
(?: # pre release
[-_\.]?
(a|b|c|rc|alpha|beta|pre|preview)
[-_\.]?
[0-9]*
)?
(?: # post release
(?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*)
)?
(?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release
)
)
\s*
$
""",
re.VERBOSE | re.IGNORECASE,
)
_operators = {
"~=": "compatible",
"==": "equal",
"!=": "not_equal",
"<=": "less_than_equal",
">=": "greater_than_equal",
"<": "less_than",
">": "greater_than",
"===": "arbitrary",
}
@_require_version_compare
def _compare_compatible(self, prospective, spec):
# Compatible releases have an equivalent combination of >= and ==. That
# is that ~=2.2 is equivalent to >=2.2,==2.*. This allows us to
# implement this in terms of the other specifiers instead of
# implementing it ourselves. The only thing we need to do is construct
# the other specifiers.
# We want everything but the last item in the version, but we want to
# ignore post and dev releases and we want to treat the pre-release as
# it's own separate segment.
prefix = ".".join(
list(
itertools.takewhile(
lambda x: (not x.startswith("post")
and not x.startswith("dev")),
_version_split(spec),
)
)[:-1]
)
# Add the prefix notation to the end of our string
prefix += ".*"
return (self._get_operator(">=")(prospective, spec)
and self._get_operator("==")(prospective, prefix))
@_require_version_compare
def _compare_equal(self, prospective, spec):
# We need special logic to handle prefix matching
if spec.endswith(".*"):
# Split the spec out by dots, and pretend that there is an implicit
# dot in between a release segment and a pre-release segment.
spec = _version_split(spec[:-2]) # Remove the trailing .*
# Split the prospective version out by dots, and pretend that there
# is an implicit dot in between a release segment and a pre-release
# segment.
prospective = _version_split(str(prospective))
# Shorten the prospective version to be the same length as the spec
# so that we can determine if the specifier is a prefix of the
# prospective version or not.
prospective = prospective[:len(spec)]
# Pad out our two sides with zeros so that they both equal the same
# length.
spec, prospective = _pad_version(spec, prospective)
else:
# Convert our spec string into a Version
spec = Version(spec)
# If the specifier does not have a local segment, then we want to
# act as if the prospective version also does not have a local
# segment.
if not spec.local:
prospective = Version(prospective.public)
return prospective == spec
@_require_version_compare
def _compare_not_equal(self, prospective, spec):
return not self._compare_equal(prospective, spec)
@_require_version_compare
def _compare_less_than_equal(self, prospective, spec):
return prospective <= Version(spec)
@_require_version_compare
def _compare_greater_than_equal(self, prospective, spec):
return prospective >= Version(spec)
@_require_version_compare
def _compare_less_than(self, prospective, spec):
# Convert our spec to a Version instance, since we'll want to work with
# it as a version.
spec = Version(spec)
# Check to see if the prospective version is less than the spec
# version. If it's not we can short circuit and just return False now
# instead of doing extra unneeded work.
if not prospective < spec:
return False
# This special case is here so that, unless the specifier itself
# includes is a pre-release version, that we do not accept pre-release
# versions for the version mentioned in the specifier (e.g. <3.1 should
# not match 3.1.dev0, but should match 3.0.dev0).
if not spec.is_prerelease and prospective.is_prerelease:
if Version(prospective.base_version) == Version(spec.base_version):
return False
# If we've gotten to here, it means that prospective version is both
# less than the spec version *and* it's not a pre-release of the same
# version in the spec.
return True
@_require_version_compare
def _compare_greater_than(self, prospective, spec):
# Convert our spec to a Version instance, since we'll want to work with
# it as a version.
spec = Version(spec)
# Check to see if the prospective version is greater than the spec
# version. If it's not we can short circuit and just return False now
# instead of doing extra unneeded work.
if not prospective > spec:
return False
# This special case is here so that, unless the specifier itself
# includes is a post-release version, that we do not accept
# post-release versions for the version mentioned in the specifier
# (e.g. >3.1 should not match 3.0.post0, but should match 3.2.post0).
if not spec.is_postrelease and prospective.is_postrelease:
if Version(prospective.base_version) == Version(spec.base_version):
return False
# Ensure that we do not allow a local version of the version mentioned
# in the specifier, which is techincally greater than, to match.
if prospective.local is not None:
if Version(prospective.base_version) == Version(spec.base_version):
return False
# If we've gotten to here, it means that prospective version is both
# greater than the spec version *and* it's not a pre-release of the
# same version in the spec.
return True
def _compare_arbitrary(self, prospective, spec):
return str(prospective).lower() == str(spec).lower()
@property
def prereleases(self):
# If there is an explicit prereleases set for this, then we'll just
# blindly use that.
if self._prereleases is not None:
return self._prereleases
# Look at all of our specifiers and determine if they are inclusive
# operators, and if they are if they are including an explicit
# prerelease.
operator, version = self._spec
if operator in ["==", ">=", "<=", "~="]:
# The == specifier can include a trailing .*, if it does we
# want to remove before parsing.
if operator == "==" and version.endswith(".*"):
version = version[:-2]
# Parse the version, and if it is a pre-release than this
# specifier allows pre-releases.
if parse(version).is_prerelease:
return True
return False
@prereleases.setter
def prereleases(self, value):
self._prereleases = value
_prefix_regex = re.compile(r"^([0-9]+)((?:a|b|c|rc)[0-9]+)$")
def _version_split(version):
result = []
for item in version.split("."):
match = _prefix_regex.search(item)
if match:
result.extend(match.groups())
else:
result.append(item)
return result
def _pad_version(left, right):
left_split, right_split = [], []
# Get the release segment of our versions
left_split.append(list(itertools.takewhile(lambda x: x.isdigit(), left)))
right_split.append(list(itertools.takewhile(lambda x: x.isdigit(), right)))
# Get the rest of our versions
left_split.append(left[len(left_split):])
right_split.append(left[len(right_split):])
# Insert our padding
left_split.insert(
1,
["0"] * max(0, len(right_split[0]) - len(left_split[0])),
)
right_split.insert(
1,
["0"] * max(0, len(left_split[0]) - len(right_split[0])),
)
return (
list(itertools.chain(*left_split)),
list(itertools.chain(*right_split)),
)
class SpecifierSet(BaseSpecifier):
def __init__(self, specifiers="", prereleases=None):
# Split on , to break each indidivual specifier into it's own item, and
# strip each item to remove leading/trailing whitespace.
specifiers = [s.strip() for s in specifiers.split(",") if s.strip()]
# Parsed each individual specifier, attempting first to make it a
# Specifier and falling back to a LegacySpecifier.
parsed = set()
for specifier in specifiers:
try:
parsed.add(Specifier(specifier))
except InvalidSpecifier:
parsed.add(LegacySpecifier(specifier))
# Turn our parsed specifiers into a frozen set and save them for later.
self._specs = frozenset(parsed)
# Store our prereleases value so we can use it later to determine if
# we accept prereleases or not.
self._prereleases = prereleases
def __repr__(self):
pre = (
", prereleases={0!r}".format(self.prereleases)
if self._prereleases is not None
else ""
)
return "<SpecifierSet({0!r}{1})>".format(str(self), pre)
def __str__(self):
return ",".join(sorted(str(s) for s in self._specs))
def __hash__(self):
return hash(self._specs)
def __and__(self, other):
if isinstance(other, string_types):
other = SpecifierSet(other)
elif not isinstance(other, SpecifierSet):
return NotImplemented
specifier = SpecifierSet()
specifier._specs = frozenset(self._specs | other._specs)
if self._prereleases is None and other._prereleases is not None:
specifier._prereleases = other._prereleases
elif self._prereleases is not None and other._prereleases is None:
specifier._prereleases = self._prereleases
elif self._prereleases == other._prereleases:
specifier._prereleases = self._prereleases
else:
raise ValueError(
"Cannot combine SpecifierSets with True and False prerelease "
"overrides."
)
return specifier
def __eq__(self, other):
if isinstance(other, string_types):
other = SpecifierSet(other)
elif isinstance(other, _IndividualSpecifier):
other = SpecifierSet(str(other))
elif not isinstance(other, SpecifierSet):
return NotImplemented
return self._specs == other._specs
def __ne__(self, other):
if isinstance(other, string_types):
other = SpecifierSet(other)
elif isinstance(other, _IndividualSpecifier):
other = SpecifierSet(str(other))
elif not isinstance(other, SpecifierSet):
return NotImplemented
return self._specs != other._specs
@property
def prereleases(self):
# If we have been given an explicit prerelease modifier, then we'll
# pass that through here.
if self._prereleases is not None:
return self._prereleases
# Otherwise we'll see if any of the given specifiers accept
# prereleases, if any of them do we'll return True, otherwise False.
# Note: The use of any() here means that an empty set of specifiers
# will always return False, this is an explicit design decision.
return any(s.prereleases for s in self._specs)
@prereleases.setter
def prereleases(self, value):
self._prereleases = value
def contains(self, item, prereleases=None):
# Ensure that our item is a Version or LegacyVersion instance.
if not isinstance(item, (LegacyVersion, Version)):
item = parse(item)
# We can determine if we're going to allow pre-releases by looking to
# see if any of the underlying items supports them. If none of them do
# and this item is a pre-release then we do not allow it and we can
# short circuit that here.
# Note: This means that 1.0.dev1 would not be contained in something
# like >=1.0.devabc however it would be in >=1.0.debabc,>0.0.dev0
if (not (self.prereleases or prereleases)) and item.is_prerelease:
return False
# Determine if we're forcing a prerelease or not, we bypass
# self.prereleases here and use self._prereleases because we want to
# only take into consideration actual *forced* values. The underlying
# specifiers will handle the other logic.
# The logic here is: If prereleases is anything but None, we'll just
# go aheand and continue to use that. However if
# prereleases is None, then we'll use whatever the
# value of self._prereleases is as long as it is not
# None itself.
if prereleases is None and self._prereleases is not None:
prereleases = self._prereleases
# We simply dispatch to the underlying specs here to make sure that the
# given version is contained within all of them.
# Note: This use of all() here means that an empty set of specifiers
# will always return True, this is an explicit design decision.
return all(
s.contains(item, prereleases=prereleases)
for s in self._specs
)
def filter(self, iterable, prereleases=None):
# Determine if we're forcing a prerelease or not, we bypass
# self.prereleases here and use self._prereleases because we want to
# only take into consideration actual *forced* values. The underlying
# specifiers will handle the other logic.
# The logic here is: If prereleases is anything but None, we'll just
# go aheand and continue to use that. However if
# prereleases is None, then we'll use whatever the
# value of self._prereleases is as long as it is not
# None itself.
if prereleases is None and self._prereleases is not None:
prereleases = self._prereleases
# If we have any specifiers, then we want to wrap our iterable in the
# filter method for each one, this will act as a logical AND amongst
# each specifier.
if self._specs:
for spec in self._specs:
iterable = spec.filter(iterable, prereleases=prereleases)
return iterable
# If we do not have any specifiers, then we need to have a rough filter
# which will filter out any pre-releases, unless there are no final
# releases, and which will filter out LegacyVersion in general.
else:
filtered = []
found_prereleases = []
for item in iterable:
# Ensure that we some kind of Version class for this item.
if not isinstance(item, (LegacyVersion, Version)):
parsed_version = parse(item)
else:
parsed_version = item
# Filter out any item which is parsed as a LegacyVersion
if isinstance(parsed_version, LegacyVersion):
continue
# Store any item which is a pre-release for later unless we've
# already found a final version or we are accepting prereleases
if parsed_version.is_prerelease and not prereleases:
if not filtered:
found_prereleases.append(item)
else:
filtered.append(item)
# If we've found no items except for pre-releases, then we'll go
# ahead and use the pre-releases
if not filtered and found_prereleases and prereleases is None:
return found_prereleases
return filtered

View File

@ -0,0 +1,401 @@
# Copyright 2014 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import collections
import itertools
import re
from ._structures import Infinity
__all__ = [
"parse", "Version", "LegacyVersion", "InvalidVersion", "VERSION_PATTERN"
]
_Version = collections.namedtuple(
"_Version",
["epoch", "release", "dev", "pre", "post", "local"],
)
def parse(version):
"""
Parse the given version string and return either a :class:`Version` object
or a :class:`LegacyVersion` object depending on if the given version is
a valid PEP 440 version or a legacy version.
"""
try:
return Version(version)
except InvalidVersion:
return LegacyVersion(version)
class InvalidVersion(ValueError):
"""
An invalid version was found, users should refer to PEP 440.
"""
class _BaseVersion(object):
def __hash__(self):
return hash(self._key)
def __lt__(self, other):
return self._compare(other, lambda s, o: s < o)
def __le__(self, other):
return self._compare(other, lambda s, o: s <= o)
def __eq__(self, other):
return self._compare(other, lambda s, o: s == o)
def __ge__(self, other):
return self._compare(other, lambda s, o: s >= o)
def __gt__(self, other):
return self._compare(other, lambda s, o: s > o)
def __ne__(self, other):
return self._compare(other, lambda s, o: s != o)
def _compare(self, other, method):
if not isinstance(other, _BaseVersion):
return NotImplemented
return method(self._key, other._key)
class LegacyVersion(_BaseVersion):
def __init__(self, version):
self._version = str(version)
self._key = _legacy_cmpkey(self._version)
def __str__(self):
return self._version
def __repr__(self):
return "<LegacyVersion({0})>".format(repr(str(self)))
@property
def public(self):
return self._version
@property
def base_version(self):
return self._version
@property
def local(self):
return None
@property
def is_prerelease(self):
return False
@property
def is_postrelease(self):
return False
_legacy_version_component_re = re.compile(
r"(\d+ | [a-z]+ | \.| -)", re.VERBOSE,
)
_legacy_version_replacement_map = {
"pre": "c", "preview": "c", "-": "final-", "rc": "c", "dev": "@",
}
def _parse_version_parts(s):
for part in _legacy_version_component_re.split(s):
part = _legacy_version_replacement_map.get(part, part)
if not part or part == ".":
continue
if part[:1] in "0123456789":
# pad for numeric comparison
yield part.zfill(8)
else:
yield "*" + part
# ensure that alpha/beta/candidate are before final
yield "*final"
def _legacy_cmpkey(version):
# We hardcode an epoch of -1 here. A PEP 440 version can only have a epoch
# greater than or equal to 0. This will effectively put the LegacyVersion,
# which uses the defacto standard originally implemented by setuptools,
# as before all PEP 440 versions.
epoch = -1
# This scheme is taken from pkg_resources.parse_version setuptools prior to
# it's adoption of the packaging library.
parts = []
for part in _parse_version_parts(version.lower()):
if part.startswith("*"):
# remove "-" before a prerelease tag
if part < "*final":
while parts and parts[-1] == "*final-":
parts.pop()
# remove trailing zeros from each series of numeric parts
while parts and parts[-1] == "00000000":
parts.pop()
parts.append(part)
parts = tuple(parts)
return epoch, parts
# Deliberately not anchored to the start and end of the string, to make it
# easier for 3rd party code to reuse
VERSION_PATTERN = r"""
v?
(?:
(?:(?P<epoch>[0-9]+)!)? # epoch
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
(?P<pre> # pre-release
[-_\.]?
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
[-_\.]?
(?P<pre_n>[0-9]+)?
)?
(?P<post> # post release
(?:-(?P<post_n1>[0-9]+))
|
(?:
[-_\.]?
(?P<post_l>post|rev|r)
[-_\.]?
(?P<post_n2>[0-9]+)?
)
)?
(?P<dev> # dev release
[-_\.]?
(?P<dev_l>dev)
[-_\.]?
(?P<dev_n>[0-9]+)?
)?
)
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
"""
class Version(_BaseVersion):
_regex = re.compile(
r"^\s*" + VERSION_PATTERN + r"\s*$",
re.VERBOSE | re.IGNORECASE,
)
def __init__(self, version):
# Validate the version and parse it into pieces
match = self._regex.search(version)
if not match:
raise InvalidVersion("Invalid version: '{0}'".format(version))
# Store the parsed out pieces of the version
self._version = _Version(
epoch=int(match.group("epoch")) if match.group("epoch") else 0,
release=tuple(int(i) for i in match.group("release").split(".")),
pre=_parse_letter_version(
match.group("pre_l"),
match.group("pre_n"),
),
post=_parse_letter_version(
match.group("post_l"),
match.group("post_n1") or match.group("post_n2"),
),
dev=_parse_letter_version(
match.group("dev_l"),
match.group("dev_n"),
),
local=_parse_local_version(match.group("local")),
)
# Generate a key which will be used for sorting
self._key = _cmpkey(
self._version.epoch,
self._version.release,
self._version.pre,
self._version.post,
self._version.dev,
self._version.local,
)
def __repr__(self):
return "<Version({0})>".format(repr(str(self)))
def __str__(self):
parts = []
# Epoch
if self._version.epoch != 0:
parts.append("{0}!".format(self._version.epoch))
# Release segment
parts.append(".".join(str(x) for x in self._version.release))
# Pre-release
if self._version.pre is not None:
parts.append("".join(str(x) for x in self._version.pre))
# Post-release
if self._version.post is not None:
parts.append(".post{0}".format(self._version.post[1]))
# Development release
if self._version.dev is not None:
parts.append(".dev{0}".format(self._version.dev[1]))
# Local version segment
if self._version.local is not None:
parts.append(
"+{0}".format(".".join(str(x) for x in self._version.local))
)
return "".join(parts)
@property
def public(self):
return str(self).split("+", 1)[0]
@property
def base_version(self):
parts = []
# Epoch
if self._version.epoch != 0:
parts.append("{0}!".format(self._version.epoch))
# Release segment
parts.append(".".join(str(x) for x in self._version.release))
return "".join(parts)
@property
def local(self):
version_string = str(self)
if "+" in version_string:
return version_string.split("+", 1)[1]
@property
def is_prerelease(self):
return bool(self._version.dev or self._version.pre)
@property
def is_postrelease(self):
return bool(self._version.post)
def _parse_letter_version(letter, number):
if letter:
# We consider there to be an implicit 0 in a pre-release if there is
# not a numeral associated with it.
if number is None:
number = 0
# We normalize any letters to their lower case form
letter = letter.lower()
# We consider some words to be alternate spellings of other words and
# in those cases we want to normalize the spellings to our preferred
# spelling.
if letter == "alpha":
letter = "a"
elif letter == "beta":
letter = "b"
elif letter in ["c", "pre", "preview"]:
letter = "rc"
return letter, int(number)
if not letter and number:
# We assume if we are given a number, but we are not given a letter
# then this is using the implicit post release syntax (e.g. 1.0-1)
letter = "post"
return letter, int(number)
_local_version_seperators = re.compile(r"[\._-]")
def _parse_local_version(local):
"""
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
"""
if local is not None:
return tuple(
part.lower() if not part.isdigit() else int(part)
for part in _local_version_seperators.split(local)
)
def _cmpkey(epoch, release, pre, post, dev, local):
# When we compare a release version, we want to compare it with all of the
# trailing zeros removed. So we'll use a reverse the list, drop all the now
# leading zeros until we come to something non zero, then take the rest
# re-reverse it back into the correct order and make it a tuple and use
# that for our sorting key.
release = tuple(
reversed(list(
itertools.dropwhile(
lambda x: x == 0,
reversed(release),
)
))
)
# We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
# We'll do this by abusing the pre segment, but we _only_ want to do this
# if there is not a pre or a post segment. If we have one of those then
# the normal sorting rules will handle this case correctly.
if pre is None and post is None and dev is not None:
pre = -Infinity
# Versions without a pre-release (except as noted above) should sort after
# those with one.
elif pre is None:
pre = Infinity
# Versions without a post segment should sort before those with one.
if post is None:
post = -Infinity
# Versions without a development segment should sort after those with one.
if dev is None:
dev = Infinity
if local is None:
# Versions without a local segment should sort before those with one.
local = -Infinity
else:
# Versions with a local segment need that segment parsed to implement
# the sorting rules in PEP440.
# - Alpha numeric segments sort before numeric segments
# - Alpha numeric segments sort lexicographically
# - Numeric segments sort numerically
# - Shorter versions sort before longer versions when the prefixes
# match exactly
local = tuple(
(i, "") if isinstance(i, int) else (-Infinity, i)
for i in local
)
return epoch, release, pre, post, dev, local

View File

@ -0,0 +1 @@
packaging==15.0

View File

@ -0,0 +1,419 @@
Pluggable Distributions of Python Software
==========================================
Distributions
-------------
A "Distribution" is a collection of files that represent a "Release" of a
"Project" as of a particular point in time, denoted by a
"Version"::
>>> import sys, pkg_resources
>>> from pkg_resources import Distribution
>>> Distribution(project_name="Foo", version="1.2")
Foo 1.2
Distributions have a location, which can be a filename, URL, or really anything
else you care to use::
>>> dist = Distribution(
... location="http://example.com/something",
... project_name="Bar", version="0.9"
... )
>>> dist
Bar 0.9 (http://example.com/something)
Distributions have various introspectable attributes::
>>> dist.location
'http://example.com/something'
>>> dist.project_name
'Bar'
>>> dist.version
'0.9'
>>> dist.py_version == sys.version[:3]
True
>>> print(dist.platform)
None
Including various computed attributes::
>>> from pkg_resources import parse_version
>>> dist.parsed_version == parse_version(dist.version)
True
>>> dist.key # case-insensitive form of the project name
'bar'
Distributions are compared (and hashed) by version first::
>>> Distribution(version='1.0') == Distribution(version='1.0')
True
>>> Distribution(version='1.0') == Distribution(version='1.1')
False
>>> Distribution(version='1.0') < Distribution(version='1.1')
True
but also by project name (case-insensitive), platform, Python version,
location, etc.::
>>> Distribution(project_name="Foo",version="1.0") == \
... Distribution(project_name="Foo",version="1.0")
True
>>> Distribution(project_name="Foo",version="1.0") == \
... Distribution(project_name="foo",version="1.0")
True
>>> Distribution(project_name="Foo",version="1.0") == \
... Distribution(project_name="Foo",version="1.1")
False
>>> Distribution(project_name="Foo",py_version="2.3",version="1.0") == \
... Distribution(project_name="Foo",py_version="2.4",version="1.0")
False
>>> Distribution(location="spam",version="1.0") == \
... Distribution(location="spam",version="1.0")
True
>>> Distribution(location="spam",version="1.0") == \
... Distribution(location="baz",version="1.0")
False
Hash and compare distribution by prio/plat
Get version from metadata
provider capabilities
egg_name()
as_requirement()
from_location, from_filename (w/path normalization)
Releases may have zero or more "Requirements", which indicate
what releases of another project the release requires in order to
function. A Requirement names the other project, expresses some criteria
as to what releases of that project are acceptable, and lists any "Extras"
that the requiring release may need from that project. (An Extra is an
optional feature of a Release, that can only be used if its additional
Requirements are satisfied.)
The Working Set
---------------
A collection of active distributions is called a Working Set. Note that a
Working Set can contain any importable distribution, not just pluggable ones.
For example, the Python standard library is an importable distribution that
will usually be part of the Working Set, even though it is not pluggable.
Similarly, when you are doing development work on a project, the files you are
editing are also a Distribution. (And, with a little attention to the
directory names used, and including some additional metadata, such a
"development distribution" can be made pluggable as well.)
>>> from pkg_resources import WorkingSet
A working set's entries are the sys.path entries that correspond to the active
distributions. By default, the working set's entries are the items on
``sys.path``::
>>> ws = WorkingSet()
>>> ws.entries == sys.path
True
But you can also create an empty working set explicitly, and add distributions
to it::
>>> ws = WorkingSet([])
>>> ws.add(dist)
>>> ws.entries
['http://example.com/something']
>>> dist in ws
True
>>> Distribution('foo',version="") in ws
False
And you can iterate over its distributions::
>>> list(ws)
[Bar 0.9 (http://example.com/something)]
Adding the same distribution more than once is a no-op::
>>> ws.add(dist)
>>> list(ws)
[Bar 0.9 (http://example.com/something)]
For that matter, adding multiple distributions for the same project also does
nothing, because a working set can only hold one active distribution per
project -- the first one added to it::
>>> ws.add(
... Distribution(
... 'http://example.com/something', project_name="Bar",
... version="7.2"
... )
... )
>>> list(ws)
[Bar 0.9 (http://example.com/something)]
You can append a path entry to a working set using ``add_entry()``::
>>> ws.entries
['http://example.com/something']
>>> ws.add_entry(pkg_resources.__file__)
>>> ws.entries
['http://example.com/something', '...pkg_resources...']
Multiple additions result in multiple entries, even if the entry is already in
the working set (because ``sys.path`` can contain the same entry more than
once)::
>>> ws.add_entry(pkg_resources.__file__)
>>> ws.entries
['...example.com...', '...pkg_resources...', '...pkg_resources...']
And you can specify the path entry a distribution was found under, using the
optional second parameter to ``add()``::
>>> ws = WorkingSet([])
>>> ws.add(dist,"foo")
>>> ws.entries
['foo']
But even if a distribution is found under multiple path entries, it still only
shows up once when iterating the working set:
>>> ws.add_entry(ws.entries[0])
>>> list(ws)
[Bar 0.9 (http://example.com/something)]
You can ask a WorkingSet to ``find()`` a distribution matching a requirement::
>>> from pkg_resources import Requirement
>>> print(ws.find(Requirement.parse("Foo==1.0"))) # no match, return None
None
>>> ws.find(Requirement.parse("Bar==0.9")) # match, return distribution
Bar 0.9 (http://example.com/something)
Note that asking for a conflicting version of a distribution already in a
working set triggers a ``pkg_resources.VersionConflict`` error:
>>> try:
... ws.find(Requirement.parse("Bar==1.0"))
... except pkg_resources.VersionConflict as exc:
... print(str(exc))
... else:
... raise AssertionError("VersionConflict was not raised")
(Bar 0.9 (http://example.com/something), Requirement.parse('Bar==1.0'))
You can subscribe a callback function to receive notifications whenever a new
distribution is added to a working set. The callback is immediately invoked
once for each existing distribution in the working set, and then is called
again for new distributions added thereafter::
>>> def added(dist): print("Added %s" % dist)
>>> ws.subscribe(added)
Added Bar 0.9
>>> foo12 = Distribution(project_name="Foo", version="1.2", location="f12")
>>> ws.add(foo12)
Added Foo 1.2
Note, however, that only the first distribution added for a given project name
will trigger a callback, even during the initial ``subscribe()`` callback::
>>> foo14 = Distribution(project_name="Foo", version="1.4", location="f14")
>>> ws.add(foo14) # no callback, because Foo 1.2 is already active
>>> ws = WorkingSet([])
>>> ws.add(foo12)
>>> ws.add(foo14)
>>> ws.subscribe(added)
Added Foo 1.2
And adding a callback more than once has no effect, either::
>>> ws.subscribe(added) # no callbacks
# and no double-callbacks on subsequent additions, either
>>> just_a_test = Distribution(project_name="JustATest", version="0.99")
>>> ws.add(just_a_test)
Added JustATest 0.99
Finding Plugins
---------------
``WorkingSet`` objects can be used to figure out what plugins in an
``Environment`` can be loaded without any resolution errors::
>>> from pkg_resources import Environment
>>> plugins = Environment([]) # normally, a list of plugin directories
>>> plugins.add(foo12)
>>> plugins.add(foo14)
>>> plugins.add(just_a_test)
In the simplest case, we just get the newest version of each distribution in
the plugin environment::
>>> ws = WorkingSet([])
>>> ws.find_plugins(plugins)
([JustATest 0.99, Foo 1.4 (f14)], {})
But if there's a problem with a version conflict or missing requirements, the
method falls back to older versions, and the error info dict will contain an
exception instance for each unloadable plugin::
>>> ws.add(foo12) # this will conflict with Foo 1.4
>>> ws.find_plugins(plugins)
([JustATest 0.99, Foo 1.2 (f12)], {Foo 1.4 (f14): VersionConflict(...)})
But if you disallow fallbacks, the failed plugin will be skipped instead of
trying older versions::
>>> ws.find_plugins(plugins, fallback=False)
([JustATest 0.99], {Foo 1.4 (f14): VersionConflict(...)})
Platform Compatibility Rules
----------------------------
On the Mac, there are potential compatibility issues for modules compiled
on newer versions of Mac OS X than what the user is running. Additionally,
Mac OS X will soon have two platforms to contend with: Intel and PowerPC.
Basic equality works as on other platforms::
>>> from pkg_resources import compatible_platforms as cp
>>> reqd = 'macosx-10.4-ppc'
>>> cp(reqd, reqd)
True
>>> cp("win32", reqd)
False
Distributions made on other machine types are not compatible::
>>> cp("macosx-10.4-i386", reqd)
False
Distributions made on earlier versions of the OS are compatible, as
long as they are from the same top-level version. The patchlevel version
number does not matter::
>>> cp("macosx-10.4-ppc", reqd)
True
>>> cp("macosx-10.3-ppc", reqd)
True
>>> cp("macosx-10.5-ppc", reqd)
False
>>> cp("macosx-9.5-ppc", reqd)
False
Backwards compatibility for packages made via earlier versions of
setuptools is provided as well::
>>> cp("darwin-8.2.0-Power_Macintosh", reqd)
True
>>> cp("darwin-7.2.0-Power_Macintosh", reqd)
True
>>> cp("darwin-8.2.0-Power_Macintosh", "macosx-10.3-ppc")
False
Environment Markers
-------------------
>>> from pkg_resources import invalid_marker as im, evaluate_marker as em
>>> import os
>>> print(im("sys_platform"))
Comparison or logical expression expected
>>> print(im("sys_platform=="))
invalid syntax
>>> print(im("sys_platform=='win32'"))
False
>>> print(im("sys=='x'"))
Unknown name 'sys'
>>> print(im("(extra)"))
Comparison or logical expression expected
>>> print(im("(extra"))
invalid syntax
>>> print(im("os.open('foo')=='y'"))
Language feature not supported in environment markers
>>> print(im("'x'=='y' and os.open('foo')=='y'")) # no short-circuit!
Language feature not supported in environment markers
>>> print(im("'x'=='x' or os.open('foo')=='y'")) # no short-circuit!
Language feature not supported in environment markers
>>> print(im("'x' < 'y'"))
'<' operator not allowed in environment markers
>>> print(im("'x' < 'y' < 'z'"))
Chained comparison not allowed in environment markers
>>> print(im("r'x'=='x'"))
Only plain strings allowed in environment markers
>>> print(im("'''x'''=='x'"))
Only plain strings allowed in environment markers
>>> print(im('"""x"""=="x"'))
Only plain strings allowed in environment markers
>>> print(im(r"'x\n'=='x'"))
Only plain strings allowed in environment markers
>>> print(im("os.open=='y'"))
Language feature not supported in environment markers
>>> em('"x"=="x"')
True
>>> em('"x"=="y"')
False
>>> em('"x"=="y" and "x"=="x"')
False
>>> em('"x"=="y" or "x"=="x"')
True
>>> em('"x"=="y" and "x"=="q" or "z"=="z"')
True
>>> em('"x"=="y" and ("x"=="q" or "z"=="z")')
False
>>> em('"x"=="y" and "z"=="z" or "x"=="q"')
False
>>> em('"x"=="x" and "z"=="z" or "x"=="q"')
True
>>> em("sys_platform=='win32'") == (sys.platform=='win32')
True
>>> em("'x' in 'yx'")
True
>>> em("'yx' in 'x'")
False

View File

@ -0,0 +1,111 @@
import sys
import tempfile
import os
import zipfile
import datetime
import time
import subprocess
import pkg_resources
try:
unicode
except NameError:
unicode = str
def timestamp(dt):
"""
Return a timestamp for a local, naive datetime instance.
"""
try:
return dt.timestamp()
except AttributeError:
# Python 3.2 and earlier
return time.mktime(dt.timetuple())
class EggRemover(unicode):
def __call__(self):
if self in sys.path:
sys.path.remove(self)
if os.path.exists(self):
os.remove(self)
class TestZipProvider(object):
finalizers = []
ref_time = datetime.datetime(2013, 5, 12, 13, 25, 0)
"A reference time for a file modification"
@classmethod
def setup_class(cls):
"create a zip egg and add it to sys.path"
egg = tempfile.NamedTemporaryFile(suffix='.egg', delete=False)
zip_egg = zipfile.ZipFile(egg, 'w')
zip_info = zipfile.ZipInfo()
zip_info.filename = 'mod.py'
zip_info.date_time = cls.ref_time.timetuple()
zip_egg.writestr(zip_info, 'x = 3\n')
zip_info = zipfile.ZipInfo()
zip_info.filename = 'data.dat'
zip_info.date_time = cls.ref_time.timetuple()
zip_egg.writestr(zip_info, 'hello, world!')
zip_egg.close()
egg.close()
sys.path.append(egg.name)
cls.finalizers.append(EggRemover(egg.name))
@classmethod
def teardown_class(cls):
for finalizer in cls.finalizers:
finalizer()
def test_resource_filename_rewrites_on_change(self):
"""
If a previous call to get_resource_filename has saved the file, but
the file has been subsequently mutated with different file of the
same size and modification time, it should not be overwritten on a
subsequent call to get_resource_filename.
"""
import mod
manager = pkg_resources.ResourceManager()
zp = pkg_resources.ZipProvider(mod)
filename = zp.get_resource_filename(manager, 'data.dat')
actual = datetime.datetime.fromtimestamp(os.stat(filename).st_mtime)
assert actual == self.ref_time
f = open(filename, 'w')
f.write('hello, world?')
f.close()
ts = timestamp(self.ref_time)
os.utime(filename, (ts, ts))
filename = zp.get_resource_filename(manager, 'data.dat')
f = open(filename)
assert f.read() == 'hello, world!'
manager.cleanup_resources()
class TestResourceManager(object):
def test_get_cache_path(self):
mgr = pkg_resources.ResourceManager()
path = mgr.get_cache_path('foo')
type_ = str(type(path))
message = "Unexpected type from get_cache_path: " + type_
assert isinstance(path, (unicode, str)), message
class TestIndependence:
"""
Tests to ensure that pkg_resources runs independently from setuptools.
"""
def test_setuptools_not_imported(self):
"""
In a separate Python environment, import pkg_resources and assert
that action doesn't cause setuptools to be imported.
"""
lines = (
'import pkg_resources',
'import sys',
'assert "setuptools" not in sys.modules, '
'"setuptools was imported"',
)
cmd = [sys.executable, '-c', '; '.join(lines)]
subprocess.check_call(cmd)

View File

@ -0,0 +1,661 @@
import os
import sys
import tempfile
import shutil
import string
import pytest
import pkg_resources
from pkg_resources import (parse_requirements, VersionConflict, parse_version,
Distribution, EntryPoint, Requirement, safe_version, safe_name,
WorkingSet)
packaging = pkg_resources.packaging
def safe_repr(obj, short=False):
""" copied from Python2.7"""
try:
result = repr(obj)
except Exception:
result = object.__repr__(obj)
if not short or len(result) < pkg_resources._MAX_LENGTH:
return result
return result[:pkg_resources._MAX_LENGTH] + ' [truncated]...'
class Metadata(pkg_resources.EmptyProvider):
"""Mock object to return metadata as if from an on-disk distribution"""
def __init__(self, *pairs):
self.metadata = dict(pairs)
def has_metadata(self, name):
return name in self.metadata
def get_metadata(self, name):
return self.metadata[name]
def get_metadata_lines(self, name):
return pkg_resources.yield_lines(self.get_metadata(name))
dist_from_fn = pkg_resources.Distribution.from_filename
class TestDistro:
def testCollection(self):
# empty path should produce no distributions
ad = pkg_resources.Environment([], platform=None, python=None)
assert list(ad) == []
assert ad['FooPkg'] == []
ad.add(dist_from_fn("FooPkg-1.3_1.egg"))
ad.add(dist_from_fn("FooPkg-1.4-py2.4-win32.egg"))
ad.add(dist_from_fn("FooPkg-1.2-py2.4.egg"))
# Name is in there now
assert ad['FooPkg']
# But only 1 package
assert list(ad) == ['foopkg']
# Distributions sort by version
assert [dist.version for dist in ad['FooPkg']] == ['1.4','1.3-1','1.2']
# Removing a distribution leaves sequence alone
ad.remove(ad['FooPkg'][1])
assert [dist.version for dist in ad['FooPkg']] == ['1.4','1.2']
# And inserting adds them in order
ad.add(dist_from_fn("FooPkg-1.9.egg"))
assert [dist.version for dist in ad['FooPkg']] == ['1.9','1.4','1.2']
ws = WorkingSet([])
foo12 = dist_from_fn("FooPkg-1.2-py2.4.egg")
foo14 = dist_from_fn("FooPkg-1.4-py2.4-win32.egg")
req, = parse_requirements("FooPkg>=1.3")
# Nominal case: no distros on path, should yield all applicable
assert ad.best_match(req, ws).version == '1.9'
# If a matching distro is already installed, should return only that
ws.add(foo14)
assert ad.best_match(req, ws).version == '1.4'
# If the first matching distro is unsuitable, it's a version conflict
ws = WorkingSet([])
ws.add(foo12)
ws.add(foo14)
with pytest.raises(VersionConflict):
ad.best_match(req, ws)
# If more than one match on the path, the first one takes precedence
ws = WorkingSet([])
ws.add(foo14)
ws.add(foo12)
ws.add(foo14)
assert ad.best_match(req, ws).version == '1.4'
def checkFooPkg(self,d):
assert d.project_name == "FooPkg"
assert d.key == "foopkg"
assert d.version == "1.3.post1"
assert d.py_version == "2.4"
assert d.platform == "win32"
assert d.parsed_version == parse_version("1.3-1")
def testDistroBasics(self):
d = Distribution(
"/some/path",
project_name="FooPkg",version="1.3-1",py_version="2.4",platform="win32"
)
self.checkFooPkg(d)
d = Distribution("/some/path")
assert d.py_version == sys.version[:3]
assert d.platform == None
def testDistroParse(self):
d = dist_from_fn("FooPkg-1.3.post1-py2.4-win32.egg")
self.checkFooPkg(d)
d = dist_from_fn("FooPkg-1.3.post1-py2.4-win32.egg-info")
self.checkFooPkg(d)
def testDistroMetadata(self):
d = Distribution(
"/some/path", project_name="FooPkg", py_version="2.4", platform="win32",
metadata = Metadata(
('PKG-INFO',"Metadata-Version: 1.0\nVersion: 1.3-1\n")
)
)
self.checkFooPkg(d)
def distRequires(self, txt):
return Distribution("/foo", metadata=Metadata(('depends.txt', txt)))
def checkRequires(self, dist, txt, extras=()):
assert list(dist.requires(extras)) == list(parse_requirements(txt))
def testDistroDependsSimple(self):
for v in "Twisted>=1.5", "Twisted>=1.5\nZConfig>=2.0":
self.checkRequires(self.distRequires(v), v)
def testResolve(self):
ad = pkg_resources.Environment([])
ws = WorkingSet([])
# Resolving no requirements -> nothing to install
assert list(ws.resolve([], ad)) == []
# Request something not in the collection -> DistributionNotFound
with pytest.raises(pkg_resources.DistributionNotFound):
ws.resolve(parse_requirements("Foo"), ad)
Foo = Distribution.from_filename(
"/foo_dir/Foo-1.2.egg",
metadata=Metadata(('depends.txt', "[bar]\nBaz>=2.0"))
)
ad.add(Foo)
ad.add(Distribution.from_filename("Foo-0.9.egg"))
# Request thing(s) that are available -> list to activate
for i in range(3):
targets = list(ws.resolve(parse_requirements("Foo"), ad))
assert targets == [Foo]
list(map(ws.add,targets))
with pytest.raises(VersionConflict):
ws.resolve(parse_requirements("Foo==0.9"), ad)
ws = WorkingSet([]) # reset
# Request an extra that causes an unresolved dependency for "Baz"
with pytest.raises(pkg_resources.DistributionNotFound):
ws.resolve(parse_requirements("Foo[bar]"), ad)
Baz = Distribution.from_filename(
"/foo_dir/Baz-2.1.egg", metadata=Metadata(('depends.txt', "Foo"))
)
ad.add(Baz)
# Activation list now includes resolved dependency
assert list(ws.resolve(parse_requirements("Foo[bar]"), ad)) ==[Foo,Baz]
# Requests for conflicting versions produce VersionConflict
with pytest.raises(VersionConflict) as vc:
ws.resolve(parse_requirements("Foo==1.2\nFoo!=1.2"), ad)
msg = 'Foo 0.9 is installed but Foo==1.2 is required'
assert vc.value.report() == msg
def testDistroDependsOptions(self):
d = self.distRequires("""
Twisted>=1.5
[docgen]
ZConfig>=2.0
docutils>=0.3
[fastcgi]
fcgiapp>=0.1""")
self.checkRequires(d,"Twisted>=1.5")
self.checkRequires(
d,"Twisted>=1.5 ZConfig>=2.0 docutils>=0.3".split(), ["docgen"]
)
self.checkRequires(
d,"Twisted>=1.5 fcgiapp>=0.1".split(), ["fastcgi"]
)
self.checkRequires(
d,"Twisted>=1.5 ZConfig>=2.0 docutils>=0.3 fcgiapp>=0.1".split(),
["docgen","fastcgi"]
)
self.checkRequires(
d,"Twisted>=1.5 fcgiapp>=0.1 ZConfig>=2.0 docutils>=0.3".split(),
["fastcgi", "docgen"]
)
with pytest.raises(pkg_resources.UnknownExtra):
d.requires(["foo"])
class TestWorkingSet:
def test_find_conflicting(self):
ws = WorkingSet([])
Foo = Distribution.from_filename("/foo_dir/Foo-1.2.egg")
ws.add(Foo)
# create a requirement that conflicts with Foo 1.2
req = next(parse_requirements("Foo<1.2"))
with pytest.raises(VersionConflict) as vc:
ws.find(req)
msg = 'Foo 1.2 is installed but Foo<1.2 is required'
assert vc.value.report() == msg
def test_resolve_conflicts_with_prior(self):
"""
A ContextualVersionConflict should be raised when a requirement
conflicts with a prior requirement for a different package.
"""
# Create installation where Foo depends on Baz 1.0 and Bar depends on
# Baz 2.0.
ws = WorkingSet([])
md = Metadata(('depends.txt', "Baz==1.0"))
Foo = Distribution.from_filename("/foo_dir/Foo-1.0.egg", metadata=md)
ws.add(Foo)
md = Metadata(('depends.txt', "Baz==2.0"))
Bar = Distribution.from_filename("/foo_dir/Bar-1.0.egg", metadata=md)
ws.add(Bar)
Baz = Distribution.from_filename("/foo_dir/Baz-1.0.egg")
ws.add(Baz)
Baz = Distribution.from_filename("/foo_dir/Baz-2.0.egg")
ws.add(Baz)
with pytest.raises(VersionConflict) as vc:
ws.resolve(parse_requirements("Foo\nBar\n"))
msg = "Baz 1.0 is installed but Baz==2.0 is required by {'Bar'}"
if pkg_resources.PY2:
msg = msg.replace("{'Bar'}", "set(['Bar'])")
assert vc.value.report() == msg
class TestEntryPoints:
def assertfields(self, ep):
assert ep.name == "foo"
assert ep.module_name == "pkg_resources.tests.test_resources"
assert ep.attrs == ("TestEntryPoints",)
assert ep.extras == ("x",)
assert ep.load() is TestEntryPoints
expect = "foo = pkg_resources.tests.test_resources:TestEntryPoints [x]"
assert str(ep) == expect
def setup_method(self, method):
self.dist = Distribution.from_filename(
"FooPkg-1.2-py2.4.egg", metadata=Metadata(('requires.txt','[x]')))
def testBasics(self):
ep = EntryPoint(
"foo", "pkg_resources.tests.test_resources", ["TestEntryPoints"],
["x"], self.dist
)
self.assertfields(ep)
def testParse(self):
s = "foo = pkg_resources.tests.test_resources:TestEntryPoints [x]"
ep = EntryPoint.parse(s, self.dist)
self.assertfields(ep)
ep = EntryPoint.parse("bar baz= spammity[PING]")
assert ep.name == "bar baz"
assert ep.module_name == "spammity"
assert ep.attrs == ()
assert ep.extras == ("ping",)
ep = EntryPoint.parse(" fizzly = wocka:foo")
assert ep.name == "fizzly"
assert ep.module_name == "wocka"
assert ep.attrs == ("foo",)
assert ep.extras == ()
# plus in the name
spec = "html+mako = mako.ext.pygmentplugin:MakoHtmlLexer"
ep = EntryPoint.parse(spec)
assert ep.name == 'html+mako'
reject_specs = "foo", "x=a:b:c", "q=x/na", "fez=pish:tush-z", "x=f[a]>2"
@pytest.mark.parametrize("reject_spec", reject_specs)
def test_reject_spec(self, reject_spec):
with pytest.raises(ValueError):
EntryPoint.parse(reject_spec)
def test_printable_name(self):
"""
Allow any printable character in the name.
"""
# Create a name with all printable characters; strip the whitespace.
name = string.printable.strip()
spec = "{name} = module:attr".format(**locals())
ep = EntryPoint.parse(spec)
assert ep.name == name
def checkSubMap(self, m):
assert len(m) == len(self.submap_expect)
for key, ep in pkg_resources.iteritems(self.submap_expect):
assert repr(m.get(key)) == repr(ep)
submap_expect = dict(
feature1=EntryPoint('feature1', 'somemodule', ['somefunction']),
feature2=EntryPoint('feature2', 'another.module', ['SomeClass'], ['extra1','extra2']),
feature3=EntryPoint('feature3', 'this.module', extras=['something'])
)
submap_str = """
# define features for blah blah
feature1 = somemodule:somefunction
feature2 = another.module:SomeClass [extra1,extra2]
feature3 = this.module [something]
"""
def testParseList(self):
self.checkSubMap(EntryPoint.parse_group("xyz", self.submap_str))
with pytest.raises(ValueError):
EntryPoint.parse_group("x a", "foo=bar")
with pytest.raises(ValueError):
EntryPoint.parse_group("x", ["foo=baz", "foo=bar"])
def testParseMap(self):
m = EntryPoint.parse_map({'xyz':self.submap_str})
self.checkSubMap(m['xyz'])
assert list(m.keys()) == ['xyz']
m = EntryPoint.parse_map("[xyz]\n"+self.submap_str)
self.checkSubMap(m['xyz'])
assert list(m.keys()) == ['xyz']
with pytest.raises(ValueError):
EntryPoint.parse_map(["[xyz]", "[xyz]"])
with pytest.raises(ValueError):
EntryPoint.parse_map(self.submap_str)
class TestRequirements:
def testBasics(self):
r = Requirement.parse("Twisted>=1.2")
assert str(r) == "Twisted>=1.2"
assert repr(r) == "Requirement.parse('Twisted>=1.2')"
assert r == Requirement("Twisted", [('>=','1.2')], ())
assert r == Requirement("twisTed", [('>=','1.2')], ())
assert r != Requirement("Twisted", [('>=','2.0')], ())
assert r != Requirement("Zope", [('>=','1.2')], ())
assert r != Requirement("Zope", [('>=','3.0')], ())
assert r != Requirement.parse("Twisted[extras]>=1.2")
def testOrdering(self):
r1 = Requirement("Twisted", [('==','1.2c1'),('>=','1.2')], ())
r2 = Requirement("Twisted", [('>=','1.2'),('==','1.2c1')], ())
assert r1 == r2
assert str(r1) == str(r2)
assert str(r2) == "Twisted==1.2c1,>=1.2"
def testBasicContains(self):
r = Requirement("Twisted", [('>=','1.2')], ())
foo_dist = Distribution.from_filename("FooPkg-1.3_1.egg")
twist11 = Distribution.from_filename("Twisted-1.1.egg")
twist12 = Distribution.from_filename("Twisted-1.2.egg")
assert parse_version('1.2') in r
assert parse_version('1.1') not in r
assert '1.2' in r
assert '1.1' not in r
assert foo_dist not in r
assert twist11 not in r
assert twist12 in r
def testOptionsAndHashing(self):
r1 = Requirement.parse("Twisted[foo,bar]>=1.2")
r2 = Requirement.parse("Twisted[bar,FOO]>=1.2")
assert r1 == r2
assert r1.extras == ("foo","bar")
assert r2.extras == ("bar","foo") # extras are normalized
assert hash(r1) == hash(r2)
assert (
hash(r1)
==
hash((
"twisted",
packaging.specifiers.SpecifierSet(">=1.2"),
frozenset(["foo","bar"]),
))
)
def testVersionEquality(self):
r1 = Requirement.parse("foo==0.3a2")
r2 = Requirement.parse("foo!=0.3a4")
d = Distribution.from_filename
assert d("foo-0.3a4.egg") not in r1
assert d("foo-0.3a1.egg") not in r1
assert d("foo-0.3a4.egg") not in r2
assert d("foo-0.3a2.egg") in r1
assert d("foo-0.3a2.egg") in r2
assert d("foo-0.3a3.egg") in r2
assert d("foo-0.3a5.egg") in r2
def testSetuptoolsProjectName(self):
"""
The setuptools project should implement the setuptools package.
"""
assert (
Requirement.parse('setuptools').project_name == 'setuptools')
# setuptools 0.7 and higher means setuptools.
assert (
Requirement.parse('setuptools == 0.7').project_name == 'setuptools')
assert (
Requirement.parse('setuptools == 0.7a1').project_name == 'setuptools')
assert (
Requirement.parse('setuptools >= 0.7').project_name == 'setuptools')
class TestParsing:
def testEmptyParse(self):
assert list(parse_requirements('')) == []
def testYielding(self):
for inp,out in [
([], []), ('x',['x']), ([[]],[]), (' x\n y', ['x','y']),
(['x\n\n','y'], ['x','y']),
]:
assert list(pkg_resources.yield_lines(inp)) == out
def testSplitting(self):
sample = """
x
[Y]
z
a
[b ]
# foo
c
[ d]
[q]
v
"""
assert (
list(pkg_resources.split_sections(sample))
==
[
(None, ["x"]),
("Y", ["z", "a"]),
("b", ["c"]),
("d", []),
("q", ["v"]),
]
)
with pytest.raises(ValueError):
list(pkg_resources.split_sections("[foo"))
def testSafeName(self):
assert safe_name("adns-python") == "adns-python"
assert safe_name("WSGI Utils") == "WSGI-Utils"
assert safe_name("WSGI Utils") == "WSGI-Utils"
assert safe_name("Money$$$Maker") == "Money-Maker"
assert safe_name("peak.web") != "peak-web"
def testSafeVersion(self):
assert safe_version("1.2-1") == "1.2.post1"
assert safe_version("1.2 alpha") == "1.2.alpha"
assert safe_version("2.3.4 20050521") == "2.3.4.20050521"
assert safe_version("Money$$$Maker") == "Money-Maker"
assert safe_version("peak.web") == "peak.web"
def testSimpleRequirements(self):
assert (
list(parse_requirements('Twis-Ted>=1.2-1'))
==
[Requirement('Twis-Ted',[('>=','1.2-1')], ())]
)
assert (
list(parse_requirements('Twisted >=1.2, \ # more\n<2.0'))
==
[Requirement('Twisted',[('>=','1.2'),('<','2.0')], ())]
)
assert (
Requirement.parse("FooBar==1.99a3")
==
Requirement("FooBar", [('==','1.99a3')], ())
)
with pytest.raises(ValueError):
Requirement.parse(">=2.3")
with pytest.raises(ValueError):
Requirement.parse("x\\")
with pytest.raises(ValueError):
Requirement.parse("x==2 q")
with pytest.raises(ValueError):
Requirement.parse("X==1\nY==2")
with pytest.raises(ValueError):
Requirement.parse("#")
def testVersionEquality(self):
def c(s1,s2):
p1, p2 = parse_version(s1),parse_version(s2)
assert p1 == p2, (s1,s2,p1,p2)
c('1.2-rc1', '1.2rc1')
c('0.4', '0.4.0')
c('0.4.0.0', '0.4.0')
c('0.4.0-0', '0.4-0')
c('0post1', '0.0post1')
c('0pre1', '0.0c1')
c('0.0.0preview1', '0c1')
c('0.0c1', '0-rc1')
c('1.2a1', '1.2.a.1')
c('1.2.a', '1.2a')
def testVersionOrdering(self):
def c(s1,s2):
p1, p2 = parse_version(s1),parse_version(s2)
assert p1<p2, (s1,s2,p1,p2)
c('2.1','2.1.1')
c('2a1','2b0')
c('2a1','2.1')
c('2.3a1', '2.3')
c('2.1-1', '2.1-2')
c('2.1-1', '2.1.1')
c('2.1', '2.1post4')
c('2.1a0-20040501', '2.1')
c('1.1', '02.1')
c('3.2', '3.2.post0')
c('3.2post1', '3.2post2')
c('0.4', '4.0')
c('0.0.4', '0.4.0')
c('0post1', '0.4post1')
c('2.1.0-rc1','2.1.0')
c('2.1dev','2.1a0')
torture ="""
0.80.1-3 0.80.1-2 0.80.1-1 0.79.9999+0.80.0pre4-1
0.79.9999+0.80.0pre2-3 0.79.9999+0.80.0pre2-2
0.77.2-1 0.77.1-1 0.77.0-1
""".split()
for p,v1 in enumerate(torture):
for v2 in torture[p+1:]:
c(v2,v1)
def testVersionBuildout(self):
"""
Buildout has a function in it's bootstrap.py that inspected the return
value of parse_version. The new parse_version returns a Version class
which needs to support this behavior, at least for now.
"""
def buildout(parsed_version):
_final_parts = '*final-', '*final'
def _final_version(parsed_version):
for part in parsed_version:
if (part[:1] == '*') and (part not in _final_parts):
return False
return True
return _final_version(parsed_version)
assert buildout(parse_version("1.0"))
assert not buildout(parse_version("1.0a1"))
def testVersionIndexable(self):
"""
Some projects were doing things like parse_version("v")[0], so we'll
support indexing the same as we support iterating.
"""
assert parse_version("1.0")[0] == "00000001"
def testVersionTupleSort(self):
"""
Some projects expected to be able to sort tuples against the return
value of parse_version. So again we'll add a warning enabled shim to
make this possible.
"""
assert parse_version("1.0") < tuple(parse_version("2.0"))
assert parse_version("1.0") <= tuple(parse_version("2.0"))
assert parse_version("1.0") == tuple(parse_version("1.0"))
assert parse_version("3.0") > tuple(parse_version("2.0"))
assert parse_version("3.0") >= tuple(parse_version("2.0"))
assert parse_version("3.0") != tuple(parse_version("2.0"))
assert not (parse_version("3.0") != tuple(parse_version("3.0")))
def testVersionHashable(self):
"""
Ensure that our versions stay hashable even though we've subclassed
them and added some shim code to them.
"""
assert (
hash(parse_version("1.0"))
==
hash(parse_version("1.0"))
)
class TestNamespaces:
def setup_method(self, method):
self._ns_pkgs = pkg_resources._namespace_packages.copy()
self._tmpdir = tempfile.mkdtemp(prefix="tests-setuptools-")
os.makedirs(os.path.join(self._tmpdir, "site-pkgs"))
self._prev_sys_path = sys.path[:]
sys.path.append(os.path.join(self._tmpdir, "site-pkgs"))
def teardown_method(self, method):
shutil.rmtree(self._tmpdir)
pkg_resources._namespace_packages = self._ns_pkgs.copy()
sys.path = self._prev_sys_path[:]
@pytest.mark.skipif(os.path.islink(tempfile.gettempdir()),
reason="Test fails when /tmp is a symlink. See #231")
def test_two_levels_deep(self):
"""
Test nested namespace packages
Create namespace packages in the following tree :
site-packages-1/pkg1/pkg2
site-packages-2/pkg1/pkg2
Check both are in the _namespace_packages dict and that their __path__
is correct
"""
sys.path.append(os.path.join(self._tmpdir, "site-pkgs2"))
os.makedirs(os.path.join(self._tmpdir, "site-pkgs", "pkg1", "pkg2"))
os.makedirs(os.path.join(self._tmpdir, "site-pkgs2", "pkg1", "pkg2"))
ns_str = "__import__('pkg_resources').declare_namespace(__name__)\n"
for site in ["site-pkgs", "site-pkgs2"]:
pkg1_init = open(os.path.join(self._tmpdir, site,
"pkg1", "__init__.py"), "w")
pkg1_init.write(ns_str)
pkg1_init.close()
pkg2_init = open(os.path.join(self._tmpdir, site,
"pkg1", "pkg2", "__init__.py"), "w")
pkg2_init.write(ns_str)
pkg2_init.close()
import pkg1
assert "pkg1" in pkg_resources._namespace_packages
# attempt to import pkg2 from site-pkgs2
import pkg1.pkg2
# check the _namespace_packages dict
assert "pkg1.pkg2" in pkg_resources._namespace_packages
assert pkg_resources._namespace_packages["pkg1"] == ["pkg1.pkg2"]
# check the __path__ attribute contains both paths
expected = [
os.path.join(self._tmpdir, "site-pkgs", "pkg1", "pkg2"),
os.path.join(self._tmpdir, "site-pkgs2", "pkg1", "pkg2"),
]
assert pkg1.pkg2.__path__ == expected

View File

@ -9,8 +9,8 @@ on how to use these modules.
'''
# The Olson database is updated several times a year.
OLSON_VERSION = '2014d'
VERSION = '2014.4' # Switching to pip compatible version numbering.
OLSON_VERSION = '2014j'
VERSION = '2014.10' # Switching to pip compatible version numbering.
__version__ = VERSION
OLSEN_VERSION = OLSON_VERSION # Old releases had this misspelling
@ -735,6 +735,7 @@ all_timezones = \
'Asia/Bishkek',
'Asia/Brunei',
'Asia/Calcutta',
'Asia/Chita',
'Asia/Choibalsan',
'Asia/Chongqing',
'Asia/Chungking',
@ -792,6 +793,7 @@ all_timezones = \
'Asia/Seoul',
'Asia/Shanghai',
'Asia/Singapore',
'Asia/Srednekolymsk',
'Asia/Taipei',
'Asia/Tashkent',
'Asia/Tbilisi',
@ -1002,6 +1004,7 @@ all_timezones = \
'PST8PDT',
'Pacific/Apia',
'Pacific/Auckland',
'Pacific/Bougainville',
'Pacific/Chatham',
'Pacific/Chuuk',
'Pacific/Easter',
@ -1297,8 +1300,8 @@ common_timezones = \
'Asia/Beirut',
'Asia/Bishkek',
'Asia/Brunei',
'Asia/Chita',
'Asia/Choibalsan',
'Asia/Chongqing',
'Asia/Colombo',
'Asia/Damascus',
'Asia/Dhaka',
@ -1306,7 +1309,6 @@ common_timezones = \
'Asia/Dubai',
'Asia/Dushanbe',
'Asia/Gaza',
'Asia/Harbin',
'Asia/Hebron',
'Asia/Ho_Chi_Minh',
'Asia/Hong_Kong',
@ -1318,7 +1320,6 @@ common_timezones = \
'Asia/Kabul',
'Asia/Kamchatka',
'Asia/Karachi',
'Asia/Kashgar',
'Asia/Kathmandu',
'Asia/Khandyga',
'Asia/Kolkata',
@ -1348,6 +1349,7 @@ common_timezones = \
'Asia/Seoul',
'Asia/Shanghai',
'Asia/Singapore',
'Asia/Srednekolymsk',
'Asia/Taipei',
'Asia/Tashkent',
'Asia/Tbilisi',
@ -1460,6 +1462,7 @@ common_timezones = \
'Indian/Reunion',
'Pacific/Apia',
'Pacific/Auckland',
'Pacific/Bougainville',
'Pacific/Chatham',
'Pacific/Chuuk',
'Pacific/Easter',

View File

@ -142,7 +142,7 @@ class StaticTzInfo(BaseTzInfo):
def __reduce__(self):
# Special pickle to zone remains a singleton and to cope with
# database changes.
# database changes.
return pytz._p, (self.zone,)
@ -369,13 +369,15 @@ class DstTzInfo(BaseTzInfo):
# hints to be passed in (such as the UTC offset or abbreviation),
# but that is just getting silly.
#
# Choose the earliest (by UTC) applicable timezone.
sorting_keys = {}
# Choose the earliest (by UTC) applicable timezone if is_dst=True
# Choose the latest (by UTC) applicable timezone if is_dst=False
# i.e., behave like end-of-DST transition
dates = {} # utc -> local
for local_dt in filtered_possible_loc_dt:
key = local_dt.replace(tzinfo=None) - local_dt.tzinfo._utcoffset
sorting_keys[key] = local_dt
first_key = sorted(sorting_keys)[0]
return sorting_keys[first_key]
utc_time = local_dt.replace(tzinfo=None) - local_dt.tzinfo._utcoffset
assert utc_time not in dates
dates[utc_time] = local_dt
return dates[[min, max][not is_dst](dates)]
def utcoffset(self, dt, is_dst=None):
'''See datetime.tzinfo.utcoffset
@ -560,4 +562,3 @@ def unpickler(zone, utcoffset=None, dstoffset=None, tzname=None):
inf = (utcoffset, dstoffset, tzname)
tz._tzinfos[inf] = tz.__class__(inf, tz._tzinfos)
return tz._tzinfos[inf]

Some files were not shown because too many files have changed in this diff Show More