Merge pull request #48 from matburt/upgrade_vendored_packages

Upgrade vendored packages
This commit is contained in:
Matthew Jones
2015-01-30 11:12:32 -05:00
367 changed files with 26588 additions and 25350 deletions

View File

@@ -110,7 +110,7 @@ push:
# locally downloaded packages). # locally downloaded packages).
requirements: requirements:
@if [ "$(VIRTUAL_ENV)" ]; then \ @if [ "$(VIRTUAL_ENV)" ]; then \
(cd requirements && pip install --no-index setuptools-2.2.tar.gz); \ (cd requirements && pip install --no-index setuptools-12.0.5.tar.gz); \
(cd requirements && pip install --no-index Django-1.6.7.tar.gz); \ (cd requirements && pip install --no-index Django-1.6.7.tar.gz); \
(cd requirements && pip install --no-index -r dev_local.txt); \ (cd requirements && pip install --no-index -r dev_local.txt); \
$(PYTHON) fix_virtualenv_setuptools.py; \ $(PYTHON) fix_virtualenv_setuptools.py; \
@@ -122,7 +122,7 @@ requirements:
# (downloading from PyPI if necessary). # (downloading from PyPI if necessary).
requirements_pypi: requirements_pypi:
@if [ "$(VIRTUAL_ENV)" ]; then \ @if [ "$(VIRTUAL_ENV)" ]; then \
pip install setuptools==2.2; \ pip install setuptools==12.0.5; \
pip install Django\>=1.6.7,\<1.7; \ pip install Django\>=1.6.7,\<1.7; \
pip install -r requirements/dev.txt; \ pip install -r requirements/dev.txt; \
$(PYTHON) fix_virtualenv_setuptools.py; \ $(PYTHON) fix_virtualenv_setuptools.py; \

View File

@@ -5,6 +5,7 @@ amqp==1.4.5 (amqp/*)
ansi2html==1.0.6 (ansi2html/*) ansi2html==1.0.6 (ansi2html/*)
anyjson==0.3.3 (anyjson/*) anyjson==0.3.3 (anyjson/*)
argparse==1.2.1 (argparse.py, needed for Python 2.6 support) argparse==1.2.1 (argparse.py, needed for Python 2.6 support)
azure==0.9.0 (azure/*)
Babel==1.3 (babel/*, excluded bin/pybabel) Babel==1.3 (babel/*, excluded bin/pybabel)
billiard==3.3.0.16 (billiard/*, funtests/*, excluded _billiard.so) billiard==3.3.0.16 (billiard/*, funtests/*, excluded _billiard.so)
boto==2.34.0 (boto/*, excluded bin/asadmin, bin/bundle_image, bin/cfadmin, boto==2.34.0 (boto/*, excluded bin/asadmin, bin/bundle_image, bin/cfadmin,
@@ -30,9 +31,9 @@ gevent-websocket==0.9.3 (geventwebsocket/*)
httplib2==0.9 (httplib2/*) httplib2==0.9 (httplib2/*)
importlib==1.0.3 (importlib/*, needed for Python 2.6 support) importlib==1.0.3 (importlib/*, needed for Python 2.6 support)
iso8601==0.1.10 (iso8601/*) iso8601==0.1.10 (iso8601/*)
keyring==4.0 (keyring/*, excluded bin/keyring) keyring==4.1 (keyring/*, excluded bin/keyring)
kombu==3.0.21 (kombu/*) kombu==3.0.21 (kombu/*)
Markdown==2.4.1 (markdown/*, excluded bin/markdown_py) Markdown==2.5.2 (markdown/*, excluded bin/markdown_py)
mock==1.0.1 (mock.py) mock==1.0.1 (mock.py)
ordereddict==1.1 (ordereddict.py, needed for Python 2.6 support) ordereddict==1.1 (ordereddict.py, needed for Python 2.6 support)
os-diskconfig-python-novaclient-ext==0.1.2 (os_diskconfig_python_novaclient_ext/*) os-diskconfig-python-novaclient-ext==0.1.2 (os_diskconfig_python_novaclient_ext/*)
@@ -44,16 +45,16 @@ pexpect==3.1 (pexpect/*, excluded pxssh.py, fdpexpect.py, FSM.py, screen.py,
pip==1.5.4 (pip/*, excluded bin/pip*) pip==1.5.4 (pip/*, excluded bin/pip*)
prettytable==0.7.2 (prettytable.py) prettytable==0.7.2 (prettytable.py)
pyrax==1.9.0 (pyrax/*) pyrax==1.9.0 (pyrax/*)
python-dateutil==2.2 (dateutil/*) python-dateutil==2.4.0 (dateutil/*)
python-novaclient==2.18.1 (novaclient/*, excluded bin/nova) python-novaclient==2.18.1 (novaclient/*, excluded bin/nova)
python-swiftclient==2.2.0 (swiftclient/*, excluded bin/swift) python-swiftclient==2.2.0 (swiftclient/*, excluded bin/swift)
pytz==2014.4 (pytz/*) pytz==2014.10 (pytz/*)
rackspace-auth-openstack==1.3 (rackspace_auth_openstack/*) rackspace-auth-openstack==1.3 (rackspace_auth_openstack/*)
rackspace-novaclient==1.4 (no files) rackspace-novaclient==1.4 (no files)
rax-default-network-flags-python-novaclient-ext==0.2.3 (rax_default_network_flags_python_novaclient_ext/*) rax-default-network-flags-python-novaclient-ext==0.2.3 (rax_default_network_flags_python_novaclient_ext/*)
rax-scheduled-images-python-novaclient-ext==0.2.1 (rax_scheduled_images_python_novaclient_ext/*) rax-scheduled-images-python-novaclient-ext==0.2.1 (rax_scheduled_images_python_novaclient_ext/*)
requests==2.3.0 (requests/*) requests==2.5.1 (requests/*)
setuptools==2.2 (setuptools/*, _markerlib/*, pkg_resources.py, easy_install.py, excluded bin/easy_install*) setuptools==12.0.5 (setuptools/*, _markerlib/*, pkg_resources/*, easy_install.py)
simplejson==3.6.0 (simplejson/*, excluded simplejson/_speedups.so) simplejson==3.6.0 (simplejson/*, excluded simplejson/_speedups.so)
six==1.7.3 (six.py) six==1.9.0 (six.py)
South==0.8.4 (south/*) South==0.8.4 (south/*)

View File

@@ -14,6 +14,8 @@
#-------------------------------------------------------------------------- #--------------------------------------------------------------------------
import ast import ast
import base64 import base64
import hashlib
import hmac
import sys import sys
import types import types
import warnings import warnings
@@ -34,7 +36,7 @@ from xml.sax.saxutils import escape as xml_escape
# constants # constants
__author__ = 'Microsoft Corp. <ptvshelp@microsoft.com>' __author__ = 'Microsoft Corp. <ptvshelp@microsoft.com>'
__version__ = '0.8.1' __version__ = '0.9.0'
# Live ServiceClient URLs # Live ServiceClient URLs
BLOB_SERVICE_HOST_BASE = '.blob.core.windows.net' BLOB_SERVICE_HOST_BASE = '.blob.core.windows.net'
@@ -103,10 +105,9 @@ class WindowsAzureData(object):
It is only used to check whether it is instance or not. ''' It is only used to check whether it is instance or not. '''
pass pass
class WindowsAzureError(Exception): class WindowsAzureError(Exception):
''' WindowsAzure Excpetion base class. ''' ''' WindowsAzure Exception base class. '''
def __init__(self, message): def __init__(self, message):
super(WindowsAzureError, self).__init__(message) super(WindowsAzureError, self).__init__(message)
@@ -188,25 +189,38 @@ def _get_readable_id(id_name, id_prefix_to_skip):
return id_name return id_name
def _get_entry_properties_from_node(entry, include_id, id_prefix_to_skip=None, use_title_as_id=False):
''' get properties from entry xml '''
properties = {}
etag = entry.getAttributeNS(METADATA_NS, 'etag')
if etag:
properties['etag'] = etag
for updated in _get_child_nodes(entry, 'updated'):
properties['updated'] = updated.firstChild.nodeValue
for name in _get_children_from_path(entry, 'author', 'name'):
if name.firstChild is not None:
properties['author'] = name.firstChild.nodeValue
if include_id:
if use_title_as_id:
for title in _get_child_nodes(entry, 'title'):
properties['name'] = title.firstChild.nodeValue
else:
for id in _get_child_nodes(entry, 'id'):
properties['name'] = _get_readable_id(
id.firstChild.nodeValue, id_prefix_to_skip)
return properties
def _get_entry_properties(xmlstr, include_id, id_prefix_to_skip=None): def _get_entry_properties(xmlstr, include_id, id_prefix_to_skip=None):
''' get properties from entry xml ''' ''' get properties from entry xml '''
xmldoc = minidom.parseString(xmlstr) xmldoc = minidom.parseString(xmlstr)
properties = {} properties = {}
for entry in _get_child_nodes(xmldoc, 'entry'): for entry in _get_child_nodes(xmldoc, 'entry'):
etag = entry.getAttributeNS(METADATA_NS, 'etag') properties.update(_get_entry_properties_from_node(entry, include_id, id_prefix_to_skip))
if etag:
properties['etag'] = etag
for updated in _get_child_nodes(entry, 'updated'):
properties['updated'] = updated.firstChild.nodeValue
for name in _get_children_from_path(entry, 'author', 'name'):
if name.firstChild is not None:
properties['author'] = name.firstChild.nodeValue
if include_id:
for id in _get_child_nodes(entry, 'id'):
properties['name'] = _get_readable_id(
id.firstChild.nodeValue, id_prefix_to_skip)
return properties return properties
@@ -284,6 +298,18 @@ _KNOWN_SERIALIZATION_XFORMS = {
'os': 'OS', 'os': 'OS',
'persistent_vm_downtime_info': 'PersistentVMDowntimeInfo', 'persistent_vm_downtime_info': 'PersistentVMDowntimeInfo',
'copy_id': 'CopyId', 'copy_id': 'CopyId',
'os_state': 'OSState',
'vm_image': 'VMImage',
'vm_images': 'VMImages',
'os_disk_configuration': 'OSDiskConfiguration',
'public_ips': 'PublicIPs',
'public_ip': 'PublicIP',
'supported_os': 'SupportedOS',
'reserved_ip': 'ReservedIP',
'reserved_ips': 'ReservedIPs',
'aad_tenant_id': 'AADTenantID',
'start_ip_address': 'StartIPAddress',
'end_ip_address': 'EndIPAddress',
} }
@@ -428,6 +454,25 @@ def _convert_response_to_feeds(response, convert_func):
return feeds return feeds
def _convert_xml_to_windows_azure_object(xmlstr, azure_type, include_id=True, use_title_as_id=True):
xmldoc = minidom.parseString(xmlstr)
return_obj = azure_type()
xml_name = azure_type._xml_name if hasattr(azure_type, '_xml_name') else azure_type.__name__
# Only one entry here
for xml_entry in _get_children_from_path(xmldoc,
'entry'):
for node in _get_children_from_path(xml_entry,
'content',
xml_name):
_fill_data_to_return_object(node, return_obj)
for name, value in _get_entry_properties_from_node(xml_entry,
include_id=include_id,
use_title_as_id=use_title_as_id).items():
setattr(return_obj, name, value)
return return_obj
def _validate_type_bytes(param_name, param): def _validate_type_bytes(param_name, param):
if not isinstance(param, bytes): if not isinstance(param, bytes):
raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES.format(param_name)) raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES.format(param_name))
@@ -675,6 +720,13 @@ def _parse_response(response, return_type):
''' '''
return _parse_response_body_from_xml_text(response.body, return_type) return _parse_response_body_from_xml_text(response.body, return_type)
def _parse_service_resources_response(response, return_type):
'''
Parse the HTTPResponse's body and fill all the data into a class of
return_type.
'''
return _parse_response_body_from_service_resources_xml_text(response.body, return_type)
def _fill_data_to_return_object(node, return_obj): def _fill_data_to_return_object(node, return_obj):
members = dict(vars(return_obj)) members = dict(vars(return_obj))
@@ -700,6 +752,12 @@ def _fill_data_to_return_object(node, return_obj):
value.pair_xml_element_name, value.pair_xml_element_name,
value.key_xml_element_name, value.key_xml_element_name,
value.value_xml_element_name)) value.value_xml_element_name))
elif isinstance(value, _xml_attribute):
real_value = None
if node.hasAttribute(value.xml_element_name):
real_value = node.getAttribute(value.xml_element_name)
if real_value is not None:
setattr(return_obj, name, real_value)
elif isinstance(value, WindowsAzureData): elif isinstance(value, WindowsAzureData):
setattr(return_obj, setattr(return_obj,
name, name,
@@ -737,11 +795,24 @@ def _parse_response_body_from_xml_text(respbody, return_type):
''' '''
doc = minidom.parseString(respbody) doc = minidom.parseString(respbody)
return_obj = return_type() return_obj = return_type()
for node in _get_child_nodes(doc, return_type.__name__): xml_name = return_type._xml_name if hasattr(return_type, '_xml_name') else return_type.__name__
for node in _get_child_nodes(doc, xml_name):
_fill_data_to_return_object(node, return_obj) _fill_data_to_return_object(node, return_obj)
return return_obj return return_obj
def _parse_response_body_from_service_resources_xml_text(respbody, return_type):
'''
parse the xml and fill all the data into a class of return_type
'''
doc = minidom.parseString(respbody)
return_obj = _list_of(return_type)
for node in _get_children_from_path(doc, "ServiceResources", "ServiceResource"):
local_obj = return_type()
_fill_data_to_return_object(node, local_obj)
return_obj.append(local_obj)
return return_obj
class _dict_of(dict): class _dict_of(dict):
@@ -781,6 +852,15 @@ class _scalar_list_of(list):
self.xml_element_name = xml_element_name self.xml_element_name = xml_element_name
super(_scalar_list_of, self).__init__() super(_scalar_list_of, self).__init__()
class _xml_attribute:
"""a accessor to XML attributes
expected to go in it along with its xml element name.
Used for deserialization and construction"""
def __init__(self, xml_element_name):
self.xml_element_name = xml_element_name
def _update_request_uri_query_local_storage(request, use_local_storage): def _update_request_uri_query_local_storage(request, use_local_storage):
''' create correct uri and query for the request ''' ''' create correct uri and query for the request '''
@@ -903,3 +983,17 @@ def _parse_response_for_dict_filter(response, filter):
return return_dict return return_dict
else: else:
return None return None
def _sign_string(key, string_to_sign, key_is_base64=True):
if key_is_base64:
key = _decode_base64_to_bytes(key)
else:
if isinstance(key, _unicode_type):
key = key.encode('utf-8')
if isinstance(string_to_sign, _unicode_type):
string_to_sign = string_to_sign.encode('utf-8')
signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256)
digest = signed_hmac_sha256.digest()
encoded_digest = _encode_base64(digest)
return encoded_digest

View File

@@ -0,0 +1,81 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003" ToolsVersion="4.0">
<PropertyGroup>
<Configuration Condition=" '$(Configuration)' == '' ">Debug</Configuration>
<SchemaVersion>2.0</SchemaVersion>
<ProjectGuid>{25b2c65a-0553-4452-8907-8b5b17544e68}</ProjectGuid>
<ProjectHome>
</ProjectHome>
<StartupFile>storage\blobservice.py</StartupFile>
<SearchPath>..</SearchPath>
<WorkingDirectory>.</WorkingDirectory>
<OutputPath>.</OutputPath>
<Name>azure</Name>
<RootNamespace>azure</RootNamespace>
<IsWindowsApplication>False</IsWindowsApplication>
<LaunchProvider>Standard Python launcher</LaunchProvider>
<CommandLineArguments />
<InterpreterPath />
<InterpreterArguments />
<InterpreterId>{2af0f10d-7135-4994-9156-5d01c9c11b7e}</InterpreterId>
<InterpreterVersion>2.7</InterpreterVersion>
<SccProjectName>SAK</SccProjectName>
<SccProvider>SAK</SccProvider>
<SccAuxPath>SAK</SccAuxPath>
<SccLocalPath>SAK</SccLocalPath>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)' == 'Debug' ">
<DebugSymbols>true</DebugSymbols>
<EnableUnmanagedDebugging>false</EnableUnmanagedDebugging>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)' == 'Release' ">
<DebugSymbols>true</DebugSymbols>
<EnableUnmanagedDebugging>false</EnableUnmanagedDebugging>
</PropertyGroup>
<ItemGroup>
<Compile Include="http\batchclient.py" />
<Compile Include="http\httpclient.py" />
<Compile Include="http\requestsclient.py" />
<Compile Include="http\winhttp.py" />
<Compile Include="http\__init__.py" />
<Compile Include="servicemanagement\schedulermanagementservice.py" />
<Compile Include="servicemanagement\servicebusmanagementservice.py" />
<Compile Include="servicemanagement\servicemanagementclient.py" />
<Compile Include="servicemanagement\servicemanagementservice.py" />
<Compile Include="servicemanagement\sqldatabasemanagementservice.py" />
<Compile Include="servicemanagement\websitemanagementservice.py" />
<Compile Include="servicemanagement\__init__.py" />
<Compile Include="servicebus\servicebusservice.py" />
<Compile Include="storage\blobservice.py" />
<Compile Include="storage\queueservice.py" />
<Compile Include="storage\cloudstorageaccount.py" />
<Compile Include="storage\tableservice.py" />
<Compile Include="storage\sharedaccesssignature.py" />
<Compile Include="__init__.py" />
<Compile Include="servicebus\__init__.py" />
<Compile Include="storage\storageclient.py" />
<Compile Include="storage\__init__.py" />
</ItemGroup>
<ItemGroup>
<Folder Include="http" />
<Folder Include="servicemanagement" />
<Folder Include="servicebus" />
<Folder Include="storage" />
</ItemGroup>
<ItemGroup>
<InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\2.6" />
<InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\2.7" />
<InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\3.3" />
<InterpreterReference Include="{2af0f10d-7135-4994-9156-5d01c9c11b7e}\3.4" />
<InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\2.7" />
<InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\3.3" />
<InterpreterReference Include="{9a7a9026-48c1-4688-9d5d-e5699d47d074}\3.4" />
</ItemGroup>
<PropertyGroup>
<VisualStudioVersion Condition="'$(VisualStudioVersion)' == ''">10.0</VisualStudioVersion>
<VSToolsPath Condition="'$(VSToolsPath)' == ''">$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)</VSToolsPath>
<PtvsTargetsFile>$(VSToolsPath)\Python Tools\Microsoft.PythonTools.targets</PtvsTargetsFile>
</PropertyGroup>
<Import Condition="Exists($(PtvsTargetsFile))" Project="$(PtvsTargetsFile)" />
<Import Condition="!Exists($(PtvsTargetsFile))" Project="$(MSBuildToolsPath)\Microsoft.Common.targets" />
</Project>

View File

@@ -36,6 +36,8 @@ else:
from azure.http import HTTPError, HTTPResponse from azure.http import HTTPError, HTTPResponse
from azure import _USER_AGENT_STRING, _update_request_uri_query from azure import _USER_AGENT_STRING, _update_request_uri_query
DEBUG_REQUESTS = False
DEBUG_RESPONSES = False
class _HTTPClient(object): class _HTTPClient(object):
@@ -44,8 +46,7 @@ class _HTTPClient(object):
''' '''
def __init__(self, service_instance, cert_file=None, account_name=None, def __init__(self, service_instance, cert_file=None, account_name=None,
account_key=None, service_namespace=None, issuer=None, account_key=None, protocol='https', request_session=None):
protocol='https'):
''' '''
service_instance: service client instance. service_instance: service client instance.
cert_file: cert_file:
@@ -53,10 +54,9 @@ class _HTTPClient(object):
service management. service management.
account_name: the storage account. account_name: the storage account.
account_key: account_key:
the storage account access key for storage services or servicebus the storage account access key.
access key for service bus service. request_session:
service_namespace: the service namespace for service bus. session object created with requests library (or compatible).
issuer: the issuer for service bus service.
''' '''
self.service_instance = service_instance self.service_instance = service_instance
self.status = None self.status = None
@@ -65,14 +65,16 @@ class _HTTPClient(object):
self.cert_file = cert_file self.cert_file = cert_file
self.account_name = account_name self.account_name = account_name
self.account_key = account_key self.account_key = account_key
self.service_namespace = service_namespace
self.issuer = issuer
self.protocol = protocol self.protocol = protocol
self.proxy_host = None self.proxy_host = None
self.proxy_port = None self.proxy_port = None
self.proxy_user = None self.proxy_user = None
self.proxy_password = None self.proxy_password = None
self.use_httplib = self.should_use_httplib() self.request_session = request_session
if request_session:
self.use_httplib = True
else:
self.use_httplib = self.should_use_httplib()
def should_use_httplib(self): def should_use_httplib(self):
if sys.platform.lower().startswith('win') and self.cert_file: if sys.platform.lower().startswith('win') and self.cert_file:
@@ -111,6 +113,13 @@ class _HTTPClient(object):
self.proxy_user = user self.proxy_user = user
self.proxy_password = password self.proxy_password = password
def get_uri(self, request):
''' Return the target uri for the request.'''
protocol = request.protocol_override \
if request.protocol_override else self.protocol
port = HTTP_PORT if protocol == 'http' else HTTPS_PORT
return protocol + '://' + request.host + ':' + str(port) + request.path
def get_connection(self, request): def get_connection(self, request):
''' Create connection for the request. ''' ''' Create connection for the request. '''
protocol = request.protocol_override \ protocol = request.protocol_override \
@@ -118,7 +127,12 @@ class _HTTPClient(object):
target_host = request.host target_host = request.host
target_port = HTTP_PORT if protocol == 'http' else HTTPS_PORT target_port = HTTP_PORT if protocol == 'http' else HTTPS_PORT
if not self.use_httplib: if self.request_session:
import azure.http.requestsclient
connection = azure.http.requestsclient._RequestsConnection(
target_host, protocol, self.request_session)
#TODO: proxy stuff
elif not self.use_httplib:
import azure.http.winhttp import azure.http.winhttp
connection = azure.http.winhttp._HTTPConnection( connection = azure.http.winhttp._HTTPConnection(
target_host, cert_file=self.cert_file, protocol=protocol) target_host, cert_file=self.cert_file, protocol=protocol)
@@ -191,6 +205,13 @@ class _HTTPClient(object):
self.send_request_headers(connection, request.headers) self.send_request_headers(connection, request.headers)
self.send_request_body(connection, request.body) self.send_request_body(connection, request.body)
if DEBUG_REQUESTS and request.body:
print('request:')
try:
print(request.body)
except:
pass
resp = connection.getresponse() resp = connection.getresponse()
self.status = int(resp.status) self.status = int(resp.status)
self.message = resp.reason self.message = resp.reason
@@ -206,6 +227,13 @@ class _HTTPClient(object):
elif resp.length > 0: elif resp.length > 0:
respbody = resp.read(resp.length) respbody = resp.read(resp.length)
if DEBUG_RESPONSES and respbody:
print('response:')
try:
print(respbody)
except:
pass
response = HTTPResponse( response = HTTPResponse(
int(resp.status), resp.reason, headers, respbody) int(resp.status), resp.reason, headers, respbody)
if self.status == 307: if self.status == 307:

View File

@@ -0,0 +1,74 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
class _Response(object):
''' Response class corresponding to the response returned from httplib
HTTPConnection. '''
def __init__(self, response):
self.status = response.status_code
self.reason = response.reason
self.respbody = response.content
self.length = len(response.content)
self.headers = []
for key, name in response.headers.items():
self.headers.append((key.lower(), name))
def getheaders(self):
'''Returns response headers.'''
return self.headers
def read(self, _length):
'''Returns response body. '''
return self.respbody[:_length]
class _RequestsConnection(object):
def __init__(self, host, protocol, session):
self.host = host
self.protocol = protocol
self.session = session
self.headers = {}
self.method = None
self.body = None
self.response = None
self.uri = None
def close(self):
pass
def set_tunnel(self, host, port=None, headers=None):
pass
def set_proxy_credentials(self, user, password):
pass
def putrequest(self, method, uri):
self.method = method
self.uri = self.protocol + '://' + self.host + uri
def putheader(self, name, value):
self.headers[name] = value
def endheaders(self):
pass
def send(self, request_body):
self.response = self.session.request(self.method, self.uri, data=request_body, headers=self.headers)
def getresponse(self):
return _Response(self.response)

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
#-------------------------------------------------------------------------- #--------------------------------------------------------------------------
import ast import ast
import json
import sys import sys
from datetime import datetime from datetime import datetime
@@ -167,16 +168,16 @@ class Message(WindowsAzureData):
# extracts the topic and subscriptions name if it is topic message. # extracts the topic and subscriptions name if it is topic message.
if location: if location:
if '/subscriptions/' in location: if '/subscriptions/' in location:
pos = location.find('/subscriptions/') pos = location.find(service_bus_service.host_base.lower())+1
pos1 = location.rfind('/', 0, pos - 1) pos1 = location.find('/subscriptions/')
self._topic_name = location[pos1 + 1:pos] self._topic_name = location[pos+len(service_bus_service.host_base):pos1]
pos += len('/subscriptions/') pos = pos1 + len('/subscriptions/')
pos1 = location.find('/', pos) pos1 = location.find('/', pos)
self._subscription_name = location[pos:pos1] self._subscription_name = location[pos:pos1]
elif '/messages/' in location: elif '/messages/' in location:
pos = location.find('/messages/') pos = location.find(service_bus_service.host_base.lower())+1
pos1 = location.rfind('/', 0, pos - 1) pos1 = location.find('/messages/')
self._queue_name = location[pos1 + 1:pos] self._queue_name = location[pos+len(service_bus_service.host_base):pos1]
def delete(self): def delete(self):
''' Deletes itself if find queue name or topic name and subscription ''' Deletes itself if find queue name or topic name and subscription
@@ -255,7 +256,7 @@ def _create_message(response, service_instance):
# gets all information from respheaders. # gets all information from respheaders.
for name, value in response.headers: for name, value in response.headers:
if name.lower() == 'brokerproperties': if name.lower() == 'brokerproperties':
broker_properties = ast.literal_eval(value) broker_properties = json.loads(value)
elif name.lower() == 'content-type': elif name.lower() == 'content-type':
message_type = value message_type = value
elif name.lower() == 'location': elif name.lower() == 'location':

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#-------------------------------------------------------------------------- #--------------------------------------------------------------------------
import datetime
import os import os
import time import time
@@ -21,10 +22,13 @@ from azure import (
_convert_response_to_feeds, _convert_response_to_feeds,
_dont_fail_not_exist, _dont_fail_not_exist,
_dont_fail_on_exist, _dont_fail_on_exist,
_encode_base64,
_get_request_body, _get_request_body,
_get_request_body_bytes_only, _get_request_body_bytes_only,
_int_or_none, _int_or_none,
_sign_string,
_str, _str,
_unicode_type,
_update_request_uri_query, _update_request_uri_query,
url_quote, url_quote,
url_unquote, url_unquote,
@@ -55,44 +59,94 @@ from azure.servicebus import (
_service_bus_error_handler, _service_bus_error_handler,
) )
# Token cache for Authentication
# Shared by the different instances of ServiceBusService
_tokens = {}
class ServiceBusService(object): class ServiceBusService(object):
def __init__(self, service_namespace=None, account_key=None, issuer=None, def __init__(self, service_namespace=None, account_key=None, issuer=None,
x_ms_version='2011-06-01', host_base=SERVICE_BUS_HOST_BASE): x_ms_version='2011-06-01', host_base=SERVICE_BUS_HOST_BASE,
# x_ms_version is not used, but the parameter is kept for backwards shared_access_key_name=None, shared_access_key_value=None,
# compatibility authentication=None):
'''
Initializes the service bus service for a namespace with the specified
authentication settings (SAS or ACS).
service_namespace:
Service bus namespace, required for all operations. If None,
the value is set to the AZURE_SERVICEBUS_NAMESPACE env variable.
account_key:
ACS authentication account key. If None, the value is set to the
AZURE_SERVICEBUS_ACCESS_KEY env variable.
Note that if both SAS and ACS settings are specified, SAS is used.
issuer:
ACS authentication issuer. If None, the value is set to the
AZURE_SERVICEBUS_ISSUER env variable.
Note that if both SAS and ACS settings are specified, SAS is used.
x_ms_version: Unused. Kept for backwards compatibility.
host_base:
Optional. Live host base url. Defaults to Azure url. Override this
for on-premise.
shared_access_key_name:
SAS authentication key name.
Note that if both SAS and ACS settings are specified, SAS is used.
shared_access_key_value:
SAS authentication key value.
Note that if both SAS and ACS settings are specified, SAS is used.
authentication:
Instance of authentication class. If this is specified, then
ACS and SAS parameters are ignored.
'''
self.requestid = None self.requestid = None
self.service_namespace = service_namespace self.service_namespace = service_namespace
self.account_key = account_key
self.issuer = issuer
self.host_base = host_base self.host_base = host_base
# Get service namespace, account key and issuer.
# If they are set when constructing, then use them, else find them
# from environment variables.
if not self.service_namespace: if not self.service_namespace:
self.service_namespace = os.environ.get(AZURE_SERVICEBUS_NAMESPACE) self.service_namespace = os.environ.get(AZURE_SERVICEBUS_NAMESPACE)
if not self.account_key:
self.account_key = os.environ.get(AZURE_SERVICEBUS_ACCESS_KEY)
if not self.issuer:
self.issuer = os.environ.get(AZURE_SERVICEBUS_ISSUER)
if not self.service_namespace or \ if not self.service_namespace:
not self.account_key or not self.issuer: raise WindowsAzureError('You need to provide servicebus namespace')
raise WindowsAzureError(
'You need to provide servicebus namespace, access key and Issuer')
self._httpclient = _HTTPClient(service_instance=self, if authentication:
service_namespace=self.service_namespace, self.authentication = authentication
account_key=self.account_key, else:
issuer=self.issuer) if not account_key:
account_key = os.environ.get(AZURE_SERVICEBUS_ACCESS_KEY)
if not issuer:
issuer = os.environ.get(AZURE_SERVICEBUS_ISSUER)
if shared_access_key_name and shared_access_key_value:
self.authentication = ServiceBusSASAuthentication(
shared_access_key_name,
shared_access_key_value)
elif account_key and issuer:
self.authentication = ServiceBusWrapTokenAuthentication(
account_key,
issuer)
else:
raise WindowsAzureError(
'You need to provide servicebus access key and Issuer OR shared access key and value')
self._httpclient = _HTTPClient(service_instance=self)
self._filter = self._httpclient.perform_request self._filter = self._httpclient.perform_request
# Backwards compatibility:
# account_key and issuer used to be stored on the service class, they are
# now stored on the authentication class.
@property
def account_key(self):
return self.authentication.account_key
@account_key.setter
def account_key(self, value):
self.authentication.account_key = value
@property
def issuer(self):
return self.authentication.issuer
@issuer.setter
def issuer(self, value):
self.authentication.issuer = value
def with_filter(self, filter): def with_filter(self, filter):
''' '''
Returns a new service which will process requests with the specified Returns a new service which will process requests with the specified
@@ -102,8 +156,10 @@ class ServiceBusService(object):
request, pass it off to the next lambda, and then perform any request, pass it off to the next lambda, and then perform any
post-processing on the response. post-processing on the response.
''' '''
res = ServiceBusService(self.service_namespace, self.account_key, res = ServiceBusService(
self.issuer) service_namespace=self.service_namespace,
authentication=self.authentication)
old_filter = self._filter old_filter = self._filter
def new_filter(request): def new_filter(request):
@@ -855,17 +911,30 @@ class ServiceBusService(object):
('Content-Type', ('Content-Type',
'application/atom+xml;type=entry;charset=utf-8')) 'application/atom+xml;type=entry;charset=utf-8'))
# Adds authoriaztion header for authentication. # Adds authorization header for authentication.
request.headers.append( self.authentication.sign_request(request, self._httpclient)
('Authorization', self._sign_service_bus_request(request)))
return request.headers return request.headers
def _sign_service_bus_request(self, request):
''' return the signed string with token. '''
# Token cache for Authentication
# Shared by the different instances of ServiceBusWrapTokenAuthentication
_tokens = {}
class ServiceBusWrapTokenAuthentication:
def __init__(self, account_key, issuer):
self.account_key = account_key
self.issuer = issuer
def sign_request(self, request, httpclient):
request.headers.append(
('Authorization', self._get_authorization(request, httpclient)))
def _get_authorization(self, request, httpclient):
''' return the signed string with token. '''
return 'WRAP access_token="' + \ return 'WRAP access_token="' + \
self._get_token(request.host, request.path) + '"' self._get_token(request.host, request.path, httpclient) + '"'
def _token_is_expired(self, token): def _token_is_expired(self, token):
''' Check if token expires or not. ''' ''' Check if token expires or not. '''
@@ -878,7 +947,7 @@ class ServiceBusService(object):
# token to server. # token to server.
return (token_expire_time - time_now) < 30 return (token_expire_time - time_now) < 30
def _get_token(self, host, path): def _get_token(self, host, path, httpclient):
''' '''
Returns token for the request. Returns token for the request.
@@ -905,10 +974,38 @@ class ServiceBusService(object):
'&wrap_scope=' + '&wrap_scope=' +
url_quote('http://' + host + path)).encode('utf-8') url_quote('http://' + host + path)).encode('utf-8')
request.headers.append(('Content-Length', str(len(request.body)))) request.headers.append(('Content-Length', str(len(request.body))))
resp = self._httpclient.perform_request(request) resp = httpclient.perform_request(request)
token = resp.body.decode('utf-8') token = resp.body.decode('utf-8')
token = url_unquote(token[token.find('=') + 1:token.rfind('&')]) token = url_unquote(token[token.find('=') + 1:token.rfind('&')])
_tokens[wrap_scope] = token _tokens[wrap_scope] = token
return token return token
class ServiceBusSASAuthentication:
def __init__(self, key_name, key_value):
self.key_name = key_name
self.key_value = key_value
def sign_request(self, request, httpclient):
request.headers.append(
('Authorization', self._get_authorization(request, httpclient)))
def _get_authorization(self, request, httpclient):
uri = httpclient.get_uri(request)
uri = url_quote(uri, '').lower()
expiry = str(self._get_expiry())
to_sign = uri + '\n' + expiry
signature = url_quote(_sign_string(self.key_value, to_sign, False), '')
auth_format = 'SharedAccessSignature sig={0}&se={1}&skn={2}&sr={3}'
auth = auth_format.format(signature, expiry, self.key_name, uri)
return auth
def _get_expiry(self):
'''Returns the UTC datetime, in seconds since Epoch, when this signed
request expires (5 minutes from now).'''
return int(round(time.time() + 300))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,70 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import (
MANAGEMENT_HOST,
_str
)
from azure.servicemanagement import (
CloudServices,
)
from azure.servicemanagement.servicemanagementclient import (
_ServiceManagementClient,
)
class SchedulerManagementService(_ServiceManagementClient):
''' Note that this class is a preliminary work on Scheduler
management. Since it lack a lot a features, final version
can be slightly different from the current one.
'''
def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST, request_session=None):
'''
Initializes the scheduler management service.
subscription_id: Subscription to manage.
cert_file:
Path to .pem certificate file (httplib), or location of the
certificate in your Personal certificate store (winhttp) in the
CURRENT_USER\my\CertificateName format.
If a request_session is specified, then this is unused.
host: Live ServiceClient URL. Defaults to Azure public cloud.
request_session:
Session object to use for http requests. If this is specified, it
replaces the default use of httplib or winhttp. Also, the cert_file
parameter is unused when a session is passed in.
The session object handles authentication, and as such can support
multiple types of authentication: .pem certificate, oauth.
For example, you can pass in a Session instance from the requests
library. To use .pem certificate authentication with requests
library, set the path to the .pem file on the session.cert
attribute.
'''
super(SchedulerManagementService, self).__init__(
subscription_id, cert_file, host, request_session)
#--Operations for scheduler ----------------------------------------
def list_cloud_services(self):
'''
List the cloud services for scheduling defined on the account.
'''
return self._perform_get(self._get_list_cloud_services_path(),
CloudServices)
#--Helper functions --------------------------------------------------
def _get_list_cloud_services_path(self):
return self._get_path('cloudservices', None)

View File

@@ -17,23 +17,56 @@ from azure import (
_convert_response_to_feeds, _convert_response_to_feeds,
_str, _str,
_validate_not_none, _validate_not_none,
) _convert_xml_to_windows_azure_object,
)
from azure.servicemanagement import ( from azure.servicemanagement import (
_ServiceBusManagementXmlSerializer, _ServiceBusManagementXmlSerializer,
) QueueDescription,
TopicDescription,
NotificationHubDescription,
RelayDescription,
MetricProperties,
MetricValues,
MetricRollups,
)
from azure.servicemanagement.servicemanagementclient import ( from azure.servicemanagement.servicemanagementclient import (
_ServiceManagementClient, _ServiceManagementClient,
) )
from functools import partial
X_MS_VERSION = '2012-03-01'
class ServiceBusManagementService(_ServiceManagementClient): class ServiceBusManagementService(_ServiceManagementClient):
def __init__(self, subscription_id=None, cert_file=None, def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST): host=MANAGEMENT_HOST, request_session=None):
super(ServiceBusManagementService, self).__init__( '''
subscription_id, cert_file, host) Initializes the service bus management service.
#--Operations for service bus ---------------------------------------- subscription_id: Subscription to manage.
cert_file:
Path to .pem certificate file (httplib), or location of the
certificate in your Personal certificate store (winhttp) in the
CURRENT_USER\my\CertificateName format.
If a request_session is specified, then this is unused.
host: Live ServiceClient URL. Defaults to Azure public cloud.
request_session:
Session object to use for http requests. If this is specified, it
replaces the default use of httplib or winhttp. Also, the cert_file
parameter is unused when a session is passed in.
The session object handles authentication, and as such can support
multiple types of authentication: .pem certificate, oauth.
For example, you can pass in a Session instance from the requests
library. To use .pem certificate authentication with requests
library, set the path to the .pem file on the session.cert
attribute.
'''
super(ServiceBusManagementService, self).__init__(
subscription_id, cert_file, host, request_session)
self.x_ms_version = X_MS_VERSION
# Operations for service bus ----------------------------------------
def get_regions(self): def get_regions(self):
''' '''
Get list of available service bus regions. Get list of available service bus regions.
@@ -111,3 +144,391 @@ class ServiceBusManagementService(_ServiceManagementClient):
return _ServiceBusManagementXmlSerializer.xml_to_namespace_availability( return _ServiceBusManagementXmlSerializer.xml_to_namespace_availability(
response.body) response.body)
def list_queues(self, name):
'''
Enumerates the queues in the service namespace.
name: Name of the service bus namespace.
'''
_validate_not_none('name', name)
response = self._perform_get(
self._get_list_queues_path(name),
None)
return _convert_response_to_feeds(response,
partial(_convert_xml_to_windows_azure_object,
azure_type=QueueDescription))
def list_topics(self, name):
'''
Retrieves the topics in the service namespace.
name: Name of the service bus namespace.
'''
response = self._perform_get(
self._get_list_topics_path(name),
None)
return _convert_response_to_feeds(response,
partial(_convert_xml_to_windows_azure_object,
azure_type=TopicDescription))
def list_notification_hubs(self, name):
'''
Retrieves the notification hubs in the service namespace.
name: Name of the service bus namespace.
'''
response = self._perform_get(
self._get_list_notification_hubs_path(name),
None)
return _convert_response_to_feeds(response,
partial(_convert_xml_to_windows_azure_object,
azure_type=NotificationHubDescription))
def list_relays(self, name):
'''
Retrieves the relays in the service namespace.
name: Name of the service bus namespace.
'''
response = self._perform_get(
self._get_list_relays_path(name),
None)
return _convert_response_to_feeds(response,
partial(_convert_xml_to_windows_azure_object,
azure_type=RelayDescription))
def get_supported_metrics_queue(self, name, queue_name):
'''
Retrieves the list of supported metrics for this namespace and queue
name: Name of the service bus namespace.
queue_name: Name of the service bus queue in this namespace.
'''
response = self._perform_get(
self._get_get_supported_metrics_queue_path(name, queue_name),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricProperties))
def get_supported_metrics_topic(self, name, topic_name):
'''
Retrieves the list of supported metrics for this namespace and topic
name: Name of the service bus namespace.
topic_name: Name of the service bus queue in this namespace.
'''
response = self._perform_get(
self._get_get_supported_metrics_topic_path(name, topic_name),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricProperties))
def get_supported_metrics_notification_hub(self, name, hub_name):
'''
Retrieves the list of supported metrics for this namespace and topic
name: Name of the service bus namespace.
hub_name: Name of the service bus notification hub in this namespace.
'''
response = self._perform_get(
self._get_get_supported_metrics_hub_path(name, hub_name),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricProperties))
def get_supported_metrics_relay(self, name, relay_name):
'''
Retrieves the list of supported metrics for this namespace and relay
name: Name of the service bus namespace.
relay_name: Name of the service bus relay in this namespace.
'''
response = self._perform_get(
self._get_get_supported_metrics_relay_path(name, relay_name),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricProperties))
def get_metrics_data_queue(self, name, queue_name, metric, rollup, filter_expresssion):
'''
Retrieves the list of supported metrics for this namespace and queue
name: Name of the service bus namespace.
queue_name: Name of the service bus queue in this namespace.
metric: name of a supported metric
rollup: name of a supported rollup
filter_expression: filter, for instance "$filter=Timestamp gt datetime'2014-10-01T00:00:00Z'"
'''
response = self._perform_get(
self._get_get_metrics_data_queue_path(name, queue_name, metric, rollup, filter_expresssion),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricValues))
def get_metrics_data_topic(self, name, topic_name, metric, rollup, filter_expresssion):
'''
Retrieves the list of supported metrics for this namespace and topic
name: Name of the service bus namespace.
topic_name: Name of the service bus queue in this namespace.
metric: name of a supported metric
rollup: name of a supported rollup
filter_expression: filter, for instance "$filter=Timestamp gt datetime'2014-10-01T00:00:00Z'"
'''
response = self._perform_get(
self._get_get_metrics_data_topic_path(name, topic_name, metric, rollup, filter_expresssion),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricValues))
def get_metrics_data_notification_hub(self, name, hub_name, metric, rollup, filter_expresssion):
'''
Retrieves the list of supported metrics for this namespace and topic
name: Name of the service bus namespace.
hub_name: Name of the service bus notification hub in this namespace.
metric: name of a supported metric
rollup: name of a supported rollup
filter_expression: filter, for instance "$filter=Timestamp gt datetime'2014-10-01T00:00:00Z'"
'''
response = self._perform_get(
self._get_get_metrics_data_hub_path(name, hub_name, metric, rollup, filter_expresssion),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricValues))
def get_metrics_data_relay(self, name, relay_name, metric, rollup, filter_expresssion):
'''
Retrieves the list of supported metrics for this namespace and relay
name: Name of the service bus namespace.
relay_name: Name of the service bus relay in this namespace.
metric: name of a supported metric
rollup: name of a supported rollup
filter_expression: filter, for instance "$filter=Timestamp gt datetime'2014-10-01T00:00:00Z'"
'''
response = self._perform_get(
self._get_get_metrics_data_relay_path(name, relay_name, metric, rollup, filter_expresssion),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricValues))
def get_metrics_rollups_queue(self, name, queue_name, metric):
'''
This operation gets rollup data for Service Bus metrics queue.
Rollup data includes the time granularity for the telemetry aggregation as well as
the retention settings for each time granularity.
name: Name of the service bus namespace.
queue_name: Name of the service bus queue in this namespace.
metric: name of a supported metric
'''
response = self._perform_get(
self._get_get_metrics_rollup_queue_path(name, queue_name, metric),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricRollups))
def get_metrics_rollups_topic(self, name, topic_name, metric):
'''
This operation gets rollup data for Service Bus metrics topic.
Rollup data includes the time granularity for the telemetry aggregation as well as
the retention settings for each time granularity.
name: Name of the service bus namespace.
topic_name: Name of the service bus queue in this namespace.
metric: name of a supported metric
'''
response = self._perform_get(
self._get_get_metrics_rollup_topic_path(name, topic_name, metric),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricRollups))
def get_metrics_rollups_notification_hub(self, name, hub_name, metric):
'''
This operation gets rollup data for Service Bus metrics notification hub.
Rollup data includes the time granularity for the telemetry aggregation as well as
the retention settings for each time granularity.
name: Name of the service bus namespace.
hub_name: Name of the service bus notification hub in this namespace.
metric: name of a supported metric
'''
response = self._perform_get(
self._get_get_metrics_rollup_hub_path(name, hub_name, metric),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricRollups))
def get_metrics_rollups_relay(self, name, relay_name, metric):
'''
This operation gets rollup data for Service Bus metrics relay.
Rollup data includes the time granularity for the telemetry aggregation as well as
the retention settings for each time granularity.
name: Name of the service bus namespace.
relay_name: Name of the service bus relay in this namespace.
metric: name of a supported metric
'''
response = self._perform_get(
self._get_get_metrics_rollup_relay_path(name, relay_name, metric),
None)
return _convert_response_to_feeds(response,
partial(_ServiceBusManagementXmlSerializer.xml_to_metrics,
object_type=MetricRollups))
# Helper functions --------------------------------------------------
def _get_list_queues_path(self, namespace_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/Queues'
def _get_list_topics_path(self, namespace_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/Topics'
def _get_list_notification_hubs_path(self, namespace_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/NotificationHubs'
def _get_list_relays_path(self, namespace_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/Relays'
def _get_get_supported_metrics_queue_path(self, namespace_name, queue_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/Queues/' + _str(queue_name) + '/Metrics'
def _get_get_supported_metrics_topic_path(self, namespace_name, topic_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/Topics/' + _str(topic_name) + '/Metrics'
def _get_get_supported_metrics_hub_path(self, namespace_name, hub_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/NotificationHubs/' + _str(hub_name) + '/Metrics'
def _get_get_supported_metrics_relay_path(self, namespace_name, queue_name):
return self._get_path('services/serviceBus/Namespaces/',
namespace_name) + '/Relays/' + _str(queue_name) + '/Metrics'
def _get_get_metrics_data_queue_path(self, namespace_name, queue_name, metric, rollup, filter_expr):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/Queues/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups/',
_str(rollup),
'/Values?',
filter_expr
])
def _get_get_metrics_data_topic_path(self, namespace_name, queue_name, metric, rollup, filter_expr):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/Topics/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups/',
_str(rollup),
'/Values?',
filter_expr
])
def _get_get_metrics_data_hub_path(self, namespace_name, queue_name, metric, rollup, filter_expr):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/NotificationHubs/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups/',
_str(rollup),
'/Values?',
filter_expr
])
def _get_get_metrics_data_relay_path(self, namespace_name, queue_name, metric, rollup, filter_expr):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/Relays/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups/',
_str(rollup),
'/Values?',
filter_expr
])
def _get_get_metrics_rollup_queue_path(self, namespace_name, queue_name, metric):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/Queues/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups',
])
def _get_get_metrics_rollup_topic_path(self, namespace_name, queue_name, metric):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/Topics/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups',
])
def _get_get_metrics_rollup_hub_path(self, namespace_name, queue_name, metric):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/NotificationHubs/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups',
])
def _get_get_metrics_rollup_relay_path(self, namespace_name, queue_name, metric):
return "".join([
self._get_path('services/serviceBus/Namespaces/', namespace_name),
'/Relays/',
_str(queue_name),
'/Metrics/',
_str(metric),
'/Rollups',
])

View File

@@ -31,21 +31,24 @@ from azure.servicemanagement import (
AZURE_MANAGEMENT_CERTFILE, AZURE_MANAGEMENT_CERTFILE,
AZURE_MANAGEMENT_SUBSCRIPTIONID, AZURE_MANAGEMENT_SUBSCRIPTIONID,
_management_error_handler, _management_error_handler,
_parse_response_for_async_op, parse_response_for_async_op,
_update_management_header, X_MS_VERSION,
) )
class _ServiceManagementClient(object): class _ServiceManagementClient(object):
def __init__(self, subscription_id=None, cert_file=None, def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST): host=MANAGEMENT_HOST, request_session=None):
self.requestid = None self.requestid = None
self.subscription_id = subscription_id self.subscription_id = subscription_id
self.cert_file = cert_file self.cert_file = cert_file
self.host = host self.host = host
self.request_session = request_session
self.x_ms_version = X_MS_VERSION
self.content_type = 'application/atom+xml;type=entry;charset=utf-8'
if not self.cert_file: if not self.cert_file and not request_session:
if AZURE_MANAGEMENT_CERTFILE in os.environ: if AZURE_MANAGEMENT_CERTFILE in os.environ:
self.cert_file = os.environ[AZURE_MANAGEMENT_CERTFILE] self.cert_file = os.environ[AZURE_MANAGEMENT_CERTFILE]
@@ -54,12 +57,14 @@ class _ServiceManagementClient(object):
self.subscription_id = os.environ[ self.subscription_id = os.environ[
AZURE_MANAGEMENT_SUBSCRIPTIONID] AZURE_MANAGEMENT_SUBSCRIPTIONID]
if not self.cert_file or not self.subscription_id: if not self.request_session:
raise WindowsAzureError( if not self.cert_file or not self.subscription_id:
'You need to provide subscription id and certificate file') raise WindowsAzureError(
'You need to provide subscription id and certificate file')
self._httpclient = _HTTPClient( self._httpclient = _HTTPClient(
service_instance=self, cert_file=self.cert_file) service_instance=self, cert_file=self.cert_file,
request_session=self.request_session)
self._filter = self._httpclient.perform_request self._filter = self._httpclient.perform_request
def with_filter(self, filter): def with_filter(self, filter):
@@ -69,7 +74,8 @@ class _ServiceManagementClient(object):
and another lambda. The filter can perform any pre-processing on the and another lambda. The filter can perform any pre-processing on the
request, pass it off to the next lambda, and then perform any request, pass it off to the next lambda, and then perform any
post-processing on the response.''' post-processing on the response.'''
res = type(self)(self.subscription_id, self.cert_file, self.host) res = type(self)(self.subscription_id, self.cert_file, self.host,
self.request_session)
old_filter = self._filter old_filter = self._filter
def new_filter(request): def new_filter(request):
@@ -89,6 +95,96 @@ class _ServiceManagementClient(object):
''' '''
self._httpclient.set_proxy(host, port, user, password) self._httpclient.set_proxy(host, port, user, password)
def perform_get(self, path, x_ms_version=None):
'''
Performs a GET request and returns the response.
path:
Path to the resource.
Ex: '/<subscription-id>/services/hostedservices/<service-name>'
x_ms_version:
If specified, this is used for the x-ms-version header.
Otherwise, self.x_ms_version is used.
'''
request = HTTPRequest()
request.method = 'GET'
request.host = self.host
request.path = path
request.path, request.query = _update_request_uri_query(request)
request.headers = self._update_management_header(request, x_ms_version)
response = self._perform_request(request)
return response
def perform_put(self, path, body, x_ms_version=None):
'''
Performs a PUT request and returns the response.
path:
Path to the resource.
Ex: '/<subscription-id>/services/hostedservices/<service-name>'
body:
Body for the PUT request.
x_ms_version:
If specified, this is used for the x-ms-version header.
Otherwise, self.x_ms_version is used.
'''
request = HTTPRequest()
request.method = 'PUT'
request.host = self.host
request.path = path
request.body = _get_request_body(body)
request.path, request.query = _update_request_uri_query(request)
request.headers = self._update_management_header(request, x_ms_version)
response = self._perform_request(request)
return response
def perform_post(self, path, body, x_ms_version=None):
'''
Performs a POST request and returns the response.
path:
Path to the resource.
Ex: '/<subscription-id>/services/hostedservices/<service-name>'
body:
Body for the POST request.
x_ms_version:
If specified, this is used for the x-ms-version header.
Otherwise, self.x_ms_version is used.
'''
request = HTTPRequest()
request.method = 'POST'
request.host = self.host
request.path = path
request.body = _get_request_body(body)
request.path, request.query = _update_request_uri_query(request)
request.headers = self._update_management_header(request, x_ms_version)
response = self._perform_request(request)
return response
def perform_delete(self, path, x_ms_version=None):
'''
Performs a DELETE request and returns the response.
path:
Path to the resource.
Ex: '/<subscription-id>/services/hostedservices/<service-name>'
x_ms_version:
If specified, this is used for the x-ms-version header.
Otherwise, self.x_ms_version is used.
'''
request = HTTPRequest()
request.method = 'DELETE'
request.host = self.host
request.path = path
request.path, request.query = _update_request_uri_query(request)
request.headers = self._update_management_header(request, x_ms_version)
response = self._perform_request(request)
return response
#--Helper functions -------------------------------------------------- #--Helper functions --------------------------------------------------
def _perform_request(self, request): def _perform_request(self, request):
try: try:
@@ -98,64 +194,60 @@ class _ServiceManagementClient(object):
return resp return resp
def _perform_get(self, path, response_type): def _update_management_header(self, request, x_ms_version):
request = HTTPRequest() ''' Add additional headers for management. '''
request.method = 'GET'
request.host = self.host if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']:
request.path = path request.headers.append(('Content-Length', str(len(request.body))))
request.path, request.query = _update_request_uri_query(request)
request.headers = _update_management_header(request) # append additional headers base on the service
response = self._perform_request(request) request.headers.append(('x-ms-version', x_ms_version or self.x_ms_version))
# if it is not GET or HEAD request, must set content-type.
if not request.method in ['GET', 'HEAD']:
for name, _ in request.headers:
if 'content-type' == name.lower():
break
else:
request.headers.append(
('Content-Type',
self.content_type))
return request.headers
def _perform_get(self, path, response_type, x_ms_version=None):
response = self.perform_get(path, x_ms_version)
if response_type is not None: if response_type is not None:
return _parse_response(response, response_type) return _parse_response(response, response_type)
return response return response
def _perform_put(self, path, body, async=False): def _perform_put(self, path, body, async=False, x_ms_version=None):
request = HTTPRequest() response = self.perform_put(path, body, x_ms_version)
request.method = 'PUT'
request.host = self.host
request.path = path
request.body = _get_request_body(body)
request.path, request.query = _update_request_uri_query(request)
request.headers = _update_management_header(request)
response = self._perform_request(request)
if async: if async:
return _parse_response_for_async_op(response) return parse_response_for_async_op(response)
return None return None
def _perform_post(self, path, body, response_type=None, async=False): def _perform_post(self, path, body, response_type=None, async=False,
request = HTTPRequest() x_ms_version=None):
request.method = 'POST' response = self.perform_post(path, body, x_ms_version)
request.host = self.host
request.path = path
request.body = _get_request_body(body)
request.path, request.query = _update_request_uri_query(request)
request.headers = _update_management_header(request)
response = self._perform_request(request)
if response_type is not None: if response_type is not None:
return _parse_response(response, response_type) return _parse_response(response, response_type)
if async: if async:
return _parse_response_for_async_op(response) return parse_response_for_async_op(response)
return None return None
def _perform_delete(self, path, async=False): def _perform_delete(self, path, async=False, x_ms_version=None):
request = HTTPRequest() response = self.perform_delete(path, x_ms_version)
request.method = 'DELETE'
request.host = self.host
request.path = path
request.path, request.query = _update_request_uri_query(request)
request.headers = _update_management_header(request)
response = self._perform_request(request)
if async: if async:
return _parse_response_for_async_op(response) return parse_response_for_async_op(response)
return None return None

View File

@@ -37,12 +37,19 @@ from azure.servicemanagement import (
OperatingSystemFamilies, OperatingSystemFamilies,
OSImage, OSImage,
PersistentVMRole, PersistentVMRole,
ResourceExtensions,
ReservedIP,
ReservedIPs,
RoleSize,
RoleSizes,
StorageService, StorageService,
StorageServices, StorageServices,
Subscription, Subscription,
Subscriptions,
SubscriptionCertificate, SubscriptionCertificate,
SubscriptionCertificates, SubscriptionCertificates,
VirtualNetworkSites, VirtualNetworkSites,
VMImages,
_XmlSerializer, _XmlSerializer,
) )
from azure.servicemanagement.servicemanagementclient import ( from azure.servicemanagement.servicemanagementclient import (
@@ -52,9 +59,49 @@ from azure.servicemanagement.servicemanagementclient import (
class ServiceManagementService(_ServiceManagementClient): class ServiceManagementService(_ServiceManagementClient):
def __init__(self, subscription_id=None, cert_file=None, def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST): host=MANAGEMENT_HOST, request_session=None):
'''
Initializes the management service.
subscription_id: Subscription to manage.
cert_file:
Path to .pem certificate file (httplib), or location of the
certificate in your Personal certificate store (winhttp) in the
CURRENT_USER\my\CertificateName format.
If a request_session is specified, then this is unused.
host: Live ServiceClient URL. Defaults to Azure public cloud.
request_session:
Session object to use for http requests. If this is specified, it
replaces the default use of httplib or winhttp. Also, the cert_file
parameter is unused when a session is passed in.
The session object handles authentication, and as such can support
multiple types of authentication: .pem certificate, oauth.
For example, you can pass in a Session instance from the requests
library. To use .pem certificate authentication with requests
library, set the path to the .pem file on the session.cert
attribute.
'''
super(ServiceManagementService, self).__init__( super(ServiceManagementService, self).__init__(
subscription_id, cert_file, host) subscription_id, cert_file, host, request_session)
#--Operations for subscriptions --------------------------------------
def list_role_sizes(self):
'''
Lists the role sizes that are available under the specified
subscription.
'''
return self._perform_get(self._get_role_sizes_path(),
RoleSizes)
def list_subscriptions(self):
'''
Returns a list of subscriptions that you can access.
You must make sure that the request that is made to the management
service is secure using an Active Directory access token.
'''
return self._perform_get(self._get_subscriptions_path(),
Subscriptions)
#--Operations for storage accounts ----------------------------------- #--Operations for storage accounts -----------------------------------
def list_storage_accounts(self): def list_storage_accounts(self):
@@ -107,8 +154,9 @@ class ServiceManagementService(_ServiceManagementClient):
def create_storage_account(self, service_name, description, label, def create_storage_account(self, service_name, description, label,
affinity_group=None, location=None, affinity_group=None, location=None,
geo_replication_enabled=True, geo_replication_enabled=None,
extended_properties=None): extended_properties=None,
account_type='Standard_GRS'):
''' '''
Creates a new storage account in Windows Azure. Creates a new storage account in Windows Azure.
@@ -131,12 +179,7 @@ class ServiceManagementService(_ServiceManagementClient):
The location where the storage account is created. You can specify The location where the storage account is created. You can specify
either a location or affinity_group, but not both. either a location or affinity_group, but not both.
geo_replication_enabled: geo_replication_enabled:
Specifies whether the storage account is created with the Deprecated. Replaced by the account_type parameter.
geo-replication enabled. If the element is not included in the
request body, the default value is true. If set to true, the data
in the storage account is replicated across more than one
geographic location so as to enable resilience in the face of
catastrophic service loss.
extended_properties: extended_properties:
Dictionary containing name/value pairs of storage account Dictionary containing name/value pairs of storage account
properties. You can have a maximum of 50 extended property properties. You can have a maximum of 50 extended property
@@ -144,6 +187,12 @@ class ServiceManagementService(_ServiceManagementClient):
characters, only alphanumeric characters and underscores are valid characters, only alphanumeric characters and underscores are valid
in the Name, and the name must start with a letter. The value has in the Name, and the name must start with a letter. The value has
a maximum length of 255 characters. a maximum length of 255 characters.
account_type:
Specifies whether the account supports locally-redundant storage,
geo-redundant storage, zone-redundant storage, or read access
geo-redundant storage.
Possible values are:
Standard_LRS, Standard_ZRS, Standard_GRS, Standard_RAGRS
''' '''
_validate_not_none('service_name', service_name) _validate_not_none('service_name', service_name)
_validate_not_none('description', description) _validate_not_none('description', description)
@@ -154,6 +203,8 @@ class ServiceManagementService(_ServiceManagementClient):
if affinity_group is not None and location is not None: if affinity_group is not None and location is not None:
raise WindowsAzureError( raise WindowsAzureError(
'Only one of location or affinity_group needs to be specified') 'Only one of location or affinity_group needs to be specified')
if geo_replication_enabled == False:
account_type = 'Standard_LRS'
return self._perform_post( return self._perform_post(
self._get_storage_service_path(), self._get_storage_service_path(),
_XmlSerializer.create_storage_service_input_to_xml( _XmlSerializer.create_storage_service_input_to_xml(
@@ -162,13 +213,14 @@ class ServiceManagementService(_ServiceManagementClient):
label, label,
affinity_group, affinity_group,
location, location,
geo_replication_enabled, account_type,
extended_properties), extended_properties),
async=True) async=True)
def update_storage_account(self, service_name, description=None, def update_storage_account(self, service_name, description=None,
label=None, geo_replication_enabled=None, label=None, geo_replication_enabled=None,
extended_properties=None): extended_properties=None,
account_type='Standard_GRS'):
''' '''
Updates the label, the description, and enables or disables the Updates the label, the description, and enables or disables the
geo-replication status for a storage account in Windows Azure. geo-replication status for a storage account in Windows Azure.
@@ -182,12 +234,7 @@ class ServiceManagementService(_ServiceManagementClient):
characters in length. The name can be used to identify the storage characters in length. The name can be used to identify the storage
account for your tracking purposes. account for your tracking purposes.
geo_replication_enabled: geo_replication_enabled:
Specifies whether the storage account is created with the Deprecated. Replaced by the account_type parameter.
geo-replication enabled. If the element is not included in the
request body, the default value is true. If set to true, the data
in the storage account is replicated across more than one
geographic location so as to enable resilience in the face of
catastrophic service loss.
extended_properties: extended_properties:
Dictionary containing name/value pairs of storage account Dictionary containing name/value pairs of storage account
properties. You can have a maximum of 50 extended property properties. You can have a maximum of 50 extended property
@@ -195,14 +242,22 @@ class ServiceManagementService(_ServiceManagementClient):
characters, only alphanumeric characters and underscores are valid characters, only alphanumeric characters and underscores are valid
in the Name, and the name must start with a letter. The value has in the Name, and the name must start with a letter. The value has
a maximum length of 255 characters. a maximum length of 255 characters.
account_type:
Specifies whether the account supports locally-redundant storage,
geo-redundant storage, zone-redundant storage, or read access
geo-redundant storage.
Possible values are:
Standard_LRS, Standard_ZRS, Standard_GRS, Standard_RAGRS
''' '''
_validate_not_none('service_name', service_name) _validate_not_none('service_name', service_name)
if geo_replication_enabled == False:
account_type = 'Standard_LRS'
return self._perform_put( return self._perform_put(
self._get_storage_service_path(service_name), self._get_storage_service_path(service_name),
_XmlSerializer.update_storage_service_input_to_xml( _XmlSerializer.update_storage_service_input_to_xml(
description, description,
label, label,
geo_replication_enabled, account_type,
extended_properties)) extended_properties))
def delete_storage_account(self, service_name): def delete_storage_account(self, service_name):
@@ -697,6 +752,50 @@ class ServiceManagementService(_ServiceManagementClient):
'', '',
async=True) async=True)
def rebuild_role_instance(self, service_name, deployment_name,
role_instance_name):
'''
Reinstalls the operating system on instances of web roles or worker
roles and initializes the storage resources that are used by them. If
you do not want to initialize storage resources, you can use
reimage_role_instance.
service_name: Name of the hosted service.
deployment_name: The name of the deployment.
role_instance_name: The name of the role instance.
'''
_validate_not_none('service_name', service_name)
_validate_not_none('deployment_name', deployment_name)
_validate_not_none('role_instance_name', role_instance_name)
return self._perform_post(
self._get_deployment_path_using_name(
service_name, deployment_name) + \
'/roleinstances/' + _str(role_instance_name) + \
'?comp=rebuild&resources=allLocalDrives',
'',
async=True)
def delete_role_instances(self, service_name, deployment_name,
role_instance_names):
'''
Reinstalls the operating system on instances of web roles or worker
roles and initializes the storage resources that are used by them. If
you do not want to initialize storage resources, you can use
reimage_role_instance.
service_name: Name of the hosted service.
deployment_name: The name of the deployment.
role_instance_names: List of role instance names.
'''
_validate_not_none('service_name', service_name)
_validate_not_none('deployment_name', deployment_name)
_validate_not_none('role_instance_names', role_instance_names)
return self._perform_post(
self._get_deployment_path_using_name(
service_name, deployment_name) + '/roleinstances/?comp=delete',
_XmlSerializer.role_instances_to_xml(role_instance_names),
async=True)
def check_hosted_service_name_availability(self, service_name): def check_hosted_service_name_availability(self, service_name):
''' '''
Checks to see if the specified hosted service name is available, or if Checks to see if the specified hosted service name is available, or if
@@ -980,6 +1079,53 @@ class ServiceManagementService(_ServiceManagementClient):
return self._perform_get('/' + self.subscription_id + '', return self._perform_get('/' + self.subscription_id + '',
Subscription) Subscription)
#--Operations for reserved ip addresses -----------------------------
def create_reserved_ip_address(self, name, label=None, location=None):
'''
Reserves an IPv4 address for the specified subscription.
name:
Required. Specifies the name for the reserved IP address.
label:
Optional. Specifies a label for the reserved IP address. The label
can be up to 100 characters long and can be used for your tracking
purposes.
location:
Required. Specifies the location of the reserved IP address. This
should be the same location that is assigned to the cloud service
containing the deployment that will use the reserved IP address.
To see the available locations, you can use list_locations.
'''
_validate_not_none('name', name)
return self._perform_post(
self._get_reserved_ip_path(),
_XmlSerializer.create_reserved_ip_to_xml(name, label, location))
def delete_reserved_ip_address(self, name):
'''
Deletes a reserved IP address from the specified subscription.
name: Required. Name of the reserved IP address.
'''
_validate_not_none('name', name)
return self._perform_delete(self._get_reserved_ip_path(name))
def get_reserved_ip_address(self, name):
'''
Retrieves information about the specified reserved IP address.
name: Required. Name of the reserved IP address.
'''
_validate_not_none('name', name)
return self._perform_get(self._get_reserved_ip_path(name), ReservedIP)
def list_reserved_ip_addresses(self):
'''
Lists the IP addresses that have been reserved for the specified
subscription.
'''
return self._perform_get(self._get_reserved_ip_path(), ReservedIPs)
#--Operations for virtual machines ----------------------------------- #--Operations for virtual machines -----------------------------------
def get_role(self, service_name, deployment_name, role_name): def get_role(self, service_name, deployment_name, role_name):
''' '''
@@ -1004,7 +1150,13 @@ class ServiceManagementService(_ServiceManagementClient):
data_virtual_hard_disks=None, data_virtual_hard_disks=None,
role_size=None, role_size=None,
role_type='PersistentVMRole', role_type='PersistentVMRole',
virtual_network_name=None): virtual_network_name=None,
resource_extension_references=None,
provision_guest_agent=None,
vm_image_name=None,
media_location=None,
dns_servers=None,
reserved_ip_name=None):
''' '''
Provisions a virtual machine based on the supplied configuration. Provisions a virtual machine based on the supplied configuration.
@@ -1025,7 +1177,8 @@ class ServiceManagementService(_ServiceManagementClient):
WindowsConfigurationSet or LinuxConfigurationSet. WindowsConfigurationSet or LinuxConfigurationSet.
os_virtual_hard_disk: os_virtual_hard_disk:
Contains the parameters Windows Azure uses to create the operating Contains the parameters Windows Azure uses to create the operating
system disk for the virtual machine. system disk for the virtual machine. If you are creating a Virtual
Machine by using a VM Image, this parameter is not used.
network_config: network_config:
Encapsulates the metadata required to create the virtual network Encapsulates the metadata required to create the virtual network
configuration for a virtual machine. If you do not include a configuration for a virtual machine. If you do not include a
@@ -1053,14 +1206,36 @@ class ServiceManagementService(_ServiceManagementClient):
virtual_network_name: virtual_network_name:
Specifies the name of an existing virtual network to which the Specifies the name of an existing virtual network to which the
deployment will belong. deployment will belong.
resource_extension_references:
Optional. Contains a collection of resource extensions that are to
be installed on the Virtual Machine. This element is used if
provision_guest_agent is set to True.
provision_guest_agent:
Optional. Indicates whether the VM Agent is installed on the
Virtual Machine. To run a resource extension in a Virtual Machine,
this service must be installed.
vm_image_name:
Optional. Specifies the name of the VM Image that is to be used to
create the Virtual Machine. If this is specified, the
system_config and network_config parameters are not used.
media_location:
Optional. Required if the Virtual Machine is being created from a
published VM Image. Specifies the location of the VHD file that is
created when VMImageName specifies a published VM Image.
dns_servers:
Optional. List of DNS servers (use DnsServer class) to associate
with the Virtual Machine.
reserved_ip_name:
Optional. Specifies the name of a reserved IP address that is to be
assigned to the deployment. You must run create_reserved_ip_address
before you can assign the address to the deployment using this
element.
''' '''
_validate_not_none('service_name', service_name) _validate_not_none('service_name', service_name)
_validate_not_none('deployment_name', deployment_name) _validate_not_none('deployment_name', deployment_name)
_validate_not_none('deployment_slot', deployment_slot) _validate_not_none('deployment_slot', deployment_slot)
_validate_not_none('label', label) _validate_not_none('label', label)
_validate_not_none('role_name', role_name) _validate_not_none('role_name', role_name)
_validate_not_none('system_config', system_config)
_validate_not_none('os_virtual_hard_disk', os_virtual_hard_disk)
return self._perform_post( return self._perform_post(
self._get_deployment_path_using_name(service_name), self._get_deployment_path_using_name(service_name),
_XmlSerializer.virtual_machine_deployment_to_xml( _XmlSerializer.virtual_machine_deployment_to_xml(
@@ -1075,13 +1250,22 @@ class ServiceManagementService(_ServiceManagementClient):
availability_set_name, availability_set_name,
data_virtual_hard_disks, data_virtual_hard_disks,
role_size, role_size,
virtual_network_name), virtual_network_name,
resource_extension_references,
provision_guest_agent,
vm_image_name,
media_location,
dns_servers,
reserved_ip_name),
async=True) async=True)
def add_role(self, service_name, deployment_name, role_name, system_config, def add_role(self, service_name, deployment_name, role_name, system_config,
os_virtual_hard_disk, network_config=None, os_virtual_hard_disk, network_config=None,
availability_set_name=None, data_virtual_hard_disks=None, availability_set_name=None, data_virtual_hard_disks=None,
role_size=None, role_type='PersistentVMRole'): role_size=None, role_type='PersistentVMRole',
resource_extension_references=None,
provision_guest_agent=None, vm_image_name=None,
media_location=None):
''' '''
Adds a virtual machine to an existing deployment. Adds a virtual machine to an existing deployment.
@@ -1094,7 +1278,8 @@ class ServiceManagementService(_ServiceManagementClient):
WindowsConfigurationSet or LinuxConfigurationSet. WindowsConfigurationSet or LinuxConfigurationSet.
os_virtual_hard_disk: os_virtual_hard_disk:
Contains the parameters Windows Azure uses to create the operating Contains the parameters Windows Azure uses to create the operating
system disk for the virtual machine. system disk for the virtual machine. If you are creating a Virtual
Machine by using a VM Image, this parameter is not used.
network_config: network_config:
Encapsulates the metadata required to create the virtual network Encapsulates the metadata required to create the virtual network
configuration for a virtual machine. If you do not include a configuration for a virtual machine. If you do not include a
@@ -1119,12 +1304,26 @@ class ServiceManagementService(_ServiceManagementClient):
role_type: role_type:
The type of the role for the virtual machine. The only supported The type of the role for the virtual machine. The only supported
value is PersistentVMRole. value is PersistentVMRole.
resource_extension_references:
Optional. Contains a collection of resource extensions that are to
be installed on the Virtual Machine. This element is used if
provision_guest_agent is set to True.
provision_guest_agent:
Optional. Indicates whether the VM Agent is installed on the
Virtual Machine. To run a resource extension in a Virtual Machine,
this service must be installed.
vm_image_name:
Optional. Specifies the name of the VM Image that is to be used to
create the Virtual Machine. If this is specified, the
system_config and network_config parameters are not used.
media_location:
Optional. Required if the Virtual Machine is being created from a
published VM Image. Specifies the location of the VHD file that is
created when VMImageName specifies a published VM Image.
''' '''
_validate_not_none('service_name', service_name) _validate_not_none('service_name', service_name)
_validate_not_none('deployment_name', deployment_name) _validate_not_none('deployment_name', deployment_name)
_validate_not_none('role_name', role_name) _validate_not_none('role_name', role_name)
_validate_not_none('system_config', system_config)
_validate_not_none('os_virtual_hard_disk', os_virtual_hard_disk)
return self._perform_post( return self._perform_post(
self._get_role_path(service_name, deployment_name), self._get_role_path(service_name, deployment_name),
_XmlSerializer.add_role_to_xml( _XmlSerializer.add_role_to_xml(
@@ -1135,13 +1334,19 @@ class ServiceManagementService(_ServiceManagementClient):
network_config, network_config,
availability_set_name, availability_set_name,
data_virtual_hard_disks, data_virtual_hard_disks,
role_size), role_size,
resource_extension_references,
provision_guest_agent,
vm_image_name,
media_location),
async=True) async=True)
def update_role(self, service_name, deployment_name, role_name, def update_role(self, service_name, deployment_name, role_name,
os_virtual_hard_disk=None, network_config=None, os_virtual_hard_disk=None, network_config=None,
availability_set_name=None, data_virtual_hard_disks=None, availability_set_name=None, data_virtual_hard_disks=None,
role_size=None, role_type='PersistentVMRole'): role_size=None, role_type='PersistentVMRole',
resource_extension_references=None,
provision_guest_agent=None):
''' '''
Updates the specified virtual machine. Updates the specified virtual machine.
@@ -1175,6 +1380,14 @@ class ServiceManagementService(_ServiceManagementClient):
role_type: role_type:
The type of the role for the virtual machine. The only supported The type of the role for the virtual machine. The only supported
value is PersistentVMRole. value is PersistentVMRole.
resource_extension_references:
Optional. Contains a collection of resource extensions that are to
be installed on the Virtual Machine. This element is used if
provision_guest_agent is set to True.
provision_guest_agent:
Optional. Indicates whether the VM Agent is installed on the
Virtual Machine. To run a resource extension in a Virtual Machine,
this service must be installed.
''' '''
_validate_not_none('service_name', service_name) _validate_not_none('service_name', service_name)
_validate_not_none('deployment_name', deployment_name) _validate_not_none('deployment_name', deployment_name)
@@ -1188,7 +1401,9 @@ class ServiceManagementService(_ServiceManagementClient):
network_config, network_config,
availability_set_name, availability_set_name,
data_virtual_hard_disks, data_virtual_hard_disks,
role_size), role_size,
resource_extension_references,
provision_guest_agent),
async=True) async=True)
def delete_role(self, service_name, deployment_name, role_name): def delete_role(self, service_name, deployment_name, role_name):
@@ -1354,7 +1569,307 @@ class ServiceManagementService(_ServiceManagementClient):
role_names, post_shutdown_action), role_names, post_shutdown_action),
async=True) async=True)
def add_dns_server(self, service_name, deployment_name, dns_server_name, address):
'''
Adds a DNS server definition to an existing deployment.
service_name: The name of the service.
deployment_name: The name of the deployment.
dns_server_name: Specifies the name of the DNS server.
address: Specifies the IP address of the DNS server.
'''
_validate_not_none('service_name', service_name)
_validate_not_none('deployment_name', deployment_name)
_validate_not_none('dns_server_name', dns_server_name)
_validate_not_none('address', address)
return self._perform_post(
self._get_dns_server_path(service_name, deployment_name),
_XmlSerializer.dns_server_to_xml(dns_server_name, address),
async=True)
def update_dns_server(self, service_name, deployment_name, dns_server_name, address):
'''
Updates the ip address of a DNS server.
service_name: The name of the service.
deployment_name: The name of the deployment.
dns_server_name: Specifies the name of the DNS server.
address: Specifies the IP address of the DNS server.
'''
_validate_not_none('service_name', service_name)
_validate_not_none('deployment_name', deployment_name)
_validate_not_none('dns_server_name', dns_server_name)
_validate_not_none('address', address)
return self._perform_put(
self._get_dns_server_path(service_name,
deployment_name,
dns_server_name),
_XmlSerializer.dns_server_to_xml(dns_server_name, address),
async=True)
def delete_dns_server(self, service_name, deployment_name, dns_server_name):
'''
Deletes a DNS server from a deployment.
service_name: The name of the service.
deployment_name: The name of the deployment.
dns_server_name: Name of the DNS server that you want to delete.
'''
_validate_not_none('service_name', service_name)
_validate_not_none('deployment_name', deployment_name)
_validate_not_none('dns_server_name', dns_server_name)
return self._perform_delete(
self._get_dns_server_path(service_name,
deployment_name,
dns_server_name),
async=True)
def list_resource_extensions(self):
'''
Lists the resource extensions that are available to add to a
Virtual Machine.
'''
return self._perform_get(self._get_resource_extensions_path(),
ResourceExtensions)
def list_resource_extension_versions(self, publisher_name, extension_name):
'''
Lists the versions of a resource extension that are available to add
to a Virtual Machine.
publisher_name: Name of the resource extension publisher.
extension_name: Name of the resource extension.
'''
return self._perform_get(self._get_resource_extension_versions_path(
publisher_name, extension_name),
ResourceExtensions)
#--Operations for virtual machine images ----------------------------- #--Operations for virtual machine images -----------------------------
def capture_vm_image(self, service_name, deployment_name, role_name, options):
'''
Creates a copy of the operating system virtual hard disk (VHD) and all
of the data VHDs that are associated with the Virtual Machine, saves
the VHD copies in the same storage location as the original VHDs, and
registers the copies as a VM Image in the image repository that is
associated with the specified subscription.
service_name: The name of the service.
deployment_name: The name of the deployment.
role_name: The name of the role.
options: An instance of CaptureRoleAsVMImage class.
options.os_state:
Required. Specifies the state of the operating system in the image.
Possible values are: Generalized, Specialized
A Virtual Machine that is fully configured and running contains a
Specialized operating system. A Virtual Machine on which the
Sysprep command has been run with the generalize option contains a
Generalized operating system. If you capture an image from a
generalized Virtual Machine, the machine is deleted after the image
is captured. It is recommended that all Virtual Machines are shut
down before capturing an image.
options.vm_image_name:
Required. Specifies the name of the VM Image.
options.vm_image_name:
Required. Specifies the label of the VM Image.
options.description:
Optional. Specifies the description of the VM Image.
options.language:
Optional. Specifies the language of the VM Image.
options.image_family:
Optional. Specifies a value that can be used to group VM Images.
options.recommended_vm_size:
Optional. Specifies the size to use for the Virtual Machine that
is created from the VM Image.
'''
_validate_not_none('service_name', service_name)
_validate_not_none('deployment_name', deployment_name)
_validate_not_none('role_name', role_name)
_validate_not_none('options', options)
_validate_not_none('options.os_state', options.os_state)
_validate_not_none('options.vm_image_name', options.vm_image_name)
_validate_not_none('options.vm_image_label', options.vm_image_label)
return self._perform_post(
self._get_capture_vm_image_path(service_name, deployment_name, role_name),
_XmlSerializer.capture_vm_image_to_xml(options),
async=True)
def create_vm_image(self, vm_image):
'''
Creates a VM Image in the image repository that is associated with the
specified subscription using a specified set of virtual hard disks.
vm_image: An instance of VMImage class.
vm_image.name: Required. Specifies the name of the image.
vm_image.label: Required. Specifies an identifier for the image.
vm_image.description: Optional. Specifies the description of the image.
vm_image.os_disk_configuration:
Required. Specifies configuration information for the operating
system disk that is associated with the image.
vm_image.os_disk_configuration.host_caching:
Optional. Specifies the caching behavior of the operating system disk.
Possible values are: None, ReadOnly, ReadWrite
vm_image.os_disk_configuration.os_state:
Required. Specifies the state of the operating system in the image.
Possible values are: Generalized, Specialized
A Virtual Machine that is fully configured and running contains a
Specialized operating system. A Virtual Machine on which the
Sysprep command has been run with the generalize option contains a
Generalized operating system.
vm_image.os_disk_configuration.os:
Required. Specifies the operating system type of the image.
vm_image.os_disk_configuration.media_link:
Required. Specifies the location of the blob in Windows Azure
storage. The blob location belongs to a storage account in the
subscription specified by the <subscription-id> value in the
operation call.
vm_image.data_disk_configurations:
Optional. Specifies configuration information for the data disks
that are associated with the image. A VM Image might not have data
disks associated with it.
vm_image.data_disk_configurations[].host_caching:
Optional. Specifies the caching behavior of the data disk.
Possible values are: None, ReadOnly, ReadWrite
vm_image.data_disk_configurations[].lun:
Optional if the lun for the disk is 0. Specifies the Logical Unit
Number (LUN) for the data disk.
vm_image.data_disk_configurations[].media_link:
Required. Specifies the location of the blob in Windows Azure
storage. The blob location belongs to a storage account in the
subscription specified by the <subscription-id> value in the
operation call.
vm_image.data_disk_configurations[].logical_size_in_gb:
Required. Specifies the size, in GB, of the data disk.
vm_image.language: Optional. Specifies the language of the image.
vm_image.image_family:
Optional. Specifies a value that can be used to group VM Images.
vm_image.recommended_vm_size:
Optional. Specifies the size to use for the Virtual Machine that
is created from the VM Image.
vm_image.eula:
Optional. Specifies the End User License Agreement that is
associated with the image. The value for this element is a string,
but it is recommended that the value be a URL that points to a EULA.
vm_image.icon_uri:
Optional. Specifies the URI to the icon that is displayed for the
image in the Management Portal.
vm_image.small_icon_uri:
Optional. Specifies the URI to the small icon that is displayed for
the image in the Management Portal.
vm_image.privacy_uri:
Optional. Specifies the URI that points to a document that contains
the privacy policy related to the image.
vm_image.published_date:
Optional. Specifies the date when the image was added to the image
repository.
vm_image.show_in_gui:
Optional. Indicates whether the VM Images should be listed in the
portal.
'''
_validate_not_none('vm_image', vm_image)
_validate_not_none('vm_image.name', vm_image.name)
_validate_not_none('vm_image.label', vm_image.label)
_validate_not_none('vm_image.os_disk_configuration.os_state',
vm_image.os_disk_configuration.os_state)
_validate_not_none('vm_image.os_disk_configuration.os',
vm_image.os_disk_configuration.os)
_validate_not_none('vm_image.os_disk_configuration.media_link',
vm_image.os_disk_configuration.media_link)
return self._perform_post(
self._get_vm_image_path(),
_XmlSerializer.create_vm_image_to_xml(vm_image),
async=True)
def delete_vm_image(self, vm_image_name, delete_vhd=False):
'''
Deletes the specified VM Image from the image repository that is
associated with the specified subscription.
vm_image_name: The name of the image.
delete_vhd: Deletes the underlying vhd blob in Azure storage.
'''
_validate_not_none('vm_image_name', vm_image_name)
path = self._get_vm_image_path(vm_image_name)
if delete_vhd:
path += '?comp=media'
return self._perform_delete(path, async=True)
def list_vm_images(self, location=None, publisher=None, category=None):
'''
Retrieves a list of the VM Images from the image repository that is
associated with the specified subscription.
'''
path = self._get_vm_image_path()
query = ''
if location:
query += '&location=' + location
if publisher:
query += '&publisher=' + publisher
if category:
query += '&category=' + category
if query:
path = path + '?' + query.lstrip('&')
return self._perform_get(path, VMImages)
def update_vm_image(self, vm_image_name, vm_image):
'''
Updates a VM Image in the image repository that is associated with the
specified subscription.
vm_image_name: Name of image to update.
vm_image: An instance of VMImage class.
vm_image.label: Optional. Specifies an identifier for the image.
vm_image.os_disk_configuration:
Required. Specifies configuration information for the operating
system disk that is associated with the image.
vm_image.os_disk_configuration.host_caching:
Optional. Specifies the caching behavior of the operating system disk.
Possible values are: None, ReadOnly, ReadWrite
vm_image.data_disk_configurations:
Optional. Specifies configuration information for the data disks
that are associated with the image. A VM Image might not have data
disks associated with it.
vm_image.data_disk_configurations[].name:
Required. Specifies the name of the data disk.
vm_image.data_disk_configurations[].host_caching:
Optional. Specifies the caching behavior of the data disk.
Possible values are: None, ReadOnly, ReadWrite
vm_image.data_disk_configurations[].lun:
Optional if the lun for the disk is 0. Specifies the Logical Unit
Number (LUN) for the data disk.
vm_image.description: Optional. Specifies the description of the image.
vm_image.language: Optional. Specifies the language of the image.
vm_image.image_family:
Optional. Specifies a value that can be used to group VM Images.
vm_image.recommended_vm_size:
Optional. Specifies the size to use for the Virtual Machine that
is created from the VM Image.
vm_image.eula:
Optional. Specifies the End User License Agreement that is
associated with the image. The value for this element is a string,
but it is recommended that the value be a URL that points to a EULA.
vm_image.icon_uri:
Optional. Specifies the URI to the icon that is displayed for the
image in the Management Portal.
vm_image.small_icon_uri:
Optional. Specifies the URI to the small icon that is displayed for
the image in the Management Portal.
vm_image.privacy_uri:
Optional. Specifies the URI that points to a document that contains
the privacy policy related to the image.
vm_image.published_date:
Optional. Specifies the date when the image was added to the image
repository.
vm_image.show_in_gui:
Optional. Indicates whether the VM Images should be listed in the
portal.
'''
_validate_not_none('vm_image_name', vm_image_name)
_validate_not_none('vm_image', vm_image)
return self._perform_put(self._get_vm_image_path(vm_image_name),
_XmlSerializer.update_vm_image_to_xml(vm_image),
async=True)
#--Operations for operating system images ----------------------------
def list_os_images(self): def list_os_images(self):
''' '''
Retrieves a list of the OS images from the image repository. Retrieves a list of the OS images from the image repository.
@@ -1707,6 +2222,12 @@ class ServiceManagementService(_ServiceManagementClient):
return self._perform_get(self._get_virtual_network_site_path(), VirtualNetworkSites) return self._perform_get(self._get_virtual_network_site_path(), VirtualNetworkSites)
#--Helper functions -------------------------------------------------- #--Helper functions --------------------------------------------------
def _get_role_sizes_path(self):
return self._get_path('rolesizes', None)
def _get_subscriptions_path(self):
return '/subscriptions'
def _get_virtual_network_site_path(self): def _get_virtual_network_site_path(self):
return self._get_path('services/networking/virtualnetwork', None) return self._get_path('services/networking/virtualnetwork', None)
@@ -1741,6 +2262,31 @@ class ServiceManagementService(_ServiceManagementClient):
'/deployments/' + deployment_name + '/deployments/' + deployment_name +
'/roles/Operations', None) '/roles/Operations', None)
def _get_resource_extensions_path(self):
return self._get_path('services/resourceextensions', None)
def _get_resource_extension_versions_path(self, publisher_name, extension_name):
return self._get_path('services/resourceextensions',
publisher_name + '/' + extension_name)
def _get_dns_server_path(self, service_name, deployment_name,
dns_server_name=None):
return self._get_path('services/hostedservices/' + _str(service_name) +
'/deployments/' + deployment_name +
'/dnsservers', dns_server_name)
def _get_capture_vm_image_path(self, service_name, deployment_name, role_name):
return self._get_path('services/hostedservices/' + _str(service_name) +
'/deployments/' + _str(deployment_name) +
'/roleinstances/' + _str(role_name) + '/Operations',
None)
def _get_vm_image_path(self, image_name=None):
return self._get_path('services/vmimages', image_name)
def _get_reserved_ip_path(self, name=None):
return self._get_path('services/networking/reservedips', name)
def _get_data_disk_path(self, service_name, deployment_name, role_name, def _get_data_disk_path(self, service_name, deployment_name, role_name,
lun=None): lun=None):
return self._get_path('services/hostedservices/' + _str(service_name) + return self._get_path('services/hostedservices/' + _str(service_name) +

View File

@@ -0,0 +1,390 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import (
MANAGEMENT_HOST,
_parse_service_resources_response,
_validate_not_none,
)
from azure.servicemanagement import (
EventLog,
ServerQuota,
Servers,
ServiceObjective,
Database,
FirewallRule,
_SqlManagementXmlSerializer,
)
from azure.servicemanagement.servicemanagementclient import (
_ServiceManagementClient,
)
class SqlDatabaseManagementService(_ServiceManagementClient):
''' Note that this class is a preliminary work on SQL Database
management. Since it lack a lot a features, final version
can be slightly different from the current one.
'''
def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST, request_session=None):
'''
Initializes the sql database management service.
subscription_id: Subscription to manage.
cert_file:
Path to .pem certificate file (httplib), or location of the
certificate in your Personal certificate store (winhttp) in the
CURRENT_USER\my\CertificateName format.
If a request_session is specified, then this is unused.
host: Live ServiceClient URL. Defaults to Azure public cloud.
request_session:
Session object to use for http requests. If this is specified, it
replaces the default use of httplib or winhttp. Also, the cert_file
parameter is unused when a session is passed in.
The session object handles authentication, and as such can support
multiple types of authentication: .pem certificate, oauth.
For example, you can pass in a Session instance from the requests
library. To use .pem certificate authentication with requests
library, set the path to the .pem file on the session.cert
attribute.
'''
super(SqlDatabaseManagementService, self).__init__(
subscription_id, cert_file, host, request_session)
self.content_type = 'application/xml'
#--Operations for sql servers ----------------------------------------
def create_server(self, admin_login, admin_password, location):
'''
Create a new Azure SQL Database server.
admin_login: The administrator login name for the new server.
admin_password: The administrator login password for the new server.
location: The region to deploy the new server.
'''
_validate_not_none('admin_login', admin_login)
_validate_not_none('admin_password', admin_password)
_validate_not_none('location', location)
response = self.perform_post(
self._get_servers_path(),
_SqlManagementXmlSerializer.create_server_to_xml(
admin_login,
admin_password,
location
)
)
return _SqlManagementXmlSerializer.xml_to_create_server_response(
response.body)
def set_server_admin_password(self, server_name, admin_password):
'''
Reset the administrator password for a server.
server_name: Name of the server to change the password.
admin_password: The new administrator password for the server.
'''
_validate_not_none('server_name', server_name)
_validate_not_none('admin_password', admin_password)
return self._perform_post(
self._get_servers_path(server_name) + '?op=ResetPassword',
_SqlManagementXmlSerializer.set_server_admin_password_to_xml(
admin_password
)
)
def delete_server(self, server_name):
'''
Deletes an Azure SQL Database server (including all its databases).
server_name: Name of the server you want to delete.
'''
_validate_not_none('server_name', server_name)
return self._perform_delete(
self._get_servers_path(server_name))
def list_servers(self):
'''
List the SQL servers defined on the account.
'''
return self._perform_get(self._get_servers_path(),
Servers)
def list_quotas(self, server_name):
'''
Gets quotas for an Azure SQL Database Server.
server_name: Name of the server.
'''
_validate_not_none('server_name', server_name)
response = self._perform_get(self._get_quotas_path(server_name),
None)
return _parse_service_resources_response(response, ServerQuota)
def get_server_event_logs(self, server_name, start_date,
interval_size_in_minutes, event_types=''):
'''
Gets the event logs for an Azure SQL Database Server.
server_name: Name of the server to retrieve the event logs from.
start_date:
The starting date and time of the events to retrieve in UTC format,
for example '2011-09-28 16:05:00'.
interval_size_in_minutes:
Size of the event logs to retrieve (in minutes).
Valid values are: 5, 60, or 1440.
event_types:
The event type of the log entries you want to retrieve.
Valid values are:
- connection_successful
- connection_failed
- connection_terminated
- deadlock
- throttling
- throttling_long_transaction
To return all event types pass in an empty string.
'''
_validate_not_none('server_name', server_name)
_validate_not_none('start_date', start_date)
_validate_not_none('interval_size_in_minutes', interval_size_in_minutes)
_validate_not_none('event_types', event_types)
path = self._get_server_event_logs_path(server_name) + \
'?startDate={0}&intervalSizeInMinutes={1}&eventTypes={2}'.format(
start_date, interval_size_in_minutes, event_types)
response = self._perform_get(path, None)
return _parse_service_resources_response(response, EventLog)
#--Operations for firewall rules ------------------------------------------
def create_firewall_rule(self, server_name, name, start_ip_address,
end_ip_address):
'''
Creates an Azure SQL Database server firewall rule.
server_name: Name of the server to set the firewall rule on.
name: The name of the new firewall rule.
start_ip_address:
The lowest IP address in the range of the server-level firewall
setting. IP addresses equal to or greater than this can attempt to
connect to the server. The lowest possible IP address is 0.0.0.0.
end_ip_address:
The highest IP address in the range of the server-level firewall
setting. IP addresses equal to or less than this can attempt to
connect to the server. The highest possible IP address is
255.255.255.255.
'''
_validate_not_none('server_name', server_name)
_validate_not_none('name', name)
_validate_not_none('start_ip_address', start_ip_address)
_validate_not_none('end_ip_address', end_ip_address)
return self._perform_post(
self._get_firewall_rules_path(server_name),
_SqlManagementXmlSerializer.create_firewall_rule_to_xml(
name, start_ip_address, end_ip_address
)
)
def update_firewall_rule(self, server_name, name, start_ip_address,
end_ip_address):
'''
Update a firewall rule for an Azure SQL Database server.
server_name: Name of the server to set the firewall rule on.
name: The name of the firewall rule to update.
start_ip_address:
The lowest IP address in the range of the server-level firewall
setting. IP addresses equal to or greater than this can attempt to
connect to the server. The lowest possible IP address is 0.0.0.0.
end_ip_address:
The highest IP address in the range of the server-level firewall
setting. IP addresses equal to or less than this can attempt to
connect to the server. The highest possible IP address is
255.255.255.255.
'''
_validate_not_none('server_name', server_name)
_validate_not_none('name', name)
_validate_not_none('start_ip_address', start_ip_address)
_validate_not_none('end_ip_address', end_ip_address)
return self._perform_put(
self._get_firewall_rules_path(server_name, name),
_SqlManagementXmlSerializer.update_firewall_rule_to_xml(
name, start_ip_address, end_ip_address
)
)
def delete_firewall_rule(self, server_name, name):
'''
Deletes an Azure SQL Database server firewall rule.
server_name:
Name of the server with the firewall rule you want to delete.
name:
Name of the firewall rule you want to delete.
'''
_validate_not_none('server_name', server_name)
_validate_not_none('name', name)
return self._perform_delete(
self._get_firewall_rules_path(server_name, name))
def list_firewall_rules(self, server_name):
'''
Retrieves the set of firewall rules for an Azure SQL Database Server.
server_name: Name of the server.
'''
_validate_not_none('server_name', server_name)
response = self._perform_get(self._get_firewall_rules_path(server_name),
None)
return _parse_service_resources_response(response, FirewallRule)
def list_service_level_objectives(self, server_name):
'''
Gets the service level objectives for an Azure SQL Database server.
server_name: Name of the server.
'''
_validate_not_none('server_name', server_name)
response = self._perform_get(
self._get_service_objectives_path(server_name), None)
return _parse_service_resources_response(response, ServiceObjective)
#--Operations for sql databases ----------------------------------------
def create_database(self, server_name, name, service_objective_id,
edition=None, collation_name=None,
max_size_bytes=None):
'''
Creates a new Azure SQL Database.
server_name: Name of the server to contain the new database.
name:
Required. The name for the new database. See Naming Requirements
in Azure SQL Database General Guidelines and Limitations and
Database Identifiers for more information.
service_objective_id:
Required. The GUID corresponding to the performance level for
Edition. See List Service Level Objectives for current values.
edition:
Optional. The Service Tier (Edition) for the new database. If
omitted, the default is Web. Valid values are Web, Business,
Basic, Standard, and Premium. See Azure SQL Database Service Tiers
(Editions) and Web and Business Edition Sunset FAQ for more
information.
collation_name:
Optional. The database collation. This can be any collation
supported by SQL. If omitted, the default collation is used. See
SQL Server Collation Support in Azure SQL Database General
Guidelines and Limitations for more information.
max_size_bytes:
Optional. Sets the maximum size, in bytes, for the database. This
value must be within the range of allowed values for Edition. If
omitted, the default value for the edition is used. See Azure SQL
Database Service Tiers (Editions) for current maximum databases
sizes. Convert MB or GB values to bytes.
1 MB = 1048576 bytes. 1 GB = 1073741824 bytes.
'''
_validate_not_none('server_name', server_name)
_validate_not_none('name', name)
_validate_not_none('service_objective_id', service_objective_id)
return self._perform_post(
self._get_databases_path(server_name),
_SqlManagementXmlSerializer.create_database_to_xml(
name, service_objective_id, edition, collation_name,
max_size_bytes
)
)
def update_database(self, server_name, name, new_database_name=None,
service_objective_id=None, edition=None,
max_size_bytes=None):
'''
Updates existing database details.
server_name: Name of the server to contain the new database.
name:
Required. The name for the new database. See Naming Requirements
in Azure SQL Database General Guidelines and Limitations and
Database Identifiers for more information.
new_database_name:
Optional. The new name for the new database.
service_objective_id:
Optional. The new service level to apply to the database. For more
information about service levels, see Azure SQL Database Service
Tiers and Performance Levels. Use List Service Level Objectives to
get the correct ID for the desired service objective.
edition:
Optional. The new edition for the new database.
max_size_bytes:
Optional. The new size of the database in bytes. For information on
available sizes for each edition, see Azure SQL Database Service
Tiers (Editions).
'''
_validate_not_none('server_name', server_name)
_validate_not_none('name', name)
return self._perform_put(
self._get_databases_path(server_name, name),
_SqlManagementXmlSerializer.update_database_to_xml(
new_database_name, service_objective_id, edition,
max_size_bytes
)
)
def delete_database(self, server_name, name):
'''
Deletes an Azure SQL Database.
server_name: Name of the server where the database is located.
name: Name of the database to delete.
'''
return self._perform_delete(self._get_databases_path(server_name, name))
def list_databases(self, name):
'''
List the SQL databases defined on the specified server name
'''
response = self._perform_get(self._get_list_databases_path(name),
None)
return _parse_service_resources_response(response, Database)
#--Helper functions --------------------------------------------------
def _get_servers_path(self, server_name=None):
return self._get_path('services/sqlservers/servers', server_name)
def _get_firewall_rules_path(self, server_name, name=None):
path = self._get_servers_path(server_name) + '/firewallrules'
if name:
path = path + '/' + name
return path
def _get_databases_path(self, server_name, name=None):
path = self._get_servers_path(server_name) + '/databases'
if name:
path = path + '/' + name
return path
def _get_server_event_logs_path(self, server_name):
return self._get_servers_path(server_name) + '/events'
def _get_service_objectives_path(self, server_name):
return self._get_servers_path(server_name) + '/serviceobjectives'
def _get_quotas_path(self, server_name, name=None):
path = self._get_servers_path(server_name) + '/serverquotas'
if name:
path = path + '/' + name
return path
def _get_list_databases_path(self, name):
# *contentview=generic is mandatory*
return self._get_path('services/sqlservers/servers/',
name) + '/databases?contentview=generic'

View File

@@ -0,0 +1,256 @@
#-------------------------------------------------------------------------
# Copyright (c) Microsoft. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#--------------------------------------------------------------------------
from azure import (
MANAGEMENT_HOST,
_str,
)
from azure.servicemanagement import (
WebSpaces,
WebSpace,
Sites,
Site,
MetricResponses,
MetricDefinitions,
PublishData,
_XmlSerializer,
)
from azure.servicemanagement.servicemanagementclient import (
_ServiceManagementClient,
)
class WebsiteManagementService(_ServiceManagementClient):
''' Note that this class is a preliminary work on WebSite
management. Since it lack a lot a features, final version
can be slightly different from the current one.
'''
def __init__(self, subscription_id=None, cert_file=None,
host=MANAGEMENT_HOST, request_session=None):
'''
Initializes the website management service.
subscription_id: Subscription to manage.
cert_file:
Path to .pem certificate file (httplib), or location of the
certificate in your Personal certificate store (winhttp) in the
CURRENT_USER\my\CertificateName format.
If a request_session is specified, then this is unused.
host: Live ServiceClient URL. Defaults to Azure public cloud.
request_session:
Session object to use for http requests. If this is specified, it
replaces the default use of httplib or winhttp. Also, the cert_file
parameter is unused when a session is passed in.
The session object handles authentication, and as such can support
multiple types of authentication: .pem certificate, oauth.
For example, you can pass in a Session instance from the requests
library. To use .pem certificate authentication with requests
library, set the path to the .pem file on the session.cert
attribute.
'''
super(WebsiteManagementService, self).__init__(
subscription_id, cert_file, host, request_session)
#--Operations for web sites ----------------------------------------
def list_webspaces(self):
'''
List the webspaces defined on the account.
'''
return self._perform_get(self._get_list_webspaces_path(),
WebSpaces)
def get_webspace(self, webspace_name):
'''
Get details of a specific webspace.
webspace_name: The name of the webspace.
'''
return self._perform_get(self._get_webspace_details_path(webspace_name),
WebSpace)
def list_sites(self, webspace_name):
'''
List the web sites defined on this webspace.
webspace_name: The name of the webspace.
'''
return self._perform_get(self._get_sites_path(webspace_name),
Sites)
def get_site(self, webspace_name, website_name):
'''
List the web sites defined on this webspace.
webspace_name: The name of the webspace.
website_name: The name of the website.
'''
return self._perform_get(self._get_sites_details_path(webspace_name,
website_name),
Site)
def create_site(self, webspace_name, website_name, geo_region, host_names,
plan='VirtualDedicatedPlan', compute_mode='Shared',
server_farm=None, site_mode=None):
'''
Create a website.
webspace_name: The name of the webspace.
website_name: The name of the website.
geo_region:
The geographical region of the webspace that will be created.
host_names:
An array of fully qualified domain names for website. Only one
hostname can be specified in the azurewebsites.net domain.
The hostname should match the name of the website. Custom domains
can only be specified for Shared or Standard websites.
plan:
This value must be 'VirtualDedicatedPlan'.
compute_mode:
This value should be 'Shared' for the Free or Paid Shared
offerings, or 'Dedicated' for the Standard offering. The default
value is 'Shared'. If you set it to 'Dedicated', you must specify
a value for the server_farm parameter.
server_farm:
The name of the Server Farm associated with this website. This is
a required value for Standard mode.
site_mode:
Can be None, 'Limited' or 'Basic'. This value is 'Limited' for the
Free offering, and 'Basic' for the Paid Shared offering. Standard
mode does not use the site_mode parameter; it uses the compute_mode
parameter.
'''
xml = _XmlSerializer.create_website_to_xml(webspace_name, website_name, geo_region, plan, host_names, compute_mode, server_farm, site_mode)
return self._perform_post(
self._get_sites_path(webspace_name),
xml,
Site)
def delete_site(self, webspace_name, website_name,
delete_empty_server_farm=False, delete_metrics=False):
'''
Delete a website.
webspace_name: The name of the webspace.
website_name: The name of the website.
delete_empty_server_farm:
If the site being deleted is the last web site in a server farm,
you can delete the server farm by setting this to True.
delete_metrics:
To also delete the metrics for the site that you are deleting, you
can set this to True.
'''
path = self._get_sites_details_path(webspace_name, website_name)
query = ''
if delete_empty_server_farm:
query += '&deleteEmptyServerFarm=true'
if delete_metrics:
query += '&deleteMetrics=true'
if query:
path = path + '?' + query.lstrip('&')
return self._perform_delete(path)
def restart_site(self, webspace_name, website_name):
'''
Restart a web site.
webspace_name: The name of the webspace.
website_name: The name of the website.
'''
return self._perform_post(
self._get_restart_path(webspace_name, website_name),
'')
def get_historical_usage_metrics(self, webspace_name, website_name,
metrics = None, start_time=None, end_time=None, time_grain=None):
'''
Get historical usage metrics.
webspace_name: The name of the webspace.
website_name: The name of the website.
metrics: Optional. List of metrics name. Otherwise, all metrics returned.
start_time: Optional. An ISO8601 date. Otherwise, current hour is used.
end_time: Optional. An ISO8601 date. Otherwise, current time is used.
time_grain: Optional. A rollup name, as P1D. OTherwise, default rollup for the metrics is used.
More information and metrics name at:
http://msdn.microsoft.com/en-us/library/azure/dn166964.aspx
'''
metrics = ('names='+','.join(metrics)) if metrics else ''
start_time = ('StartTime='+start_time) if start_time else ''
end_time = ('EndTime='+end_time) if end_time else ''
time_grain = ('TimeGrain='+time_grain) if time_grain else ''
parameters = ('&'.join(v for v in (metrics, start_time, end_time, time_grain) if v))
parameters = '?'+parameters if parameters else ''
return self._perform_get(self._get_historical_usage_metrics_path(webspace_name, website_name) + parameters,
MetricResponses)
def get_metric_definitions(self, webspace_name, website_name):
'''
Get metric definitions of metrics available of this web site.
webspace_name: The name of the webspace.
website_name: The name of the website.
'''
return self._perform_get(self._get_metric_definitions_path(webspace_name, website_name),
MetricDefinitions)
def get_publish_profile_xml(self, webspace_name, website_name):
'''
Get a site's publish profile as a string
webspace_name: The name of the webspace.
website_name: The name of the website.
'''
return self._perform_get(self._get_publishxml_path(webspace_name, website_name),
None).body.decode("utf-8")
def get_publish_profile(self, webspace_name, website_name):
'''
Get a site's publish profile as an object
webspace_name: The name of the webspace.
website_name: The name of the website.
'''
return self._perform_get(self._get_publishxml_path(webspace_name, website_name),
PublishData)
#--Helper functions --------------------------------------------------
def _get_list_webspaces_path(self):
return self._get_path('services/webspaces', None)
def _get_webspace_details_path(self, webspace_name):
return self._get_path('services/webspaces/', webspace_name)
def _get_sites_path(self, webspace_name):
return self._get_path('services/webspaces/',
webspace_name) + '/sites'
def _get_sites_details_path(self, webspace_name, website_name):
return self._get_path('services/webspaces/',
webspace_name) + '/sites/' + _str(website_name)
def _get_restart_path(self, webspace_name, website_name):
return self._get_path('services/webspaces/',
webspace_name) + '/sites/' + _str(website_name) + '/restart/'
def _get_historical_usage_metrics_path(self, webspace_name, website_name):
return self._get_path('services/webspaces/',
webspace_name) + '/sites/' + _str(website_name) + '/metrics/'
def _get_metric_definitions_path(self, webspace_name, website_name):
return self._get_path('services/webspaces/',
webspace_name) + '/sites/' + _str(website_name) + '/metricdefinitions/'
def _get_publishxml_path(self, webspace_name, website_name):
return self._get_path('services/webspaces/',
webspace_name) + '/sites/' + _str(website_name) + '/publishxml/'

View File

@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#-------------------------------------------------------------------------- #--------------------------------------------------------------------------
import hashlib
import hmac
import sys import sys
import types import types
from datetime import datetime from datetime import datetime
from dateutil import parser
from dateutil.tz import tzutc
from xml.dom import minidom from xml.dom import minidom
from azure import (WindowsAzureData, from azure import (WindowsAzureData,
WindowsAzureError, WindowsAzureError,
@@ -36,6 +36,7 @@ from azure import (WindowsAzureData,
_general_error_handler, _general_error_handler,
_list_of, _list_of,
_parse_response_for_dict, _parse_response_for_dict,
_sign_string,
_unicode_type, _unicode_type,
_ERROR_CANNOT_SERIALIZE_VALUE_TO_ENTITY, _ERROR_CANNOT_SERIALIZE_VALUE_TO_ENTITY,
) )
@@ -408,7 +409,7 @@ def _update_storage_header(request):
if request.body: if request.body:
assert isinstance(request.body, bytes) assert isinstance(request.body, bytes)
# if it is PUT, POST, MERGE, DELETE, need to add content-lengt to header. # if it is PUT, POST, MERGE, DELETE, need to add content-length to header.
if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']: if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']:
request.headers.append(('Content-Length', str(len(request.body)))) request.headers.append(('Content-Length', str(len(request.body))))
@@ -540,17 +541,6 @@ def _sign_storage_table_request(request, account_name, account_key):
return auth_string return auth_string
def _sign_string(account_key, string_to_sign):
decoded_account_key = _decode_base64_to_bytes(account_key)
if isinstance(string_to_sign, _unicode_type):
string_to_sign = string_to_sign.encode('utf-8')
signed_hmac_sha256 = hmac.HMAC(
decoded_account_key, string_to_sign, hashlib.sha256)
digest = signed_hmac_sha256.digest()
encoded_digest = _encode_base64(digest)
return encoded_digest
def _to_python_bool(value): def _to_python_bool(value):
if value.lower() == 'true': if value.lower() == 'true':
return True return True
@@ -572,7 +562,12 @@ def _to_entity_bool(value):
def _to_entity_datetime(value): def _to_entity_datetime(value):
return 'Edm.DateTime', value.strftime('%Y-%m-%dT%H:%M:%S') # Azure expects the date value passed in to be UTC.
# Azure will always return values as UTC.
# If a date is passed in without timezone info, it is assumed to be UTC.
if value.tzinfo:
value = value.astimezone(tzutc())
return 'Edm.DateTime', value.strftime('%Y-%m-%dT%H:%M:%SZ')
def _to_entity_float(value): def _to_entity_float(value):
@@ -607,12 +602,9 @@ def _from_entity_int(value):
def _from_entity_datetime(value): def _from_entity_datetime(value):
format = '%Y-%m-%dT%H:%M:%S' # Note that Azure always returns UTC datetime, and dateutil parser
if '.' in value: # will set the tzinfo on the date it returns
format = format + '.%f' return parser.parse(value)
if value.endswith('Z'):
format = format + 'Z'
return datetime.strptime(value, format)
_ENTITY_TO_PYTHON_CONVERSIONS = { _ENTITY_TO_PYTHON_CONVERSIONS = {
'Edm.Binary': _from_entity_binary, 'Edm.Binary': _from_entity_binary,
@@ -705,7 +697,7 @@ def _convert_entity_to_xml(source):
if sys.version_info < (3,): if sys.version_info < (3,):
if isinstance(properties_str, unicode): if isinstance(properties_str, unicode):
properties_str = properties_str.encode(encoding='utf-8') properties_str = properties_str.encode('utf-8')
# generate the entity_body # generate the entity_body
entity_body = entity_body.format(properties=properties_str) entity_body = entity_body.format(properties=properties_str)
@@ -835,10 +827,6 @@ def _convert_xml_to_entity(xmlstr):
# extract each property node and get the type from attribute and node value # extract each property node and get the type from attribute and node value
for xml_property in xml_properties[0].childNodes: for xml_property in xml_properties[0].childNodes:
name = _remove_prefix(xml_property.nodeName) name = _remove_prefix(xml_property.nodeName)
# exclude the Timestamp since it is auto added by azure when
# inserting entity. We don't want this to mix with real properties
if name in ['Timestamp']:
continue
if xml_property.firstChild: if xml_property.firstChild:
value = xml_property.firstChild.nodeValue value = xml_property.firstChild.nodeValue

View File

@@ -62,6 +62,7 @@ else:
# Keep this value sync with _ERROR_PAGE_BLOB_SIZE_ALIGNMENT # Keep this value sync with _ERROR_PAGE_BLOB_SIZE_ALIGNMENT
_PAGE_SIZE = 512 _PAGE_SIZE = 512
class BlobService(_StorageClient): class BlobService(_StorageClient):
''' '''
@@ -101,6 +102,7 @@ class BlobService(_StorageClient):
Live host base url. If not specified, uses the host base specified Live host base url. If not specified, uses the host base specified
when BlobService was initialized. when BlobService was initialized.
''' '''
if not account_name: if not account_name:
account_name = self.account_name account_name = self.account_name
if not protocol: if not protocol:
@@ -553,6 +555,7 @@ class BlobService(_StorageClient):
request, self.use_local_storage) request, self.use_local_storage)
request.headers = _update_storage_blob_header( request.headers = _update_storage_blob_header(
request, self.account_name, self.account_key) request, self.account_name, self.account_key)
response = self._perform_request(request) response = self._perform_request(request)
return _parse_response_for_dict(response) return _parse_response_for_dict(response)
@@ -740,26 +743,6 @@ class BlobService(_StorageClient):
_validate_not_none('container_name', container_name) _validate_not_none('container_name', container_name)
_validate_not_none('blob_name', blob_name) _validate_not_none('blob_name', blob_name)
_validate_not_none('file_path', file_path) _validate_not_none('file_path', file_path)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + _str(container_name) + '/' + _str(blob_name) + ''
request.headers = [
('x-ms-blob-type', 'BlockBlob'),
('Content-Encoding', _str_or_none(content_encoding)),
('Content-Language', _str_or_none(content_language)),
('Content-MD5', _str_or_none(content_md5)),
('Cache-Control', _str_or_none(cache_control)),
('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)),
('x-ms-blob-content-encoding',
_str_or_none(x_ms_blob_content_encoding)),
('x-ms-blob-content-language',
_str_or_none(x_ms_blob_content_language)),
('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)),
('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)),
('x-ms-meta-name-values', x_ms_meta_name_values),
('x-ms-lease-id', _str_or_none(x_ms_lease_id)),
]
count = path.getsize(file_path) count = path.getsize(file_path)
with open(file_path, 'rb') as stream: with open(file_path, 'rb') as stream:
@@ -833,26 +816,6 @@ class BlobService(_StorageClient):
_validate_not_none('container_name', container_name) _validate_not_none('container_name', container_name)
_validate_not_none('blob_name', blob_name) _validate_not_none('blob_name', blob_name)
_validate_not_none('stream', stream) _validate_not_none('stream', stream)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + _str(container_name) + '/' + _str(blob_name) + ''
request.headers = [
('x-ms-blob-type', 'BlockBlob'),
('Content-Encoding', _str_or_none(content_encoding)),
('Content-Language', _str_or_none(content_language)),
('Content-MD5', _str_or_none(content_md5)),
('Cache-Control', _str_or_none(cache_control)),
('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)),
('x-ms-blob-content-encoding',
_str_or_none(x_ms_blob_content_encoding)),
('x-ms-blob-content-language',
_str_or_none(x_ms_blob_content_language)),
('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)),
('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)),
('x-ms-meta-name-values', x_ms_meta_name_values),
('x-ms-lease-id', _str_or_none(x_ms_lease_id)),
]
if count and count < self._BLOB_MAX_DATA_SIZE: if count and count < self._BLOB_MAX_DATA_SIZE:
if progress_callback: if progress_callback:
@@ -922,7 +885,14 @@ class BlobService(_StorageClient):
else: else:
break break
self.put_block_list(container_name, blob_name, block_ids) self.put_block_list(container_name, blob_name, block_ids,
content_md5, x_ms_blob_cache_control,
x_ms_blob_content_type,
x_ms_blob_content_encoding,
x_ms_blob_content_language,
x_ms_blob_content_md5,
x_ms_meta_name_values,
x_ms_lease_id)
def put_block_blob_from_bytes(self, container_name, blob_name, blob, def put_block_blob_from_bytes(self, container_name, blob_name, blob,
index=0, count=None, content_encoding=None, index=0, count=None, content_encoding=None,
@@ -980,26 +950,6 @@ class BlobService(_StorageClient):
_validate_not_none('blob', blob) _validate_not_none('blob', blob)
_validate_not_none('index', index) _validate_not_none('index', index)
_validate_type_bytes('blob', blob) _validate_type_bytes('blob', blob)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + _str(container_name) + '/' + _str(blob_name) + ''
request.headers = [
('x-ms-blob-type', 'BlockBlob'),
('Content-Encoding', _str_or_none(content_encoding)),
('Content-Language', _str_or_none(content_language)),
('Content-MD5', _str_or_none(content_md5)),
('Cache-Control', _str_or_none(cache_control)),
('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)),
('x-ms-blob-content-encoding',
_str_or_none(x_ms_blob_content_encoding)),
('x-ms-blob-content-language',
_str_or_none(x_ms_blob_content_language)),
('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)),
('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)),
('x-ms-meta-name-values', x_ms_meta_name_values),
('x-ms-lease-id', _str_or_none(x_ms_lease_id)),
]
if index < 0: if index < 0:
raise TypeError(_ERROR_VALUE_NEGATIVE.format('index')) raise TypeError(_ERROR_VALUE_NEGATIVE.format('index'))
@@ -1101,26 +1051,6 @@ class BlobService(_StorageClient):
_validate_not_none('container_name', container_name) _validate_not_none('container_name', container_name)
_validate_not_none('blob_name', blob_name) _validate_not_none('blob_name', blob_name)
_validate_not_none('text', text) _validate_not_none('text', text)
request = HTTPRequest()
request.method = 'PUT'
request.host = self._get_host()
request.path = '/' + _str(container_name) + '/' + _str(blob_name) + ''
request.headers = [
('x-ms-blob-type', 'BlockBlob'),
('Content-Encoding', _str_or_none(content_encoding)),
('Content-Language', _str_or_none(content_language)),
('Content-MD5', _str_or_none(content_md5)),
('Cache-Control', _str_or_none(cache_control)),
('x-ms-blob-content-type', _str_or_none(x_ms_blob_content_type)),
('x-ms-blob-content-encoding',
_str_or_none(x_ms_blob_content_encoding)),
('x-ms-blob-content-language',
_str_or_none(x_ms_blob_content_language)),
('x-ms-blob-content-md5', _str_or_none(x_ms_blob_content_md5)),
('x-ms-blob-cache-control', _str_or_none(x_ms_blob_cache_control)),
('x-ms-meta-name-values', x_ms_meta_name_values),
('x-ms-lease-id', _str_or_none(x_ms_lease_id)),
]
if not isinstance(text, bytes): if not isinstance(text, bytes):
_validate_not_none('text_encoding', text_encoding) _validate_not_none('text_encoding', text_encoding)
@@ -1541,7 +1471,7 @@ class BlobService(_StorageClient):
index = 0 index = 0
while index < blob_size: while index < blob_size:
chunk_range = 'bytes={}-{}'.format( chunk_range = 'bytes={0}-{1}'.format(
index, index,
index + self._BLOB_MAX_CHUNK_DATA_SIZE - 1) index + self._BLOB_MAX_CHUNK_DATA_SIZE - 1)
data = self.get_blob( data = self.get_blob(
@@ -1862,6 +1792,7 @@ class BlobService(_StorageClient):
('x-ms-lease-id', _str_or_none(x_ms_lease_id)), ('x-ms-lease-id', _str_or_none(x_ms_lease_id)),
('x-ms-source-lease-id', _str_or_none(x_ms_source_lease_id)) ('x-ms-source-lease-id', _str_or_none(x_ms_source_lease_id))
] ]
request.path, request.query = _update_request_uri_query_local_storage( request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage) request, self.use_local_storage)
request.headers = _update_storage_blob_header( request.headers = _update_storage_blob_header(
@@ -1904,7 +1835,8 @@ class BlobService(_StorageClient):
self._perform_request(request) self._perform_request(request)
def delete_blob(self, container_name, blob_name, snapshot=None, def delete_blob(self, container_name, blob_name, snapshot=None,
x_ms_lease_id=None): timeout=None, x_ms_lease_id=None,
x_ms_delete_snapshots=None):
''' '''
Marks the specified blob or snapshot for deletion. The blob is later Marks the specified blob or snapshot for deletion. The blob is later
deleted during garbage collection. deleted during garbage collection.
@@ -1917,7 +1849,22 @@ class BlobService(_StorageClient):
snapshot: snapshot:
Optional. The snapshot parameter is an opaque DateTime value that, Optional. The snapshot parameter is an opaque DateTime value that,
when present, specifies the blob snapshot to delete. when present, specifies the blob snapshot to delete.
timeout:
Optional. The timeout parameter is expressed in seconds.
The Blob service returns an error when the timeout interval elapses
while processing the request.
x_ms_lease_id: Required if the blob has an active lease. x_ms_lease_id: Required if the blob has an active lease.
x_ms_delete_snapshots:
Required if the blob has associated snapshots. Specify one of the
following two options:
include: Delete the base blob and all of its snapshots.
only: Delete only the blob's snapshots and not the blob itself.
This header should be specified only for a request against the base
blob resource. If this header is specified on a request to delete
an individual snapshot, the Blob service returns status code 400
(Bad Request). If this header is not specified on the request and
the blob has associated snapshots, the Blob service returns status
code 409 (Conflict).
''' '''
_validate_not_none('container_name', container_name) _validate_not_none('container_name', container_name)
_validate_not_none('blob_name', blob_name) _validate_not_none('blob_name', blob_name)
@@ -1925,8 +1872,14 @@ class BlobService(_StorageClient):
request.method = 'DELETE' request.method = 'DELETE'
request.host = self._get_host() request.host = self._get_host()
request.path = '/' + _str(container_name) + '/' + _str(blob_name) + '' request.path = '/' + _str(container_name) + '/' + _str(blob_name) + ''
request.headers = [('x-ms-lease-id', _str_or_none(x_ms_lease_id))] request.headers = [
request.query = [('snapshot', _str_or_none(snapshot))] ('x-ms-lease-id', _str_or_none(x_ms_lease_id)),
('x-ms-delete-snapshots', _str_or_none(x_ms_delete_snapshots))
]
request.query = [
('snapshot', _str_or_none(snapshot)),
('timeout', _int_or_none(timeout))
]
request.path, request.query = _update_request_uri_query_local_storage( request.path, request.query = _update_request_uri_query_local_storage(
request, self.use_local_storage) request, self.use_local_storage)
request.headers = _update_storage_blob_header( request.headers = _update_storage_blob_header(

View File

@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#-------------------------------------------------------------------------- #--------------------------------------------------------------------------
from azure import url_quote from azure import _sign_string, url_quote
from azure.storage import _sign_string, X_MS_VERSION from azure.storage import X_MS_VERSION
#------------------------------------------------------------------------- #-------------------------------------------------------------------------
# Constants for the share access signature # Constants for the share access signature
SIGNED_VERSION = 'sv'
SIGNED_START = 'st' SIGNED_START = 'st'
SIGNED_EXPIRY = 'se' SIGNED_EXPIRY = 'se'
SIGNED_RESOURCE = 'sr' SIGNED_RESOURCE = 'sr'

View File

@@ -1,10 +1,2 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" __version__ = "2.4.0"
Copyright (c) 2003-2010 Gustavo Niemeyer <gustavo@niemeyer.net>
This module offers extensions to the standard Python
datetime module.
"""
__author__ = "Tomi Pieviläinen <tomi.pievilainen@iki.fi>"
__license__ = "Simplified BSD"
__version__ = "2.2"

View File

@@ -1,18 +1,17 @@
# -*- coding: utf-8 -*-
""" """
Copyright (c) 2003-2007 Gustavo Niemeyer <gustavo@niemeyer.net> This module offers a generic easter computing method for any given year, using
Western, Orthodox or Julian algorithms.
This module offers extensions to the standard Python
datetime module.
""" """
__license__ = "Simplified BSD"
import datetime import datetime
__all__ = ["easter", "EASTER_JULIAN", "EASTER_ORTHODOX", "EASTER_WESTERN"] __all__ = ["easter", "EASTER_JULIAN", "EASTER_ORTHODOX", "EASTER_WESTERN"]
EASTER_JULIAN = 1 EASTER_JULIAN = 1
EASTER_ORTHODOX = 2 EASTER_ORTHODOX = 2
EASTER_WESTERN = 3 EASTER_WESTERN = 3
def easter(year, method=EASTER_WESTERN): def easter(year, method=EASTER_WESTERN):
""" """
@@ -68,24 +67,23 @@ def easter(year, method=EASTER_WESTERN):
e = 0 e = 0
if method < 3: if method < 3:
# Old method # Old method
i = (19*g+15)%30 i = (19*g + 15) % 30
j = (y+y//4+i)%7 j = (y + y//4 + i) % 7
if method == 2: if method == 2:
# Extra dates to convert Julian to Gregorian date # Extra dates to convert Julian to Gregorian date
e = 10 e = 10
if y > 1600: if y > 1600:
e = e+y//100-16-(y//100-16)//4 e = e + y//100 - 16 - (y//100 - 16)//4
else: else:
# New method # New method
c = y//100 c = y//100
h = (c-c//4-(8*c+13)//25+19*g+15)%30 h = (c - c//4 - (8*c + 13)//25 + 19*g + 15) % 30
i = h-(h//28)*(1-(h//28)*(29//(h+1))*((21-g)//11)) i = h - (h//28)*(1 - (h//28)*(29//(h + 1))*((21 - g)//11))
j = (y+y//4+i+2-c+c//4)%7 j = (y + y//4 + i + 2 - c + c//4) % 7
# p can be from -6 to 56 corresponding to dates 22 March to 23 May # p can be from -6 to 56 corresponding to dates 22 March to 23 May
# (later dates apply to method 2, although 23 May never actually occurs) # (later dates apply to method 2, although 23 May never actually occurs)
p = i-j+e p = i - j + e
d = 1+(p+27+(p+6)//40)%31 d = 1 + (p + 27 + (p + 6)//40) % 31
m = 3+(p+26)//30 m = 3 + (p + 26)//30
return datetime.date(int(y), int(m), int(d)) return datetime.date(int(y), int(m), int(d))

View File

@@ -1,32 +1,21 @@
# -*- coding:iso-8859-1 -*- # -*- coding:iso-8859-1 -*-
""" """
Copyright (c) 2003-2007 Gustavo Niemeyer <gustavo@niemeyer.net> This module offers a generic date/time string parser which is able to parse
most known formats to represent a date and/or time.
This module offers extensions to the standard Python
datetime module.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
__license__ = "Simplified BSD"
import datetime import datetime
import string import string
import time import time
import sys
import os
import collections import collections
from io import StringIO
try:
from io import StringIO
except ImportError:
from io import StringIO
from six import text_type, binary_type, integer_types from six import text_type, binary_type, integer_types
from . import relativedelta from . import relativedelta
from . import tz from . import tz
__all__ = ["parse", "parserinfo"] __all__ = ["parse", "parserinfo"]
@@ -83,9 +72,9 @@ class _timelex(object):
state = '0' state = '0'
elif nextchar in whitespace: elif nextchar in whitespace:
token = ' ' token = ' '
break # emit token break # emit token
else: else:
break # emit token break # emit token
elif state == 'a': elif state == 'a':
seenletters = True seenletters = True
if nextchar in wordchars: if nextchar in wordchars:
@@ -95,7 +84,7 @@ class _timelex(object):
state = 'a.' state = 'a.'
else: else:
self.charstack.append(nextchar) self.charstack.append(nextchar)
break # emit token break # emit token
elif state == '0': elif state == '0':
if nextchar in numchars: if nextchar in numchars:
token += nextchar token += nextchar
@@ -104,7 +93,7 @@ class _timelex(object):
state = '0.' state = '0.'
else: else:
self.charstack.append(nextchar) self.charstack.append(nextchar)
break # emit token break # emit token
elif state == 'a.': elif state == 'a.':
seenletters = True seenletters = True
if nextchar == '.' or nextchar in wordchars: if nextchar == '.' or nextchar in wordchars:
@@ -114,7 +103,7 @@ class _timelex(object):
state = '0.' state = '0.'
else: else:
self.charstack.append(nextchar) self.charstack.append(nextchar)
break # emit token break # emit token
elif state == '0.': elif state == '0.':
if nextchar == '.' or nextchar in numchars: if nextchar == '.' or nextchar in numchars:
token += nextchar token += nextchar
@@ -123,9 +112,9 @@ class _timelex(object):
state = 'a.' state = 'a.'
else: else:
self.charstack.append(nextchar) self.charstack.append(nextchar)
break # emit token break # emit token
if (state in ('a.', '0.') and if (state in ('a.', '0.') and (seenletters or token.count('.') > 1 or
(seenletters or token.count('.') > 1 or token[-1] == '.')): token[-1] == '.')):
l = token.split('.') l = token.split('.')
token = l[0] token = l[0]
for tok in l[1:]: for tok in l[1:]:
@@ -183,18 +172,18 @@ class parserinfo(object):
("Fri", "Friday"), ("Fri", "Friday"),
("Sat", "Saturday"), ("Sat", "Saturday"),
("Sun", "Sunday")] ("Sun", "Sunday")]
MONTHS = [("Jan", "January"), MONTHS = [("Jan", "January"),
("Feb", "February"), ("Feb", "February"),
("Mar", "March"), ("Mar", "March"),
("Apr", "April"), ("Apr", "April"),
("May", "May"), ("May", "May"),
("Jun", "June"), ("Jun", "June"),
("Jul", "July"), ("Jul", "July"),
("Aug", "August"), ("Aug", "August"),
("Sep", "Sept", "September"), ("Sep", "Sept", "September"),
("Oct", "October"), ("Oct", "October"),
("Nov", "November"), ("Nov", "November"),
("Dec", "December")] ("Dec", "December")]
HMS = [("h", "hour", "hours"), HMS = [("h", "hour", "hours"),
("m", "minute", "minutes"), ("m", "minute", "minutes"),
("s", "second", "seconds")] ("s", "second", "seconds")]
@@ -299,15 +288,16 @@ class parser(object):
def __init__(self, info=None): def __init__(self, info=None):
self.info = info or parserinfo() self.info = info or parserinfo()
def parse(self, timestr, default=None, def parse(self, timestr, default=None, ignoretz=False, tzinfos=None,
ignoretz=False, tzinfos=None, **kwargs):
**kwargs):
if not default: if not default:
default = datetime.datetime.now().replace(hour=0, minute=0, default = datetime.datetime.now().replace(hour=0, minute=0,
second=0, microsecond=0) second=0, microsecond=0)
if kwargs.get('fuzzy_with_tokens', False):
res, skipped_tokens = self._parse(timestr, **kwargs) res, skipped_tokens = self._parse(timestr, **kwargs)
else:
res = self._parse(timestr, **kwargs)
if res is None: if res is None:
raise ValueError("unknown string format") raise ValueError("unknown string format")
@@ -321,7 +311,8 @@ class parser(object):
if res.weekday is not None and not res.day: if res.weekday is not None and not res.day:
ret = ret+relativedelta.relativedelta(weekday=res.weekday) ret = ret+relativedelta.relativedelta(weekday=res.weekday)
if not ignoretz: if not ignoretz:
if isinstance(tzinfos, collections.Callable) or tzinfos and res.tzname in tzinfos: if (isinstance(tzinfos, collections.Callable) or
tzinfos and res.tzname in tzinfos):
if isinstance(tzinfos, collections.Callable): if isinstance(tzinfos, collections.Callable):
tzdata = tzinfos(res.tzname, res.tzoffset) tzdata = tzinfos(res.tzname, res.tzoffset)
else: else:
@@ -333,8 +324,8 @@ class parser(object):
elif isinstance(tzdata, integer_types): elif isinstance(tzdata, integer_types):
tzinfo = tz.tzoffset(res.tzname, tzdata) tzinfo = tz.tzoffset(res.tzname, tzdata)
else: else:
raise ValueError("offset must be tzinfo subclass, " \ raise ValueError("offset must be tzinfo subclass, "
"tz string, or int offset") "tz string, or int offset")
ret = ret.replace(tzinfo=tzinfo) ret = ret.replace(tzinfo=tzinfo)
elif res.tzname and res.tzname in time.tzname: elif res.tzname and res.tzname in time.tzname:
ret = ret.replace(tzinfo=tz.tzlocal()) ret = ret.replace(tzinfo=tz.tzlocal())
@@ -343,17 +334,18 @@ class parser(object):
elif res.tzoffset: elif res.tzoffset:
ret = ret.replace(tzinfo=tz.tzoffset(res.tzname, res.tzoffset)) ret = ret.replace(tzinfo=tz.tzoffset(res.tzname, res.tzoffset))
if skipped_tokens: if kwargs.get('fuzzy_with_tokens', False):
return ret, skipped_tokens return ret, skipped_tokens
else:
return ret return ret
class _result(_resultbase): class _result(_resultbase):
__slots__ = ["year", "month", "day", "weekday", __slots__ = ["year", "month", "day", "weekday",
"hour", "minute", "second", "microsecond", "hour", "minute", "second", "microsecond",
"tzname", "tzoffset"] "tzname", "tzoffset"]
def _parse(self, timestr, dayfirst=None, yearfirst=None, fuzzy=False, fuzzy_with_tokens=False): def _parse(self, timestr, dayfirst=None, yearfirst=None, fuzzy=False,
fuzzy_with_tokens=False):
if fuzzy_with_tokens: if fuzzy_with_tokens:
fuzzy = True fuzzy = True
@@ -365,7 +357,6 @@ class parser(object):
res = self._result() res = self._result()
l = _timelex.split(timestr) l = _timelex.split(timestr)
# keep up with the last token skipped so we can recombine # keep up with the last token skipped so we can recombine
# consecutively skipped tokens (-2 for when i begins at 0). # consecutively skipped tokens (-2 for when i begins at 0).
last_skipped_token_i = -2 last_skipped_token_i = -2
@@ -440,12 +431,12 @@ class parser(object):
while True: while True:
if idx == 0: if idx == 0:
res.hour = int(value) res.hour = int(value)
if value%1: if value % 1:
res.minute = int(60*(value%1)) res.minute = int(60*(value % 1))
elif idx == 1: elif idx == 1:
res.minute = int(value) res.minute = int(value)
if value%1: if value % 1:
res.second = int(60*(value%1)) res.second = int(60*(value % 1))
elif idx == 2: elif idx == 2:
res.second, res.microsecond = \ res.second, res.microsecond = \
_parsems(value_repr) _parsems(value_repr)
@@ -465,16 +456,17 @@ class parser(object):
newidx = info.hms(l[i]) newidx = info.hms(l[i])
if newidx is not None: if newidx is not None:
idx = newidx idx = newidx
elif i == len_l and l[i-2] == ' ' and info.hms(l[i-3]) is not None: elif (i == len_l and l[i-2] == ' ' and
info.hms(l[i-3]) is not None):
# X h MM or X m SS # X h MM or X m SS
idx = info.hms(l[i-3]) + 1 idx = info.hms(l[i-3]) + 1
if idx == 1: if idx == 1:
res.minute = int(value) res.minute = int(value)
if value%1: if value % 1:
res.second = int(60*(value%1)) res.second = int(60*(value % 1))
elif idx == 2: elif idx == 2:
res.second, res.microsecond = \ res.second, res.microsecond = \
_parsems(value_repr) _parsems(value_repr)
i += 1 i += 1
elif i+1 < len_l and l[i] == ':': elif i+1 < len_l and l[i] == ':':
# HH:MM[:SS[.ss]] # HH:MM[:SS[.ss]]
@@ -482,8 +474,8 @@ class parser(object):
i += 1 i += 1
value = float(l[i]) value = float(l[i])
res.minute = int(value) res.minute = int(value)
if value%1: if value % 1:
res.second = int(60*(value%1)) res.second = int(60*(value % 1))
i += 1 i += 1
if i < len_l and l[i] == ':': if i < len_l and l[i] == ':':
res.second, res.microsecond = _parsems(l[i+1]) res.second, res.microsecond = _parsems(l[i+1])
@@ -597,8 +589,9 @@ class parser(object):
# Check for a timezone name # Check for a timezone name
if (res.hour is not None and len(l[i]) <= 5 and if (res.hour is not None and len(l[i]) <= 5 and
res.tzname is None and res.tzoffset is None and res.tzname is None and res.tzoffset is None and
not [x for x in l[i] if x not in string.ascii_uppercase]): not [x for x in l[i] if x not in
string.ascii_uppercase]):
res.tzname = l[i] res.tzname = l[i]
res.tzoffset = info.tzoffset(res.tzname) res.tzoffset = info.tzoffset(res.tzname)
i += 1 i += 1
@@ -643,7 +636,7 @@ class parser(object):
info.jump(l[i]) and l[i+1] == '(' and l[i+3] == ')' and info.jump(l[i]) and l[i+1] == '(' and l[i+3] == ')' and
3 <= len(l[i+2]) <= 5 and 3 <= len(l[i+2]) <= 5 and
not [x for x in l[i+2] not [x for x in l[i+2]
if x not in string.ascii_uppercase]): if x not in string.ascii_uppercase]):
# -0300 (BRST) # -0300 (BRST)
res.tzname = l[i+2] res.tzname = l[i+2]
i += 4 i += 4
@@ -732,10 +725,12 @@ class parser(object):
if fuzzy_with_tokens: if fuzzy_with_tokens:
return res, tuple(skipped_tokens) return res, tuple(skipped_tokens)
else:
return res, None return res
DEFAULTPARSER = parser() DEFAULTPARSER = parser()
def parse(timestr, parserinfo=None, **kwargs): def parse(timestr, parserinfo=None, **kwargs):
# Python 2.x support: datetimes return their string presentation as # Python 2.x support: datetimes return their string presentation as
# bytes in 2.x and unicode in 3.x, so it's reasonable to expect that # bytes in 2.x and unicode in 3.x, so it's reasonable to expect that
@@ -779,7 +774,7 @@ class _tzparser(object):
# BRST+3[BRDT[+2]] # BRST+3[BRDT[+2]]
j = i j = i
while j < len_l and not [x for x in l[j] while j < len_l and not [x for x in l[j]
if x in "0123456789:,-+"]: if x in "0123456789:,-+"]:
j += 1 j += 1
if j != i: if j != i:
if not res.stdabbr: if not res.stdabbr:
@@ -789,8 +784,8 @@ class _tzparser(object):
offattr = "dstoffset" offattr = "dstoffset"
res.dstabbr = "".join(l[i:j]) res.dstabbr = "".join(l[i:j])
i = j i = j
if (i < len_l and if (i < len_l and (l[i] in ('+', '-') or l[i][0] in
(l[i] in ('+', '-') or l[i][0] in "0123456789")): "0123456789")):
if l[i] in ('+', '-'): if l[i] in ('+', '-'):
# Yes, that's right. See the TZ variable # Yes, that's right. See the TZ variable
# documentation. # documentation.
@@ -801,8 +796,8 @@ class _tzparser(object):
len_li = len(l[i]) len_li = len(l[i])
if len_li == 4: if len_li == 4:
# -0300 # -0300
setattr(res, offattr, setattr(res, offattr, (int(l[i][:2])*3600 +
(int(l[i][:2])*3600+int(l[i][2:])*60)*signal) int(l[i][2:])*60)*signal)
elif i+1 < len_l and l[i+1] == ':': elif i+1 < len_l and l[i+1] == ':':
# -03:00 # -03:00
setattr(res, offattr, setattr(res, offattr,
@@ -822,7 +817,8 @@ class _tzparser(object):
if i < len_l: if i < len_l:
for j in range(i, len_l): for j in range(i, len_l):
if l[j] == ';': l[j] = ',' if l[j] == ';':
l[j] = ','
assert l[i] == ',' assert l[i] == ','
@@ -831,7 +827,7 @@ class _tzparser(object):
if i >= len_l: if i >= len_l:
pass pass
elif (8 <= l.count(',') <= 9 and elif (8 <= l.count(',') <= 9 and
not [y for x in l[i:] if x != ',' not [y for x in l[i:] if x != ','
for y in x if y not in "0123456789"]): for y in x if y not in "0123456789"]):
# GMT0BST,3,0,30,3600,10,0,26,7200[,3600] # GMT0BST,3,0,30,3600,10,0,26,7200[,3600]
for x in (res.start, res.end): for x in (res.start, res.end):
@@ -845,7 +841,7 @@ class _tzparser(object):
i += 2 i += 2
if value: if value:
x.week = value x.week = value
x.weekday = (int(l[i])-1)%7 x.weekday = (int(l[i])-1) % 7
else: else:
x.day = int(l[i]) x.day = int(l[i])
i += 2 i += 2
@@ -861,7 +857,7 @@ class _tzparser(object):
elif (l.count(',') == 2 and l[i:].count('/') <= 2 and elif (l.count(',') == 2 and l[i:].count('/') <= 2 and
not [y for x in l[i:] if x not in (',', '/', 'J', 'M', not [y for x in l[i:] if x not in (',', '/', 'J', 'M',
'.', '-', ':') '.', '-', ':')
for y in x if y not in "0123456789"]): for y in x if y not in "0123456789"]):
for x in (res.start, res.end): for x in (res.start, res.end):
if l[i] == 'J': if l[i] == 'J':
# non-leap year day (1 based) # non-leap year day (1 based)
@@ -880,7 +876,7 @@ class _tzparser(object):
i += 1 i += 1
assert l[i] in ('-', '.') assert l[i] in ('-', '.')
i += 1 i += 1
x.weekday = (int(l[i])-1)%7 x.weekday = (int(l[i])-1) % 7
else: else:
# year day (zero based) # year day (zero based)
x.yday = int(l[i])+1 x.yday = int(l[i])+1
@@ -921,6 +917,8 @@ class _tzparser(object):
DEFAULTTZPARSER = _tzparser() DEFAULTTZPARSER = _tzparser()
def _parsetz(tzstr): def _parsetz(tzstr):
return DEFAULTTZPARSER.parse(tzstr) return DEFAULTTZPARSER.parse(tzstr)

View File

@@ -1,11 +1,4 @@
""" # -*- coding: utf-8 -*-
Copyright (c) 2003-2010 Gustavo Niemeyer <gustavo@niemeyer.net>
This module offers extensions to the standard Python
datetime module.
"""
__license__ = "Simplified BSD"
import datetime import datetime
import calendar import calendar
@@ -13,6 +6,7 @@ from six import integer_types
__all__ = ["relativedelta", "MO", "TU", "WE", "TH", "FR", "SA", "SU"] __all__ = ["relativedelta", "MO", "TU", "WE", "TH", "FR", "SA", "SU"]
class weekday(object): class weekday(object):
__slots__ = ["weekday", "n"] __slots__ = ["weekday", "n"]
@@ -43,25 +37,35 @@ class weekday(object):
MO, TU, WE, TH, FR, SA, SU = weekdays = tuple([weekday(x) for x in range(7)]) MO, TU, WE, TH, FR, SA, SU = weekdays = tuple([weekday(x) for x in range(7)])
class relativedelta(object): class relativedelta(object):
""" """
The relativedelta type is based on the specification of the excelent The relativedelta type is based on the specification of the excellent
work done by M.-A. Lemburg in his mx.DateTime extension. However, work done by M.-A. Lemburg in his
notice that this type does *NOT* implement the same algorithm as `mx.DateTime <http://www.egenix.com/files/python/mxDateTime.html>`_ extension.
However, notice that this type does *NOT* implement the same algorithm as
his work. Do *NOT* expect it to behave like mx.DateTime's counterpart. his work. Do *NOT* expect it to behave like mx.DateTime's counterpart.
There's two different ways to build a relativedelta instance. The There are two different ways to build a relativedelta instance. The
first one is passing it two date/datetime classes: first one is passing it two date/datetime classes::
relativedelta(datetime1, datetime2) relativedelta(datetime1, datetime2)
And the other way is to use the following keyword arguments: The second one is passing it any number of the following keyword arguments::
relativedelta(arg1=x,arg2=y,arg3=z...)
year, month, day, hour, minute, second, microsecond: year, month, day, hour, minute, second, microsecond:
Absolute information. Absolute information (argument is singular); adding or subtracting a
relativedelta with absolute information does not perform an aritmetic
operation, but rather REPLACES the corresponding value in the
original datetime with the value(s) in relativedelta.
years, months, weeks, days, hours, minutes, seconds, microseconds: years, months, weeks, days, hours, minutes, seconds, microseconds:
Relative information, may be negative. Relative information, may be negative (argument is plural); adding
or subtracting a relativedelta with relative information performs
the corresponding aritmetic operation on the original datetime value
with the information in the relativedelta.
weekday: weekday:
One of the weekday instances (MO, TU, etc). These instances may One of the weekday instances (MO, TU, etc). These instances may
@@ -80,26 +84,26 @@ And the other way is to use the following keyword arguments:
Here is the behavior of operations with relativedelta: Here is the behavior of operations with relativedelta:
1) Calculate the absolute year, using the 'year' argument, or the 1. Calculate the absolute year, using the 'year' argument, or the
original datetime year, if the argument is not present. original datetime year, if the argument is not present.
2) Add the relative 'years' argument to the absolute year. 2. Add the relative 'years' argument to the absolute year.
3) Do steps 1 and 2 for month/months. 3. Do steps 1 and 2 for month/months.
4) Calculate the absolute day, using the 'day' argument, or the 4. Calculate the absolute day, using the 'day' argument, or the
original datetime day, if the argument is not present. Then, original datetime day, if the argument is not present. Then,
subtract from the day until it fits in the year and month subtract from the day until it fits in the year and month
found after their operations. found after their operations.
5) Add the relative 'days' argument to the absolute day. Notice 5. Add the relative 'days' argument to the absolute day. Notice
that the 'weeks' argument is multiplied by 7 and added to that the 'weeks' argument is multiplied by 7 and added to
'days'. 'days'.
6) Do steps 1 and 2 for hour/hours, minute/minutes, second/seconds, 6. Do steps 1 and 2 for hour/hours, minute/minutes, second/seconds,
microsecond/microseconds. microsecond/microseconds.
7) If the 'weekday' argument is present, calculate the weekday, 7. If the 'weekday' argument is present, calculate the weekday,
with the given (wday, nth) tuple. wday is the index of the with the given (wday, nth) tuple. wday is the index of the
weekday (0-6, 0=Mon), and nth is the number of weeks to add weekday (0-6, 0=Mon), and nth is the number of weeks to add
forward or backward, depending on its signal. Notice that if forward or backward, depending on its signal. Notice that if
@@ -114,9 +118,14 @@ Here is the behavior of operations with relativedelta:
yearday=None, nlyearday=None, yearday=None, nlyearday=None,
hour=None, minute=None, second=None, microsecond=None): hour=None, minute=None, second=None, microsecond=None):
if dt1 and dt2: if dt1 and dt2:
if (not isinstance(dt1, datetime.date)) or (not isinstance(dt2, datetime.date)): # datetime is a subclass of date. So both must be date
if not (isinstance(dt1, datetime.date) and
isinstance(dt2, datetime.date)):
raise TypeError("relativedelta only diffs datetime/date") raise TypeError("relativedelta only diffs datetime/date")
if not type(dt1) == type(dt2): #isinstance(dt1, type(dt2)): # We allow two dates, or two datetimes, so we coerce them to be
# of the same type
if (isinstance(dt1, datetime.datetime) !=
isinstance(dt2, datetime.datetime)):
if not isinstance(dt1, datetime.datetime): if not isinstance(dt1, datetime.datetime):
dt1 = datetime.datetime.fromordinal(dt1.toordinal()) dt1 = datetime.datetime.fromordinal(dt1.toordinal())
elif not isinstance(dt2, datetime.datetime): elif not isinstance(dt2, datetime.datetime):
@@ -185,7 +194,8 @@ Here is the behavior of operations with relativedelta:
if yearday > 59: if yearday > 59:
self.leapdays = -1 self.leapdays = -1
if yday: if yday:
ydayidx = [31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 366] ydayidx = [31, 59, 90, 120, 151, 181, 212,
243, 273, 304, 334, 366]
for idx, ydays in enumerate(ydayidx): for idx, ydays in enumerate(ydayidx):
if yday <= ydays: if yday <= ydays:
self.month = idx+1 self.month = idx+1
@@ -225,9 +235,9 @@ Here is the behavior of operations with relativedelta:
div, mod = divmod(self.months*s, 12) div, mod = divmod(self.months*s, 12)
self.months = mod*s self.months = mod*s
self.years += div*s self.years += div*s
if (self.hours or self.minutes or self.seconds or self.microseconds or if (self.hours or self.minutes or self.seconds or self.microseconds
self.hour is not None or self.minute is not None or or self.hour is not None or self.minute is not None or
self.second is not None or self.microsecond is not None): self.second is not None or self.microsecond is not None):
self._has_time = 1 self._has_time = 1
else: else:
self._has_time = 0 self._has_time = 0
@@ -245,21 +255,23 @@ Here is the behavior of operations with relativedelta:
def __add__(self, other): def __add__(self, other):
if isinstance(other, relativedelta): if isinstance(other, relativedelta):
return relativedelta(years=other.years+self.years, return relativedelta(years=other.years+self.years,
months=other.months+self.months, months=other.months+self.months,
days=other.days+self.days, days=other.days+self.days,
hours=other.hours+self.hours, hours=other.hours+self.hours,
minutes=other.minutes+self.minutes, minutes=other.minutes+self.minutes,
seconds=other.seconds+self.seconds, seconds=other.seconds+self.seconds,
microseconds=other.microseconds+self.microseconds, microseconds=(other.microseconds +
leapdays=other.leapdays or self.leapdays, self.microseconds),
year=other.year or self.year, leapdays=other.leapdays or self.leapdays,
month=other.month or self.month, year=other.year or self.year,
day=other.day or self.day, month=other.month or self.month,
weekday=other.weekday or self.weekday, day=other.day or self.day,
hour=other.hour or self.hour, weekday=other.weekday or self.weekday,
minute=other.minute or self.minute, hour=other.hour or self.hour,
second=other.second or self.second, minute=other.minute or self.minute,
microsecond=other.microsecond or self.microsecond) second=other.second or self.second,
microsecond=(other.microsecond or
self.microsecond))
if not isinstance(other, datetime.date): if not isinstance(other, datetime.date):
raise TypeError("unsupported type for add operation") raise TypeError("unsupported type for add operation")
elif self._has_time and not isinstance(other, datetime.datetime): elif self._has_time and not isinstance(other, datetime.datetime):
@@ -295,9 +307,9 @@ Here is the behavior of operations with relativedelta:
weekday, nth = self.weekday.weekday, self.weekday.n or 1 weekday, nth = self.weekday.weekday, self.weekday.n or 1
jumpdays = (abs(nth)-1)*7 jumpdays = (abs(nth)-1)*7
if nth > 0: if nth > 0:
jumpdays += (7-ret.weekday()+weekday)%7 jumpdays += (7-ret.weekday()+weekday) % 7
else: else:
jumpdays += (ret.weekday()-weekday)%7 jumpdays += (ret.weekday()-weekday) % 7
jumpdays *= -1 jumpdays *= -1
ret += datetime.timedelta(days=jumpdays) ret += datetime.timedelta(days=jumpdays)
return ret return ret

View File

@@ -1,21 +1,19 @@
# -*- coding: utf-8 -*-
""" """
Copyright (c) 2003-2010 Gustavo Niemeyer <gustavo@niemeyer.net> The rrule module offers a small, complete, and very fast, implementation of
the recurrence rules documented in the
This module offers extensions to the standard Python `iCalendar RFC <http://www.ietf.org/rfc/rfc2445.txt>`_,
datetime module. including support for caching of results.
""" """
__license__ = "Simplified BSD"
import itertools import itertools
import datetime import datetime
import calendar import calendar
try:
import _thread
except ImportError:
import thread as _thread
import sys import sys
from fractions import gcd
from six import advance_iterator, integer_types from six import advance_iterator, integer_types
from six.moves import _thread
__all__ = ["rrule", "rruleset", "rrulestr", __all__ = ["rrule", "rruleset", "rrulestr",
"YEARLY", "MONTHLY", "WEEKLY", "DAILY", "YEARLY", "MONTHLY", "WEEKLY", "DAILY",
@@ -23,7 +21,7 @@ __all__ = ["rrule", "rruleset", "rrulestr",
"MO", "TU", "WE", "TH", "FR", "SA", "SU"] "MO", "TU", "WE", "TH", "FR", "SA", "SU"]
# Every mask is 7 days longer to handle cross-year weekly periods. # Every mask is 7 days longer to handle cross-year weekly periods.
M366MASK = tuple([1]*31+[2]*29+[3]*31+[4]*30+[5]*31+[6]*30+ M366MASK = tuple([1]*31+[2]*29+[3]*31+[4]*30+[5]*31+[6]*30 +
[7]*31+[8]*31+[9]*30+[10]*31+[11]*30+[12]*31+[1]*7) [7]*31+[8]*31+[9]*30+[10]*31+[11]*30+[12]*31+[1]*7)
M365MASK = list(M366MASK) M365MASK = list(M366MASK)
M29, M30, M31 = list(range(1, 30)), list(range(1, 31)), list(range(1, 32)) M29, M30, M31 = list(range(1, 30)), list(range(1, 31)), list(range(1, 32))
@@ -51,6 +49,7 @@ M365MASK = tuple(M365MASK)
easter = None easter = None
parser = None parser = None
class weekday(object): class weekday(object):
__slots__ = ["weekday", "n"] __slots__ = ["weekday", "n"]
@@ -83,12 +82,13 @@ class weekday(object):
MO, TU, WE, TH, FR, SA, SU = weekdays = tuple([weekday(x) for x in range(7)]) MO, TU, WE, TH, FR, SA, SU = weekdays = tuple([weekday(x) for x in range(7)])
class rrulebase(object): class rrulebase(object):
def __init__(self, cache=False): def __init__(self, cache=False):
if cache: if cache:
self._cache = [] self._cache = []
self._cache_lock = _thread.allocate_lock() self._cache_lock = _thread.allocate_lock()
self._cache_gen = self._iter() self._cache_gen = self._iter()
self._cache_complete = False self._cache_complete = False
else: else:
self._cache = None self._cache = None
@@ -163,11 +163,17 @@ class rrulebase(object):
# __len__() introduces a large performance penality. # __len__() introduces a large performance penality.
def count(self): def count(self):
""" Returns the number of recurrences in this set. It will have go
trough the whole recurrence, if this hasn't been done before. """
if self._len is None: if self._len is None:
for x in self: pass for x in self:
pass
return self._len return self._len
def before(self, dt, inc=False): def before(self, dt, inc=False):
""" Returns the last recurrence before the given datetime instance. The
inc keyword defines what happens if dt is an occurrence. With
inc=True, if dt itself is an occurrence, it will be returned. """
if self._cache_complete: if self._cache_complete:
gen = self._cache gen = self._cache
else: else:
@@ -186,6 +192,9 @@ class rrulebase(object):
return last return last
def after(self, dt, inc=False): def after(self, dt, inc=False):
""" Returns the first recurrence after the given datetime instance. The
inc keyword defines what happens if dt is an occurrence. With
inc=True, if dt itself is an occurrence, it will be returned. """
if self._cache_complete: if self._cache_complete:
gen = self._cache gen = self._cache
else: else:
@@ -201,6 +210,10 @@ class rrulebase(object):
return None return None
def between(self, after, before, inc=False): def between(self, after, before, inc=False):
""" Returns all the occurrences of the rrule between after and before.
The inc keyword defines what happens if after and/or before are
themselves occurrences. With inc=True, they will be included in the
list, if they are found in the recurrence set. """
if self._cache_complete: if self._cache_complete:
gen = self._cache gen = self._cache
else: else:
@@ -229,7 +242,93 @@ class rrulebase(object):
l.append(i) l.append(i)
return l return l
class rrule(rrulebase): class rrule(rrulebase):
"""
That's the base of the rrule operation. It accepts all the keywords
defined in the RFC as its constructor parameters (except byday,
which was renamed to byweekday) and more. The constructor prototype is::
rrule(freq)
Where freq must be one of YEARLY, MONTHLY, WEEKLY, DAILY, HOURLY, MINUTELY,
or SECONDLY.
Additionally, it supports the following keyword arguments:
:param cache:
If given, it must be a boolean value specifying to enable or disable
caching of results. If you will use the same rrule instance multiple
times, enabling caching will improve the performance considerably.
:param dtstart:
The recurrence start. Besides being the base for the recurrence,
missing parameters in the final recurrence instances will also be
extracted from this date. If not given, datetime.now() will be used
instead.
:param interval:
The interval between each freq iteration. For example, when using
YEARLY, an interval of 2 means once every two years, but with HOURLY,
it means once every two hours. The default interval is 1.
:param wkst:
The week start day. Must be one of the MO, TU, WE constants, or an
integer, specifying the first day of the week. This will affect
recurrences based on weekly periods. The default week start is got
from calendar.firstweekday(), and may be modified by
calendar.setfirstweekday().
:param count:
How many occurrences will be generated.
:param until:
If given, this must be a datetime instance, that will specify the
limit of the recurrence. If a recurrence instance happens to be the
same as the datetime instance given in the until keyword, this will
be the last occurrence.
:param bysetpos:
If given, it must be either an integer, or a sequence of integers,
positive or negative. Each given integer will specify an occurrence
number, corresponding to the nth occurrence of the rule inside the
frequency period. For example, a bysetpos of -1 if combined with a
MONTHLY frequency, and a byweekday of (MO, TU, WE, TH, FR), will
result in the last work day of every month.
:param bymonth:
If given, it must be either an integer, or a sequence of integers,
meaning the months to apply the recurrence to.
:param bymonthday:
If given, it must be either an integer, or a sequence of integers,
meaning the month days to apply the recurrence to.
:param byyearday:
If given, it must be either an integer, or a sequence of integers,
meaning the year days to apply the recurrence to.
:param byweekno:
If given, it must be either an integer, or a sequence of integers,
meaning the week numbers to apply the recurrence to. Week numbers
have the meaning described in ISO8601, that is, the first week of
the year is that containing at least four days of the new year.
:param byweekday:
If given, it must be either an integer (0 == MO), a sequence of
integers, one of the weekday constants (MO, TU, etc), or a sequence
of these constants. When given, these variables will define the
weekdays where the recurrence will be applied. It's also possible to
use an argument n for the weekday instances, which will mean the nth
occurrence of this weekday in the period. For example, with MONTHLY,
or with YEARLY and BYMONTH, using FR(+1) in byweekday will specify the
first friday of the month where the recurrence happens. Notice that in
the RFC documentation, this is specified as BYDAY, but was renamed to
avoid the ambiguity of that keyword.
:param byhour:
If given, it must be either an integer, or a sequence of integers,
meaning the hours to apply the recurrence to.
:param byminute:
If given, it must be either an integer, or a sequence of integers,
meaning the minutes to apply the recurrence to.
:param bysecond:
If given, it must be either an integer, or a sequence of integers,
meaning the seconds to apply the recurrence to.
:param byeaster:
If given, it must be either an integer, or a sequence of integers,
positive or negative. Each integer will define an offset from the
Easter Sunday. Passing the offset 0 to byeaster will yield the Easter
Sunday itself. This is an extension to the RFC specification.
"""
def __init__(self, freq, dtstart=None, def __init__(self, freq, dtstart=None,
interval=1, wkst=None, count=None, until=None, bysetpos=None, interval=1, wkst=None, count=None, until=None, bysetpos=None,
bymonth=None, bymonthday=None, byyearday=None, byeaster=None, bymonth=None, bymonthday=None, byyearday=None, byeaster=None,
@@ -249,15 +348,18 @@ class rrule(rrulebase):
self._freq = freq self._freq = freq
self._interval = interval self._interval = interval
self._count = count self._count = count
if until and not isinstance(until, datetime.datetime): if until and not isinstance(until, datetime.datetime):
until = datetime.datetime.fromordinal(until.toordinal()) until = datetime.datetime.fromordinal(until.toordinal())
self._until = until self._until = until
if wkst is None: if wkst is None:
self._wkst = calendar.firstweekday() self._wkst = calendar.firstweekday()
elif isinstance(wkst, integer_types): elif isinstance(wkst, integer_types):
self._wkst = wkst self._wkst = wkst
else: else:
self._wkst = wkst.weekday self._wkst = wkst.weekday
if bysetpos is None: if bysetpos is None:
self._bysetpos = None self._bysetpos = None
elif isinstance(bysetpos, integer_types): elif isinstance(bysetpos, integer_types):
@@ -271,30 +373,36 @@ class rrule(rrulebase):
if pos == 0 or not (-366 <= pos <= 366): if pos == 0 or not (-366 <= pos <= 366):
raise ValueError("bysetpos must be between 1 and 366, " raise ValueError("bysetpos must be between 1 and 366, "
"or between -366 and -1") "or between -366 and -1")
if not (byweekno or byyearday or bymonthday or
byweekday is not None or byeaster is not None): if (byweekno is None and byyearday is None and bymonthday is None and
byweekday is None and byeaster is None):
if freq == YEARLY: if freq == YEARLY:
if not bymonth: if bymonth is None:
bymonth = dtstart.month bymonth = dtstart.month
bymonthday = dtstart.day bymonthday = dtstart.day
elif freq == MONTHLY: elif freq == MONTHLY:
bymonthday = dtstart.day bymonthday = dtstart.day
elif freq == WEEKLY: elif freq == WEEKLY:
byweekday = dtstart.weekday() byweekday = dtstart.weekday()
# bymonth # bymonth
if not bymonth: if bymonth is None:
self._bymonth = None self._bymonth = None
elif isinstance(bymonth, integer_types):
self._bymonth = (bymonth,)
else: else:
self._bymonth = tuple(bymonth) if isinstance(bymonth, integer_types):
bymonth = (bymonth,)
self._bymonth = set(bymonth)
# byyearday # byyearday
if not byyearday: if byyearday is None:
self._byyearday = None self._byyearday = None
elif isinstance(byyearday, integer_types):
self._byyearday = (byyearday,)
else: else:
self._byyearday = tuple(byyearday) if isinstance(byyearday, integer_types):
byyearday = (byyearday,)
self._byyearday = set(byyearday)
# byeaster # byeaster
if byeaster is not None: if byeaster is not None:
if not easter: if not easter:
@@ -305,87 +413,104 @@ class rrule(rrulebase):
self._byeaster = tuple(byeaster) self._byeaster = tuple(byeaster)
else: else:
self._byeaster = None self._byeaster = None
# bymonthay # bymonthay
if not bymonthday: if bymonthday is None:
self._bymonthday = () self._bymonthday = ()
self._bynmonthday = () self._bynmonthday = ()
elif isinstance(bymonthday, integer_types):
if bymonthday < 0:
self._bynmonthday = (bymonthday,)
self._bymonthday = ()
else:
self._bymonthday = (bymonthday,)
self._bynmonthday = ()
else: else:
self._bymonthday = tuple([x for x in bymonthday if x > 0]) if isinstance(bymonthday, integer_types):
self._bynmonthday = tuple([x for x in bymonthday if x < 0]) bymonthday = (bymonthday,)
self._bymonthday = set([x for x in bymonthday if x > 0])
self._bynmonthday = set([x for x in bymonthday if x < 0])
# byweekno # byweekno
if byweekno is None: if byweekno is None:
self._byweekno = None self._byweekno = None
elif isinstance(byweekno, integer_types):
self._byweekno = (byweekno,)
else: else:
self._byweekno = tuple(byweekno) if isinstance(byweekno, integer_types):
byweekno = (byweekno,)
self._byweekno = set(byweekno)
# byweekday / bynweekday # byweekday / bynweekday
if byweekday is None: if byweekday is None:
self._byweekday = None self._byweekday = None
self._bynweekday = None self._bynweekday = None
elif isinstance(byweekday, integer_types):
self._byweekday = (byweekday,)
self._bynweekday = None
elif hasattr(byweekday, "n"):
if not byweekday.n or freq > MONTHLY:
self._byweekday = (byweekday.weekday,)
self._bynweekday = None
else:
self._bynweekday = ((byweekday.weekday, byweekday.n),)
self._byweekday = None
else: else:
self._byweekday = [] if isinstance(byweekday, integer_types):
self._bynweekday = [] byweekday = (byweekday,)
elif hasattr(byweekday, "n"):
byweekday = (byweekday.weekday,)
self._byweekday = set()
self._bynweekday = set()
for wday in byweekday: for wday in byweekday:
if isinstance(wday, integer_types): if isinstance(wday, integer_types):
self._byweekday.append(wday) self._byweekday.add(wday)
elif not wday.n or freq > MONTHLY: elif not wday.n or freq > MONTHLY:
self._byweekday.append(wday.weekday) self._byweekday.add(wday.weekday)
else: else:
self._bynweekday.append((wday.weekday, wday.n)) self._bynweekday.add((wday.weekday, wday.n))
self._byweekday = tuple(self._byweekday)
self._bynweekday = tuple(self._bynweekday)
if not self._byweekday: if not self._byweekday:
self._byweekday = None self._byweekday = None
elif not self._bynweekday: elif not self._bynweekday:
self._bynweekday = None self._bynweekday = None
# byhour # byhour
if byhour is None: if byhour is None:
if freq < HOURLY: if freq < HOURLY:
self._byhour = (dtstart.hour,) self._byhour = set((dtstart.hour,))
else: else:
self._byhour = None self._byhour = None
elif isinstance(byhour, integer_types):
self._byhour = (byhour,)
else: else:
self._byhour = tuple(byhour) if isinstance(byhour, integer_types):
byhour = (byhour,)
if freq == HOURLY:
self._byhour = self.__construct_byset(start=dtstart.hour,
byxxx=byhour,
base=24)
else:
self._byhour = set(byhour)
# byminute # byminute
if byminute is None: if byminute is None:
if freq < MINUTELY: if freq < MINUTELY:
self._byminute = (dtstart.minute,) self._byminute = set((dtstart.minute,))
else: else:
self._byminute = None self._byminute = None
elif isinstance(byminute, integer_types):
self._byminute = (byminute,)
else: else:
self._byminute = tuple(byminute) if isinstance(byminute, integer_types):
byminute = (byminute,)
if freq == MINUTELY:
self._byminute = self.__construct_byset(start=dtstart.minute,
byxxx=byminute,
base=60)
else:
self._byminute = set(byminute)
# bysecond # bysecond
if bysecond is None: if bysecond is None:
if freq < SECONDLY: if freq < SECONDLY:
self._bysecond = (dtstart.second,) self._bysecond = ((dtstart.second,))
else: else:
self._bysecond = None self._bysecond = None
elif isinstance(bysecond, integer_types):
self._bysecond = (bysecond,)
else: else:
self._bysecond = tuple(bysecond) if isinstance(bysecond, integer_types):
bysecond = (bysecond,)
self._bysecond = set(bysecond)
if freq == SECONDLY:
self._bysecond = self.__construct_byset(start=dtstart.second,
byxxx=bysecond,
base=60)
else:
self._bysecond = set(bysecond)
if self._freq >= HOURLY: if self._freq >= HOURLY:
self._timeset = None self._timeset = None
@@ -395,8 +520,8 @@ class rrule(rrulebase):
for minute in self._byminute: for minute in self._byminute:
for second in self._bysecond: for second in self._bysecond:
self._timeset.append( self._timeset.append(
datetime.time(hour, minute, second, datetime.time(hour, minute, second,
tzinfo=self._tzinfo)) tzinfo=self._tzinfo))
self._timeset.sort() self._timeset.sort()
self._timeset = tuple(self._timeset) self._timeset = tuple(self._timeset)
@@ -424,20 +549,20 @@ class rrule(rrulebase):
ii = _iterinfo(self) ii = _iterinfo(self)
ii.rebuild(year, month) ii.rebuild(year, month)
getdayset = {YEARLY:ii.ydayset, getdayset = {YEARLY: ii.ydayset,
MONTHLY:ii.mdayset, MONTHLY: ii.mdayset,
WEEKLY:ii.wdayset, WEEKLY: ii.wdayset,
DAILY:ii.ddayset, DAILY: ii.ddayset,
HOURLY:ii.ddayset, HOURLY: ii.ddayset,
MINUTELY:ii.ddayset, MINUTELY: ii.ddayset,
SECONDLY:ii.ddayset}[freq] SECONDLY: ii.ddayset}[freq]
if freq < HOURLY: if freq < HOURLY:
timeset = self._timeset timeset = self._timeset
else: else:
gettimeset = {HOURLY:ii.htimeset, gettimeset = {HOURLY: ii.htimeset,
MINUTELY:ii.mtimeset, MINUTELY: ii.mtimeset,
SECONDLY:ii.stimeset}[freq] SECONDLY: ii.stimeset}[freq]
if ((freq >= HOURLY and if ((freq >= HOURLY and
self._byhour and hour not in self._byhour) or self._byhour and hour not in self._byhour) or
(freq >= MINUTELY and (freq >= MINUTELY and
@@ -466,11 +591,10 @@ class rrule(rrulebase):
ii.mdaymask[i] not in bymonthday and ii.mdaymask[i] not in bymonthday and
ii.nmdaymask[i] not in bynmonthday) or ii.nmdaymask[i] not in bynmonthday) or
(byyearday and (byyearday and
((i < ii.yearlen and i+1 not in byyearday ((i < ii.yearlen and i+1 not in byyearday and
and -ii.yearlen+i not in byyearday) or -ii.yearlen+i not in byyearday) or
(i >= ii.yearlen and i+1-ii.yearlen not in byyearday (i >= ii.yearlen and i+1-ii.yearlen not in byyearday and
and -ii.nextyearlen+i-ii.yearlen -ii.nextyearlen+i-ii.yearlen not in byyearday)))):
not in byyearday)))):
dayset[i] = None dayset[i] = None
filtered = True filtered = True
@@ -484,7 +608,7 @@ class rrule(rrulebase):
daypos, timepos = divmod(pos-1, len(timeset)) daypos, timepos = divmod(pos-1, len(timeset))
try: try:
i = [x for x in dayset[start:end] i = [x for x in dayset[start:end]
if x is not None][daypos] if x is not None][daypos]
time = timeset[timepos] time = timeset[timepos]
except IndexError: except IndexError:
pass pass
@@ -559,60 +683,86 @@ class rrule(rrulebase):
if filtered: if filtered:
# Jump to one iteration before next day # Jump to one iteration before next day
hour += ((23-hour)//interval)*interval hour += ((23-hour)//interval)*interval
while True:
hour += interval if byhour:
div, mod = divmod(hour, 24) ndays, hour = self.__mod_distance(value=hour,
if div: byxxx=self._byhour,
hour = mod base=24)
day += div else:
fixday = True ndays, hour = divmod(hour+interval, 24)
if not byhour or hour in byhour:
break if ndays:
day += ndays
fixday = True
timeset = gettimeset(hour, minute, second) timeset = gettimeset(hour, minute, second)
elif freq == MINUTELY: elif freq == MINUTELY:
if filtered: if filtered:
# Jump to one iteration before next day # Jump to one iteration before next day
minute += ((1439-(hour*60+minute))//interval)*interval minute += ((1439-(hour*60+minute))//interval)*interval
while True:
minute += interval valid = False
div, mod = divmod(minute, 60) rep_rate = (24*60)
for j in range(rep_rate // gcd(interval, rep_rate)):
if byminute:
nhours, minute = \
self.__mod_distance(value=minute,
byxxx=self._byminute,
base=60)
else:
nhours, minute = divmod(minute+interval, 60)
div, hour = divmod(hour+nhours, 24)
if div: if div:
minute = mod day += div
hour += div fixday = True
div, mod = divmod(hour, 24) filtered = False
if div:
hour = mod if not byhour or hour in byhour:
day += div valid = True
fixday = True
filtered = False
if ((not byhour or hour in byhour) and
(not byminute or minute in byminute)):
break break
if not valid:
raise ValueError('Invalid combination of interval and ' +
'byhour resulting in empty rule.')
timeset = gettimeset(hour, minute, second) timeset = gettimeset(hour, minute, second)
elif freq == SECONDLY: elif freq == SECONDLY:
if filtered: if filtered:
# Jump to one iteration before next day # Jump to one iteration before next day
second += (((86399-(hour*3600+minute*60+second)) second += (((86399-(hour*3600+minute*60+second))
//interval)*interval) // interval)*interval)
while True:
second += self._interval rep_rate = (24*3600)
div, mod = divmod(second, 60) valid = False
for j in range(0, rep_rate // gcd(interval, rep_rate)):
if bysecond:
nminutes, second = \
self.__mod_distance(value=second,
byxxx=self._bysecond,
base=60)
else:
nminutes, second = divmod(second+interval, 60)
div, minute = divmod(minute+nminutes, 60)
if div: if div:
second = mod hour += div
minute += div div, hour = divmod(hour, 24)
div, mod = divmod(minute, 60)
if div: if div:
minute = mod day += div
hour += div fixday = True
div, mod = divmod(hour, 24)
if div:
hour = mod
day += div
fixday = True
if ((not byhour or hour in byhour) and if ((not byhour or hour in byhour) and
(not byminute or minute in byminute) and (not byminute or minute in byminute) and
(not bysecond or second in bysecond)): (not bysecond or second in bysecond)):
valid = True
break break
if not valid:
raise ValueError('Invalid combination of interval, ' +
'byhour and byminute resulting in empty' +
' rule.')
timeset = gettimeset(hour, minute, second) timeset = gettimeset(hour, minute, second)
if fixday and day > 28: if fixday and day > 28:
@@ -630,6 +780,80 @@ class rrule(rrulebase):
daysinmonth = calendar.monthrange(year, month)[1] daysinmonth = calendar.monthrange(year, month)[1]
ii.rebuild(year, month) ii.rebuild(year, month)
def __construct_byset(self, start, byxxx, base):
"""
If a `BYXXX` sequence is passed to the constructor at the same level as
`FREQ` (e.g. `FREQ=HOURLY,BYHOUR={2,4,7},INTERVAL=3`), there are some
specifications which cannot be reached given some starting conditions.
This occurs whenever the interval is not coprime with the base of a
given unit and the difference between the starting position and the
ending position is not coprime with the greatest common denominator
between the interval and the base. For example, with a FREQ of hourly
starting at 17:00 and an interval of 4, the only valid values for
BYHOUR would be {21, 1, 5, 9, 13, 17}, because 4 and 24 are not
coprime.
:param:`start` specifies the starting position.
:param:`byxxx` is an iterable containing the list of allowed values.
:param:`base` is the largest allowable value for the specified
frequency (e.g. 24 hours, 60 minutes).
This does not preserve the type of the iterable, returning a set, since
the values should be unique and the order is irrelevant, this will
speed up later lookups.
In the event of an empty set, raises a :exception:`ValueError`, as this
results in an empty rrule.
"""
cset = set()
# Support a single byxxx value.
if isinstance(byxxx, integer_types):
byxxx = (byxxx)
for num in byxxx:
i_gcd = gcd(self._interval, base)
# Use divmod rather than % because we need to wrap negative nums.
if i_gcd == 1 or divmod(num - start, i_gcd)[1] == 0:
cset.add(num)
if len(cset) == 0:
raise ValueError("Invalid rrule byxxx generates an empty set.")
return cset
def __mod_distance(self, value, byxxx, base):
"""
Calculates the next value in a sequence where the `FREQ` parameter is
specified along with a `BYXXX` parameter at the same "level"
(e.g. `HOURLY` specified with `BYHOUR`).
:param:`value` is the old value of the component.
:param:`byxxx` is the `BYXXX` set, which should have been generated
by `rrule._construct_byset`, or something else which
checks that a valid rule is present.
:param:`base` is the largest allowable value for the specified
frequency (e.g. 24 hours, 60 minutes).
If a valid value is not found after `base` iterations (the maximum
number before the sequence would start to repeat), this raises a
:exception:`ValueError`, as no valid values were found.
This returns a tuple of `divmod(n*interval, base)`, where `n` is the
smallest number of `interval` repetitions until the next specified
value in `byxxx` is found.
"""
accumulator = 0
for ii in range(1, base + 1):
# Using divmod() over % to account for negative intervals
div, value = divmod(value + self._interval, base)
accumulator += div
if value in byxxx:
return (accumulator, value)
class _iterinfo(object): class _iterinfo(object):
__slots__ = ["rrule", "lastyear", "lastmonth", __slots__ = ["rrule", "lastyear", "lastmonth",
"yearlen", "nextyearlen", "yearordinal", "yearweekday", "yearlen", "nextyearlen", "yearordinal", "yearweekday",
@@ -669,13 +893,13 @@ class _iterinfo(object):
self.wnomask = None self.wnomask = None
else: else:
self.wnomask = [0]*(self.yearlen+7) self.wnomask = [0]*(self.yearlen+7)
#no1wkst = firstwkst = self.wdaymask.index(rr._wkst) # no1wkst = firstwkst = self.wdaymask.index(rr._wkst)
no1wkst = firstwkst = (7-self.yearweekday+rr._wkst)%7 no1wkst = firstwkst = (7-self.yearweekday+rr._wkst) % 7
if no1wkst >= 4: if no1wkst >= 4:
no1wkst = 0 no1wkst = 0
# Number of days in the year, plus the days we got # Number of days in the year, plus the days we got
# from last year. # from last year.
wyearlen = self.yearlen+(self.yearweekday-rr._wkst)%7 wyearlen = self.yearlen+(self.yearweekday-rr._wkst) % 7
else: else:
# Number of days in the year, minus the days we # Number of days in the year, minus the days we
# left in last year. # left in last year.
@@ -721,22 +945,22 @@ class _iterinfo(object):
# this year. # this year.
if -1 not in rr._byweekno: if -1 not in rr._byweekno:
lyearweekday = datetime.date(year-1, 1, 1).weekday() lyearweekday = datetime.date(year-1, 1, 1).weekday()
lno1wkst = (7-lyearweekday+rr._wkst)%7 lno1wkst = (7-lyearweekday+rr._wkst) % 7
lyearlen = 365+calendar.isleap(year-1) lyearlen = 365+calendar.isleap(year-1)
if lno1wkst >= 4: if lno1wkst >= 4:
lno1wkst = 0 lno1wkst = 0
lnumweeks = 52+(lyearlen+ lnumweeks = 52+(lyearlen +
(lyearweekday-rr._wkst)%7)%7//4 (lyearweekday-rr._wkst) % 7) % 7//4
else: else:
lnumweeks = 52+(self.yearlen-no1wkst)%7//4 lnumweeks = 52+(self.yearlen-no1wkst) % 7//4
else: else:
lnumweeks = -1 lnumweeks = -1
if lnumweeks in rr._byweekno: if lnumweeks in rr._byweekno:
for i in range(no1wkst): for i in range(no1wkst):
self.wnomask[i] = 1 self.wnomask[i] = 1
if (rr._bynweekday and if (rr._bynweekday and (month != self.lastmonth or
(month != self.lastmonth or year != self.lastyear)): year != self.lastyear)):
ranges = [] ranges = []
if rr._freq == YEARLY: if rr._freq == YEARLY:
if rr._bymonth: if rr._bymonth:
@@ -755,10 +979,10 @@ class _iterinfo(object):
for wday, n in rr._bynweekday: for wday, n in rr._bynweekday:
if n < 0: if n < 0:
i = last+(n+1)*7 i = last+(n+1)*7
i -= (self.wdaymask[i]-wday)%7 i -= (self.wdaymask[i]-wday) % 7
else: else:
i = first+(n-1)*7 i = first+(n-1)*7
i += (7-self.wdaymask[i]+wday)%7 i += (7-self.wdaymask[i]+wday) % 7
if first <= i <= last: if first <= i <= last:
self.nwdaymask[i] = 1 self.nwdaymask[i] = 1
@@ -775,50 +999,50 @@ class _iterinfo(object):
return list(range(self.yearlen)), 0, self.yearlen return list(range(self.yearlen)), 0, self.yearlen
def mdayset(self, year, month, day): def mdayset(self, year, month, day):
set = [None]*self.yearlen dset = [None]*self.yearlen
start, end = self.mrange[month-1:month+1] start, end = self.mrange[month-1:month+1]
for i in range(start, end): for i in range(start, end):
set[i] = i dset[i] = i
return set, start, end return dset, start, end
def wdayset(self, year, month, day): def wdayset(self, year, month, day):
# We need to handle cross-year weeks here. # We need to handle cross-year weeks here.
set = [None]*(self.yearlen+7) dset = [None]*(self.yearlen+7)
i = datetime.date(year, month, day).toordinal()-self.yearordinal i = datetime.date(year, month, day).toordinal()-self.yearordinal
start = i start = i
for j in range(7): for j in range(7):
set[i] = i dset[i] = i
i += 1 i += 1
#if (not (0 <= i < self.yearlen) or # if (not (0 <= i < self.yearlen) or
# self.wdaymask[i] == self.rrule._wkst): # self.wdaymask[i] == self.rrule._wkst):
# This will cross the year boundary, if necessary. # This will cross the year boundary, if necessary.
if self.wdaymask[i] == self.rrule._wkst: if self.wdaymask[i] == self.rrule._wkst:
break break
return set, start, i return dset, start, i
def ddayset(self, year, month, day): def ddayset(self, year, month, day):
set = [None]*self.yearlen dset = [None]*self.yearlen
i = datetime.date(year, month, day).toordinal()-self.yearordinal i = datetime.date(year, month, day).toordinal()-self.yearordinal
set[i] = i dset[i] = i
return set, i, i+1 return dset, i, i+1
def htimeset(self, hour, minute, second): def htimeset(self, hour, minute, second):
set = [] tset = []
rr = self.rrule rr = self.rrule
for minute in rr._byminute: for minute in rr._byminute:
for second in rr._bysecond: for second in rr._bysecond:
set.append(datetime.time(hour, minute, second, tset.append(datetime.time(hour, minute, second,
tzinfo=rr._tzinfo)) tzinfo=rr._tzinfo))
set.sort() tset.sort()
return set return tset
def mtimeset(self, hour, minute, second): def mtimeset(self, hour, minute, second):
set = [] tset = []
rr = self.rrule rr = self.rrule
for second in rr._bysecond: for second in rr._bysecond:
set.append(datetime.time(hour, minute, second, tzinfo=rr._tzinfo)) tset.append(datetime.time(hour, minute, second, tzinfo=rr._tzinfo))
set.sort() tset.sort()
return set return tset
def stimeset(self, hour, minute, second): def stimeset(self, hour, minute, second):
return (datetime.time(hour, minute, second, return (datetime.time(hour, minute, second,
@@ -826,6 +1050,12 @@ class _iterinfo(object):
class rruleset(rrulebase): class rruleset(rrulebase):
""" The rruleset type allows more complex recurrence setups, mixing
multiple rules, dates, exclusion rules, and exclusion dates. The type
constructor takes the following keyword arguments:
:param cache: If True, caching of results will be enabled, improving
performance of multiple queries considerably. """
class _genitem(object): class _genitem(object):
def __init__(self, genlist, gen): def __init__(self, genlist, gen):
@@ -865,15 +1095,26 @@ class rruleset(rrulebase):
self._exdate = [] self._exdate = []
def rrule(self, rrule): def rrule(self, rrule):
""" Include the given :py:class:`rrule` instance in the recurrence set
generation. """
self._rrule.append(rrule) self._rrule.append(rrule)
def rdate(self, rdate): def rdate(self, rdate):
""" Include the given :py:class:`datetime` instance in the recurrence
set generation. """
self._rdate.append(rdate) self._rdate.append(rdate)
def exrule(self, exrule): def exrule(self, exrule):
""" Include the given rrule instance in the recurrence set exclusion
list. Dates which are part of the given recurrence rules will not
be generated, even if some inclusive rrule or rdate matches them.
"""
self._exrule.append(exrule) self._exrule.append(exrule)
def exdate(self, exdate): def exdate(self, exdate):
""" Include the given datetime instance in the recurrence set
exclusion list. Dates included that way will not be generated,
even if some inclusive rrule or rdate matches them. """
self._exdate.append(exdate) self._exdate.append(exdate)
def _iter(self): def _iter(self):
@@ -905,6 +1146,7 @@ class rruleset(rrulebase):
rlist.sort() rlist.sort()
self._len = total self._len = total
class _rrulestr(object): class _rrulestr(object):
_freq_map = {"YEARLY": YEARLY, _freq_map = {"YEARLY": YEARLY,
@@ -915,7 +1157,8 @@ class _rrulestr(object):
"MINUTELY": MINUTELY, "MINUTELY": MINUTELY,
"SECONDLY": SECONDLY} "SECONDLY": SECONDLY}
_weekday_map = {"MO":0,"TU":1,"WE":2,"TH":3,"FR":4,"SA":5,"SU":6} _weekday_map = {"MO": 0, "TU": 1, "WE": 2, "TH": 3,
"FR": 4, "SA": 5, "SU": 6}
def _handle_int(self, rrkwargs, name, value, **kwargs): def _handle_int(self, rrkwargs, name, value, **kwargs):
rrkwargs[name.lower()] = int(value) rrkwargs[name.lower()] = int(value)
@@ -923,17 +1166,17 @@ class _rrulestr(object):
def _handle_int_list(self, rrkwargs, name, value, **kwargs): def _handle_int_list(self, rrkwargs, name, value, **kwargs):
rrkwargs[name.lower()] = [int(x) for x in value.split(',')] rrkwargs[name.lower()] = [int(x) for x in value.split(',')]
_handle_INTERVAL = _handle_int _handle_INTERVAL = _handle_int
_handle_COUNT = _handle_int _handle_COUNT = _handle_int
_handle_BYSETPOS = _handle_int_list _handle_BYSETPOS = _handle_int_list
_handle_BYMONTH = _handle_int_list _handle_BYMONTH = _handle_int_list
_handle_BYMONTHDAY = _handle_int_list _handle_BYMONTHDAY = _handle_int_list
_handle_BYYEARDAY = _handle_int_list _handle_BYYEARDAY = _handle_int_list
_handle_BYEASTER = _handle_int_list _handle_BYEASTER = _handle_int_list
_handle_BYWEEKNO = _handle_int_list _handle_BYWEEKNO = _handle_int_list
_handle_BYHOUR = _handle_int_list _handle_BYHOUR = _handle_int_list
_handle_BYMINUTE = _handle_int_list _handle_BYMINUTE = _handle_int_list
_handle_BYSECOND = _handle_int_list _handle_BYSECOND = _handle_int_list
def _handle_FREQ(self, rrkwargs, name, value, **kwargs): def _handle_FREQ(self, rrkwargs, name, value, **kwargs):
rrkwargs["freq"] = self._freq_map[value] rrkwargs["freq"] = self._freq_map[value]
@@ -944,8 +1187,8 @@ class _rrulestr(object):
from dateutil import parser from dateutil import parser
try: try:
rrkwargs["until"] = parser.parse(value, rrkwargs["until"] = parser.parse(value,
ignoretz=kwargs.get("ignoretz"), ignoretz=kwargs.get("ignoretz"),
tzinfos=kwargs.get("tzinfos")) tzinfos=kwargs.get("tzinfos"))
except ValueError: except ValueError:
raise ValueError("invalid until date") raise ValueError("invalid until date")
@@ -960,7 +1203,8 @@ class _rrulestr(object):
break break
n = wday[:i] or None n = wday[:i] or None
w = wday[i:] w = wday[i:]
if n: n = int(n) if n:
n = int(n)
l.append(weekdays[self._weekday_map[w]](n)) l.append(weekdays[self._weekday_map[w]](n))
rrkwargs["byweekday"] = l rrkwargs["byweekday"] = l
@@ -1021,8 +1265,8 @@ class _rrulestr(object):
i += 1 i += 1
else: else:
lines = s.split() lines = s.split()
if (not forceset and len(lines) == 1 and if (not forceset and len(lines) == 1 and (s.find(':') == -1 or
(s.find(':') == -1 or s.startswith('RRULE:'))): s.startswith('RRULE:'))):
return self._parse_rfc_rrule(lines[0], cache=cache, return self._parse_rfc_rrule(lines[0], cache=cache,
dtstart=dtstart, ignoretz=ignoretz, dtstart=dtstart, ignoretz=ignoretz,
tzinfos=tzinfos) tzinfos=tzinfos)
@@ -1071,32 +1315,32 @@ class _rrulestr(object):
tzinfos=tzinfos) tzinfos=tzinfos)
else: else:
raise ValueError("unsupported property: "+name) raise ValueError("unsupported property: "+name)
if (forceset or len(rrulevals) > 1 or if (forceset or len(rrulevals) > 1 or rdatevals
rdatevals or exrulevals or exdatevals): or exrulevals or exdatevals):
if not parser and (rdatevals or exdatevals): if not parser and (rdatevals or exdatevals):
from dateutil import parser from dateutil import parser
set = rruleset(cache=cache) rset = rruleset(cache=cache)
for value in rrulevals: for value in rrulevals:
set.rrule(self._parse_rfc_rrule(value, dtstart=dtstart, rset.rrule(self._parse_rfc_rrule(value, dtstart=dtstart,
ignoretz=ignoretz,
tzinfos=tzinfos))
for value in rdatevals:
for datestr in value.split(','):
set.rdate(parser.parse(datestr,
ignoretz=ignoretz,
tzinfos=tzinfos))
for value in exrulevals:
set.exrule(self._parse_rfc_rrule(value, dtstart=dtstart,
ignoretz=ignoretz, ignoretz=ignoretz,
tzinfos=tzinfos)) tzinfos=tzinfos))
for value in exdatevals: for value in rdatevals:
for datestr in value.split(','): for datestr in value.split(','):
set.exdate(parser.parse(datestr, rset.rdate(parser.parse(datestr,
ignoretz=ignoretz, ignoretz=ignoretz,
tzinfos=tzinfos)) tzinfos=tzinfos))
for value in exrulevals:
rset.exrule(self._parse_rfc_rrule(value, dtstart=dtstart,
ignoretz=ignoretz,
tzinfos=tzinfos))
for value in exdatevals:
for datestr in value.split(','):
rset.exdate(parser.parse(datestr,
ignoretz=ignoretz,
tzinfos=tzinfos))
if compatible and dtstart: if compatible and dtstart:
set.rdate(dtstart) rset.rdate(dtstart)
return set return rset
else: else:
return self._parse_rfc_rrule(rrulevals[0], return self._parse_rfc_rrule(rrulevals[0],
dtstart=dtstart, dtstart=dtstart,

View File

@@ -1,19 +1,25 @@
# -*- coding: utf-8 -*-
""" """
Copyright (c) 2003-2007 Gustavo Niemeyer <gustavo@niemeyer.net> This module offers timezone implementations subclassing the abstract
:py:`datetime.tzinfo` type. There are classes to handle tzfile format files
This module offers extensions to the standard Python (usually are in :file:`/etc/localtime`, :file:`/usr/share/zoneinfo`, etc), TZ
datetime module. environment string (in all known formats), given ranges (with help from
relative deltas), local machine timezone, fixed offset timezone, and UTC
timezone.
""" """
__license__ = "Simplified BSD"
from six import string_types, PY3
import datetime import datetime
import struct import struct
import time import time
import sys import sys
import os import os
from six import string_types, PY3
try:
from dateutil.tzwin import tzwin, tzwinlocal
except ImportError:
tzwin = tzwinlocal = None
relativedelta = None relativedelta = None
parser = None parser = None
rrule = None rrule = None
@@ -21,10 +27,6 @@ rrule = None
__all__ = ["tzutc", "tzoffset", "tzlocal", "tzfile", "tzrange", __all__ = ["tzutc", "tzoffset", "tzlocal", "tzfile", "tzrange",
"tzstr", "tzical", "tzwin", "tzwinlocal", "gettz"] "tzstr", "tzical", "tzwin", "tzwinlocal", "gettz"]
try:
from dateutil.tzwin import tzwin, tzwinlocal
except (ImportError, OSError):
tzwin, tzwinlocal = None, None
def tzname_in_python2(myfunc): def tzname_in_python2(myfunc):
"""Change unicode output into bytestrings in Python 2 """Change unicode output into bytestrings in Python 2
@@ -42,6 +44,7 @@ def tzname_in_python2(myfunc):
ZERO = datetime.timedelta(0) ZERO = datetime.timedelta(0)
EPOCHORDINAL = datetime.datetime.utcfromtimestamp(0).toordinal() EPOCHORDINAL = datetime.datetime.utcfromtimestamp(0).toordinal()
class tzutc(datetime.tzinfo): class tzutc(datetime.tzinfo):
def utcoffset(self, dt): def utcoffset(self, dt):
@@ -66,6 +69,7 @@ class tzutc(datetime.tzinfo):
__reduce__ = object.__reduce__ __reduce__ = object.__reduce__
class tzoffset(datetime.tzinfo): class tzoffset(datetime.tzinfo):
def __init__(self, name, offset): def __init__(self, name, offset):
@@ -96,6 +100,7 @@ class tzoffset(datetime.tzinfo):
__reduce__ = object.__reduce__ __reduce__ = object.__reduce__
class tzlocal(datetime.tzinfo): class tzlocal(datetime.tzinfo):
_std_offset = datetime.timedelta(seconds=-time.timezone) _std_offset = datetime.timedelta(seconds=-time.timezone)
@@ -130,18 +135,18 @@ class tzlocal(datetime.tzinfo):
# #
# The code above yields the following result: # The code above yields the following result:
# #
#>>> import tz, datetime # >>> import tz, datetime
#>>> t = tz.tzlocal() # >>> t = tz.tzlocal()
#>>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname()
#'BRDT' # 'BRDT'
#>>> datetime.datetime(2003,2,16,0,tzinfo=t).tzname() # >>> datetime.datetime(2003,2,16,0,tzinfo=t).tzname()
#'BRST' # 'BRST'
#>>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname()
#'BRST' # 'BRST'
#>>> datetime.datetime(2003,2,15,22,tzinfo=t).tzname() # >>> datetime.datetime(2003,2,15,22,tzinfo=t).tzname()
#'BRDT' # 'BRDT'
#>>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname()
#'BRDT' # 'BRDT'
# #
# Here is a more stable implementation: # Here is a more stable implementation:
# #
@@ -166,6 +171,7 @@ class tzlocal(datetime.tzinfo):
__reduce__ = object.__reduce__ __reduce__ = object.__reduce__
class _ttinfo(object): class _ttinfo(object):
__slots__ = ["offset", "delta", "isdst", "abbr", "isstd", "isgmt"] __slots__ = ["offset", "delta", "isdst", "abbr", "isstd", "isgmt"]
@@ -205,15 +211,20 @@ class _ttinfo(object):
if name in state: if name in state:
setattr(self, name, state[name]) setattr(self, name, state[name])
class tzfile(datetime.tzinfo): class tzfile(datetime.tzinfo):
# http://www.twinsun.com/tz/tz-link.htm # http://www.twinsun.com/tz/tz-link.htm
# ftp://ftp.iana.org/tz/tz*.tar.gz # ftp://ftp.iana.org/tz/tz*.tar.gz
def __init__(self, fileobj): def __init__(self, fileobj, filename=None):
file_opened_here = False
if isinstance(fileobj, string_types): if isinstance(fileobj, string_types):
self._filename = fileobj self._filename = fileobj
fileobj = open(fileobj, 'rb') fileobj = open(fileobj, 'rb')
file_opened_here = True
elif filename is not None:
self._filename = filename
elif hasattr(fileobj, "name"): elif hasattr(fileobj, "name"):
self._filename = fileobj.name self._filename = fileobj.name
else: else:
@@ -228,125 +239,128 @@ class tzfile(datetime.tzinfo):
# six four-byte values of type long, written in a # six four-byte values of type long, written in a
# ``standard'' byte order (the high-order byte # ``standard'' byte order (the high-order byte
# of the value is written first). # of the value is written first).
try:
if fileobj.read(4).decode() != "TZif":
raise ValueError("magic not found")
if fileobj.read(4).decode() != "TZif": fileobj.read(16)
raise ValueError("magic not found")
fileobj.read(16) (
# The number of UTC/local indicators stored in the file.
ttisgmtcnt,
( # The number of standard/wall indicators stored in the file.
# The number of UTC/local indicators stored in the file. ttisstdcnt,
ttisgmtcnt,
# The number of standard/wall indicators stored in the file. # The number of leap seconds for which data is
ttisstdcnt, # stored in the file.
leapcnt,
# The number of leap seconds for which data is # The number of "transition times" for which data
# stored in the file. # is stored in the file.
leapcnt, timecnt,
# The number of "transition times" for which data # The number of "local time types" for which data
# is stored in the file. # is stored in the file (must not be zero).
timecnt, typecnt,
# The number of "local time types" for which data # The number of characters of "time zone
# is stored in the file (must not be zero). # abbreviation strings" stored in the file.
typecnt, charcnt,
# The number of characters of "time zone ) = struct.unpack(">6l", fileobj.read(24))
# abbreviation strings" stored in the file.
charcnt,
) = struct.unpack(">6l", fileobj.read(24)) # The above header is followed by tzh_timecnt four-byte
# values of type long, sorted in ascending order.
# These values are written in ``standard'' byte order.
# Each is used as a transition time (as returned by
# time(2)) at which the rules for computing local time
# change.
# The above header is followed by tzh_timecnt four-byte if timecnt:
# values of type long, sorted in ascending order. self._trans_list = struct.unpack(">%dl" % timecnt,
# These values are written in ``standard'' byte order. fileobj.read(timecnt*4))
# Each is used as a transition time (as returned by else:
# time(2)) at which the rules for computing local time self._trans_list = []
# change.
if timecnt: # Next come tzh_timecnt one-byte values of type unsigned
self._trans_list = struct.unpack(">%dl" % timecnt, # char; each one tells which of the different types of
fileobj.read(timecnt*4)) # ``local time'' types described in the file is associated
else: # with the same-indexed transition time. These values
self._trans_list = [] # serve as indices into an array of ttinfo structures that
# appears next in the file.
# Next come tzh_timecnt one-byte values of type unsigned if timecnt:
# char; each one tells which of the different types of self._trans_idx = struct.unpack(">%dB" % timecnt,
# ``local time'' types described in the file is associated fileobj.read(timecnt))
# with the same-indexed transition time. These values else:
# serve as indices into an array of ttinfo structures that self._trans_idx = []
# appears next in the file.
if timecnt: # Each ttinfo structure is written as a four-byte value
self._trans_idx = struct.unpack(">%dB" % timecnt, # for tt_gmtoff of type long, in a standard byte
fileobj.read(timecnt)) # order, followed by a one-byte value for tt_isdst
else: # and a one-byte value for tt_abbrind. In each
self._trans_idx = [] # structure, tt_gmtoff gives the number of
# seconds to be added to UTC, tt_isdst tells whether
# tm_isdst should be set by localtime(3), and
# tt_abbrind serves as an index into the array of
# time zone abbreviation characters that follow the
# ttinfo structure(s) in the file.
# Each ttinfo structure is written as a four-byte value ttinfo = []
# for tt_gmtoff of type long, in a standard byte
# order, followed by a one-byte value for tt_isdst
# and a one-byte value for tt_abbrind. In each
# structure, tt_gmtoff gives the number of
# seconds to be added to UTC, tt_isdst tells whether
# tm_isdst should be set by localtime(3), and
# tt_abbrind serves as an index into the array of
# time zone abbreviation characters that follow the
# ttinfo structure(s) in the file.
ttinfo = [] for i in range(typecnt):
ttinfo.append(struct.unpack(">lbb", fileobj.read(6)))
for i in range(typecnt): abbr = fileobj.read(charcnt).decode()
ttinfo.append(struct.unpack(">lbb", fileobj.read(6)))
abbr = fileobj.read(charcnt).decode() # Then there are tzh_leapcnt pairs of four-byte
# values, written in standard byte order; the
# first value of each pair gives the time (as
# returned by time(2)) at which a leap second
# occurs; the second gives the total number of
# leap seconds to be applied after the given time.
# The pairs of values are sorted in ascending order
# by time.
# Then there are tzh_leapcnt pairs of four-byte # Not used, for now
# values, written in standard byte order; the # if leapcnt:
# first value of each pair gives the time (as # leap = struct.unpack(">%dl" % (leapcnt*2),
# returned by time(2)) at which a leap second # fileobj.read(leapcnt*8))
# occurs; the second gives the total number of
# leap seconds to be applied after the given time.
# The pairs of values are sorted in ascending order
# by time.
# Not used, for now # Then there are tzh_ttisstdcnt standard/wall
if leapcnt: # indicators, each stored as a one-byte value;
leap = struct.unpack(">%dl" % (leapcnt*2), # they tell whether the transition times associated
fileobj.read(leapcnt*8)) # with local time types were specified as standard
# time or wall clock time, and are used when
# a time zone file is used in handling POSIX-style
# time zone environment variables.
# Then there are tzh_ttisstdcnt standard/wall if ttisstdcnt:
# indicators, each stored as a one-byte value; isstd = struct.unpack(">%db" % ttisstdcnt,
# they tell whether the transition times associated fileobj.read(ttisstdcnt))
# with local time types were specified as standard
# time or wall clock time, and are used when
# a time zone file is used in handling POSIX-style
# time zone environment variables.
if ttisstdcnt: # Finally, there are tzh_ttisgmtcnt UTC/local
isstd = struct.unpack(">%db" % ttisstdcnt, # indicators, each stored as a one-byte value;
fileobj.read(ttisstdcnt)) # they tell whether the transition times associated
# with local time types were specified as UTC or
# local time, and are used when a time zone file
# is used in handling POSIX-style time zone envi-
# ronment variables.
# Finally, there are tzh_ttisgmtcnt UTC/local if ttisgmtcnt:
# indicators, each stored as a one-byte value; isgmt = struct.unpack(">%db" % ttisgmtcnt,
# they tell whether the transition times associated fileobj.read(ttisgmtcnt))
# with local time types were specified as UTC or
# local time, and are used when a time zone file
# is used in handling POSIX-style time zone envi-
# ronment variables.
if ttisgmtcnt: # ** Everything has been read **
isgmt = struct.unpack(">%db" % ttisgmtcnt, finally:
fileobj.read(ttisgmtcnt)) if file_opened_here:
fileobj.close()
# ** Everything has been read **
# Build ttinfo list # Build ttinfo list
self._ttinfo_list = [] self._ttinfo_list = []
for i in range(typecnt): for i in range(typecnt):
gmtoff, isdst, abbrind = ttinfo[i] gmtoff, isdst, abbrind = ttinfo[i]
# Round to full-minutes if that's not the case. Python's # Round to full-minutes if that's not the case. Python's
# datetime doesn't accept sub-minute timezones. Check # datetime doesn't accept sub-minute timezones. Check
# http://python.org/sf/1447945 for some information. # http://python.org/sf/1447945 for some information.
@@ -481,7 +495,6 @@ class tzfile(datetime.tzinfo):
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __repr__(self): def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, repr(self._filename)) return "%s(%s)" % (self.__class__.__name__, repr(self._filename))
@@ -490,8 +503,8 @@ class tzfile(datetime.tzinfo):
raise ValueError("Unpickable %s class" % self.__class__.__name__) raise ValueError("Unpickable %s class" % self.__class__.__name__)
return (self.__class__, (self._filename,)) return (self.__class__, (self._filename,))
class tzrange(datetime.tzinfo):
class tzrange(datetime.tzinfo):
def __init__(self, stdabbr, stdoffset=None, def __init__(self, stdabbr, stdoffset=None,
dstabbr=None, dstoffset=None, dstabbr=None, dstoffset=None,
start=None, end=None): start=None, end=None):
@@ -512,12 +525,12 @@ class tzrange(datetime.tzinfo):
self._dst_offset = ZERO self._dst_offset = ZERO
if dstabbr and start is None: if dstabbr and start is None:
self._start_delta = relativedelta.relativedelta( self._start_delta = relativedelta.relativedelta(
hours=+2, month=4, day=1, weekday=relativedelta.SU(+1)) hours=+2, month=4, day=1, weekday=relativedelta.SU(+1))
else: else:
self._start_delta = start self._start_delta = start
if dstabbr and end is None: if dstabbr and end is None:
self._end_delta = relativedelta.relativedelta( self._end_delta = relativedelta.relativedelta(
hours=+1, month=10, day=31, weekday=relativedelta.SU(-1)) hours=+1, month=10, day=31, weekday=relativedelta.SU(-1))
else: else:
self._end_delta = end self._end_delta = end
@@ -570,6 +583,7 @@ class tzrange(datetime.tzinfo):
__reduce__ = object.__reduce__ __reduce__ = object.__reduce__
class tzstr(tzrange): class tzstr(tzrange):
def __init__(self, s): def __init__(self, s):
@@ -645,9 +659,10 @@ class tzstr(tzrange):
def __repr__(self): def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, repr(self._s)) return "%s(%s)" % (self.__class__.__name__, repr(self._s))
class _tzicalvtzcomp(object): class _tzicalvtzcomp(object):
def __init__(self, tzoffsetfrom, tzoffsetto, isdst, def __init__(self, tzoffsetfrom, tzoffsetto, isdst,
tzname=None, rrule=None): tzname=None, rrule=None):
self.tzoffsetfrom = datetime.timedelta(seconds=tzoffsetfrom) self.tzoffsetfrom = datetime.timedelta(seconds=tzoffsetfrom)
self.tzoffsetto = datetime.timedelta(seconds=tzoffsetto) self.tzoffsetto = datetime.timedelta(seconds=tzoffsetto)
self.tzoffsetdiff = self.tzoffsetto-self.tzoffsetfrom self.tzoffsetdiff = self.tzoffsetto-self.tzoffsetfrom
@@ -655,6 +670,7 @@ class _tzicalvtzcomp(object):
self.tzname = tzname self.tzname = tzname
self.rrule = rrule self.rrule = rrule
class _tzicalvtz(datetime.tzinfo): class _tzicalvtz(datetime.tzinfo):
def __init__(self, tzid, comps=[]): def __init__(self, tzid, comps=[]):
self._tzid = tzid self._tzid = tzid
@@ -718,6 +734,7 @@ class _tzicalvtz(datetime.tzinfo):
__reduce__ = object.__reduce__ __reduce__ = object.__reduce__
class tzical(object): class tzical(object):
def __init__(self, fileobj): def __init__(self, fileobj):
global rrule global rrule
@@ -726,7 +743,8 @@ class tzical(object):
if isinstance(fileobj, string_types): if isinstance(fileobj, string_types):
self._s = fileobj self._s = fileobj
fileobj = open(fileobj, 'r') # ical should be encoded in UTF-8 with CRLF # ical should be encoded in UTF-8 with CRLF
fileobj = open(fileobj, 'r')
elif hasattr(fileobj, "name"): elif hasattr(fileobj, "name"):
self._s = fileobj.name self._s = fileobj.name
else: else:
@@ -754,7 +772,7 @@ class tzical(object):
if not s: if not s:
raise ValueError("empty offset") raise ValueError("empty offset")
if s[0] in ('+', '-'): if s[0] in ('+', '-'):
signal = (-1, +1)[s[0]=='+'] signal = (-1, +1)[s[0] == '+']
s = s[1:] s = s[1:]
else: else:
signal = +1 signal = +1
@@ -815,7 +833,8 @@ class tzical(object):
if not tzid: if not tzid:
raise ValueError("mandatory TZID not found") raise ValueError("mandatory TZID not found")
if not comps: if not comps:
raise ValueError("at least one component is needed") raise ValueError(
"at least one component is needed")
# Process vtimezone # Process vtimezone
self._vtz[tzid] = _tzicalvtz(tzid, comps) self._vtz[tzid] = _tzicalvtz(tzid, comps)
invtz = False invtz = False
@@ -823,9 +842,11 @@ class tzical(object):
if not founddtstart: if not founddtstart:
raise ValueError("mandatory DTSTART not found") raise ValueError("mandatory DTSTART not found")
if tzoffsetfrom is None: if tzoffsetfrom is None:
raise ValueError("mandatory TZOFFSETFROM not found") raise ValueError(
"mandatory TZOFFSETFROM not found")
if tzoffsetto is None: if tzoffsetto is None:
raise ValueError("mandatory TZOFFSETFROM not found") raise ValueError(
"mandatory TZOFFSETFROM not found")
# Process component # Process component
rr = None rr = None
if rrulelines: if rrulelines:
@@ -848,15 +869,18 @@ class tzical(object):
rrulelines.append(line) rrulelines.append(line)
elif name == "TZOFFSETFROM": elif name == "TZOFFSETFROM":
if parms: if parms:
raise ValueError("unsupported %s parm: %s "%(name, parms[0])) raise ValueError(
"unsupported %s parm: %s " % (name, parms[0]))
tzoffsetfrom = self._parse_offset(value) tzoffsetfrom = self._parse_offset(value)
elif name == "TZOFFSETTO": elif name == "TZOFFSETTO":
if parms: if parms:
raise ValueError("unsupported TZOFFSETTO parm: "+parms[0]) raise ValueError(
"unsupported TZOFFSETTO parm: "+parms[0])
tzoffsetto = self._parse_offset(value) tzoffsetto = self._parse_offset(value)
elif name == "TZNAME": elif name == "TZNAME":
if parms: if parms:
raise ValueError("unsupported TZNAME parm: "+parms[0]) raise ValueError(
"unsupported TZNAME parm: "+parms[0])
tzname = value tzname = value
elif name == "COMMENT": elif name == "COMMENT":
pass pass
@@ -865,7 +889,8 @@ class tzical(object):
else: else:
if name == "TZID": if name == "TZID":
if parms: if parms:
raise ValueError("unsupported TZID parm: "+parms[0]) raise ValueError(
"unsupported TZID parm: "+parms[0])
tzid = value tzid = value
elif name in ("TZURL", "LAST-MODIFIED", "COMMENT"): elif name in ("TZURL", "LAST-MODIFIED", "COMMENT"):
pass pass
@@ -886,6 +911,7 @@ else:
TZFILES = [] TZFILES = []
TZPATHS = [] TZPATHS = []
def gettz(name=None): def gettz(name=None):
tz = None tz = None
if not name: if not name:
@@ -933,11 +959,11 @@ def gettz(name=None):
pass pass
else: else:
tz = None tz = None
if tzwin: if tzwin is not None:
try: try:
tz = tzwin(name) tz = tzwin(name)
except OSError: except WindowsError:
pass tz = None
if not tz: if not tz:
from dateutil.zoneinfo import gettz from dateutil.zoneinfo import gettz
tz = gettz(name) tz = gettz(name)

View File

@@ -1,8 +1,8 @@
# This code was originally contributed by Jeffrey Harris. # This code was originally contributed by Jeffrey Harris.
import datetime import datetime
import struct import struct
import winreg
from six.moves import winreg
__all__ = ["tzwin", "tzwinlocal"] __all__ = ["tzwin", "tzwinlocal"]
@@ -12,8 +12,8 @@ TZKEYNAMENT = r"SOFTWARE\Microsoft\Windows NT\CurrentVersion\Time Zones"
TZKEYNAME9X = r"SOFTWARE\Microsoft\Windows\CurrentVersion\Time Zones" TZKEYNAME9X = r"SOFTWARE\Microsoft\Windows\CurrentVersion\Time Zones"
TZLOCALKEYNAME = r"SYSTEM\CurrentControlSet\Control\TimeZoneInformation" TZLOCALKEYNAME = r"SYSTEM\CurrentControlSet\Control\TimeZoneInformation"
def _settzkeyname(): def _settzkeyname():
global TZKEYNAME
handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
try: try:
winreg.OpenKey(handle, TZKEYNAMENT).Close() winreg.OpenKey(handle, TZKEYNAMENT).Close()
@@ -21,8 +21,10 @@ def _settzkeyname():
except WindowsError: except WindowsError:
TZKEYNAME = TZKEYNAME9X TZKEYNAME = TZKEYNAME9X
handle.Close() handle.Close()
return TZKEYNAME
TZKEYNAME = _settzkeyname()
_settzkeyname()
class tzwinbase(datetime.tzinfo): class tzwinbase(datetime.tzinfo):
"""tzinfo class based on win32's timezones available in the registry.""" """tzinfo class based on win32's timezones available in the registry."""
@@ -61,6 +63,9 @@ class tzwinbase(datetime.tzinfo):
return self._display return self._display
def _isdst(self, dt): def _isdst(self, dt):
if not self._dstmonth:
# dstmonth == 0 signals the zone has no daylight saving time
return False
dston = picknthweekday(dt.year, self._dstmonth, self._dstdayofweek, dston = picknthweekday(dt.year, self._dstmonth, self._dstdayofweek,
self._dsthour, self._dstminute, self._dsthour, self._dstminute,
self._dstweeknumber) self._dstweeknumber)
@@ -78,11 +83,11 @@ class tzwin(tzwinbase):
def __init__(self, name): def __init__(self, name):
self._name = name self._name = name
handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) # multiple contexts only possible in 2.7 and 3.1, we still support 2.6
tzkey = winreg.OpenKey(handle, "%s\%s" % (TZKEYNAME, name)) with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle:
keydict = valuestodict(tzkey) with winreg.OpenKey(handle,
tzkey.Close() "%s\%s" % (TZKEYNAME, name)) as tzkey:
handle.Close() keydict = valuestodict(tzkey)
self._stdname = keydict["Std"].encode("iso-8859-1") self._stdname = keydict["Std"].encode("iso-8859-1")
self._dstname = keydict["Dlt"].encode("iso-8859-1") self._dstname = keydict["Dlt"].encode("iso-8859-1")
@@ -91,18 +96,20 @@ class tzwin(tzwinbase):
# See http://ww_winreg.jsiinc.com/SUBA/tip0300/rh0398.htm # See http://ww_winreg.jsiinc.com/SUBA/tip0300/rh0398.htm
tup = struct.unpack("=3l16h", keydict["TZI"]) tup = struct.unpack("=3l16h", keydict["TZI"])
self._stdoffset = -tup[0]-tup[1] # Bias + StandardBias * -1 self._stdoffset = -tup[0]-tup[1] # Bias + StandardBias * -1
self._dstoffset = self._stdoffset-tup[2] # + DaylightBias * -1 self._dstoffset = self._stdoffset-tup[2] # + DaylightBias * -1
# for the meaning see the win32 TIME_ZONE_INFORMATION structure docs
# http://msdn.microsoft.com/en-us/library/windows/desktop/ms725481(v=vs.85).aspx
(self._stdmonth, (self._stdmonth,
self._stddayofweek, # Sunday = 0 self._stddayofweek, # Sunday = 0
self._stdweeknumber, # Last = 5 self._stdweeknumber, # Last = 5
self._stdhour, self._stdhour,
self._stdminute) = tup[4:9] self._stdminute) = tup[4:9]
(self._dstmonth, (self._dstmonth,
self._dstdayofweek, # Sunday = 0 self._dstdayofweek, # Sunday = 0
self._dstweeknumber, # Last = 5 self._dstweeknumber, # Last = 5
self._dsthour, self._dsthour,
self._dstminute) = tup[12:17] self._dstminute) = tup[12:17]
@@ -117,58 +124,56 @@ class tzwinlocal(tzwinbase):
def __init__(self): def __init__(self):
handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle:
tzlocalkey = winreg.OpenKey(handle, TZLOCALKEYNAME) with winreg.OpenKey(handle, TZLOCALKEYNAME) as tzlocalkey:
keydict = valuestodict(tzlocalkey) keydict = valuestodict(tzlocalkey)
tzlocalkey.Close()
self._stdname = keydict["StandardName"].encode("iso-8859-1") self._stdname = keydict["StandardName"].encode("iso-8859-1")
self._dstname = keydict["DaylightName"].encode("iso-8859-1") self._dstname = keydict["DaylightName"].encode("iso-8859-1")
try: try:
tzkey = winreg.OpenKey(handle, "%s\%s"%(TZKEYNAME, self._stdname)) with winreg.OpenKey(
_keydict = valuestodict(tzkey) handle, "%s\%s" % (TZKEYNAME, self._stdname)) as tzkey:
self._display = _keydict["Display"] _keydict = valuestodict(tzkey)
tzkey.Close() self._display = _keydict["Display"]
except OSError: except OSError:
self._display = None self._display = None
handle.Close()
self._stdoffset = -keydict["Bias"]-keydict["StandardBias"] self._stdoffset = -keydict["Bias"]-keydict["StandardBias"]
self._dstoffset = self._stdoffset-keydict["DaylightBias"] self._dstoffset = self._stdoffset-keydict["DaylightBias"]
# See http://ww_winreg.jsiinc.com/SUBA/tip0300/rh0398.htm # See http://ww_winreg.jsiinc.com/SUBA/tip0300/rh0398.htm
tup = struct.unpack("=8h", keydict["StandardStart"]) tup = struct.unpack("=8h", keydict["StandardStart"])
(self._stdmonth, (self._stdmonth,
self._stddayofweek, # Sunday = 0 self._stddayofweek, # Sunday = 0
self._stdweeknumber, # Last = 5 self._stdweeknumber, # Last = 5
self._stdhour, self._stdhour,
self._stdminute) = tup[1:6] self._stdminute) = tup[1:6]
tup = struct.unpack("=8h", keydict["DaylightStart"]) tup = struct.unpack("=8h", keydict["DaylightStart"])
(self._dstmonth, (self._dstmonth,
self._dstdayofweek, # Sunday = 0 self._dstdayofweek, # Sunday = 0
self._dstweeknumber, # Last = 5 self._dstweeknumber, # Last = 5
self._dsthour, self._dsthour,
self._dstminute) = tup[1:6] self._dstminute) = tup[1:6]
def __reduce__(self): def __reduce__(self):
return (self.__class__, ()) return (self.__class__, ())
def picknthweekday(year, month, dayofweek, hour, minute, whichweek): def picknthweekday(year, month, dayofweek, hour, minute, whichweek):
"""dayofweek == 0 means Sunday, whichweek 5 means last instance""" """dayofweek == 0 means Sunday, whichweek 5 means last instance"""
first = datetime.datetime(year, month, 1, hour, minute) first = datetime.datetime(year, month, 1, hour, minute)
weekdayone = first.replace(day=((dayofweek-first.isoweekday())%7+1)) weekdayone = first.replace(day=((dayofweek-first.isoweekday()) % 7+1))
for n in range(whichweek): for n in range(whichweek):
dt = weekdayone+(whichweek-n)*ONEWEEK dt = weekdayone+(whichweek-n)*ONEWEEK
if dt.month == month: if dt.month == month:
return dt return dt
def valuestodict(key): def valuestodict(key):
"""Convert a registry key's values to a dictionary.""" """Convert a registry key's values to a dictionary."""
dict = {} dict = {}

View File

@@ -1,109 +1,108 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""
Copyright (c) 2003-2005 Gustavo Niemeyer <gustavo@niemeyer.net>
This module offers extensions to the standard Python
datetime module.
"""
import logging import logging
import os import os
from subprocess import call import warnings
import tempfile
import shutil
from subprocess import check_call
from tarfile import TarFile from tarfile import TarFile
from pkgutil import get_data
from io import BytesIO
from contextlib import closing
from dateutil.tz import tzfile from dateutil.tz import tzfile
__author__ = "Tomi Pieviläinen <tomi.pievilainen@iki.fi>"
__license__ = "Simplified BSD"
__all__ = ["setcachesize", "gettz", "rebuild"] __all__ = ["setcachesize", "gettz", "rebuild"]
CACHE = [] _ZONEFILENAME = "dateutil-zoneinfo.tar.gz"
CACHESIZE = 10
# python2.6 compatability. Note that TarFile.__exit__ != TarFile.close, but
# it's close enough for python2.6
_tar_open = TarFile.open
if not hasattr(TarFile, '__exit__'):
def _tar_open(*args, **kwargs):
return closing(TarFile.open(*args, **kwargs))
class tzfile(tzfile): class tzfile(tzfile):
def __reduce__(self): def __reduce__(self):
return (gettz, (self._filename,)) return (gettz, (self._filename,))
def getzoneinfofile():
filenames = sorted(os.listdir(os.path.join(os.path.dirname(__file__))))
filenames.reverse()
for entry in filenames:
if entry.startswith("zoneinfo") and ".tar." in entry:
return os.path.join(os.path.dirname(__file__), entry)
return None
ZONEINFOFILE = getzoneinfofile() def getzoneinfofile_stream():
try:
return BytesIO(get_data(__name__, _ZONEFILENAME))
except IOError as e: # TODO switch to FileNotFoundError?
warnings.warn("I/O error({0}): {1}".format(e.errno, e.strerror))
return None
del getzoneinfofile
def setcachesize(size): class ZoneInfoFile(object):
global CACHESIZE, CACHE def __init__(self, zonefile_stream=None):
CACHESIZE = size if zonefile_stream is not None:
del CACHE[size:] with _tar_open(fileobj=zonefile_stream, mode='r') as tf:
# dict comprehension does not work on python2.6
# TODO: get back to the nicer syntax when we ditch python2.6
# self.zones = {zf.name: tzfile(tf.extractfile(zf),
# filename = zf.name)
# for zf in tf.getmembers() if zf.isfile()}
self.zones = dict((zf.name, tzfile(tf.extractfile(zf),
filename=zf.name))
for zf in tf.getmembers() if zf.isfile())
# deal with links: They'll point to their parent object. Less
# waste of memory
# links = {zl.name: self.zones[zl.linkname]
# for zl in tf.getmembers() if zl.islnk() or zl.issym()}
links = dict((zl.name, self.zones[zl.linkname])
for zl in tf.getmembers() if
zl.islnk() or zl.issym())
self.zones.update(links)
else:
self.zones = dict()
# The current API has gettz as a module function, although in fact it taps into
# a stateful class. So as a workaround for now, without changing the API, we
# will create a new "global" class instance the first time a user requests a
# timezone. Ugly, but adheres to the api.
#
# TODO: deprecate this.
_CLASS_ZONE_INSTANCE = list()
def gettz(name): def gettz(name):
tzinfo = None if len(_CLASS_ZONE_INSTANCE) == 0:
if ZONEINFOFILE: _CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream()))
for cachedname, tzinfo in CACHE: return _CLASS_ZONE_INSTANCE[0].zones.get(name)
if cachedname == name:
break
else:
tf = TarFile.open(ZONEINFOFILE)
try:
zonefile = tf.extractfile(name)
except KeyError:
tzinfo = None
else:
tzinfo = tzfile(zonefile)
tf.close()
CACHE.insert(0, (name, tzinfo))
del CACHE[CACHESIZE:]
return tzinfo
def rebuild(filename, tag=None, format="gz"):
def rebuild(filename, tag=None, format="gz", zonegroups=[]):
"""Rebuild the internal timezone info in dateutil/zoneinfo/zoneinfo*tar* """Rebuild the internal timezone info in dateutil/zoneinfo/zoneinfo*tar*
filename is the timezone tarball from ftp.iana.org/tz. filename is the timezone tarball from ftp.iana.org/tz.
""" """
import tempfile, shutil
tmpdir = tempfile.mkdtemp() tmpdir = tempfile.mkdtemp()
zonedir = os.path.join(tmpdir, "zoneinfo") zonedir = os.path.join(tmpdir, "zoneinfo")
moduledir = os.path.dirname(__file__) moduledir = os.path.dirname(__file__)
if tag: tag = "-"+tag
targetname = "zoneinfo%s.tar.%s" % (tag, format)
try: try:
tf = TarFile.open(filename) with _tar_open(filename) as tf:
# The "backwards" zone file contains links to other files, so must be for name in zonegroups:
# processed as last
for name in sorted(tf.getnames(),
key=lambda k: k != "backward" and k or "z"):
if not (name.endswith(".sh") or
name.endswith(".tab") or
name == "leapseconds"):
tf.extract(name, tmpdir) tf.extract(name, tmpdir)
filepath = os.path.join(tmpdir, name) filepaths = [os.path.join(tmpdir, n) for n in zonegroups]
try: try:
# zic will return errors for nontz files in the package check_call(["zic", "-d", zonedir] + filepaths)
# such as the Makefile or README, so check_call cannot except OSError as e:
# be used (or at least extra checks would be needed) if e.errno == 2:
call(["zic", "-d", zonedir, filepath]) logging.error(
except OSError as e: "Could not find zic. Perhaps you need to install "
if e.errno == 2: "libc-bin or some other package that provides it, "
logging.error( "or it's not in your PATH?")
"Could not find zic. Perhaps you need to install "
"libc-bin or some other package that provides it, "
"or it's not in your PATH?")
raise raise
tf.close() target = os.path.join(moduledir, _ZONEFILENAME)
target = os.path.join(moduledir, targetname) with _tar_open(target, "w:%s" % format) as tf:
for entry in os.listdir(moduledir): for entry in os.listdir(zonedir):
if entry.startswith("zoneinfo") and ".tar." in entry: entrypath = os.path.join(zonedir, entry)
os.unlink(os.path.join(moduledir, entry)) tf.add(entrypath, entry)
tf = TarFile.open(target, "w:%s" % format)
for entry in os.listdir(zonedir):
entrypath = os.path.join(zonedir, entry)
tf.add(entrypath, entry)
tf.close()
finally: finally:
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)

View File

@@ -5,17 +5,27 @@ Keyring implementation support
from __future__ import absolute_import from __future__ import absolute_import
import abc import abc
import logging
try: try:
import importlib import importlib
except ImportError: except ImportError:
pass pass
try:
import pkg_resources
except ImportError:
pass
from . import errors, util from . import errors, util
from . import backends from . import backends
from .util import properties from .util import properties
from .py27compat import add_metaclass, filter from .py27compat import add_metaclass, filter
log = logging.getLogger(__name__)
class KeyringBackendMeta(abc.ABCMeta): class KeyringBackendMeta(abc.ABCMeta):
""" """
A metaclass that's both an ABCMeta and a type that keeps a registry of A metaclass that's both an ABCMeta and a type that keeps a registry of
@@ -127,6 +137,38 @@ def _load_backends():
backends = ('file', 'Gnome', 'Google', 'keyczar', 'kwallet', 'multi', backends = ('file', 'Gnome', 'Google', 'keyczar', 'kwallet', 'multi',
'OS_X', 'pyfs', 'SecretService', 'Windows') 'OS_X', 'pyfs', 'SecretService', 'Windows')
list(map(_load_backend, backends)) list(map(_load_backend, backends))
_load_plugins()
def _load_plugins():
"""
Locate all setuptools entry points by the name 'keyring backends'
and initialize them.
Any third-party library may register an entry point by adding the
following to their setup.py::
entry_points = {
'keyring backends': [
'plugin_name = mylib.mymodule:initialize_func',
],
},
`plugin_name` can be anything, and is only used to display the name
of the plugin at initialization time.
`initialize_func` is optional, but will be invoked if callable.
"""
if 'pkg_resources' not in globals():
return
group = 'keyring backends'
entry_points = pkg_resources.iter_entry_points(group=group)
for ep in entry_points:
try:
log.info('Loading %s', ep.name)
init_func = ep.load()
if callable(init_func):
init_func()
except Exception:
log.exception("Error initializing plugin %s." % ep)
@util.once @util.once
def get_all_keyring(): def get_all_keyring():

View File

@@ -10,7 +10,7 @@ called from the command line.
import markdown import markdown
html = markdown.markdown(your_text_string) html = markdown.markdown(your_text_string)
See <http://packages.python.org/Markdown/> for more See <https://pythonhosted.org/Markdown/> for more
information and instructions on how to extend the functionality of information and instructions on how to extend the functionality of
Python Markdown. Read that before you try modifying this file. Python Markdown. Read that before you try modifying this file.
@@ -36,6 +36,8 @@ from .__version__ import version, version_info
import codecs import codecs
import sys import sys
import logging import logging
import warnings
import importlib
from . import util from . import util
from .preprocessors import build_preprocessors from .preprocessors import build_preprocessors
from .blockprocessors import build_block_parser from .blockprocessors import build_block_parser
@@ -48,6 +50,7 @@ from .serializers import to_html_string, to_xhtml_string
__all__ = ['Markdown', 'markdown', 'markdownFromFile'] __all__ = ['Markdown', 'markdown', 'markdownFromFile']
logger = logging.getLogger('MARKDOWN') logger = logging.getLogger('MARKDOWN')
logging.captureWarnings(True)
class Markdown(object): class Markdown(object):
@@ -96,8 +99,8 @@ class Markdown(object):
Note that it is suggested that the more specific formats ("xhtml1" Note that it is suggested that the more specific formats ("xhtml1"
and "html4") be used as "xhtml" or "html" may change in the future and "html4") be used as "xhtml" or "html" may change in the future
if it makes sense at that time. if it makes sense at that time.
* safe_mode: Disallow raw html. One of "remove", "replace" or "escape". * safe_mode: Deprecated! Disallow raw html. One of "remove", "replace" or "escape".
* html_replacement_text: Text used when safe_mode is set to "replace". * html_replacement_text: Deprecated! Text used when safe_mode is set to "replace".
* tab_length: Length of tabs in the source. Default: 4 * tab_length: Length of tabs in the source. Default: 4
* enable_attributes: Enable the conversion of attributes. Default: True * enable_attributes: Enable the conversion of attributes. Default: True
* smart_emphasis: Treat `_connected_words_` intelligently Default: True * smart_emphasis: Treat `_connected_words_` intelligently Default: True
@@ -107,14 +110,16 @@ class Markdown(object):
# For backward compatibility, loop through old positional args # For backward compatibility, loop through old positional args
pos = ['extensions', 'extension_configs', 'safe_mode', 'output_format'] pos = ['extensions', 'extension_configs', 'safe_mode', 'output_format']
c = 0 for c, arg in enumerate(args):
for arg in args:
if pos[c] not in kwargs: if pos[c] not in kwargs:
kwargs[pos[c]] = arg kwargs[pos[c]] = arg
c += 1 if c+1 == len(pos): #pragma: no cover
if c == len(pos):
# ignore any additional args # ignore any additional args
break break
if len(args):
warnings.warn('Positional arguments are pending depreacted in Markdown '
'and will be deprecated in version 2.6. Use keyword '
'arguments only.', PendingDeprecationWarning)
# Loop through kwargs and assign defaults # Loop through kwargs and assign defaults
for option, default in self.option_defaults.items(): for option, default in self.option_defaults.items():
@@ -125,6 +130,18 @@ class Markdown(object):
# Disable attributes in safeMode when not explicitly set # Disable attributes in safeMode when not explicitly set
self.enable_attributes = False self.enable_attributes = False
if 'safe_mode' in kwargs:
warnings.warn('"safe_mode" is pending deprecation in Python-Markdown '
'and will be deprecated in version 2.6. Use an HTML '
'sanitizer (like Bleach http://bleach.readthedocs.org/) '
'if you are parsing untrusted markdown text. See the '
'2.5 release notes for more info', PendingDeprecationWarning)
if 'html_replacement_text' in kwargs:
warnings.warn('The "html_replacement_text" keyword is pending deprecation '
'in Python-Markdown and will be deprecated in version 2.6 '
'along with "safe_mode".', PendingDeprecationWarning)
self.registeredExtensions = [] self.registeredExtensions = []
self.docType = "" self.docType = ""
self.stripTopLevelTags = True self.stripTopLevelTags = True
@@ -160,9 +177,11 @@ class Markdown(object):
""" """
for ext in extensions: for ext in extensions:
if isinstance(ext, util.string_type): if isinstance(ext, util.string_type):
ext = self.build_extension(ext, configs.get(ext, [])) ext = self.build_extension(ext, configs.get(ext, {}))
if isinstance(ext, Extension): if isinstance(ext, Extension):
ext.extendMarkdown(self, globals()) ext.extendMarkdown(self, globals())
logger.debug('Successfully loaded extension "%s.%s".'
% (ext.__class__.__module__, ext.__class__.__name__))
elif ext is not None: elif ext is not None:
raise TypeError( raise TypeError(
'Extension "%s.%s" must be of type: "markdown.Extension"' 'Extension "%s.%s" must be of type: "markdown.Extension"'
@@ -170,7 +189,7 @@ class Markdown(object):
return self return self
def build_extension(self, ext_name, configs = []): def build_extension(self, ext_name, configs):
"""Build extension by name, then return the module. """Build extension by name, then return the module.
The extension name may contain arguments as part of the string in the The extension name may contain arguments as part of the string in the
@@ -178,44 +197,79 @@ class Markdown(object):
""" """
# Parse extensions config params (ignore the order)
configs = dict(configs) configs = dict(configs)
# Parse extensions config params (ignore the order)
pos = ext_name.find("(") # find the first "(" pos = ext_name.find("(") # find the first "("
if pos > 0: if pos > 0:
ext_args = ext_name[pos+1:-1] ext_args = ext_name[pos+1:-1]
ext_name = ext_name[:pos] ext_name = ext_name[:pos]
pairs = [x.split("=") for x in ext_args.split(",")] pairs = [x.split("=") for x in ext_args.split(",")]
configs.update([(x.strip(), y.strip()) for (x, y) in pairs]) configs.update([(x.strip(), y.strip()) for (x, y) in pairs])
warnings.warn('Setting configs in the Named Extension string is pending deprecation. '
'It is recommended that you pass an instance of the extension class to '
'Markdown or use the "extension_configs" keyword. The current behavior '
'will be deprecated in version 2.6 and raise an error in version 2.7. '
'See the Release Notes for Python-Markdown version 2.5 for more info.',
PendingDeprecationWarning)
# Setup the module name # Get class name (if provided): `path.to.module:ClassName`
module_name = ext_name ext_name, class_name = ext_name.split(':', 1) if ':' in ext_name else (ext_name, '')
if '.' not in ext_name:
module_name = '.'.join(['markdown.extensions', ext_name])
# Try loading the extension first from one place, then another # Try loading the extension first from one place, then another
try: # New style (markdown.extensions.<extension>) try:
module = __import__(module_name, {}, {}, [module_name.rpartition('.')[0]]) # Assume string uses dot syntax (`path.to.some.module`)
module = importlib.import_module(ext_name)
logger.debug('Successfuly imported extension module "%s".' % ext_name)
# For backward compat (until deprecation) check that this is an extension
if '.' not in ext_name and not (hasattr(module, 'extendMarkdown') or (class_name and hasattr(module, class_name))):
# We have a name conflict (eg: extensions=['tables'] and PyTables is installed)
raise ImportError
except ImportError: except ImportError:
module_name_old_style = '_'.join(['mdx', ext_name]) # Preppend `markdown.extensions.` to name
try: # Old style (mdx_<extension>) module_name = '.'.join(['markdown.extensions', ext_name])
module = __import__(module_name_old_style) try:
except ImportError as e: module = importlib.import_module(module_name)
message = "Failed loading extension '%s' from '%s' or '%s'" \ logger.debug('Successfuly imported extension module "%s".' % module_name)
% (ext_name, module_name, module_name_old_style) warnings.warn('Using short names for Markdown\'s builtin extensions is pending deprecation. '
'Use the full path to the extension with Python\'s dot notation '
'(eg: "%s" instead of "%s"). The current behavior will be deprecated in '
'version 2.6 and raise an error in version 2.7. See the Release Notes for '
'Python-Markdown version 2.5 for more info.' % (module_name, ext_name),
PendingDeprecationWarning)
except ImportError:
# Preppend `mdx_` to name
module_name_old_style = '_'.join(['mdx', ext_name])
try:
module = importlib.import_module(module_name_old_style)
logger.debug('Successfuly imported extension module "%s".' % module_name_old_style)
warnings.warn('Markdown\'s behavuor of appending "mdx_" to an extension name '
'is pending deprecation. Use the full path to the extension with '
'Python\'s dot notation (eg: "%s" instead of "%s"). The '
'current behavior will be deprecated in version 2.6 and raise an '
'error in version 2.7. See the Release Notes for Python-Markdown '
'version 2.5 for more info.' % (module_name_old_style, ext_name),
PendingDeprecationWarning)
except ImportError as e:
message = "Failed loading extension '%s' from '%s', '%s' or '%s'" \
% (ext_name, ext_name, module_name, module_name_old_style)
e.args = (message,) + e.args[1:]
raise
if class_name:
# Load given class name from module.
return getattr(module, class_name)(**configs)
else:
# Expect makeExtension() function to return a class.
try:
return module.makeExtension(**configs)
except AttributeError as e:
message = e.args[0]
message = "Failed to initiate extension " \
"'%s': %s" % (ext_name, message)
e.args = (message,) + e.args[1:] e.args = (message,) + e.args[1:]
raise raise
# If the module is loaded successfully, we expect it to define a
# function called makeExtension()
try:
return module.makeExtension(configs.items())
except AttributeError as e:
message = e.args[0]
message = "Failed to initiate extension " \
"'%s': %s" % (ext_name, message)
e.args = (message,) + e.args[1:]
raise
def registerExtension(self, extension): def registerExtension(self, extension):
""" This gets called by the extension """ """ This gets called by the extension """
self.registeredExtensions.append(extension) self.registeredExtensions.append(extension)
@@ -303,7 +357,7 @@ class Markdown(object):
start = output.index('<%s>'%self.doc_tag)+len(self.doc_tag)+2 start = output.index('<%s>'%self.doc_tag)+len(self.doc_tag)+2
end = output.rindex('</%s>'%self.doc_tag) end = output.rindex('</%s>'%self.doc_tag)
output = output[start:end].strip() output = output[start:end].strip()
except ValueError: except ValueError: #pragma: no cover
if output.strip().endswith('<%s />'%self.doc_tag): if output.strip().endswith('<%s />'%self.doc_tag):
# We have an empty document # We have an empty document
output = '' output = ''
@@ -434,6 +488,10 @@ def markdownFromFile(*args, **kwargs):
c += 1 c += 1
if c == len(pos): if c == len(pos):
break break
if len(args):
warnings.warn('Positional arguments are pending depreacted in Markdown '
'and will be deprecated in version 2.6. Use keyword '
'arguments only.', PendingDeprecationWarning)
md = Markdown(**kwargs) md = Markdown(**kwargs)
md.convertFile(kwargs.get('input', None), md.convertFile(kwargs.get('input', None),

View File

@@ -7,20 +7,25 @@ COMMAND-LINE SPECIFIC STUFF
import markdown import markdown
import sys import sys
import optparse import optparse
import codecs
try:
import yaml
except ImportError: #pragma: no cover
import json as yaml
import logging import logging
from logging import DEBUG, INFO, CRITICAL from logging import DEBUG, INFO, CRITICAL
logger = logging.getLogger('MARKDOWN') logger = logging.getLogger('MARKDOWN')
def parse_options(): def parse_options(args=None, values=None):
""" """
Define and parse `optparse` options for command-line usage. Define and parse `optparse` options for command-line usage.
""" """
usage = """%prog [options] [INPUTFILE] usage = """%prog [options] [INPUTFILE]
(STDIN is assumed if no INPUTFILE is given)""" (STDIN is assumed if no INPUTFILE is given)"""
desc = "A Python implementation of John Gruber's Markdown. " \ desc = "A Python implementation of John Gruber's Markdown. " \
"http://packages.python.org/Markdown/" "https://pythonhosted.org/Markdown/"
ver = "%%prog %s" % markdown.version ver = "%%prog %s" % markdown.version
parser = optparse.OptionParser(usage=usage, description=desc, version=ver) parser = optparse.OptionParser(usage=usage, description=desc, version=ver)
@@ -29,28 +34,36 @@ def parse_options():
metavar="OUTPUT_FILE") metavar="OUTPUT_FILE")
parser.add_option("-e", "--encoding", dest="encoding", parser.add_option("-e", "--encoding", dest="encoding",
help="Encoding for input and output files.",) help="Encoding for input and output files.",)
parser.add_option("-s", "--safe", dest="safe", default=False,
metavar="SAFE_MODE",
help="Deprecated! 'replace', 'remove' or 'escape' HTML tags in input")
parser.add_option("-o", "--output_format", dest="output_format",
default='xhtml1', metavar="OUTPUT_FORMAT",
help="'xhtml1' (default), 'html4' or 'html5'.")
parser.add_option("-n", "--no_lazy_ol", dest="lazy_ol",
action='store_false', default=True,
help="Observe number of first item of ordered lists.")
parser.add_option("-x", "--extension", action="append", dest="extensions",
help = "Load extension EXTENSION.", metavar="EXTENSION")
parser.add_option("-c", "--extension_configs", dest="configfile", default=None,
help="Read extension configurations from CONFIG_FILE. "
"CONFIG_FILE must be of JSON or YAML format. YAML format requires "
"that a python YAML library be installed. The parsed JSON or YAML "
"must result in a python dictionary which would be accepted by the "
"'extension_configs' keyword on the markdown.Markdown class. "
"The extensions must also be loaded with the `--extension` option.",
metavar="CONFIG_FILE")
parser.add_option("-q", "--quiet", default = CRITICAL, parser.add_option("-q", "--quiet", default = CRITICAL,
action="store_const", const=CRITICAL+10, dest="verbose", action="store_const", const=CRITICAL+10, dest="verbose",
help="Suppress all warnings.") help="Suppress all warnings.")
parser.add_option("-v", "--verbose", parser.add_option("-v", "--verbose",
action="store_const", const=INFO, dest="verbose", action="store_const", const=INFO, dest="verbose",
help="Print all warnings.") help="Print all warnings.")
parser.add_option("-s", "--safe", dest="safe", default=False,
metavar="SAFE_MODE",
help="'replace', 'remove' or 'escape' HTML tags in input")
parser.add_option("-o", "--output_format", dest="output_format",
default='xhtml1', metavar="OUTPUT_FORMAT",
help="'xhtml1' (default), 'html4' or 'html5'.")
parser.add_option("--noisy", parser.add_option("--noisy",
action="store_const", const=DEBUG, dest="verbose", action="store_const", const=DEBUG, dest="verbose",
help="Print debug messages.") help="Print debug messages.")
parser.add_option("-x", "--extension", action="append", dest="extensions",
help = "Load extension EXTENSION.", metavar="EXTENSION")
parser.add_option("-n", "--no_lazy_ol", dest="lazy_ol",
action='store_false', default=True,
help="Observe number of first item of ordered lists.")
(options, args) = parser.parse_args() (options, args) = parser.parse_args(args, values)
if len(args) == 0: if len(args) == 0:
input_file = None input_file = None
@@ -60,15 +73,26 @@ def parse_options():
if not options.extensions: if not options.extensions:
options.extensions = [] options.extensions = []
extension_configs = {}
if options.configfile:
with codecs.open(options.configfile, mode="r", encoding=options.encoding) as fp:
try:
extension_configs = yaml.load(fp)
except Exception as e:
message = "Failed parsing extension config file: %s" % options.configfile
e.args = (message,) + e.args[1:]
raise
return {'input': input_file, return {'input': input_file,
'output': options.filename, 'output': options.filename,
'safe_mode': options.safe, 'safe_mode': options.safe,
'extensions': options.extensions, 'extensions': options.extensions,
'extension_configs': extension_configs,
'encoding': options.encoding, 'encoding': options.encoding,
'output_format': options.output_format, 'output_format': options.output_format,
'lazy_ol': options.lazy_ol}, options.verbose 'lazy_ol': options.lazy_ol}, options.verbose
def run(): def run(): #pragma: no cover
"""Run Markdown from the command line.""" """Run Markdown from the command line."""
# Parse options and adjust logging level if necessary # Parse options and adjust logging level if necessary
@@ -80,7 +104,7 @@ def run():
# Run # Run
markdown.markdownFromFile(**options) markdown.markdownFromFile(**options)
if __name__ == '__main__': if __name__ == '__main__': #pragma: no cover
# Support running module as a commandline command. # Support running module as a commandline command.
# Python 2.5 & 2.6 do: `python -m markdown.__main__ [options] [args]`. # Python 2.5 & 2.6 do: `python -m markdown.__main__ [options] [args]`.
# Python 2.7 & 3.x do: `python -m markdown [options] [args]`. # Python 2.7 & 3.x do: `python -m markdown [options] [args]`.

View File

@@ -5,7 +5,7 @@
# (major, minor, micro, alpha/beta/rc/final, #) # (major, minor, micro, alpha/beta/rc/final, #)
# (1, 1, 2, 'alpha', 0) => "1.1.2.dev" # (1, 1, 2, 'alpha', 0) => "1.1.2.dev"
# (1, 2, 0, 'beta', 2) => "1.2b2" # (1, 2, 0, 'beta', 2) => "1.2b2"
version_info = (2, 4, 1, 'final', 0) version_info = (2, 5, 2, 'final', 0)
def _get_version(): def _get_version():
" Returns a PEP 386-compliant version number from version_info. " " Returns a PEP 386-compliant version number from version_info. "

View File

@@ -99,7 +99,7 @@ class BlockProcessor:
* ``block``: A block of text from the source which has been split at * ``block``: A block of text from the source which has been split at
blank lines. blank lines.
""" """
pass pass #pragma: no cover
def run(self, parent, blocks): def run(self, parent, blocks):
""" Run processor. Must be overridden by subclasses. """ Run processor. Must be overridden by subclasses.
@@ -123,7 +123,7 @@ class BlockProcessor:
* ``parent``: A etree element which is the parent of the current block. * ``parent``: A etree element which is the parent of the current block.
* ``blocks``: A list of all remaining blocks of the document. * ``blocks``: A list of all remaining blocks of the document.
""" """
pass pass #pragma: no cover
class ListIndentProcessor(BlockProcessor): class ListIndentProcessor(BlockProcessor):
@@ -433,7 +433,7 @@ class HashHeaderProcessor(BlockProcessor):
if after: if after:
# Insert remaining lines as first block for future parsing. # Insert remaining lines as first block for future parsing.
blocks.insert(0, after) blocks.insert(0, after)
else: else: #pragma: no cover
# This should never happen, but just in case... # This should never happen, but just in case...
logger.warn("We've got a problem header: %r" % block) logger.warn("We've got a problem header: %r" % block)

View File

@@ -4,17 +4,45 @@ Extensions
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from ..util import parseBoolValue
import warnings
class Extension(object): class Extension(object):
""" Base class for extensions to subclass. """ """ Base class for extensions to subclass. """
def __init__(self, configs = {}):
"""Create an instance of an Extention.
Keyword arguments: # Default config -- to be overriden by a subclass
# Must be of the following format:
# {
# 'key': ['value', 'description']
# }
# Note that Extension.setConfig will raise a KeyError
# if a default is not set here.
config = {}
def __init__(self, *args, **kwargs):
""" Initiate Extension and set up configs. """
# check for configs arg for backward compat.
# (there only ever used to be one so we use arg[0])
if len(args):
self.setConfigs(args[0])
warnings.warn('Extension classes accepting positional args is pending Deprecation. '
'Each setting should be passed into the Class as a keyword. Positional '
'args will be deprecated in version 2.6 and raise an error in version '
'2.7. See the Release Notes for Python-Markdown version 2.5 for more info.',
PendingDeprecationWarning)
# check for configs kwarg for backward compat.
if 'configs' in kwargs.keys():
self.setConfigs(kwargs.pop('configs', {}))
warnings.warn('Extension classes accepting a dict on the single keyword "config" is '
'pending Deprecation. Each setting should be passed into the Class as '
'a keyword directly. The "config" keyword will be deprecated in version '
'2.6 and raise an error in version 2.7. See the Release Notes for '
'Python-Markdown version 2.5 for more info.',
PendingDeprecationWarning)
# finally, use kwargs
self.setConfigs(kwargs)
* configs: A dict of configuration setting used by an Extension.
"""
self.config = configs
def getConfig(self, key, default=''): def getConfig(self, key, default=''):
""" Return a setting for the given key or an empty string. """ """ Return a setting for the given key or an empty string. """
@@ -33,8 +61,20 @@ class Extension(object):
def setConfig(self, key, value): def setConfig(self, key, value):
""" Set a config setting for `key` with the given `value`. """ """ Set a config setting for `key` with the given `value`. """
if isinstance(self.config[key][0], bool):
value = parseBoolValue(value)
if self.config[key][0] is None:
value = parseBoolValue(value, preserve_none=True)
self.config[key][0] = value self.config[key][0] = value
def setConfigs(self, items):
""" Set multiple config settings given a dict or list of tuples. """
if hasattr(items, 'items'):
# it's a dict
items = items.items()
for key, value in items:
self.setConfig(key, value)
def extendMarkdown(self, md, md_globals): def extendMarkdown(self, md, md_globals):
""" """
Add the various proccesors and patterns to the Markdown Instance. Add the various proccesors and patterns to the Markdown Instance.

View File

@@ -4,22 +4,15 @@ Abbreviation Extension for Python-Markdown
This extension adds abbreviation handling to Python-Markdown. This extension adds abbreviation handling to Python-Markdown.
Simple Usage: See <https://pythonhosted.org/Markdown/extensions/abbreviations.html>
for documentation.
>>> import markdown Oringinal code Copyright 2007-2008 [Waylan Limberg](http://achinghead.com/) and
>>> text = """ [Seemant Kulleen](http://www.kulleen.org/)
... Some text with an ABBR and a REF. Ignore REFERENCE and ref.
...
... *[ABBR]: Abbreviation
... *[REF]: Abbreviation Reference
... """
>>> print markdown.markdown(text, ['abbr'])
<p>Some text with an <abbr title="Abbreviation">ABBR</abbr> and a <abbr title="Abbreviation Reference">REF</abbr>. Ignore REFERENCE and ref.</p>
Copyright 2007-2008 All changes Copyright 2008-2014 The Python Markdown Project
* [Waylan Limberg](http://achinghead.com/)
* [Seemant Kulleen](http://www.kulleen.org/)
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
''' '''
@@ -92,5 +85,5 @@ class AbbrPattern(Pattern):
abbr.set('title', self.title) abbr.set('title', self.title)
return abbr return abbr
def makeExtension(configs=None): def makeExtension(*args, **kwargs):
return AbbrExtension(configs=configs) return AbbrExtension(*args, **kwargs)

View File

@@ -4,39 +4,16 @@ Admonition extension for Python-Markdown
Adds rST-style admonitions. Inspired by [rST][] feature with the same name. Adds rST-style admonitions. Inspired by [rST][] feature with the same name.
The syntax is (followed by an indented block with the contents):
!!! [type] [optional explicit title]
Where `type` is used as a CSS class name of the div. If not present, `title`
defaults to the capitalized `type`, so "note" -> "Note".
rST suggests the following `types`, but you're free to use whatever you want:
attention, caution, danger, error, hint, important, note, tip, warning
A simple example:
!!! note
This is the first line inside the box.
Outputs:
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>This is the first line inside the box</p>
</div>
You can also specify the title and CSS class of the admonition:
!!! custom "Did you know?"
Another line here.
Outputs:
<div class="admonition custom">
<p class="admonition-title">Did you know?</p>
<p>Another line here.</p>
</div>
[rST]: http://docutils.sourceforge.net/docs/ref/rst/directives.html#specific-admonitions [rST]: http://docutils.sourceforge.net/docs/ref/rst/directives.html#specific-admonitions
By [Tiago Serafim](http://www.tiagoserafim.com/). See <https://pythonhosted.org/Markdown/extensions/admonition.html>
for documentation.
Original code Copyright [Tiago Serafim](http://www.tiagoserafim.com/).
All changes Copyright The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
""" """
@@ -114,5 +91,6 @@ class AdmonitionProcessor(BlockProcessor):
return klass, title return klass, title
def makeExtension(configs={}): def makeExtension(*args, **kwargs):
return AdmonitionExtension(configs=configs) return AdmonitionExtension(*args, **kwargs)

View File

@@ -6,15 +6,14 @@ Adds attribute list syntax. Inspired by
[maruku](http://maruku.rubyforge.org/proposal.html#attribute_lists)'s [maruku](http://maruku.rubyforge.org/proposal.html#attribute_lists)'s
feature of the same name. feature of the same name.
Copyright 2011 [Waylan Limberg](http://achinghead.com/). See <https://pythonhosted.org/Markdown/extensions/attr_list.html>
for documentation.
Contact: markdown@freewisdom.org Original code Copyright 2011 [Waylan Limberg](http://achinghead.com/).
License: BSD (see ../LICENSE.md for details) All changes Copyright 2011-2014 The Python Markdown Project
Dependencies: License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
* [Python 2.4+](http://python.org)
* [Markdown 2.1+](http://packages.python.org/Markdown/)
""" """
@@ -27,7 +26,7 @@ import re
try: try:
Scanner = re.Scanner Scanner = re.Scanner
except AttributeError: except AttributeError: #pragma: no cover
# must be on Python 2.4 # must be on Python 2.4
from sre import Scanner from sre import Scanner
@@ -164,5 +163,5 @@ class AttrListExtension(Extension):
md.treeprocessors.add('attr_list', AttrListTreeprocessor(md), '>prettify') md.treeprocessors.add('attr_list', AttrListTreeprocessor(md), '>prettify')
def makeExtension(configs={}): def makeExtension(*args, **kwargs):
return AttrListExtension(configs=configs) return AttrListExtension(*args, **kwargs)

View File

@@ -4,17 +4,14 @@ CodeHilite Extension for Python-Markdown
Adds code/syntax highlighting to standard Python-Markdown code blocks. Adds code/syntax highlighting to standard Python-Markdown code blocks.
Copyright 2006-2008 [Waylan Limberg](http://achinghead.com/). See <https://pythonhosted.org/Markdown/extensions/code_hilite.html>
for documentation.
Project website: <http://packages.python.org/Markdown/extensions/code_hilite.html> Original code Copyright 2006-2008 [Waylan Limberg](http://achinghead.com/).
Contact: markdown@freewisdom.org
License: BSD (see ../LICENSE.md for details) All changes Copyright 2008-2014 The Python Markdown Project
Dependencies: License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
* [Python 2.3+](http://python.org/)
* [Markdown 2.0+](http://packages.python.org/Markdown/)
* [Pygments](http://pygments.org/)
""" """
@@ -25,8 +22,8 @@ from ..treeprocessors import Treeprocessor
import warnings import warnings
try: try:
from pygments import highlight from pygments import highlight
from pygments.lexers import get_lexer_by_name, guess_lexer, TextLexer from pygments.lexers import get_lexer_by_name, guess_lexer
from pygments.formatters import HtmlFormatter from pygments.formatters import get_formatter_by_name
pygments = True pygments = True
except ImportError: except ImportError:
pygments = False pygments = False
@@ -112,14 +109,15 @@ class CodeHilite(object):
if self.guess_lang: if self.guess_lang:
lexer = guess_lexer(self.src) lexer = guess_lexer(self.src)
else: else:
lexer = TextLexer() lexer = get_lexer_by_name('text')
except ValueError: except ValueError:
lexer = TextLexer() lexer = get_lexer_by_name('text')
formatter = HtmlFormatter(linenos=self.linenums, formatter = get_formatter_by_name('html',
cssclass=self.css_class, linenos=self.linenums,
style=self.style, cssclass=self.css_class,
noclasses=self.noclasses, style=self.style,
hl_lines=self.hl_lines) noclasses=self.noclasses,
hl_lines=self.hl_lines)
return highlight(self.src, lexer, formatter) return highlight(self.src, lexer, formatter)
else: else:
# just escape and build markup usable by JS highlighting libs # just escape and build markup usable by JS highlighting libs
@@ -225,7 +223,7 @@ class HiliteTreeprocessor(Treeprocessor):
class CodeHiliteExtension(Extension): class CodeHiliteExtension(Extension):
""" Add source code hilighting to markdown codeblocks. """ """ Add source code hilighting to markdown codeblocks. """
def __init__(self, configs): def __init__(self, *args, **kwargs):
# define default configs # define default configs
self.config = { self.config = {
'linenums': [None, "Use lines numbers. True=yes, False=no, None=auto"], 'linenums': [None, "Use lines numbers. True=yes, False=no, None=auto"],
@@ -237,22 +235,7 @@ class CodeHiliteExtension(Extension):
'noclasses': [False, 'Use inline styles instead of CSS classes - Default false'] 'noclasses': [False, 'Use inline styles instead of CSS classes - Default false']
} }
# Override defaults with user settings super(CodeHiliteExtension, self).__init__(*args, **kwargs)
for key, value in configs:
# convert strings to booleans
if value == 'True': value = True
if value == 'False': value = False
if value == 'None': value = None
if key == 'force_linenos':
warnings.warn('The "force_linenos" config setting'
' to the CodeHilite extension is deprecrecated.'
' Use "linenums" instead.', DeprecationWarning)
if value:
# Carry 'force_linenos' over to new 'linenos'.
self.setConfig('linenums', True)
self.setConfig(key, value)
def extendMarkdown(self, md, md_globals): def extendMarkdown(self, md, md_globals):
""" Add HilitePostprocessor to Markdown instance. """ """ Add HilitePostprocessor to Markdown instance. """
@@ -263,6 +246,5 @@ class CodeHiliteExtension(Extension):
md.registerExtension(self) md.registerExtension(self)
def makeExtension(configs={}): def makeExtension(*args, **kwargs):
return CodeHiliteExtension(configs=configs) return CodeHiliteExtension(*args, **kwargs)

View File

@@ -2,19 +2,16 @@
Definition List Extension for Python-Markdown Definition List Extension for Python-Markdown
============================================= =============================================
Added parsing of Definition Lists to Python-Markdown. Adds parsing of Definition Lists to Python-Markdown.
A simple example: See <https://pythonhosted.org/Markdown/extensions/definition_lists.html>
for documentation.
Apple Original code Copyright 2008 [Waylan Limberg](http://achinghead.com)
: Pomaceous fruit of plants of the genus Malus in
the family Rosaceae.
: An american computer company.
Orange All changes Copyright 2008-2014 The Python Markdown Project
: The fruit of an evergreen tree of the genus Citrus.
Copyright 2008 - [Waylan Limberg](http://achinghead.com) License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
""" """
@@ -113,6 +110,6 @@ class DefListExtension(Extension):
'>ulist') '>ulist')
def makeExtension(configs={}): def makeExtension(*args, **kwargs):
return DefListExtension(configs=configs) return DefListExtension(*args, **kwargs)

View File

@@ -11,10 +11,6 @@ convenience so that only one extension needs to be listed when
initiating Markdown. See the documentation for each individual initiating Markdown. See the documentation for each individual
extension for specifics about that extension. extension for specifics about that extension.
In the event that one or more of the supported extensions are not
available for import, Markdown will issue a warning and simply continue
without that extension.
There may be additional extensions that are distributed with There may be additional extensions that are distributed with
Python-Markdown that are not included here in Extra. Those extensions Python-Markdown that are not included here in Extra. Those extensions
are not part of PHP Markdown Extra, and therefore, not part of are not part of PHP Markdown Extra, and therefore, not part of
@@ -24,6 +20,13 @@ under a differant name. You could also edit the `extensions` global
variable defined below, but be aware that such changes may be lost variable defined below, but be aware that such changes may be lost
when you upgrade to any future version of Python-Markdown. when you upgrade to any future version of Python-Markdown.
See <https://pythonhosted.org/Markdown/extensions/extra.html>
for documentation.
Copyright The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
""" """
from __future__ import absolute_import from __future__ import absolute_import
@@ -33,19 +36,25 @@ from ..blockprocessors import BlockProcessor
from .. import util from .. import util
import re import re
extensions = ['smart_strong', extensions = [
'fenced_code', 'markdown.extensions.smart_strong',
'footnotes', 'markdown.extensions.fenced_code',
'attr_list', 'markdown.extensions.footnotes',
'def_list', 'markdown.extensions.attr_list',
'tables', 'markdown.extensions.def_list',
'abbr', 'markdown.extensions.tables',
] 'markdown.extensions.abbr'
]
class ExtraExtension(Extension): class ExtraExtension(Extension):
""" Add various extensions to Markdown class.""" """ Add various extensions to Markdown class."""
def __init__(self, *args, **kwargs):
""" config is just a dumb holder which gets passed to actual ext later. """
self.config = kwargs.pop('configs', {})
self.config.update(kwargs)
def extendMarkdown(self, md, md_globals): def extendMarkdown(self, md, md_globals):
""" Register extension instances. """ """ Register extension instances. """
md.registerExtensions(extensions, self.config) md.registerExtensions(extensions, self.config)
@@ -60,8 +69,8 @@ class ExtraExtension(Extension):
r'^(p|h[1-6]|li|dd|dt|td|th|legend|address)$', re.IGNORECASE) r'^(p|h[1-6]|li|dd|dt|td|th|legend|address)$', re.IGNORECASE)
def makeExtension(configs={}): def makeExtension(*args, **kwargs):
return ExtraExtension(configs=dict(configs)) return ExtraExtension(*args, **kwargs)
class MarkdownInHtmlProcessor(BlockProcessor): class MarkdownInHtmlProcessor(BlockProcessor):

View File

@@ -4,87 +4,15 @@ Fenced Code Extension for Python Markdown
This extension adds Fenced Code Blocks to Python-Markdown. This extension adds Fenced Code Blocks to Python-Markdown.
>>> import markdown See <https://pythonhosted.org/Markdown/extensions/fenced_code_blocks.html>
>>> text = ''' for documentation.
... A paragraph before a fenced code block:
...
... ~~~
... Fenced code block
... ~~~
... '''
>>> html = markdown.markdown(text, extensions=['fenced_code'])
>>> print html
<p>A paragraph before a fenced code block:</p>
<pre><code>Fenced code block
</code></pre>
Works with safe_mode also (we check this because we are using the HtmlStash): Original code Copyright 2007-2008 [Waylan Limberg](http://achinghead.com/).
>>> print markdown.markdown(text, extensions=['fenced_code'], safe_mode='replace')
<p>A paragraph before a fenced code block:</p>
<pre><code>Fenced code block
</code></pre>
Include tilde's in a code block and wrap with blank lines: All changes Copyright 2008-2014 The Python Markdown Project
>>> text = '''
... ~~~~~~~~
...
... ~~~~
... ~~~~~~~~'''
>>> print markdown.markdown(text, extensions=['fenced_code'])
<pre><code>
~~~~
</code></pre>
Language tags:
>>> text = '''
... ~~~~{.python}
... # Some python code
... ~~~~'''
>>> print markdown.markdown(text, extensions=['fenced_code'])
<pre><code class="python"># Some python code
</code></pre>
Optionally backticks instead of tildes as per how github's code block markdown is identified:
>>> text = '''
... `````
... # Arbitrary code
... ~~~~~ # these tildes will not close the block
... `````'''
>>> print markdown.markdown(text, extensions=['fenced_code'])
<pre><code># Arbitrary code
~~~~~ # these tildes will not close the block
</code></pre>
If the codehighlite extension and Pygments are installed, lines can be highlighted:
>>> text = '''
... ```hl_lines="1 3"
... line 1
... line 2
... line 3
... ```'''
>>> print markdown.markdown(text, extensions=['codehilite', 'fenced_code'])
<pre><code><span class="hilight">line 1</span>
line 2
<span class="hilight">line 3</span>
</code></pre>
Copyright 2007-2008 [Waylan Limberg](http://achinghead.com/).
Project website: <http://packages.python.org/Markdown/extensions/fenced_code_blocks.html>
Contact: markdown@freewisdom.org
License: BSD (see ../docs/LICENSE for details)
Dependencies:
* [Python 2.4+](http://python.org)
* [Markdown 2.0+](http://packages.python.org/Markdown/)
* [Pygments (optional)](http://pygments.org)
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
""" """
from __future__ import absolute_import from __future__ import absolute_import
@@ -175,5 +103,6 @@ class FencedBlockPreprocessor(Preprocessor):
return txt return txt
def makeExtension(configs=None): def makeExtension(*args, **kwargs):
return FencedCodeExtension(configs=configs) return FencedCodeExtension(*args, **kwargs)

View File

@@ -1,25 +1,15 @@
""" """
========================= FOOTNOTES ================================= Footnotes Extension for Python-Markdown
=======================================
This section adds footnote handling to markdown. It can be used as Adds footnote handling to Python-Markdown.
an example for extending python-markdown with relatively complex
functionality. While in this case the extension is included inside
the module itself, it could just as easily be added from outside the
module. Not that all markdown classes above are ignorant about
footnotes. All footnote functionality is provided separately and
then added to the markdown instance at the run time.
Footnote functionality is attached by calling extendMarkdown() See <https://pythonhosted.org/Markdown/extensions/footnotes.html>
method of FootnoteExtension. The method also registers the for documentation.
extension to allow it's state to be reset by a call to reset()
method.
Example: Copyright The Python Markdown Project
Footnotes[^1] have a label[^label] and a definition[^!DEF].
[^1]: This is a footnote License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
[^label]: A footnote on "label"
[^!DEF]: The footnote for definition
""" """
@@ -42,22 +32,22 @@ TABBED_RE = re.compile(r'((\t)|( ))(.*)')
class FootnoteExtension(Extension): class FootnoteExtension(Extension):
""" Footnote Extension. """ """ Footnote Extension. """
def __init__ (self, configs): def __init__ (self, *args, **kwargs):
""" Setup configs. """ """ Setup configs. """
self.config = {'PLACE_MARKER':
["///Footnotes Go Here///",
"The text string that marks where the footnotes go"],
'UNIQUE_IDS':
[False,
"Avoid name collisions across "
"multiple calls to reset()."],
"BACKLINK_TEXT":
["&#8617;",
"The text string that links from the footnote to the reader's place."]
}
for key, value in configs: self.config = {
self.config[key][0] = value 'PLACE_MARKER':
["///Footnotes Go Here///",
"The text string that marks where the footnotes go"],
'UNIQUE_IDS':
[False,
"Avoid name collisions across "
"multiple calls to reset()."],
"BACKLINK_TEXT":
["&#8617;",
"The text string that links from the footnote to the reader's place."]
}
super(FootnoteExtension, self).__init__(*args, **kwargs)
# In multiple invocations, emit links that don't get tangled. # In multiple invocations, emit links that don't get tangled.
self.unique_prefix = 0 self.unique_prefix = 0
@@ -309,7 +299,7 @@ class FootnotePostprocessor(Postprocessor):
text = text.replace(FN_BACKLINK_TEXT, self.footnotes.getConfig("BACKLINK_TEXT")) text = text.replace(FN_BACKLINK_TEXT, self.footnotes.getConfig("BACKLINK_TEXT"))
return text.replace(NBSP_PLACEHOLDER, "&#160;") return text.replace(NBSP_PLACEHOLDER, "&#160;")
def makeExtension(configs=[]): def makeExtension(*args, **kwargs):
""" Return an instance of the FootnoteExtension """ """ Return an instance of the FootnoteExtension """
return FootnoteExtension(configs=configs) return FootnoteExtension(*args, **kwargs)

View File

@@ -4,73 +4,14 @@ HeaderID Extension for Python-Markdown
Auto-generate id attributes for HTML headers. Auto-generate id attributes for HTML headers.
Basic usage: See <https://pythonhosted.org/Markdown/extensions/header_id.html>
for documentation.
>>> import markdown Original code Copyright 2007-2011 [Waylan Limberg](http://achinghead.com/).
>>> text = "# Some Header #"
>>> md = markdown.markdown(text, ['headerid'])
>>> print md
<h1 id="some-header">Some Header</h1>
All header IDs are unique: All changes Copyright 2011-2014 The Python Markdown Project
>>> text = ''' License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
... #Header
... #Header
... #Header'''
>>> md = markdown.markdown(text, ['headerid'])
>>> print md
<h1 id="header">Header</h1>
<h1 id="header_1">Header</h1>
<h1 id="header_2">Header</h1>
To fit within a html template's hierarchy, set the header base level:
>>> text = '''
... #Some Header
... ## Next Level'''
>>> md = markdown.markdown(text, ['headerid(level=3)'])
>>> print md
<h3 id="some-header">Some Header</h3>
<h4 id="next-level">Next Level</h4>
Works with inline markup.
>>> text = '#Some *Header* with [markup](http://example.com).'
>>> md = markdown.markdown(text, ['headerid'])
>>> print md
<h1 id="some-header-with-markup">Some <em>Header</em> with <a href="http://example.com">markup</a>.</h1>
Turn off auto generated IDs:
>>> text = '''
... # Some Header
... # Another Header'''
>>> md = markdown.markdown(text, ['headerid(forceid=False)'])
>>> print md
<h1>Some Header</h1>
<h1>Another Header</h1>
Use with MetaData extension:
>>> text = '''header_level: 2
... header_forceid: Off
...
... # A Header'''
>>> md = markdown.markdown(text, ['headerid', 'meta'])
>>> print md
<h2>A Header</h2>
Copyright 2007-2011 [Waylan Limberg](http://achinghead.com/).
Project website: <http://packages.python.org/Markdown/extensions/header_id.html>
Contact: markdown@freewisdom.org
License: BSD (see ../docs/LICENSE for details)
Dependencies:
* [Python 2.3+](http://python.org)
* [Markdown 2.0+](http://packages.python.org/Markdown/)
""" """
@@ -127,7 +68,7 @@ def stashedHTML2text(text, md):
def _html_sub(m): def _html_sub(m):
""" Substitute raw html with plain text. """ """ Substitute raw html with plain text. """
try: try:
raw, safe = md.htmlStash.rawHtmlBlocks[int(m.group(1))] raw, safe = md.htmlStash.rawHtmlBlocks[int(m.group(1))]
except (IndexError, TypeError): except (IndexError, TypeError):
return m.group(0) return m.group(0)
if md.safeMode and not safe: if md.safeMode and not safe:
@@ -176,7 +117,7 @@ class HeaderIdTreeprocessor(Treeprocessor):
class HeaderIdExtension(Extension): class HeaderIdExtension(Extension):
def __init__(self, configs): def __init__(self, *args, **kwargs):
# set defaults # set defaults
self.config = { self.config = {
'level' : ['1', 'Base level for headers.'], 'level' : ['1', 'Base level for headers.'],
@@ -185,8 +126,7 @@ class HeaderIdExtension(Extension):
'slugify' : [slugify, 'Callable to generate anchors'], 'slugify' : [slugify, 'Callable to generate anchors'],
} }
for key, value in configs: super(HeaderIdExtension, self).__init__(*args, **kwargs)
self.setConfig(key, value)
def extendMarkdown(self, md, md_globals): def extendMarkdown(self, md, md_globals):
md.registerExtension(self) md.registerExtension(self)
@@ -204,5 +144,6 @@ class HeaderIdExtension(Extension):
self.processor.IDs = set() self.processor.IDs = set()
def makeExtension(configs=None): def makeExtension(*args, **kwargs):
return HeaderIdExtension(configs=configs) return HeaderIdExtension(*args, **kwargs)

View File

@@ -4,38 +4,14 @@ Meta Data Extension for Python-Markdown
This extension adds Meta Data handling to markdown. This extension adds Meta Data handling to markdown.
Basic Usage: See <https://pythonhosted.org/Markdown/extensions/meta_data.html>
for documentation.
>>> import markdown Original code Copyright 2007-2008 [Waylan Limberg](http://achinghead.com).
>>> text = '''Title: A Test Doc.
... Author: Waylan Limberg
... John Doe
... Blank_Data:
...
... The body. This is paragraph one.
... '''
>>> md = markdown.Markdown(['meta'])
>>> print md.convert(text)
<p>The body. This is paragraph one.</p>
>>> print md.Meta
{u'blank_data': [u''], u'author': [u'Waylan Limberg', u'John Doe'], u'title': [u'A Test Doc.']}
Make sure text without Meta Data still works (markdown < 1.6b returns a <p>). All changes Copyright 2008-2014 The Python Markdown Project
>>> text = ' Some Code - not extra lines of meta data.' License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
>>> md = markdown.Markdown(['meta'])
>>> print md.convert(text)
<pre><code>Some Code - not extra lines of meta data.
</code></pre>
>>> md.Meta
{}
Copyright 2007-2008 [Waylan Limberg](http://achinghead.com).
Project website: <http://packages.python.org/Markdown/meta_data.html>
Contact: markdown@freewisdom.org
License: BSD (see ../LICENSE.md for details)
""" """
@@ -55,7 +31,7 @@ class MetaExtension (Extension):
def extendMarkdown(self, md, md_globals): def extendMarkdown(self, md, md_globals):
""" Add MetaPreprocessor to Markdown instance. """ """ Add MetaPreprocessor to Markdown instance. """
md.preprocessors.add("meta", MetaPreprocessor(md), "_begin") md.preprocessors.add("meta", MetaPreprocessor(md), ">normalize_whitespace")
class MetaPreprocessor(Preprocessor): class MetaPreprocessor(Preprocessor):
@@ -89,5 +65,6 @@ class MetaPreprocessor(Preprocessor):
return lines return lines
def makeExtension(configs={}): def makeExtension(*args, **kwargs):
return MetaExtension(configs=configs) return MetaExtension(*args, **kwargs)

View File

@@ -5,18 +5,14 @@ NL2BR Extension
A Python-Markdown extension to treat newlines as hard breaks; like A Python-Markdown extension to treat newlines as hard breaks; like
GitHub-flavored Markdown does. GitHub-flavored Markdown does.
Usage: See <https://pythonhosted.org/Markdown/extensions/nl2br.html>
for documentation.
>>> import markdown Oringinal code Copyright 2011 [Brian Neal](http://deathofagremmie.com/)
>>> print markdown.markdown('line 1\\nline 2', extensions=['nl2br'])
<p>line 1<br />
line 2</p>
Copyright 2011 [Brian Neal](http://deathofagremmie.com/) All changes Copyright 2011-2014 The Python Markdown Project
Dependencies: License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
* [Python 2.4+](http://python.org)
* [Markdown 2.1+](http://packages.python.org/Markdown/)
""" """
@@ -34,5 +30,6 @@ class Nl2BrExtension(Extension):
md.inlinePatterns.add('nl', br_tag, '_end') md.inlinePatterns.add('nl', br_tag, '_end')
def makeExtension(configs=None): def makeExtension(*args, **kwargs):
return Nl2BrExtension(configs) return Nl2BrExtension(*args, **kwargs)

View File

@@ -2,19 +2,16 @@
Sane List Extension for Python-Markdown Sane List Extension for Python-Markdown
======================================= =======================================
Modify the behavior of Lists in Python-Markdown t act in a sane manor. Modify the behavior of Lists in Python-Markdown to act in a sane manor.
In standard Markdown syntax, the following would constitute a single See <https://pythonhosted.org/Markdown/extensions/sane_lists.html>
ordered list. However, with this extension, the output would include for documentation.
two lists, the first an ordered list and the second and unordered list.
1. ordered Original code Copyright 2011 [Waylan Limberg](http://achinghead.com)
2. list
* unordered All changes Copyright 2011-2014 The Python Markdown Project
* list
Copyright 2011 - [Waylan Limberg](http://achinghead.com) License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
""" """
@@ -46,6 +43,6 @@ class SaneListExtension(Extension):
md.parser.blockprocessors['ulist'] = SaneUListProcessor(md.parser) md.parser.blockprocessors['ulist'] = SaneUListProcessor(md.parser)
def makeExtension(configs={}): def makeExtension(*args, **kwargs):
return SaneListExtension(configs=configs) return SaneListExtension(*args, **kwargs)

View File

@@ -4,21 +4,14 @@ Smart_Strong Extension for Python-Markdown
This extention adds smarter handling of double underscores within words. This extention adds smarter handling of double underscores within words.
Simple Usage: See <https://pythonhosted.org/Markdown/extensions/smart_strong.html>
for documentation.
>>> import markdown Original code Copyright 2011 [Waylan Limberg](http://achinghead.com)
>>> print markdown.markdown('Text with double__underscore__words.',
... extensions=['smart_strong'])
<p>Text with double__underscore__words.</p>
>>> print markdown.markdown('__Strong__ still works.',
... extensions=['smart_strong'])
<p><strong>Strong</strong> still works.</p>
>>> print markdown.markdown('__this__works__too__.',
... extensions=['smart_strong'])
<p><strong>this__works__too</strong>.</p>
Copyright 2011 All changes Copyright 2011-2014 The Python Markdown Project
[Waylan Limberg](http://achinghead.com)
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
''' '''
@@ -38,5 +31,5 @@ class SmartEmphasisExtension(Extension):
md.inlinePatterns['strong'] = SimpleTagPattern(STRONG_RE, 'strong') md.inlinePatterns['strong'] = SimpleTagPattern(STRONG_RE, 'strong')
md.inlinePatterns.add('strong2', SimpleTagPattern(SMART_STRONG_RE, 'strong'), '>emphasis2') md.inlinePatterns.add('strong2', SimpleTagPattern(SMART_STRONG_RE, 'strong'), '>emphasis2')
def makeExtension(configs={}): def makeExtension(*args, **kwargs):
return SmartEmphasisExtension(configs=dict(configs)) return SmartEmphasisExtension(*args, **kwargs)

View File

@@ -1,73 +1,91 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Smarty extension for Python-Markdown '''
# Author: 2013, Dmitry Shachnev <mitya57@gmail.com> Smarty extension for Python-Markdown
====================================
Adds conversion of ASCII dashes, quotes and ellipses to their HTML
entity equivalents.
See <https://pythonhosted.org/Markdown/extensions/smarty.html>
for documentation.
Author: 2013, Dmitry Shachnev <mitya57@gmail.com>
All changes Copyright 2013-2014 The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
SmartyPants license:
Copyright (c) 2003 John Gruber <http://daringfireball.net/>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
* Neither the name "SmartyPants" nor the names of its contributors
may be used to endorse or promote products derived from this
software without specific prior written permission.
This software is provided by the copyright holders and contributors "as
is" and any express or implied warranties, including, but not limited
to, the implied warranties of merchantability and fitness for a
particular purpose are disclaimed. In no event shall the copyright
owner or contributors be liable for any direct, indirect, incidental,
special, exemplary, or consequential damages (including, but not
limited to, procurement of substitute goods or services; loss of use,
data, or profits; or business interruption) however caused and on any
theory of liability, whether in contract, strict liability, or tort
(including negligence or otherwise) arising in any way out of the use
of this software, even if advised of the possibility of such damage.
smartypants.py license:
smartypants.py is a derivative work of SmartyPants.
Copyright (c) 2004, 2007 Chad Miller <http://web.chad.org/>
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
This software is provided by the copyright holders and contributors "as
is" and any express or implied warranties, including, but not limited
to, the implied warranties of merchantability and fitness for a
particular purpose are disclaimed. In no event shall the copyright
owner or contributors be liable for any direct, indirect, incidental,
special, exemplary, or consequential damages (including, but not
limited to, procurement of substitute goods or services; loss of use,
data, or profits; or business interruption) however caused and on any
theory of liability, whether in contract, strict liability, or tort
(including negligence or otherwise) arising in any way out of the use
of this software, even if advised of the possibility of such damage.
'''
# SmartyPants license:
#
# Copyright (c) 2003 John Gruber <http://daringfireball.net/>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the
# distribution.
#
# * Neither the name "SmartyPants" nor the names of its contributors
# may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# This software is provided by the copyright holders and contributors "as
# is" and any express or implied warranties, including, but not limited
# to, the implied warranties of merchantability and fitness for a
# particular purpose are disclaimed. In no event shall the copyright
# owner or contributors be liable for any direct, indirect, incidental,
# special, exemplary, or consequential damages (including, but not
# limited to, procurement of substitute goods or services; loss of use,
# data, or profits; or business interruption) however caused and on any
# theory of liability, whether in contract, strict liability, or tort
# (including negligence or otherwise) arising in any way out of the use
# of this software, even if advised of the possibility of such damage.
#
#
# smartypants.py license:
#
# smartypants.py is a derivative work of SmartyPants.
# Copyright (c) 2004, 2007 Chad Miller <http://web.chad.org/>
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the
# distribution.
#
# This software is provided by the copyright holders and contributors "as
# is" and any express or implied warranties, including, but not limited
# to, the implied warranties of merchantability and fitness for a
# particular purpose are disclaimed. In no event shall the copyright
# owner or contributors be liable for any direct, indirect, incidental,
# special, exemplary, or consequential damages (including, but not
# limited to, procurement of substitute goods or services; loss of use,
# data, or profits; or business interruption) however caused and on any
# theory of liability, whether in contract, strict liability, or tort
# (including negligence or otherwise) arising in any way out of the use
# of this software, even if advised of the possibility of such damage.
from __future__ import unicode_literals from __future__ import unicode_literals
from . import Extension from . import Extension
from ..inlinepatterns import HtmlPattern from ..inlinepatterns import HtmlPattern
from ..odict import OrderedDict
from ..treeprocessors import InlineProcessor
from ..util import parseBoolValue from ..util import parseBoolValue
# Constants for quote education. # Constants for quote education.
@@ -85,10 +103,23 @@ openingQuotesBase = (
')' ')'
) )
substitutions = {
'mdash': '&mdash;',
'ndash': '&ndash;',
'ellipsis': '&hellip;',
'left-angle-quote': '&laquo;',
'right-angle-quote': '&raquo;',
'left-single-quote': '&lsquo;',
'right-single-quote': '&rsquo;',
'left-double-quote': '&ldquo;',
'right-double-quote': '&rdquo;',
}
# Special case if the very first character is a quote # Special case if the very first character is a quote
# followed by punctuation at a non-word-break. Close the quotes by brute force: # followed by punctuation at a non-word-break. Close the quotes by brute force:
singleQuoteStartRe = r"^'(?=%s\\B)" % punctClass singleQuoteStartRe = r"^'(?=%s\B)" % punctClass
doubleQuoteStartRe = r'^"(?=%s\\B)' % punctClass doubleQuoteStartRe = r'^"(?=%s\B)' % punctClass
# Special case for double sets of quotes, e.g.: # Special case for double sets of quotes, e.g.:
# <p>He said, "'Quoted' words in a larger quote."</p> # <p>He said, "'Quoted' words in a larger quote."</p>
@@ -113,8 +144,6 @@ closingSingleQuotesRegex2 = r"(?<=%s)'(\s|s\b)" % closeClass
remainingSingleQuotesRegex = "'" remainingSingleQuotesRegex = "'"
remainingDoubleQuotesRegex = '"' remainingDoubleQuotesRegex = '"'
lsquo, rsquo, ldquo, rdquo = '&lsquo;', '&rsquo;', '&ldquo;', '&rdquo;'
class SubstituteTextPattern(HtmlPattern): class SubstituteTextPattern(HtmlPattern):
def __init__(self, pattern, replace, markdown_instance): def __init__(self, pattern, replace, markdown_instance):
""" Replaces matches with some text. """ """ Replaces matches with some text. """
@@ -132,35 +161,56 @@ class SubstituteTextPattern(HtmlPattern):
return result return result
class SmartyExtension(Extension): class SmartyExtension(Extension):
def __init__(self, configs): def __init__(self, *args, **kwargs):
self.config = { self.config = {
'smart_quotes': [True, 'Educate quotes'], 'smart_quotes': [True, 'Educate quotes'],
'smart_angled_quotes': [False, 'Educate angled quotes'],
'smart_dashes': [True, 'Educate dashes'], 'smart_dashes': [True, 'Educate dashes'],
'smart_ellipses': [True, 'Educate ellipses'] 'smart_ellipses': [True, 'Educate ellipses'],
'substitutions' : [{}, 'Overwrite default substitutions'],
} }
for key, value in configs: super(SmartyExtension, self).__init__(*args, **kwargs)
self.setConfig(key, parseBoolValue(value)) self.substitutions = dict(substitutions)
self.substitutions.update(self.getConfig('substitutions', default={}))
def _addPatterns(self, md, patterns, serie): def _addPatterns(self, md, patterns, serie):
for ind, pattern in enumerate(patterns): for ind, pattern in enumerate(patterns):
pattern += (md,) pattern += (md,)
pattern = SubstituteTextPattern(*pattern) pattern = SubstituteTextPattern(*pattern)
after = ('>smarty-%s-%d' % (serie, ind - 1) if ind else '>entity') after = ('>smarty-%s-%d' % (serie, ind - 1) if ind else '_begin')
name = 'smarty-%s-%d' % (serie, ind) name = 'smarty-%s-%d' % (serie, ind)
md.inlinePatterns.add(name, pattern, after) self.inlinePatterns.add(name, pattern, after)
def educateDashes(self, md): def educateDashes(self, md):
emDashesPattern = SubstituteTextPattern(r'(?<!-)---(?!-)', ('&mdash;',), md) emDashesPattern = SubstituteTextPattern(r'(?<!-)---(?!-)',
enDashesPattern = SubstituteTextPattern(r'(?<!-)--(?!-)', ('&ndash;',), md) (self.substitutions['mdash'],), md)
md.inlinePatterns.add('smarty-em-dashes', emDashesPattern, '>entity') enDashesPattern = SubstituteTextPattern(r'(?<!-)--(?!-)',
md.inlinePatterns.add('smarty-en-dashes', enDashesPattern, (self.substitutions['ndash'],), md)
self.inlinePatterns.add('smarty-em-dashes', emDashesPattern, '_begin')
self.inlinePatterns.add('smarty-en-dashes', enDashesPattern,
'>smarty-em-dashes') '>smarty-em-dashes')
def educateEllipses(self, md): def educateEllipses(self, md):
ellipsesPattern = SubstituteTextPattern(r'(?<!\.)\.{3}(?!\.)', ('&hellip;',), md) ellipsesPattern = SubstituteTextPattern(r'(?<!\.)\.{3}(?!\.)',
md.inlinePatterns.add('smarty-ellipses', ellipsesPattern, '>entity') (self.substitutions['ellipsis'],), md)
self.inlinePatterns.add('smarty-ellipses', ellipsesPattern, '_begin')
def educateAngledQuotes(self, md):
leftAngledQuotePattern = SubstituteTextPattern(r'\<\<',
(self.substitutions['left-angle-quote'],), md)
rightAngledQuotePattern = SubstituteTextPattern(r'\>\>',
(self.substitutions['right-angle-quote'],), md)
self.inlinePatterns.add('smarty-left-angle-quotes',
leftAngledQuotePattern, '_begin')
self.inlinePatterns.add('smarty-right-angle-quotes',
rightAngledQuotePattern, '>smarty-left-angle-quotes')
def educateQuotes(self, md): def educateQuotes(self, md):
configs = self.getConfigs()
lsquo = self.substitutions['left-single-quote']
rsquo = self.substitutions['right-single-quote']
ldquo = self.substitutions['left-double-quote']
rdquo = self.substitutions['right-double-quote']
patterns = ( patterns = (
(singleQuoteStartRe, (rsquo,)), (singleQuoteStartRe, (rsquo,)),
(doubleQuoteStartRe, (rdquo,)), (doubleQuoteStartRe, (rdquo,)),
@@ -179,13 +229,19 @@ class SmartyExtension(Extension):
def extendMarkdown(self, md, md_globals): def extendMarkdown(self, md, md_globals):
configs = self.getConfigs() configs = self.getConfigs()
if configs['smart_quotes']: self.inlinePatterns = OrderedDict()
self.educateQuotes(md)
if configs['smart_dashes']:
self.educateDashes(md)
if configs['smart_ellipses']: if configs['smart_ellipses']:
self.educateEllipses(md) self.educateEllipses(md)
if configs['smart_quotes']:
self.educateQuotes(md)
if configs['smart_angled_quotes']:
self.educateAngledQuotes(md)
if configs['smart_dashes']:
self.educateDashes(md)
inlineProcessor = InlineProcessor(md)
inlineProcessor.inlinePatterns = self.inlinePatterns
md.treeprocessors.add('smarty', inlineProcessor, '_end')
md.ESCAPED_CHARS.extend(['"', "'"]) md.ESCAPED_CHARS.extend(['"', "'"])
def makeExtension(configs=None): def makeExtension(*args, **kwargs):
return SmartyExtension(configs) return SmartyExtension(*args, **kwargs)

View File

@@ -4,14 +4,15 @@ Tables Extension for Python-Markdown
Added parsing of tables to Python-Markdown. Added parsing of tables to Python-Markdown.
A simple example: See <https://pythonhosted.org/Markdown/extensions/tables.html>
for documentation.
First Header | Second Header Original code Copyright 2009 [Waylan Limberg](http://achinghead.com)
------------- | -------------
Content Cell | Content Cell All changes Copyright 2008-2014 The Python Markdown Project
Content Cell | Content Cell
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
Copyright 2009 - [Waylan Limberg](http://achinghead.com)
""" """
from __future__ import absolute_import from __future__ import absolute_import
@@ -71,7 +72,7 @@ class TableProcessor(BlockProcessor):
c = etree.SubElement(tr, tag) c = etree.SubElement(tr, tag)
try: try:
c.text = cells[i].strip() c.text = cells[i].strip()
except IndexError: except IndexError: #pragma: no cover
c.text = "" c.text = ""
if a: if a:
c.set('align', a) c.set('align', a)
@@ -96,5 +97,6 @@ class TableExtension(Extension):
'<hashheader') '<hashheader')
def makeExtension(configs={}): def makeExtension(*args, **kwargs):
return TableExtension(configs=configs) return TableExtension(*args, **kwargs)

View File

@@ -1,11 +1,15 @@
""" """
Table of Contents Extension for Python-Markdown Table of Contents Extension for Python-Markdown
* * * ===============================================
(c) 2008 [Jack Miller](http://codezen.org) See <https://pythonhosted.org/Markdown/extensions/toc.html>
for documentation.
Dependencies: Oringinal code Copyright 2008 [Jack Miller](http://codezen.org)
* [Markdown 2.1+](http://packages.python.org/Markdown/)
All changes Copyright 2008-2014 The Python Markdown Project
License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
""" """
@@ -30,53 +34,52 @@ def order_toc_list(toc_list):
[{'level': 2, 'children': []}, {'level': 1, 'children': []}] [{'level': 2, 'children': []}, {'level': 1, 'children': []}]
""" """
def build_correct(remaining_list, prev_elements=[{'level': 1000}]): ordered_list = []
if len(toc_list):
# Initialize everything by processing the first entry
last = toc_list.pop(0)
last['children'] = []
levels = [last['level']]
ordered_list.append(last)
parents = []
if not remaining_list: # Walk the rest nesting the entries properly
return [], [] while toc_list:
t = toc_list.pop(0)
current_level = t['level']
t['children'] = []
current = remaining_list.pop(0) # Reduce depth if current level < last item's level
if not 'children' in current.keys(): if current_level < levels[-1]:
current['children'] = [] # Pop last level since we know we are less than it
levels.pop()
if not prev_elements: # Pop parents and levels we are less than or equal to
# This happens for instance with [8, 1, 1], ie. when some to_pop = 0
# header level is outside a scope. We treat it as a for p in reversed(parents):
# top-level if current_level <= p['level']:
next_elements, children = build_correct(remaining_list, [current]) to_pop += 1
current['children'].append(children) else:
return [current] + next_elements, [] break
if to_pop:
levels = levels[:-to_pop]
parents = parents[:-to_pop]
prev_element = prev_elements.pop() # Note current level as last
children = [] levels.append(current_level)
next_elements = []
# Is current part of the child list or next list? # Level is the same, so append to the current parent (if available)
if current['level'] > prev_element['level']: if current_level == levels[-1]:
#print "%d is a child of %d" % (current['level'], prev_element['level']) (parents[-1]['children'] if parents else ordered_list).append(t)
prev_elements.append(prev_element)
prev_elements.append(current) # Current level is > last item's level,
prev_element['children'].append(current) # So make last item a parent and append current as child
next_elements2, children2 = build_correct(remaining_list, prev_elements)
children += children2
next_elements += next_elements2
else:
#print "%d is ancestor of %d" % (current['level'], prev_element['level'])
if not prev_elements:
#print "No previous elements, so appending to the next set"
next_elements.append(current)
prev_elements = [current]
next_elements2, children2 = build_correct(remaining_list, prev_elements)
current['children'].extend(children2)
else: else:
#print "Previous elements, comparing to those first" last['children'].append(t)
remaining_list.insert(0, current) parents.append(last)
next_elements2, children2 = build_correct(remaining_list, prev_elements) levels.append(current_level)
children.extend(children2) last = t
next_elements += next_elements2
return next_elements, children
ordered_list, __ = build_correct(toc_list)
return ordered_list return ordered_list
@@ -204,26 +207,26 @@ class TocExtension(Extension):
TreeProcessorClass = TocTreeprocessor TreeProcessorClass = TocTreeprocessor
def __init__(self, configs=[]): def __init__(self, *args, **kwargs):
self.config = { "marker" : ["[TOC]", self.config = {
"Text to find and replace with Table of Contents -" "marker" : ["[TOC]",
"Defaults to \"[TOC]\""], "Text to find and replace with Table of Contents - "
"slugify" : [slugify, "Defaults to \"[TOC]\""],
"Function to generate anchors based on header text-" "slugify" : [slugify,
"Defaults to the headerid ext's slugify function."], "Function to generate anchors based on header text - "
"title" : [None, "Defaults to the headerid ext's slugify function."],
"Title to insert into TOC <div> - " "title" : ["",
"Defaults to None"], "Title to insert into TOC <div> - "
"anchorlink" : [0, "Defaults to an empty string"],
"1 if header should be a self link" "anchorlink" : [0,
"Defaults to 0"], "1 if header should be a self link - "
"permalink" : [0, "Defaults to 0"],
"1 or link text if a Sphinx-style permalink should be added", "permalink" : [0,
"Defaults to 0"] "1 or link text if a Sphinx-style permalink should be added - "
} "Defaults to 0"]
}
for key, value in configs: super(TocExtension, self).__init__(*args, **kwargs)
self.setConfig(key, value)
def extendMarkdown(self, md, md_globals): def extendMarkdown(self, md, md_globals):
tocext = self.TreeProcessorClass(md) tocext = self.TreeProcessorClass(md)
@@ -236,5 +239,5 @@ class TocExtension(Extension):
md.treeprocessors.add("toc", tocext, "_end") md.treeprocessors.add("toc", tocext, "_end")
def makeExtension(configs={}): def makeExtension(*args, **kwargs):
return TocExtension(configs=configs) return TocExtension(*args, **kwargs)

View File

@@ -2,78 +2,17 @@
WikiLinks Extension for Python-Markdown WikiLinks Extension for Python-Markdown
====================================== ======================================
Converts [[WikiLinks]] to relative links. Requires Python-Markdown 2.0+ Converts [[WikiLinks]] to relative links.
Basic usage: See <https://pythonhosted.org/Markdown/extensions/wikilinks.html>
for documentation.
>>> import markdown Original code Copyright [Waylan Limberg](http://achinghead.com/).
>>> text = "Some text with a [[WikiLink]]."
>>> html = markdown.markdown(text, ['wikilinks'])
>>> print html
<p>Some text with a <a class="wikilink" href="/WikiLink/">WikiLink</a>.</p>
Whitespace behavior: All changes Copyright The Python Markdown Project
>>> print markdown.markdown('[[ foo bar_baz ]]', ['wikilinks'])
<p><a class="wikilink" href="/foo_bar_baz/">foo bar_baz</a></p>
>>> print markdown.markdown('foo [[ ]] bar', ['wikilinks'])
<p>foo bar</p>
To define custom settings the simple way:
>>> print markdown.markdown(text,
... ['wikilinks(base_url=/wiki/,end_url=.html,html_class=foo)']
... )
<p>Some text with a <a class="foo" href="/wiki/WikiLink.html">WikiLink</a>.</p>
Custom settings the complex way:
>>> md = markdown.Markdown(
... extensions = ['wikilinks'],
... extension_configs = {'wikilinks': [
... ('base_url', 'http://example.com/'),
... ('end_url', '.html'),
... ('html_class', '') ]},
... safe_mode = True)
>>> print md.convert(text)
<p>Some text with a <a href="http://example.com/WikiLink.html">WikiLink</a>.</p>
Use MetaData with mdx_meta.py (Note the blank html_class in MetaData):
>>> text = """wiki_base_url: http://example.com/
... wiki_end_url: .html
... wiki_html_class:
...
... Some text with a [[WikiLink]]."""
>>> md = markdown.Markdown(extensions=['meta', 'wikilinks'])
>>> print md.convert(text)
<p>Some text with a <a href="http://example.com/WikiLink.html">WikiLink</a>.</p>
MetaData should not carry over to next document:
>>> print md.convert("No [[MetaData]] here.")
<p>No <a class="wikilink" href="/MetaData/">MetaData</a> here.</p>
Define a custom URL builder:
>>> def my_url_builder(label, base, end):
... return '/bar/'
>>> md = markdown.Markdown(extensions=['wikilinks'],
... extension_configs={'wikilinks' : [('build_url', my_url_builder)]})
>>> print md.convert('[[foo]]')
<p><a class="wikilink" href="/bar/">foo</a></p>
From the command line:
python markdown.py -x wikilinks(base_url=http://example.com/,end_url=.html,html_class=foo) src.txt
By [Waylan Limberg](http://achinghead.com/).
License: [BSD](http://www.opensource.org/licenses/bsd-license.php) License: [BSD](http://www.opensource.org/licenses/bsd-license.php)
Dependencies:
* [Python 2.3+](http://python.org)
* [Markdown 2.0+](http://packages.python.org/Markdown/)
''' '''
from __future__ import absolute_import from __future__ import absolute_import
@@ -90,18 +29,16 @@ def build_url(label, base, end):
class WikiLinkExtension(Extension): class WikiLinkExtension(Extension):
def __init__(self, configs):
# set extension defaults def __init__ (self, *args, **kwargs):
self.config = { self.config = {
'base_url' : ['/', 'String to append to beginning or URL.'], 'base_url' : ['/', 'String to append to beginning or URL.'],
'end_url' : ['/', 'String to append to end of URL.'], 'end_url' : ['/', 'String to append to end of URL.'],
'html_class' : ['wikilink', 'CSS hook. Leave blank for none.'], 'html_class' : ['wikilink', 'CSS hook. Leave blank for none.'],
'build_url' : [build_url, 'Callable formats URL from label.'], 'build_url' : [build_url, 'Callable formats URL from label.'],
} }
configs = dict(configs) or {}
# Override defaults with user settings super(WikiLinkExtension, self).__init__(*args, **kwargs)
for key, value in configs.items():
self.setConfig(key, value)
def extendMarkdown(self, md, md_globals): def extendMarkdown(self, md, md_globals):
self.md = md self.md = md
@@ -147,5 +84,5 @@ class WikiLinks(Pattern):
return base_url, end_url, html_class return base_url, end_url, html_class
def makeExtension(configs=None) : def makeExtension(*args, **kwargs) :
return WikiLinkExtension(configs=configs) return WikiLinkExtension(*args, **kwargs)

View File

@@ -46,13 +46,13 @@ from __future__ import unicode_literals
from . import util from . import util
from . import odict from . import odict
import re import re
try: try: #pragma: no cover
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
except ImportError: except ImportError: #pragma: no cover
from urlparse import urlparse, urlunparse from urlparse import urlparse, urlunparse
try: try: #pragma: no cover
from html import entities from html import entities
except ImportError: except ImportError: #pragma: no cover
import htmlentitydefs as entities import htmlentitydefs as entities
@@ -75,7 +75,8 @@ def build_inlinepatterns(md_instance, **kwargs):
inlinePatterns["html"] = HtmlPattern(HTML_RE, md_instance) inlinePatterns["html"] = HtmlPattern(HTML_RE, md_instance)
inlinePatterns["entity"] = HtmlPattern(ENTITY_RE, md_instance) inlinePatterns["entity"] = HtmlPattern(ENTITY_RE, md_instance)
inlinePatterns["not_strong"] = SimpleTextPattern(NOT_STRONG_RE) inlinePatterns["not_strong"] = SimpleTextPattern(NOT_STRONG_RE)
inlinePatterns["strong_em"] = DoubleTagPattern(STRONG_EM_RE, 'strong,em') inlinePatterns["em_strong"] = DoubleTagPattern(EM_STRONG_RE, 'strong,em')
inlinePatterns["strong_em"] = DoubleTagPattern(STRONG_EM_RE, 'em,strong')
inlinePatterns["strong"] = SimpleTagPattern(STRONG_RE, 'strong') inlinePatterns["strong"] = SimpleTagPattern(STRONG_RE, 'strong')
inlinePatterns["emphasis"] = SimpleTagPattern(EMPHASIS_RE, 'em') inlinePatterns["emphasis"] = SimpleTagPattern(EMPHASIS_RE, 'em')
if md_instance.smart_emphasis: if md_instance.smart_emphasis:
@@ -100,7 +101,8 @@ BACKTICK_RE = r'(?<!\\)(`+)(.+?)(?<!`)\2(?!`)' # `e=f()` or ``e=f("`")``
ESCAPE_RE = r'\\(.)' # \< ESCAPE_RE = r'\\(.)' # \<
EMPHASIS_RE = r'(\*)([^\*]+)\2' # *emphasis* EMPHASIS_RE = r'(\*)([^\*]+)\2' # *emphasis*
STRONG_RE = r'(\*{2}|_{2})(.+?)\2' # **strong** STRONG_RE = r'(\*{2}|_{2})(.+?)\2' # **strong**
STRONG_EM_RE = r'(\*{3}|_{3})(.+?)\2' # ***strong*** EM_STRONG_RE = r'(\*|_)\2{2}(.+?)\2(.*?)\2{2}' # ***strongem*** or ***em*strong**
STRONG_EM_RE = r'(\*|_)\2{2}(.+?)\2{2}(.*?)\2' # ***strong**em*
SMART_EMPHASIS_RE = r'(?<!\w)(_)(?!_)(.+?)(?<!_)\2(?!\w)' # _smart_emphasis_ SMART_EMPHASIS_RE = r'(?<!\w)(_)(?!_)(.+?)(?<!_)\2(?!\w)' # _smart_emphasis_
EMPHASIS_2_RE = r'(_)(.+?)\2' # _emphasis_ EMPHASIS_2_RE = r'(_)(.+?)\2' # _emphasis_
LINK_RE = NOIMG + BRK + \ LINK_RE = NOIMG + BRK + \
@@ -178,7 +180,7 @@ class Pattern(object):
* m: A re match object containing a match of the pattern. * m: A re match object containing a match of the pattern.
""" """
pass pass #pragma: no cover
def type(self): def type(self):
""" Return class name, to define pattern type """ """ Return class name, to define pattern type """
@@ -188,9 +190,9 @@ class Pattern(object):
""" Return unescaped text given text with an inline placeholder. """ """ Return unescaped text given text with an inline placeholder. """
try: try:
stash = self.markdown.treeprocessors['inline'].stashed_nodes stash = self.markdown.treeprocessors['inline'].stashed_nodes
except KeyError: except KeyError: #pragma: no cover
return text return text
def itertext(el): def itertext(el): #pragma: no cover
' Reimplement Element.itertext for older python versions ' ' Reimplement Element.itertext for older python versions '
tag = el.tag tag = el.tag
if not isinstance(tag, util.string_type) and tag is not None: if not isinstance(tag, util.string_type) and tag is not None:
@@ -217,10 +219,7 @@ class Pattern(object):
class SimpleTextPattern(Pattern): class SimpleTextPattern(Pattern):
""" Return a simple text of group(2) of a Pattern. """ """ Return a simple text of group(2) of a Pattern. """
def handleMatch(self, m): def handleMatch(self, m):
text = m.group(2) return m.group(2)
if text == util.INLINE_PLACEHOLDER_PREFIX:
return None
return text
class EscapePattern(Pattern): class EscapePattern(Pattern):
@@ -279,6 +278,8 @@ class DoubleTagPattern(SimpleTagPattern):
el1 = util.etree.Element(tag1) el1 = util.etree.Element(tag1)
el2 = util.etree.SubElement(el1, tag2) el2 = util.etree.SubElement(el1, tag2)
el2.text = m.group(3) el2.text = m.group(3)
if len(m.groups())==5:
el2.tail = m.group(4)
return el1 return el1
@@ -293,7 +294,7 @@ class HtmlPattern(Pattern):
""" Return unescaped text given text with an inline placeholder. """ """ Return unescaped text given text with an inline placeholder. """
try: try:
stash = self.markdown.treeprocessors['inline'].stashed_nodes stash = self.markdown.treeprocessors['inline'].stashed_nodes
except KeyError: except KeyError: #pragma: no cover
return text return text
def get_stash(m): def get_stash(m):
id = m.group(1) id = m.group(1)
@@ -350,7 +351,7 @@ class LinkPattern(Pattern):
try: try:
scheme, netloc, path, params, query, fragment = url = urlparse(url) scheme, netloc, path, params, query, fragment = url = urlparse(url)
except ValueError: except ValueError: #pragma: no cover
# Bad url - so bad it couldn't be parsed. # Bad url - so bad it couldn't be parsed.
return '' return ''
@@ -360,7 +361,7 @@ class LinkPattern(Pattern):
# Not a known (allowed) scheme. Not safe. # Not a known (allowed) scheme. Not safe.
return '' return ''
if netloc == '' and scheme not in locless_schemes: if netloc == '' and scheme not in locless_schemes: #pragma: no cover
# This should not happen. Treat as suspect. # This should not happen. Treat as suspect.
return '' return ''

View File

@@ -82,11 +82,11 @@ class OrderedDict(dict):
for key in self.keyOrder: for key in self.keyOrder:
yield self[key] yield self[key]
if util.PY3: if util.PY3: #pragma: no cover
items = _iteritems items = _iteritems
keys = _iterkeys keys = _iterkeys
values = _itervalues values = _itervalues
else: else: #pragma: no cover
iteritems = _iteritems iteritems = _iteritems
iterkeys = _iterkeys iterkeys = _iterkeys
itervalues = _itervalues itervalues = _itervalues

View File

@@ -42,7 +42,7 @@ class Postprocessor(util.Processor):
(possibly modified) string. (possibly modified) string.
""" """
pass pass #pragma: no cover
class RawHtmlPostprocessor(Postprocessor): class RawHtmlPostprocessor(Postprocessor):

View File

@@ -41,7 +41,7 @@ class Preprocessor(util.Processor):
the (possibly modified) list of lines. the (possibly modified) list of lines.
""" """
pass pass #pragma: no cover
class NormalizeWhitespace(Preprocessor): class NormalizeWhitespace(Preprocessor):
@@ -174,9 +174,10 @@ class HtmlBlockPreprocessor(Preprocessor):
else: # raw html else: # raw html
if len(items) - right_listindex <= 1: # last element if len(items) - right_listindex <= 1: # last element
right_listindex -= 1 right_listindex -= 1
offset = 1 if i == right_listindex else 0
placeholder = self.markdown.htmlStash.store('\n\n'.join( placeholder = self.markdown.htmlStash.store('\n\n'.join(
items[i:right_listindex + 1])) items[i:right_listindex + offset]))
del items[i:right_listindex + 1] del items[i:right_listindex + offset]
items.insert(i, placeholder) items.insert(i, placeholder)
return items return items

View File

@@ -42,9 +42,9 @@ from __future__ import unicode_literals
from . import util from . import util
ElementTree = util.etree.ElementTree ElementTree = util.etree.ElementTree
QName = util.etree.QName QName = util.etree.QName
if hasattr(util.etree, 'test_comment'): if hasattr(util.etree, 'test_comment'): #pragma: no cover
Comment = util.etree.test_comment Comment = util.etree.test_comment
else: else: #pragma: no cover
Comment = util.etree.Comment Comment = util.etree.Comment
PI = util.etree.PI PI = util.etree.PI
ProcessingInstruction = util.etree.ProcessingInstruction ProcessingInstruction = util.etree.ProcessingInstruction
@@ -56,7 +56,7 @@ HTML_EMPTY = ("area", "base", "basefont", "br", "col", "frame", "hr",
try: try:
HTML_EMPTY = set(HTML_EMPTY) HTML_EMPTY = set(HTML_EMPTY)
except NameError: except NameError: #pragma: no cover
pass pass
_namespace_map = { _namespace_map = {
@@ -73,7 +73,7 @@ _namespace_map = {
} }
def _raise_serialization_error(text): def _raise_serialization_error(text): #pragma: no cover
raise TypeError( raise TypeError(
"cannot serialize %r (type %s)" % (text, type(text).__name__) "cannot serialize %r (type %s)" % (text, type(text).__name__)
) )
@@ -81,7 +81,7 @@ def _raise_serialization_error(text):
def _encode(text, encoding): def _encode(text, encoding):
try: try:
return text.encode(encoding, "xmlcharrefreplace") return text.encode(encoding, "xmlcharrefreplace")
except (TypeError, AttributeError): except (TypeError, AttributeError): #pragma: no cover
_raise_serialization_error(text) _raise_serialization_error(text)
def _escape_cdata(text): def _escape_cdata(text):
@@ -97,7 +97,7 @@ def _escape_cdata(text):
if ">" in text: if ">" in text:
text = text.replace(">", "&gt;") text = text.replace(">", "&gt;")
return text return text
except (TypeError, AttributeError): except (TypeError, AttributeError): #pragma: no cover
_raise_serialization_error(text) _raise_serialization_error(text)
@@ -115,7 +115,7 @@ def _escape_attrib(text):
if "\n" in text: if "\n" in text:
text = text.replace("\n", "&#10;") text = text.replace("\n", "&#10;")
return text return text
except (TypeError, AttributeError): except (TypeError, AttributeError): #pragma: no cover
_raise_serialization_error(text) _raise_serialization_error(text)
def _escape_attrib_html(text): def _escape_attrib_html(text):
@@ -130,7 +130,7 @@ def _escape_attrib_html(text):
if "\"" in text: if "\"" in text:
text = text.replace("\"", "&quot;") text = text.replace("\"", "&quot;")
return text return text
except (TypeError, AttributeError): except (TypeError, AttributeError): #pragma: no cover
_raise_serialization_error(text) _raise_serialization_error(text)
@@ -240,7 +240,7 @@ def _namespaces(elem, default_namespace=None):
"default_namespace option" "default_namespace option"
) )
qnames[qname] = qname qnames[qname] = qname
except TypeError: except TypeError: #pragma: no cover
_raise_serialization_error(qname) _raise_serialization_error(qname)
# populate qname and namespaces table # populate qname and namespaces table

View File

@@ -38,7 +38,7 @@ class Treeprocessor(util.Processor):
object, and the existing root ElementTree will be replaced, or it can object, and the existing root ElementTree will be replaced, or it can
modify the current tree and return None. modify the current tree and return None.
""" """
pass pass #pragma: no cover
class InlineProcessor(Treeprocessor): class InlineProcessor(Treeprocessor):
@@ -53,6 +53,7 @@ class InlineProcessor(Treeprocessor):
+ len(self.__placeholder_suffix) + len(self.__placeholder_suffix)
self.__placeholder_re = util.INLINE_PLACEHOLDER_RE self.__placeholder_re = util.INLINE_PLACEHOLDER_RE
self.markdown = md self.markdown = md
self.inlinePatterns = md.inlinePatterns
def __makePlaceholder(self, type): def __makePlaceholder(self, type):
""" Generate a placeholder """ """ Generate a placeholder """
@@ -99,9 +100,9 @@ class InlineProcessor(Treeprocessor):
""" """
if not isinstance(data, util.AtomicString): if not isinstance(data, util.AtomicString):
startIndex = 0 startIndex = 0
while patternIndex < len(self.markdown.inlinePatterns): while patternIndex < len(self.inlinePatterns):
data, matched, startIndex = self.__applyPattern( data, matched, startIndex = self.__applyPattern(
self.markdown.inlinePatterns.value_for_index(patternIndex), self.inlinePatterns.value_for_index(patternIndex),
data, patternIndex, startIndex) data, patternIndex, startIndex)
if not matched: if not matched:
patternIndex += 1 patternIndex += 1
@@ -128,11 +129,10 @@ class InlineProcessor(Treeprocessor):
text = subnode.tail text = subnode.tail
subnode.tail = None subnode.tail = None
childResult = self.__processPlaceholders(text, subnode) childResult = self.__processPlaceholders(text, subnode, isText)
if not isText and node is not subnode: if not isText and node is not subnode:
pos = list(node).index(subnode) pos = list(node).index(subnode) + 1
node.remove(subnode)
else: else:
pos = 0 pos = 0
@@ -140,7 +140,7 @@ class InlineProcessor(Treeprocessor):
for newChild in childResult: for newChild in childResult:
node.insert(pos, newChild) node.insert(pos, newChild)
def __processPlaceholders(self, data, parent): def __processPlaceholders(self, data, parent, isText=True):
""" """
Process string with placeholders and generate ElementTree tree. Process string with placeholders and generate ElementTree tree.
@@ -159,6 +159,11 @@ class InlineProcessor(Treeprocessor):
result[-1].tail += text result[-1].tail += text
else: else:
result[-1].tail = text result[-1].tail = text
elif not isText:
if parent.tail:
parent.tail += text
else:
parent.tail = text
else: else:
if parent.text: if parent.text:
parent.text += text parent.text += text
@@ -182,7 +187,7 @@ class InlineProcessor(Treeprocessor):
for child in [node] + list(node): for child in [node] + list(node):
if child.tail: if child.tail:
if child.tail.strip(): if child.tail.strip():
self.__processElementText(node, child,False) self.__processElementText(node, child, False)
if child.text: if child.text:
if child.text.strip(): if child.text.strip():
self.__processElementText(child, child) self.__processElementText(child, child)
@@ -287,11 +292,10 @@ class InlineProcessor(Treeprocessor):
if child.tail: if child.tail:
tail = self.__handleInline(child.tail) tail = self.__handleInline(child.tail)
dumby = util.etree.Element('d') dumby = util.etree.Element('d')
tailResult = self.__processPlaceholders(tail, dumby) child.tail = None
if dumby.text: tailResult = self.__processPlaceholders(tail, dumby, False)
child.tail = dumby.text if dumby.tail:
else: child.tail = dumby.tail
child.tail = None
pos = list(currElement).index(child) + 1 pos = list(currElement).index(child) + 1
tailResult.reverse() tailResult.reverse()
for newChild in tailResult: for newChild in tailResult:
@@ -357,4 +361,4 @@ class PrettifyTreeprocessor(Treeprocessor):
pres = root.getiterator('pre') pres = root.getiterator('pre')
for pre in pres: for pre in pres:
if len(pre) and pre[0].tag == 'code': if len(pre) and pre[0].tag == 'code':
pre[0].text = pre[0].text.rstrip() + '\n' pre[0].text = util.AtomicString(pre[0].text.rstrip() + '\n')

View File

@@ -10,11 +10,11 @@ Python 3 Stuff
""" """
PY3 = sys.version_info[0] == 3 PY3 = sys.version_info[0] == 3
if PY3: if PY3: #pragma: no cover
string_type = str string_type = str
text_type = str text_type = str
int2str = chr int2str = chr
else: else: #pragma: no cover
string_type = basestring string_type = basestring
text_type = unicode text_type = unicode
int2str = unichr int2str = unichr
@@ -58,14 +58,15 @@ RTL_BIDI_RANGES = ( ('\u0590', '\u07FF'),
# Extensions should use "markdown.util.etree" instead of "etree" (or do `from # Extensions should use "markdown.util.etree" instead of "etree" (or do `from
# markdown.util import etree`). Do not import it by yourself. # markdown.util import etree`). Do not import it by yourself.
try: # Is the C implementation of ElementTree available? try: #pragma: no cover
# Is the C implementation of ElementTree available?
import xml.etree.cElementTree as etree import xml.etree.cElementTree as etree
from xml.etree.ElementTree import Comment from xml.etree.ElementTree import Comment
# Serializers (including ours) test with non-c Comment # Serializers (including ours) test with non-c Comment
etree.test_comment = Comment etree.test_comment = Comment
if etree.VERSION < "1.0.5": if etree.VERSION < "1.0.5":
raise RuntimeError("cElementTree version 1.0.5 or higher is required.") raise RuntimeError("cElementTree version 1.0.5 or higher is required.")
except (ImportError, RuntimeError): except (ImportError, RuntimeError): #pragma: no cover
# Use the Python implementation of ElementTree? # Use the Python implementation of ElementTree?
import xml.etree.ElementTree as etree import xml.etree.ElementTree as etree
if etree.VERSION < "1.1": if etree.VERSION < "1.1":
@@ -85,15 +86,20 @@ def isBlockLevel(tag):
# Some ElementTree tags are not strings, so return False. # Some ElementTree tags are not strings, so return False.
return False return False
def parseBoolValue(value, fail_on_errors=True): def parseBoolValue(value, fail_on_errors=True, preserve_none=False):
"""Parses a string representing bool value. If parsing was successful, """Parses a string representing bool value. If parsing was successful,
returns True or False. If parsing was not successful, raises returns True or False. If preserve_none=True, returns True, False,
ValueError, or, if fail_on_errors=False, returns None.""" or None. If parsing was not successful, raises ValueError, or, if
fail_on_errors=False, returns None."""
if not isinstance(value, string_type): if not isinstance(value, string_type):
if preserve_none and value is None:
return value
return bool(value) return bool(value)
elif preserve_none and value.lower() == 'none':
return None
elif value.lower() in ('true', 'yes', 'y', 'on', '1'): elif value.lower() in ('true', 'yes', 'y', 'on', '1'):
return True return True
elif value.lower() in ('false', 'no', 'n', 'off', '0'): elif value.lower() in ('false', 'no', 'n', 'off', '0', 'none'):
return False return False
elif fail_on_errors: elif fail_on_errors:
raise ValueError('Cannot parse bool value: %r' % value) raise ValueError('Cannot parse bool value: %r' % value)

View File

@@ -0,0 +1,31 @@
# Copyright 2014 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
__all__ = [
"__title__", "__summary__", "__uri__", "__version__", "__author__",
"__email__", "__license__", "__copyright__",
]
__title__ = "packaging"
__summary__ = "Core utilities for Python packages"
__uri__ = "https://github.com/pypa/packaging"
__version__ = "15.0"
__author__ = "Donald Stufft"
__email__ = "donald@stufft.io"
__license__ = "Apache License, Version 2.0"
__copyright__ = "Copyright 2014 %s" % __author__

View File

@@ -0,0 +1,24 @@
# Copyright 2014 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
from .__about__ import (
__author__, __copyright__, __email__, __license__, __summary__, __title__,
__uri__, __version__
)
__all__ = [
"__title__", "__summary__", "__uri__", "__version__", "__author__",
"__email__", "__license__", "__copyright__",
]

View File

@@ -0,0 +1,40 @@
# Copyright 2014 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import sys
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
# flake8: noqa
if PY3:
string_types = str,
else:
string_types = basestring,
def with_metaclass(meta, *bases):
"""
Create a base class with a metaclass.
"""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(meta):
def __new__(cls, name, this_bases, d):
return meta(name, bases, d)
return type.__new__(metaclass, 'temporary_class', (), {})

View File

@@ -0,0 +1,78 @@
# Copyright 2014 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
class Infinity(object):
def __repr__(self):
return "Infinity"
def __hash__(self):
return hash(repr(self))
def __lt__(self, other):
return False
def __le__(self, other):
return False
def __eq__(self, other):
return isinstance(other, self.__class__)
def __ne__(self, other):
return not isinstance(other, self.__class__)
def __gt__(self, other):
return True
def __ge__(self, other):
return True
def __neg__(self):
return NegativeInfinity
Infinity = Infinity()
class NegativeInfinity(object):
def __repr__(self):
return "-Infinity"
def __hash__(self):
return hash(repr(self))
def __lt__(self, other):
return True
def __le__(self, other):
return True
def __eq__(self, other):
return isinstance(other, self.__class__)
def __ne__(self, other):
return not isinstance(other, self.__class__)
def __gt__(self, other):
return False
def __ge__(self, other):
return False
def __neg__(self):
return Infinity
NegativeInfinity = NegativeInfinity()

View File

@@ -0,0 +1,772 @@
# Copyright 2014 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import abc
import functools
import itertools
import re
from ._compat import string_types, with_metaclass
from .version import Version, LegacyVersion, parse
class InvalidSpecifier(ValueError):
"""
An invalid specifier was found, users should refer to PEP 440.
"""
class BaseSpecifier(with_metaclass(abc.ABCMeta, object)):
@abc.abstractmethod
def __str__(self):
"""
Returns the str representation of this Specifier like object. This
should be representative of the Specifier itself.
"""
@abc.abstractmethod
def __hash__(self):
"""
Returns a hash value for this Specifier like object.
"""
@abc.abstractmethod
def __eq__(self, other):
"""
Returns a boolean representing whether or not the two Specifier like
objects are equal.
"""
@abc.abstractmethod
def __ne__(self, other):
"""
Returns a boolean representing whether or not the two Specifier like
objects are not equal.
"""
@abc.abstractproperty
def prereleases(self):
"""
Returns whether or not pre-releases as a whole are allowed by this
specifier.
"""
@prereleases.setter
def prereleases(self, value):
"""
Sets whether or not pre-releases as a whole are allowed by this
specifier.
"""
@abc.abstractmethod
def contains(self, item, prereleases=None):
"""
Determines if the given item is contained within this specifier.
"""
@abc.abstractmethod
def filter(self, iterable, prereleases=None):
"""
Takes an iterable of items and filters them so that only items which
are contained within this specifier are allowed in it.
"""
class _IndividualSpecifier(BaseSpecifier):
_operators = {}
def __init__(self, spec="", prereleases=None):
match = self._regex.search(spec)
if not match:
raise InvalidSpecifier("Invalid specifier: '{0}'".format(spec))
self._spec = (
match.group("operator").strip(),
match.group("version").strip(),
)
# Store whether or not this Specifier should accept prereleases
self._prereleases = prereleases
def __repr__(self):
pre = (
", prereleases={0!r}".format(self.prereleases)
if self._prereleases is not None
else ""
)
return "<{0}({1!r}{2})>".format(
self.__class__.__name__,
str(self),
pre,
)
def __str__(self):
return "{0}{1}".format(*self._spec)
def __hash__(self):
return hash(self._spec)
def __eq__(self, other):
if isinstance(other, string_types):
try:
other = self.__class__(other)
except InvalidSpecifier:
return NotImplemented
elif not isinstance(other, self.__class__):
return NotImplemented
return self._spec == other._spec
def __ne__(self, other):
if isinstance(other, string_types):
try:
other = self.__class__(other)
except InvalidSpecifier:
return NotImplemented
elif not isinstance(other, self.__class__):
return NotImplemented
return self._spec != other._spec
def _get_operator(self, op):
return getattr(self, "_compare_{0}".format(self._operators[op]))
def _coerce_version(self, version):
if not isinstance(version, (LegacyVersion, Version)):
version = parse(version)
return version
@property
def prereleases(self):
return self._prereleases
@prereleases.setter
def prereleases(self, value):
self._prereleases = value
def contains(self, item, prereleases=None):
# Determine if prereleases are to be allowed or not.
if prereleases is None:
prereleases = self.prereleases
# Normalize item to a Version or LegacyVersion, this allows us to have
# a shortcut for ``"2.0" in Specifier(">=2")
item = self._coerce_version(item)
# Determine if we should be supporting prereleases in this specifier
# or not, if we do not support prereleases than we can short circuit
# logic if this version is a prereleases.
if item.is_prerelease and not prereleases:
return False
# Actually do the comparison to determine if this item is contained
# within this Specifier or not.
return self._get_operator(self._spec[0])(item, self._spec[1])
def filter(self, iterable, prereleases=None):
yielded = False
found_prereleases = []
kw = {"prereleases": prereleases if prereleases is not None else True}
# Attempt to iterate over all the values in the iterable and if any of
# them match, yield them.
for version in iterable:
parsed_version = self._coerce_version(version)
if self.contains(parsed_version, **kw):
# If our version is a prerelease, and we were not set to allow
# prereleases, then we'll store it for later incase nothing
# else matches this specifier.
if (parsed_version.is_prerelease
and not (prereleases or self.prereleases)):
found_prereleases.append(version)
# Either this is not a prerelease, or we should have been
# accepting prereleases from the begining.
else:
yielded = True
yield version
# Now that we've iterated over everything, determine if we've yielded
# any values, and if we have not and we have any prereleases stored up
# then we will go ahead and yield the prereleases.
if not yielded and found_prereleases:
for version in found_prereleases:
yield version
class LegacySpecifier(_IndividualSpecifier):
_regex = re.compile(
r"""
^
\s*
(?P<operator>(==|!=|<=|>=|<|>))
\s*
(?P<version>
[^\s]* # We just match everything, except for whitespace since this
# is a "legacy" specifier and the version string can be just
# about anything.
)
\s*
$
""",
re.VERBOSE | re.IGNORECASE,
)
_operators = {
"==": "equal",
"!=": "not_equal",
"<=": "less_than_equal",
">=": "greater_than_equal",
"<": "less_than",
">": "greater_than",
}
def _coerce_version(self, version):
if not isinstance(version, LegacyVersion):
version = LegacyVersion(str(version))
return version
def _compare_equal(self, prospective, spec):
return prospective == self._coerce_version(spec)
def _compare_not_equal(self, prospective, spec):
return prospective != self._coerce_version(spec)
def _compare_less_than_equal(self, prospective, spec):
return prospective <= self._coerce_version(spec)
def _compare_greater_than_equal(self, prospective, spec):
return prospective >= self._coerce_version(spec)
def _compare_less_than(self, prospective, spec):
return prospective < self._coerce_version(spec)
def _compare_greater_than(self, prospective, spec):
return prospective > self._coerce_version(spec)
def _require_version_compare(fn):
@functools.wraps(fn)
def wrapped(self, prospective, spec):
if not isinstance(prospective, Version):
return False
return fn(self, prospective, spec)
return wrapped
class Specifier(_IndividualSpecifier):
_regex = re.compile(
r"""
^
\s*
(?P<operator>(~=|==|!=|<=|>=|<|>|===))
(?P<version>
(?:
# The identity operators allow for an escape hatch that will
# do an exact string match of the version you wish to install.
# This will not be parsed by PEP 440 and we cannot determine
# any semantic meaning from it. This operator is discouraged
# but included entirely as an escape hatch.
(?<====) # Only match for the identity operator
\s*
[^\s]* # We just match everything, except for whitespace
# since we are only testing for strict identity.
)
|
(?:
# The (non)equality operators allow for wild card and local
# versions to be specified so we have to define these two
# operators separately to enable that.
(?<===|!=) # Only match for equals and not equals
\s*
v?
(?:[0-9]+!)? # epoch
[0-9]+(?:\.[0-9]+)* # release
(?: # pre release
[-_\.]?
(a|b|c|rc|alpha|beta|pre|preview)
[-_\.]?
[0-9]*
)?
(?: # post release
(?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*)
)?
# You cannot use a wild card and a dev or local version
# together so group them with a | and make them optional.
(?:
(?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release
(?:\+[a-z0-9]+(?:[-_\.][a-z0-9]+)*)? # local
|
\.\* # Wild card syntax of .*
)?
)
|
(?:
# The compatible operator requires at least two digits in the
# release segment.
(?<=~=) # Only match for the compatible operator
\s*
v?
(?:[0-9]+!)? # epoch
[0-9]+(?:\.[0-9]+)+ # release (We have a + instead of a *)
(?: # pre release
[-_\.]?
(a|b|c|rc|alpha|beta|pre|preview)
[-_\.]?
[0-9]*
)?
(?: # post release
(?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*)
)?
(?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release
)
|
(?:
# All other operators only allow a sub set of what the
# (non)equality operators do. Specifically they do not allow
# local versions to be specified nor do they allow the prefix
# matching wild cards.
(?<!==|!=|~=) # We have special cases for these
# operators so we want to make sure they
# don't match here.
\s*
v?
(?:[0-9]+!)? # epoch
[0-9]+(?:\.[0-9]+)* # release
(?: # pre release
[-_\.]?
(a|b|c|rc|alpha|beta|pre|preview)
[-_\.]?
[0-9]*
)?
(?: # post release
(?:-[0-9]+)|(?:[-_\.]?(post|rev|r)[-_\.]?[0-9]*)
)?
(?:[-_\.]?dev[-_\.]?[0-9]*)? # dev release
)
)
\s*
$
""",
re.VERBOSE | re.IGNORECASE,
)
_operators = {
"~=": "compatible",
"==": "equal",
"!=": "not_equal",
"<=": "less_than_equal",
">=": "greater_than_equal",
"<": "less_than",
">": "greater_than",
"===": "arbitrary",
}
@_require_version_compare
def _compare_compatible(self, prospective, spec):
# Compatible releases have an equivalent combination of >= and ==. That
# is that ~=2.2 is equivalent to >=2.2,==2.*. This allows us to
# implement this in terms of the other specifiers instead of
# implementing it ourselves. The only thing we need to do is construct
# the other specifiers.
# We want everything but the last item in the version, but we want to
# ignore post and dev releases and we want to treat the pre-release as
# it's own separate segment.
prefix = ".".join(
list(
itertools.takewhile(
lambda x: (not x.startswith("post")
and not x.startswith("dev")),
_version_split(spec),
)
)[:-1]
)
# Add the prefix notation to the end of our string
prefix += ".*"
return (self._get_operator(">=")(prospective, spec)
and self._get_operator("==")(prospective, prefix))
@_require_version_compare
def _compare_equal(self, prospective, spec):
# We need special logic to handle prefix matching
if spec.endswith(".*"):
# Split the spec out by dots, and pretend that there is an implicit
# dot in between a release segment and a pre-release segment.
spec = _version_split(spec[:-2]) # Remove the trailing .*
# Split the prospective version out by dots, and pretend that there
# is an implicit dot in between a release segment and a pre-release
# segment.
prospective = _version_split(str(prospective))
# Shorten the prospective version to be the same length as the spec
# so that we can determine if the specifier is a prefix of the
# prospective version or not.
prospective = prospective[:len(spec)]
# Pad out our two sides with zeros so that they both equal the same
# length.
spec, prospective = _pad_version(spec, prospective)
else:
# Convert our spec string into a Version
spec = Version(spec)
# If the specifier does not have a local segment, then we want to
# act as if the prospective version also does not have a local
# segment.
if not spec.local:
prospective = Version(prospective.public)
return prospective == spec
@_require_version_compare
def _compare_not_equal(self, prospective, spec):
return not self._compare_equal(prospective, spec)
@_require_version_compare
def _compare_less_than_equal(self, prospective, spec):
return prospective <= Version(spec)
@_require_version_compare
def _compare_greater_than_equal(self, prospective, spec):
return prospective >= Version(spec)
@_require_version_compare
def _compare_less_than(self, prospective, spec):
# Convert our spec to a Version instance, since we'll want to work with
# it as a version.
spec = Version(spec)
# Check to see if the prospective version is less than the spec
# version. If it's not we can short circuit and just return False now
# instead of doing extra unneeded work.
if not prospective < spec:
return False
# This special case is here so that, unless the specifier itself
# includes is a pre-release version, that we do not accept pre-release
# versions for the version mentioned in the specifier (e.g. <3.1 should
# not match 3.1.dev0, but should match 3.0.dev0).
if not spec.is_prerelease and prospective.is_prerelease:
if Version(prospective.base_version) == Version(spec.base_version):
return False
# If we've gotten to here, it means that prospective version is both
# less than the spec version *and* it's not a pre-release of the same
# version in the spec.
return True
@_require_version_compare
def _compare_greater_than(self, prospective, spec):
# Convert our spec to a Version instance, since we'll want to work with
# it as a version.
spec = Version(spec)
# Check to see if the prospective version is greater than the spec
# version. If it's not we can short circuit and just return False now
# instead of doing extra unneeded work.
if not prospective > spec:
return False
# This special case is here so that, unless the specifier itself
# includes is a post-release version, that we do not accept
# post-release versions for the version mentioned in the specifier
# (e.g. >3.1 should not match 3.0.post0, but should match 3.2.post0).
if not spec.is_postrelease and prospective.is_postrelease:
if Version(prospective.base_version) == Version(spec.base_version):
return False
# Ensure that we do not allow a local version of the version mentioned
# in the specifier, which is techincally greater than, to match.
if prospective.local is not None:
if Version(prospective.base_version) == Version(spec.base_version):
return False
# If we've gotten to here, it means that prospective version is both
# greater than the spec version *and* it's not a pre-release of the
# same version in the spec.
return True
def _compare_arbitrary(self, prospective, spec):
return str(prospective).lower() == str(spec).lower()
@property
def prereleases(self):
# If there is an explicit prereleases set for this, then we'll just
# blindly use that.
if self._prereleases is not None:
return self._prereleases
# Look at all of our specifiers and determine if they are inclusive
# operators, and if they are if they are including an explicit
# prerelease.
operator, version = self._spec
if operator in ["==", ">=", "<=", "~="]:
# The == specifier can include a trailing .*, if it does we
# want to remove before parsing.
if operator == "==" and version.endswith(".*"):
version = version[:-2]
# Parse the version, and if it is a pre-release than this
# specifier allows pre-releases.
if parse(version).is_prerelease:
return True
return False
@prereleases.setter
def prereleases(self, value):
self._prereleases = value
_prefix_regex = re.compile(r"^([0-9]+)((?:a|b|c|rc)[0-9]+)$")
def _version_split(version):
result = []
for item in version.split("."):
match = _prefix_regex.search(item)
if match:
result.extend(match.groups())
else:
result.append(item)
return result
def _pad_version(left, right):
left_split, right_split = [], []
# Get the release segment of our versions
left_split.append(list(itertools.takewhile(lambda x: x.isdigit(), left)))
right_split.append(list(itertools.takewhile(lambda x: x.isdigit(), right)))
# Get the rest of our versions
left_split.append(left[len(left_split):])
right_split.append(left[len(right_split):])
# Insert our padding
left_split.insert(
1,
["0"] * max(0, len(right_split[0]) - len(left_split[0])),
)
right_split.insert(
1,
["0"] * max(0, len(left_split[0]) - len(right_split[0])),
)
return (
list(itertools.chain(*left_split)),
list(itertools.chain(*right_split)),
)
class SpecifierSet(BaseSpecifier):
def __init__(self, specifiers="", prereleases=None):
# Split on , to break each indidivual specifier into it's own item, and
# strip each item to remove leading/trailing whitespace.
specifiers = [s.strip() for s in specifiers.split(",") if s.strip()]
# Parsed each individual specifier, attempting first to make it a
# Specifier and falling back to a LegacySpecifier.
parsed = set()
for specifier in specifiers:
try:
parsed.add(Specifier(specifier))
except InvalidSpecifier:
parsed.add(LegacySpecifier(specifier))
# Turn our parsed specifiers into a frozen set and save them for later.
self._specs = frozenset(parsed)
# Store our prereleases value so we can use it later to determine if
# we accept prereleases or not.
self._prereleases = prereleases
def __repr__(self):
pre = (
", prereleases={0!r}".format(self.prereleases)
if self._prereleases is not None
else ""
)
return "<SpecifierSet({0!r}{1})>".format(str(self), pre)
def __str__(self):
return ",".join(sorted(str(s) for s in self._specs))
def __hash__(self):
return hash(self._specs)
def __and__(self, other):
if isinstance(other, string_types):
other = SpecifierSet(other)
elif not isinstance(other, SpecifierSet):
return NotImplemented
specifier = SpecifierSet()
specifier._specs = frozenset(self._specs | other._specs)
if self._prereleases is None and other._prereleases is not None:
specifier._prereleases = other._prereleases
elif self._prereleases is not None and other._prereleases is None:
specifier._prereleases = self._prereleases
elif self._prereleases == other._prereleases:
specifier._prereleases = self._prereleases
else:
raise ValueError(
"Cannot combine SpecifierSets with True and False prerelease "
"overrides."
)
return specifier
def __eq__(self, other):
if isinstance(other, string_types):
other = SpecifierSet(other)
elif isinstance(other, _IndividualSpecifier):
other = SpecifierSet(str(other))
elif not isinstance(other, SpecifierSet):
return NotImplemented
return self._specs == other._specs
def __ne__(self, other):
if isinstance(other, string_types):
other = SpecifierSet(other)
elif isinstance(other, _IndividualSpecifier):
other = SpecifierSet(str(other))
elif not isinstance(other, SpecifierSet):
return NotImplemented
return self._specs != other._specs
@property
def prereleases(self):
# If we have been given an explicit prerelease modifier, then we'll
# pass that through here.
if self._prereleases is not None:
return self._prereleases
# Otherwise we'll see if any of the given specifiers accept
# prereleases, if any of them do we'll return True, otherwise False.
# Note: The use of any() here means that an empty set of specifiers
# will always return False, this is an explicit design decision.
return any(s.prereleases for s in self._specs)
@prereleases.setter
def prereleases(self, value):
self._prereleases = value
def contains(self, item, prereleases=None):
# Ensure that our item is a Version or LegacyVersion instance.
if not isinstance(item, (LegacyVersion, Version)):
item = parse(item)
# We can determine if we're going to allow pre-releases by looking to
# see if any of the underlying items supports them. If none of them do
# and this item is a pre-release then we do not allow it and we can
# short circuit that here.
# Note: This means that 1.0.dev1 would not be contained in something
# like >=1.0.devabc however it would be in >=1.0.debabc,>0.0.dev0
if (not (self.prereleases or prereleases)) and item.is_prerelease:
return False
# Determine if we're forcing a prerelease or not, we bypass
# self.prereleases here and use self._prereleases because we want to
# only take into consideration actual *forced* values. The underlying
# specifiers will handle the other logic.
# The logic here is: If prereleases is anything but None, we'll just
# go aheand and continue to use that. However if
# prereleases is None, then we'll use whatever the
# value of self._prereleases is as long as it is not
# None itself.
if prereleases is None and self._prereleases is not None:
prereleases = self._prereleases
# We simply dispatch to the underlying specs here to make sure that the
# given version is contained within all of them.
# Note: This use of all() here means that an empty set of specifiers
# will always return True, this is an explicit design decision.
return all(
s.contains(item, prereleases=prereleases)
for s in self._specs
)
def filter(self, iterable, prereleases=None):
# Determine if we're forcing a prerelease or not, we bypass
# self.prereleases here and use self._prereleases because we want to
# only take into consideration actual *forced* values. The underlying
# specifiers will handle the other logic.
# The logic here is: If prereleases is anything but None, we'll just
# go aheand and continue to use that. However if
# prereleases is None, then we'll use whatever the
# value of self._prereleases is as long as it is not
# None itself.
if prereleases is None and self._prereleases is not None:
prereleases = self._prereleases
# If we have any specifiers, then we want to wrap our iterable in the
# filter method for each one, this will act as a logical AND amongst
# each specifier.
if self._specs:
for spec in self._specs:
iterable = spec.filter(iterable, prereleases=prereleases)
return iterable
# If we do not have any specifiers, then we need to have a rough filter
# which will filter out any pre-releases, unless there are no final
# releases, and which will filter out LegacyVersion in general.
else:
filtered = []
found_prereleases = []
for item in iterable:
# Ensure that we some kind of Version class for this item.
if not isinstance(item, (LegacyVersion, Version)):
parsed_version = parse(item)
else:
parsed_version = item
# Filter out any item which is parsed as a LegacyVersion
if isinstance(parsed_version, LegacyVersion):
continue
# Store any item which is a pre-release for later unless we've
# already found a final version or we are accepting prereleases
if parsed_version.is_prerelease and not prereleases:
if not filtered:
found_prereleases.append(item)
else:
filtered.append(item)
# If we've found no items except for pre-releases, then we'll go
# ahead and use the pre-releases
if not filtered and found_prereleases and prereleases is None:
return found_prereleases
return filtered

View File

@@ -0,0 +1,401 @@
# Copyright 2014 Donald Stufft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import collections
import itertools
import re
from ._structures import Infinity
__all__ = [
"parse", "Version", "LegacyVersion", "InvalidVersion", "VERSION_PATTERN"
]
_Version = collections.namedtuple(
"_Version",
["epoch", "release", "dev", "pre", "post", "local"],
)
def parse(version):
"""
Parse the given version string and return either a :class:`Version` object
or a :class:`LegacyVersion` object depending on if the given version is
a valid PEP 440 version or a legacy version.
"""
try:
return Version(version)
except InvalidVersion:
return LegacyVersion(version)
class InvalidVersion(ValueError):
"""
An invalid version was found, users should refer to PEP 440.
"""
class _BaseVersion(object):
def __hash__(self):
return hash(self._key)
def __lt__(self, other):
return self._compare(other, lambda s, o: s < o)
def __le__(self, other):
return self._compare(other, lambda s, o: s <= o)
def __eq__(self, other):
return self._compare(other, lambda s, o: s == o)
def __ge__(self, other):
return self._compare(other, lambda s, o: s >= o)
def __gt__(self, other):
return self._compare(other, lambda s, o: s > o)
def __ne__(self, other):
return self._compare(other, lambda s, o: s != o)
def _compare(self, other, method):
if not isinstance(other, _BaseVersion):
return NotImplemented
return method(self._key, other._key)
class LegacyVersion(_BaseVersion):
def __init__(self, version):
self._version = str(version)
self._key = _legacy_cmpkey(self._version)
def __str__(self):
return self._version
def __repr__(self):
return "<LegacyVersion({0})>".format(repr(str(self)))
@property
def public(self):
return self._version
@property
def base_version(self):
return self._version
@property
def local(self):
return None
@property
def is_prerelease(self):
return False
@property
def is_postrelease(self):
return False
_legacy_version_component_re = re.compile(
r"(\d+ | [a-z]+ | \.| -)", re.VERBOSE,
)
_legacy_version_replacement_map = {
"pre": "c", "preview": "c", "-": "final-", "rc": "c", "dev": "@",
}
def _parse_version_parts(s):
for part in _legacy_version_component_re.split(s):
part = _legacy_version_replacement_map.get(part, part)
if not part or part == ".":
continue
if part[:1] in "0123456789":
# pad for numeric comparison
yield part.zfill(8)
else:
yield "*" + part
# ensure that alpha/beta/candidate are before final
yield "*final"
def _legacy_cmpkey(version):
# We hardcode an epoch of -1 here. A PEP 440 version can only have a epoch
# greater than or equal to 0. This will effectively put the LegacyVersion,
# which uses the defacto standard originally implemented by setuptools,
# as before all PEP 440 versions.
epoch = -1
# This scheme is taken from pkg_resources.parse_version setuptools prior to
# it's adoption of the packaging library.
parts = []
for part in _parse_version_parts(version.lower()):
if part.startswith("*"):
# remove "-" before a prerelease tag
if part < "*final":
while parts and parts[-1] == "*final-":
parts.pop()
# remove trailing zeros from each series of numeric parts
while parts and parts[-1] == "00000000":
parts.pop()
parts.append(part)
parts = tuple(parts)
return epoch, parts
# Deliberately not anchored to the start and end of the string, to make it
# easier for 3rd party code to reuse
VERSION_PATTERN = r"""
v?
(?:
(?:(?P<epoch>[0-9]+)!)? # epoch
(?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
(?P<pre> # pre-release
[-_\.]?
(?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
[-_\.]?
(?P<pre_n>[0-9]+)?
)?
(?P<post> # post release
(?:-(?P<post_n1>[0-9]+))
|
(?:
[-_\.]?
(?P<post_l>post|rev|r)
[-_\.]?
(?P<post_n2>[0-9]+)?
)
)?
(?P<dev> # dev release
[-_\.]?
(?P<dev_l>dev)
[-_\.]?
(?P<dev_n>[0-9]+)?
)?
)
(?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
"""
class Version(_BaseVersion):
_regex = re.compile(
r"^\s*" + VERSION_PATTERN + r"\s*$",
re.VERBOSE | re.IGNORECASE,
)
def __init__(self, version):
# Validate the version and parse it into pieces
match = self._regex.search(version)
if not match:
raise InvalidVersion("Invalid version: '{0}'".format(version))
# Store the parsed out pieces of the version
self._version = _Version(
epoch=int(match.group("epoch")) if match.group("epoch") else 0,
release=tuple(int(i) for i in match.group("release").split(".")),
pre=_parse_letter_version(
match.group("pre_l"),
match.group("pre_n"),
),
post=_parse_letter_version(
match.group("post_l"),
match.group("post_n1") or match.group("post_n2"),
),
dev=_parse_letter_version(
match.group("dev_l"),
match.group("dev_n"),
),
local=_parse_local_version(match.group("local")),
)
# Generate a key which will be used for sorting
self._key = _cmpkey(
self._version.epoch,
self._version.release,
self._version.pre,
self._version.post,
self._version.dev,
self._version.local,
)
def __repr__(self):
return "<Version({0})>".format(repr(str(self)))
def __str__(self):
parts = []
# Epoch
if self._version.epoch != 0:
parts.append("{0}!".format(self._version.epoch))
# Release segment
parts.append(".".join(str(x) for x in self._version.release))
# Pre-release
if self._version.pre is not None:
parts.append("".join(str(x) for x in self._version.pre))
# Post-release
if self._version.post is not None:
parts.append(".post{0}".format(self._version.post[1]))
# Development release
if self._version.dev is not None:
parts.append(".dev{0}".format(self._version.dev[1]))
# Local version segment
if self._version.local is not None:
parts.append(
"+{0}".format(".".join(str(x) for x in self._version.local))
)
return "".join(parts)
@property
def public(self):
return str(self).split("+", 1)[0]
@property
def base_version(self):
parts = []
# Epoch
if self._version.epoch != 0:
parts.append("{0}!".format(self._version.epoch))
# Release segment
parts.append(".".join(str(x) for x in self._version.release))
return "".join(parts)
@property
def local(self):
version_string = str(self)
if "+" in version_string:
return version_string.split("+", 1)[1]
@property
def is_prerelease(self):
return bool(self._version.dev or self._version.pre)
@property
def is_postrelease(self):
return bool(self._version.post)
def _parse_letter_version(letter, number):
if letter:
# We consider there to be an implicit 0 in a pre-release if there is
# not a numeral associated with it.
if number is None:
number = 0
# We normalize any letters to their lower case form
letter = letter.lower()
# We consider some words to be alternate spellings of other words and
# in those cases we want to normalize the spellings to our preferred
# spelling.
if letter == "alpha":
letter = "a"
elif letter == "beta":
letter = "b"
elif letter in ["c", "pre", "preview"]:
letter = "rc"
return letter, int(number)
if not letter and number:
# We assume if we are given a number, but we are not given a letter
# then this is using the implicit post release syntax (e.g. 1.0-1)
letter = "post"
return letter, int(number)
_local_version_seperators = re.compile(r"[\._-]")
def _parse_local_version(local):
"""
Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
"""
if local is not None:
return tuple(
part.lower() if not part.isdigit() else int(part)
for part in _local_version_seperators.split(local)
)
def _cmpkey(epoch, release, pre, post, dev, local):
# When we compare a release version, we want to compare it with all of the
# trailing zeros removed. So we'll use a reverse the list, drop all the now
# leading zeros until we come to something non zero, then take the rest
# re-reverse it back into the correct order and make it a tuple and use
# that for our sorting key.
release = tuple(
reversed(list(
itertools.dropwhile(
lambda x: x == 0,
reversed(release),
)
))
)
# We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
# We'll do this by abusing the pre segment, but we _only_ want to do this
# if there is not a pre or a post segment. If we have one of those then
# the normal sorting rules will handle this case correctly.
if pre is None and post is None and dev is not None:
pre = -Infinity
# Versions without a pre-release (except as noted above) should sort after
# those with one.
elif pre is None:
pre = Infinity
# Versions without a post segment should sort before those with one.
if post is None:
post = -Infinity
# Versions without a development segment should sort after those with one.
if dev is None:
dev = Infinity
if local is None:
# Versions without a local segment should sort before those with one.
local = -Infinity
else:
# Versions with a local segment need that segment parsed to implement
# the sorting rules in PEP440.
# - Alpha numeric segments sort before numeric segments
# - Alpha numeric segments sort lexicographically
# - Numeric segments sort numerically
# - Shorter versions sort before longer versions when the prefixes
# match exactly
local = tuple(
(i, "") if isinstance(i, int) else (-Infinity, i)
for i in local
)
return epoch, release, pre, post, dev, local

View File

@@ -0,0 +1 @@
packaging==15.0

View File

@@ -0,0 +1,419 @@
Pluggable Distributions of Python Software
==========================================
Distributions
-------------
A "Distribution" is a collection of files that represent a "Release" of a
"Project" as of a particular point in time, denoted by a
"Version"::
>>> import sys, pkg_resources
>>> from pkg_resources import Distribution
>>> Distribution(project_name="Foo", version="1.2")
Foo 1.2
Distributions have a location, which can be a filename, URL, or really anything
else you care to use::
>>> dist = Distribution(
... location="http://example.com/something",
... project_name="Bar", version="0.9"
... )
>>> dist
Bar 0.9 (http://example.com/something)
Distributions have various introspectable attributes::
>>> dist.location
'http://example.com/something'
>>> dist.project_name
'Bar'
>>> dist.version
'0.9'
>>> dist.py_version == sys.version[:3]
True
>>> print(dist.platform)
None
Including various computed attributes::
>>> from pkg_resources import parse_version
>>> dist.parsed_version == parse_version(dist.version)
True
>>> dist.key # case-insensitive form of the project name
'bar'
Distributions are compared (and hashed) by version first::
>>> Distribution(version='1.0') == Distribution(version='1.0')
True
>>> Distribution(version='1.0') == Distribution(version='1.1')
False
>>> Distribution(version='1.0') < Distribution(version='1.1')
True
but also by project name (case-insensitive), platform, Python version,
location, etc.::
>>> Distribution(project_name="Foo",version="1.0") == \
... Distribution(project_name="Foo",version="1.0")
True
>>> Distribution(project_name="Foo",version="1.0") == \
... Distribution(project_name="foo",version="1.0")
True
>>> Distribution(project_name="Foo",version="1.0") == \
... Distribution(project_name="Foo",version="1.1")
False
>>> Distribution(project_name="Foo",py_version="2.3",version="1.0") == \
... Distribution(project_name="Foo",py_version="2.4",version="1.0")
False
>>> Distribution(location="spam",version="1.0") == \
... Distribution(location="spam",version="1.0")
True
>>> Distribution(location="spam",version="1.0") == \
... Distribution(location="baz",version="1.0")
False
Hash and compare distribution by prio/plat
Get version from metadata
provider capabilities
egg_name()
as_requirement()
from_location, from_filename (w/path normalization)
Releases may have zero or more "Requirements", which indicate
what releases of another project the release requires in order to
function. A Requirement names the other project, expresses some criteria
as to what releases of that project are acceptable, and lists any "Extras"
that the requiring release may need from that project. (An Extra is an
optional feature of a Release, that can only be used if its additional
Requirements are satisfied.)
The Working Set
---------------
A collection of active distributions is called a Working Set. Note that a
Working Set can contain any importable distribution, not just pluggable ones.
For example, the Python standard library is an importable distribution that
will usually be part of the Working Set, even though it is not pluggable.
Similarly, when you are doing development work on a project, the files you are
editing are also a Distribution. (And, with a little attention to the
directory names used, and including some additional metadata, such a
"development distribution" can be made pluggable as well.)
>>> from pkg_resources import WorkingSet
A working set's entries are the sys.path entries that correspond to the active
distributions. By default, the working set's entries are the items on
``sys.path``::
>>> ws = WorkingSet()
>>> ws.entries == sys.path
True
But you can also create an empty working set explicitly, and add distributions
to it::
>>> ws = WorkingSet([])
>>> ws.add(dist)
>>> ws.entries
['http://example.com/something']
>>> dist in ws
True
>>> Distribution('foo',version="") in ws
False
And you can iterate over its distributions::
>>> list(ws)
[Bar 0.9 (http://example.com/something)]
Adding the same distribution more than once is a no-op::
>>> ws.add(dist)
>>> list(ws)
[Bar 0.9 (http://example.com/something)]
For that matter, adding multiple distributions for the same project also does
nothing, because a working set can only hold one active distribution per
project -- the first one added to it::
>>> ws.add(
... Distribution(
... 'http://example.com/something', project_name="Bar",
... version="7.2"
... )
... )
>>> list(ws)
[Bar 0.9 (http://example.com/something)]
You can append a path entry to a working set using ``add_entry()``::
>>> ws.entries
['http://example.com/something']
>>> ws.add_entry(pkg_resources.__file__)
>>> ws.entries
['http://example.com/something', '...pkg_resources...']
Multiple additions result in multiple entries, even if the entry is already in
the working set (because ``sys.path`` can contain the same entry more than
once)::
>>> ws.add_entry(pkg_resources.__file__)
>>> ws.entries
['...example.com...', '...pkg_resources...', '...pkg_resources...']
And you can specify the path entry a distribution was found under, using the
optional second parameter to ``add()``::
>>> ws = WorkingSet([])
>>> ws.add(dist,"foo")
>>> ws.entries
['foo']
But even if a distribution is found under multiple path entries, it still only
shows up once when iterating the working set:
>>> ws.add_entry(ws.entries[0])
>>> list(ws)
[Bar 0.9 (http://example.com/something)]
You can ask a WorkingSet to ``find()`` a distribution matching a requirement::
>>> from pkg_resources import Requirement
>>> print(ws.find(Requirement.parse("Foo==1.0"))) # no match, return None
None
>>> ws.find(Requirement.parse("Bar==0.9")) # match, return distribution
Bar 0.9 (http://example.com/something)
Note that asking for a conflicting version of a distribution already in a
working set triggers a ``pkg_resources.VersionConflict`` error:
>>> try:
... ws.find(Requirement.parse("Bar==1.0"))
... except pkg_resources.VersionConflict as exc:
... print(str(exc))
... else:
... raise AssertionError("VersionConflict was not raised")
(Bar 0.9 (http://example.com/something), Requirement.parse('Bar==1.0'))
You can subscribe a callback function to receive notifications whenever a new
distribution is added to a working set. The callback is immediately invoked
once for each existing distribution in the working set, and then is called
again for new distributions added thereafter::
>>> def added(dist): print("Added %s" % dist)
>>> ws.subscribe(added)
Added Bar 0.9
>>> foo12 = Distribution(project_name="Foo", version="1.2", location="f12")
>>> ws.add(foo12)
Added Foo 1.2
Note, however, that only the first distribution added for a given project name
will trigger a callback, even during the initial ``subscribe()`` callback::
>>> foo14 = Distribution(project_name="Foo", version="1.4", location="f14")
>>> ws.add(foo14) # no callback, because Foo 1.2 is already active
>>> ws = WorkingSet([])
>>> ws.add(foo12)
>>> ws.add(foo14)
>>> ws.subscribe(added)
Added Foo 1.2
And adding a callback more than once has no effect, either::
>>> ws.subscribe(added) # no callbacks
# and no double-callbacks on subsequent additions, either
>>> just_a_test = Distribution(project_name="JustATest", version="0.99")
>>> ws.add(just_a_test)
Added JustATest 0.99
Finding Plugins
---------------
``WorkingSet`` objects can be used to figure out what plugins in an
``Environment`` can be loaded without any resolution errors::
>>> from pkg_resources import Environment
>>> plugins = Environment([]) # normally, a list of plugin directories
>>> plugins.add(foo12)
>>> plugins.add(foo14)
>>> plugins.add(just_a_test)
In the simplest case, we just get the newest version of each distribution in
the plugin environment::
>>> ws = WorkingSet([])
>>> ws.find_plugins(plugins)
([JustATest 0.99, Foo 1.4 (f14)], {})
But if there's a problem with a version conflict or missing requirements, the
method falls back to older versions, and the error info dict will contain an
exception instance for each unloadable plugin::
>>> ws.add(foo12) # this will conflict with Foo 1.4
>>> ws.find_plugins(plugins)
([JustATest 0.99, Foo 1.2 (f12)], {Foo 1.4 (f14): VersionConflict(...)})
But if you disallow fallbacks, the failed plugin will be skipped instead of
trying older versions::
>>> ws.find_plugins(plugins, fallback=False)
([JustATest 0.99], {Foo 1.4 (f14): VersionConflict(...)})
Platform Compatibility Rules
----------------------------
On the Mac, there are potential compatibility issues for modules compiled
on newer versions of Mac OS X than what the user is running. Additionally,
Mac OS X will soon have two platforms to contend with: Intel and PowerPC.
Basic equality works as on other platforms::
>>> from pkg_resources import compatible_platforms as cp
>>> reqd = 'macosx-10.4-ppc'
>>> cp(reqd, reqd)
True
>>> cp("win32", reqd)
False
Distributions made on other machine types are not compatible::
>>> cp("macosx-10.4-i386", reqd)
False
Distributions made on earlier versions of the OS are compatible, as
long as they are from the same top-level version. The patchlevel version
number does not matter::
>>> cp("macosx-10.4-ppc", reqd)
True
>>> cp("macosx-10.3-ppc", reqd)
True
>>> cp("macosx-10.5-ppc", reqd)
False
>>> cp("macosx-9.5-ppc", reqd)
False
Backwards compatibility for packages made via earlier versions of
setuptools is provided as well::
>>> cp("darwin-8.2.0-Power_Macintosh", reqd)
True
>>> cp("darwin-7.2.0-Power_Macintosh", reqd)
True
>>> cp("darwin-8.2.0-Power_Macintosh", "macosx-10.3-ppc")
False
Environment Markers
-------------------
>>> from pkg_resources import invalid_marker as im, evaluate_marker as em
>>> import os
>>> print(im("sys_platform"))
Comparison or logical expression expected
>>> print(im("sys_platform=="))
invalid syntax
>>> print(im("sys_platform=='win32'"))
False
>>> print(im("sys=='x'"))
Unknown name 'sys'
>>> print(im("(extra)"))
Comparison or logical expression expected
>>> print(im("(extra"))
invalid syntax
>>> print(im("os.open('foo')=='y'"))
Language feature not supported in environment markers
>>> print(im("'x'=='y' and os.open('foo')=='y'")) # no short-circuit!
Language feature not supported in environment markers
>>> print(im("'x'=='x' or os.open('foo')=='y'")) # no short-circuit!
Language feature not supported in environment markers
>>> print(im("'x' < 'y'"))
'<' operator not allowed in environment markers
>>> print(im("'x' < 'y' < 'z'"))
Chained comparison not allowed in environment markers
>>> print(im("r'x'=='x'"))
Only plain strings allowed in environment markers
>>> print(im("'''x'''=='x'"))
Only plain strings allowed in environment markers
>>> print(im('"""x"""=="x"'))
Only plain strings allowed in environment markers
>>> print(im(r"'x\n'=='x'"))
Only plain strings allowed in environment markers
>>> print(im("os.open=='y'"))
Language feature not supported in environment markers
>>> em('"x"=="x"')
True
>>> em('"x"=="y"')
False
>>> em('"x"=="y" and "x"=="x"')
False
>>> em('"x"=="y" or "x"=="x"')
True
>>> em('"x"=="y" and "x"=="q" or "z"=="z"')
True
>>> em('"x"=="y" and ("x"=="q" or "z"=="z")')
False
>>> em('"x"=="y" and "z"=="z" or "x"=="q"')
False
>>> em('"x"=="x" and "z"=="z" or "x"=="q"')
True
>>> em("sys_platform=='win32'") == (sys.platform=='win32')
True
>>> em("'x' in 'yx'")
True
>>> em("'yx' in 'x'")
False

View File

@@ -0,0 +1,111 @@
import sys
import tempfile
import os
import zipfile
import datetime
import time
import subprocess
import pkg_resources
try:
unicode
except NameError:
unicode = str
def timestamp(dt):
"""
Return a timestamp for a local, naive datetime instance.
"""
try:
return dt.timestamp()
except AttributeError:
# Python 3.2 and earlier
return time.mktime(dt.timetuple())
class EggRemover(unicode):
def __call__(self):
if self in sys.path:
sys.path.remove(self)
if os.path.exists(self):
os.remove(self)
class TestZipProvider(object):
finalizers = []
ref_time = datetime.datetime(2013, 5, 12, 13, 25, 0)
"A reference time for a file modification"
@classmethod
def setup_class(cls):
"create a zip egg and add it to sys.path"
egg = tempfile.NamedTemporaryFile(suffix='.egg', delete=False)
zip_egg = zipfile.ZipFile(egg, 'w')
zip_info = zipfile.ZipInfo()
zip_info.filename = 'mod.py'
zip_info.date_time = cls.ref_time.timetuple()
zip_egg.writestr(zip_info, 'x = 3\n')
zip_info = zipfile.ZipInfo()
zip_info.filename = 'data.dat'
zip_info.date_time = cls.ref_time.timetuple()
zip_egg.writestr(zip_info, 'hello, world!')
zip_egg.close()
egg.close()
sys.path.append(egg.name)
cls.finalizers.append(EggRemover(egg.name))
@classmethod
def teardown_class(cls):
for finalizer in cls.finalizers:
finalizer()
def test_resource_filename_rewrites_on_change(self):
"""
If a previous call to get_resource_filename has saved the file, but
the file has been subsequently mutated with different file of the
same size and modification time, it should not be overwritten on a
subsequent call to get_resource_filename.
"""
import mod
manager = pkg_resources.ResourceManager()
zp = pkg_resources.ZipProvider(mod)
filename = zp.get_resource_filename(manager, 'data.dat')
actual = datetime.datetime.fromtimestamp(os.stat(filename).st_mtime)
assert actual == self.ref_time
f = open(filename, 'w')
f.write('hello, world?')
f.close()
ts = timestamp(self.ref_time)
os.utime(filename, (ts, ts))
filename = zp.get_resource_filename(manager, 'data.dat')
f = open(filename)
assert f.read() == 'hello, world!'
manager.cleanup_resources()
class TestResourceManager(object):
def test_get_cache_path(self):
mgr = pkg_resources.ResourceManager()
path = mgr.get_cache_path('foo')
type_ = str(type(path))
message = "Unexpected type from get_cache_path: " + type_
assert isinstance(path, (unicode, str)), message
class TestIndependence:
"""
Tests to ensure that pkg_resources runs independently from setuptools.
"""
def test_setuptools_not_imported(self):
"""
In a separate Python environment, import pkg_resources and assert
that action doesn't cause setuptools to be imported.
"""
lines = (
'import pkg_resources',
'import sys',
'assert "setuptools" not in sys.modules, '
'"setuptools was imported"',
)
cmd = [sys.executable, '-c', '; '.join(lines)]
subprocess.check_call(cmd)

View File

@@ -0,0 +1,661 @@
import os
import sys
import tempfile
import shutil
import string
import pytest
import pkg_resources
from pkg_resources import (parse_requirements, VersionConflict, parse_version,
Distribution, EntryPoint, Requirement, safe_version, safe_name,
WorkingSet)
packaging = pkg_resources.packaging
def safe_repr(obj, short=False):
""" copied from Python2.7"""
try:
result = repr(obj)
except Exception:
result = object.__repr__(obj)
if not short or len(result) < pkg_resources._MAX_LENGTH:
return result
return result[:pkg_resources._MAX_LENGTH] + ' [truncated]...'
class Metadata(pkg_resources.EmptyProvider):
"""Mock object to return metadata as if from an on-disk distribution"""
def __init__(self, *pairs):
self.metadata = dict(pairs)
def has_metadata(self, name):
return name in self.metadata
def get_metadata(self, name):
return self.metadata[name]
def get_metadata_lines(self, name):
return pkg_resources.yield_lines(self.get_metadata(name))
dist_from_fn = pkg_resources.Distribution.from_filename
class TestDistro:
def testCollection(self):
# empty path should produce no distributions
ad = pkg_resources.Environment([], platform=None, python=None)
assert list(ad) == []
assert ad['FooPkg'] == []
ad.add(dist_from_fn("FooPkg-1.3_1.egg"))
ad.add(dist_from_fn("FooPkg-1.4-py2.4-win32.egg"))
ad.add(dist_from_fn("FooPkg-1.2-py2.4.egg"))
# Name is in there now
assert ad['FooPkg']
# But only 1 package
assert list(ad) == ['foopkg']
# Distributions sort by version
assert [dist.version for dist in ad['FooPkg']] == ['1.4','1.3-1','1.2']
# Removing a distribution leaves sequence alone
ad.remove(ad['FooPkg'][1])
assert [dist.version for dist in ad['FooPkg']] == ['1.4','1.2']
# And inserting adds them in order
ad.add(dist_from_fn("FooPkg-1.9.egg"))
assert [dist.version for dist in ad['FooPkg']] == ['1.9','1.4','1.2']
ws = WorkingSet([])
foo12 = dist_from_fn("FooPkg-1.2-py2.4.egg")
foo14 = dist_from_fn("FooPkg-1.4-py2.4-win32.egg")
req, = parse_requirements("FooPkg>=1.3")
# Nominal case: no distros on path, should yield all applicable
assert ad.best_match(req, ws).version == '1.9'
# If a matching distro is already installed, should return only that
ws.add(foo14)
assert ad.best_match(req, ws).version == '1.4'
# If the first matching distro is unsuitable, it's a version conflict
ws = WorkingSet([])
ws.add(foo12)
ws.add(foo14)
with pytest.raises(VersionConflict):
ad.best_match(req, ws)
# If more than one match on the path, the first one takes precedence
ws = WorkingSet([])
ws.add(foo14)
ws.add(foo12)
ws.add(foo14)
assert ad.best_match(req, ws).version == '1.4'
def checkFooPkg(self,d):
assert d.project_name == "FooPkg"
assert d.key == "foopkg"
assert d.version == "1.3.post1"
assert d.py_version == "2.4"
assert d.platform == "win32"
assert d.parsed_version == parse_version("1.3-1")
def testDistroBasics(self):
d = Distribution(
"/some/path",
project_name="FooPkg",version="1.3-1",py_version="2.4",platform="win32"
)
self.checkFooPkg(d)
d = Distribution("/some/path")
assert d.py_version == sys.version[:3]
assert d.platform == None
def testDistroParse(self):
d = dist_from_fn("FooPkg-1.3.post1-py2.4-win32.egg")
self.checkFooPkg(d)
d = dist_from_fn("FooPkg-1.3.post1-py2.4-win32.egg-info")
self.checkFooPkg(d)
def testDistroMetadata(self):
d = Distribution(
"/some/path", project_name="FooPkg", py_version="2.4", platform="win32",
metadata = Metadata(
('PKG-INFO',"Metadata-Version: 1.0\nVersion: 1.3-1\n")
)
)
self.checkFooPkg(d)
def distRequires(self, txt):
return Distribution("/foo", metadata=Metadata(('depends.txt', txt)))
def checkRequires(self, dist, txt, extras=()):
assert list(dist.requires(extras)) == list(parse_requirements(txt))
def testDistroDependsSimple(self):
for v in "Twisted>=1.5", "Twisted>=1.5\nZConfig>=2.0":
self.checkRequires(self.distRequires(v), v)
def testResolve(self):
ad = pkg_resources.Environment([])
ws = WorkingSet([])
# Resolving no requirements -> nothing to install
assert list(ws.resolve([], ad)) == []
# Request something not in the collection -> DistributionNotFound
with pytest.raises(pkg_resources.DistributionNotFound):
ws.resolve(parse_requirements("Foo"), ad)
Foo = Distribution.from_filename(
"/foo_dir/Foo-1.2.egg",
metadata=Metadata(('depends.txt', "[bar]\nBaz>=2.0"))
)
ad.add(Foo)
ad.add(Distribution.from_filename("Foo-0.9.egg"))
# Request thing(s) that are available -> list to activate
for i in range(3):
targets = list(ws.resolve(parse_requirements("Foo"), ad))
assert targets == [Foo]
list(map(ws.add,targets))
with pytest.raises(VersionConflict):
ws.resolve(parse_requirements("Foo==0.9"), ad)
ws = WorkingSet([]) # reset
# Request an extra that causes an unresolved dependency for "Baz"
with pytest.raises(pkg_resources.DistributionNotFound):
ws.resolve(parse_requirements("Foo[bar]"), ad)
Baz = Distribution.from_filename(
"/foo_dir/Baz-2.1.egg", metadata=Metadata(('depends.txt', "Foo"))
)
ad.add(Baz)
# Activation list now includes resolved dependency
assert list(ws.resolve(parse_requirements("Foo[bar]"), ad)) ==[Foo,Baz]
# Requests for conflicting versions produce VersionConflict
with pytest.raises(VersionConflict) as vc:
ws.resolve(parse_requirements("Foo==1.2\nFoo!=1.2"), ad)
msg = 'Foo 0.9 is installed but Foo==1.2 is required'
assert vc.value.report() == msg
def testDistroDependsOptions(self):
d = self.distRequires("""
Twisted>=1.5
[docgen]
ZConfig>=2.0
docutils>=0.3
[fastcgi]
fcgiapp>=0.1""")
self.checkRequires(d,"Twisted>=1.5")
self.checkRequires(
d,"Twisted>=1.5 ZConfig>=2.0 docutils>=0.3".split(), ["docgen"]
)
self.checkRequires(
d,"Twisted>=1.5 fcgiapp>=0.1".split(), ["fastcgi"]
)
self.checkRequires(
d,"Twisted>=1.5 ZConfig>=2.0 docutils>=0.3 fcgiapp>=0.1".split(),
["docgen","fastcgi"]
)
self.checkRequires(
d,"Twisted>=1.5 fcgiapp>=0.1 ZConfig>=2.0 docutils>=0.3".split(),
["fastcgi", "docgen"]
)
with pytest.raises(pkg_resources.UnknownExtra):
d.requires(["foo"])
class TestWorkingSet:
def test_find_conflicting(self):
ws = WorkingSet([])
Foo = Distribution.from_filename("/foo_dir/Foo-1.2.egg")
ws.add(Foo)
# create a requirement that conflicts with Foo 1.2
req = next(parse_requirements("Foo<1.2"))
with pytest.raises(VersionConflict) as vc:
ws.find(req)
msg = 'Foo 1.2 is installed but Foo<1.2 is required'
assert vc.value.report() == msg
def test_resolve_conflicts_with_prior(self):
"""
A ContextualVersionConflict should be raised when a requirement
conflicts with a prior requirement for a different package.
"""
# Create installation where Foo depends on Baz 1.0 and Bar depends on
# Baz 2.0.
ws = WorkingSet([])
md = Metadata(('depends.txt', "Baz==1.0"))
Foo = Distribution.from_filename("/foo_dir/Foo-1.0.egg", metadata=md)
ws.add(Foo)
md = Metadata(('depends.txt', "Baz==2.0"))
Bar = Distribution.from_filename("/foo_dir/Bar-1.0.egg", metadata=md)
ws.add(Bar)
Baz = Distribution.from_filename("/foo_dir/Baz-1.0.egg")
ws.add(Baz)
Baz = Distribution.from_filename("/foo_dir/Baz-2.0.egg")
ws.add(Baz)
with pytest.raises(VersionConflict) as vc:
ws.resolve(parse_requirements("Foo\nBar\n"))
msg = "Baz 1.0 is installed but Baz==2.0 is required by {'Bar'}"
if pkg_resources.PY2:
msg = msg.replace("{'Bar'}", "set(['Bar'])")
assert vc.value.report() == msg
class TestEntryPoints:
def assertfields(self, ep):
assert ep.name == "foo"
assert ep.module_name == "pkg_resources.tests.test_resources"
assert ep.attrs == ("TestEntryPoints",)
assert ep.extras == ("x",)
assert ep.load() is TestEntryPoints
expect = "foo = pkg_resources.tests.test_resources:TestEntryPoints [x]"
assert str(ep) == expect
def setup_method(self, method):
self.dist = Distribution.from_filename(
"FooPkg-1.2-py2.4.egg", metadata=Metadata(('requires.txt','[x]')))
def testBasics(self):
ep = EntryPoint(
"foo", "pkg_resources.tests.test_resources", ["TestEntryPoints"],
["x"], self.dist
)
self.assertfields(ep)
def testParse(self):
s = "foo = pkg_resources.tests.test_resources:TestEntryPoints [x]"
ep = EntryPoint.parse(s, self.dist)
self.assertfields(ep)
ep = EntryPoint.parse("bar baz= spammity[PING]")
assert ep.name == "bar baz"
assert ep.module_name == "spammity"
assert ep.attrs == ()
assert ep.extras == ("ping",)
ep = EntryPoint.parse(" fizzly = wocka:foo")
assert ep.name == "fizzly"
assert ep.module_name == "wocka"
assert ep.attrs == ("foo",)
assert ep.extras == ()
# plus in the name
spec = "html+mako = mako.ext.pygmentplugin:MakoHtmlLexer"
ep = EntryPoint.parse(spec)
assert ep.name == 'html+mako'
reject_specs = "foo", "x=a:b:c", "q=x/na", "fez=pish:tush-z", "x=f[a]>2"
@pytest.mark.parametrize("reject_spec", reject_specs)
def test_reject_spec(self, reject_spec):
with pytest.raises(ValueError):
EntryPoint.parse(reject_spec)
def test_printable_name(self):
"""
Allow any printable character in the name.
"""
# Create a name with all printable characters; strip the whitespace.
name = string.printable.strip()
spec = "{name} = module:attr".format(**locals())
ep = EntryPoint.parse(spec)
assert ep.name == name
def checkSubMap(self, m):
assert len(m) == len(self.submap_expect)
for key, ep in pkg_resources.iteritems(self.submap_expect):
assert repr(m.get(key)) == repr(ep)
submap_expect = dict(
feature1=EntryPoint('feature1', 'somemodule', ['somefunction']),
feature2=EntryPoint('feature2', 'another.module', ['SomeClass'], ['extra1','extra2']),
feature3=EntryPoint('feature3', 'this.module', extras=['something'])
)
submap_str = """
# define features for blah blah
feature1 = somemodule:somefunction
feature2 = another.module:SomeClass [extra1,extra2]
feature3 = this.module [something]
"""
def testParseList(self):
self.checkSubMap(EntryPoint.parse_group("xyz", self.submap_str))
with pytest.raises(ValueError):
EntryPoint.parse_group("x a", "foo=bar")
with pytest.raises(ValueError):
EntryPoint.parse_group("x", ["foo=baz", "foo=bar"])
def testParseMap(self):
m = EntryPoint.parse_map({'xyz':self.submap_str})
self.checkSubMap(m['xyz'])
assert list(m.keys()) == ['xyz']
m = EntryPoint.parse_map("[xyz]\n"+self.submap_str)
self.checkSubMap(m['xyz'])
assert list(m.keys()) == ['xyz']
with pytest.raises(ValueError):
EntryPoint.parse_map(["[xyz]", "[xyz]"])
with pytest.raises(ValueError):
EntryPoint.parse_map(self.submap_str)
class TestRequirements:
def testBasics(self):
r = Requirement.parse("Twisted>=1.2")
assert str(r) == "Twisted>=1.2"
assert repr(r) == "Requirement.parse('Twisted>=1.2')"
assert r == Requirement("Twisted", [('>=','1.2')], ())
assert r == Requirement("twisTed", [('>=','1.2')], ())
assert r != Requirement("Twisted", [('>=','2.0')], ())
assert r != Requirement("Zope", [('>=','1.2')], ())
assert r != Requirement("Zope", [('>=','3.0')], ())
assert r != Requirement.parse("Twisted[extras]>=1.2")
def testOrdering(self):
r1 = Requirement("Twisted", [('==','1.2c1'),('>=','1.2')], ())
r2 = Requirement("Twisted", [('>=','1.2'),('==','1.2c1')], ())
assert r1 == r2
assert str(r1) == str(r2)
assert str(r2) == "Twisted==1.2c1,>=1.2"
def testBasicContains(self):
r = Requirement("Twisted", [('>=','1.2')], ())
foo_dist = Distribution.from_filename("FooPkg-1.3_1.egg")
twist11 = Distribution.from_filename("Twisted-1.1.egg")
twist12 = Distribution.from_filename("Twisted-1.2.egg")
assert parse_version('1.2') in r
assert parse_version('1.1') not in r
assert '1.2' in r
assert '1.1' not in r
assert foo_dist not in r
assert twist11 not in r
assert twist12 in r
def testOptionsAndHashing(self):
r1 = Requirement.parse("Twisted[foo,bar]>=1.2")
r2 = Requirement.parse("Twisted[bar,FOO]>=1.2")
assert r1 == r2
assert r1.extras == ("foo","bar")
assert r2.extras == ("bar","foo") # extras are normalized
assert hash(r1) == hash(r2)
assert (
hash(r1)
==
hash((
"twisted",
packaging.specifiers.SpecifierSet(">=1.2"),
frozenset(["foo","bar"]),
))
)
def testVersionEquality(self):
r1 = Requirement.parse("foo==0.3a2")
r2 = Requirement.parse("foo!=0.3a4")
d = Distribution.from_filename
assert d("foo-0.3a4.egg") not in r1
assert d("foo-0.3a1.egg") not in r1
assert d("foo-0.3a4.egg") not in r2
assert d("foo-0.3a2.egg") in r1
assert d("foo-0.3a2.egg") in r2
assert d("foo-0.3a3.egg") in r2
assert d("foo-0.3a5.egg") in r2
def testSetuptoolsProjectName(self):
"""
The setuptools project should implement the setuptools package.
"""
assert (
Requirement.parse('setuptools').project_name == 'setuptools')
# setuptools 0.7 and higher means setuptools.
assert (
Requirement.parse('setuptools == 0.7').project_name == 'setuptools')
assert (
Requirement.parse('setuptools == 0.7a1').project_name == 'setuptools')
assert (
Requirement.parse('setuptools >= 0.7').project_name == 'setuptools')
class TestParsing:
def testEmptyParse(self):
assert list(parse_requirements('')) == []
def testYielding(self):
for inp,out in [
([], []), ('x',['x']), ([[]],[]), (' x\n y', ['x','y']),
(['x\n\n','y'], ['x','y']),
]:
assert list(pkg_resources.yield_lines(inp)) == out
def testSplitting(self):
sample = """
x
[Y]
z
a
[b ]
# foo
c
[ d]
[q]
v
"""
assert (
list(pkg_resources.split_sections(sample))
==
[
(None, ["x"]),
("Y", ["z", "a"]),
("b", ["c"]),
("d", []),
("q", ["v"]),
]
)
with pytest.raises(ValueError):
list(pkg_resources.split_sections("[foo"))
def testSafeName(self):
assert safe_name("adns-python") == "adns-python"
assert safe_name("WSGI Utils") == "WSGI-Utils"
assert safe_name("WSGI Utils") == "WSGI-Utils"
assert safe_name("Money$$$Maker") == "Money-Maker"
assert safe_name("peak.web") != "peak-web"
def testSafeVersion(self):
assert safe_version("1.2-1") == "1.2.post1"
assert safe_version("1.2 alpha") == "1.2.alpha"
assert safe_version("2.3.4 20050521") == "2.3.4.20050521"
assert safe_version("Money$$$Maker") == "Money-Maker"
assert safe_version("peak.web") == "peak.web"
def testSimpleRequirements(self):
assert (
list(parse_requirements('Twis-Ted>=1.2-1'))
==
[Requirement('Twis-Ted',[('>=','1.2-1')], ())]
)
assert (
list(parse_requirements('Twisted >=1.2, \ # more\n<2.0'))
==
[Requirement('Twisted',[('>=','1.2'),('<','2.0')], ())]
)
assert (
Requirement.parse("FooBar==1.99a3")
==
Requirement("FooBar", [('==','1.99a3')], ())
)
with pytest.raises(ValueError):
Requirement.parse(">=2.3")
with pytest.raises(ValueError):
Requirement.parse("x\\")
with pytest.raises(ValueError):
Requirement.parse("x==2 q")
with pytest.raises(ValueError):
Requirement.parse("X==1\nY==2")
with pytest.raises(ValueError):
Requirement.parse("#")
def testVersionEquality(self):
def c(s1,s2):
p1, p2 = parse_version(s1),parse_version(s2)
assert p1 == p2, (s1,s2,p1,p2)
c('1.2-rc1', '1.2rc1')
c('0.4', '0.4.0')
c('0.4.0.0', '0.4.0')
c('0.4.0-0', '0.4-0')
c('0post1', '0.0post1')
c('0pre1', '0.0c1')
c('0.0.0preview1', '0c1')
c('0.0c1', '0-rc1')
c('1.2a1', '1.2.a.1')
c('1.2.a', '1.2a')
def testVersionOrdering(self):
def c(s1,s2):
p1, p2 = parse_version(s1),parse_version(s2)
assert p1<p2, (s1,s2,p1,p2)
c('2.1','2.1.1')
c('2a1','2b0')
c('2a1','2.1')
c('2.3a1', '2.3')
c('2.1-1', '2.1-2')
c('2.1-1', '2.1.1')
c('2.1', '2.1post4')
c('2.1a0-20040501', '2.1')
c('1.1', '02.1')
c('3.2', '3.2.post0')
c('3.2post1', '3.2post2')
c('0.4', '4.0')
c('0.0.4', '0.4.0')
c('0post1', '0.4post1')
c('2.1.0-rc1','2.1.0')
c('2.1dev','2.1a0')
torture ="""
0.80.1-3 0.80.1-2 0.80.1-1 0.79.9999+0.80.0pre4-1
0.79.9999+0.80.0pre2-3 0.79.9999+0.80.0pre2-2
0.77.2-1 0.77.1-1 0.77.0-1
""".split()
for p,v1 in enumerate(torture):
for v2 in torture[p+1:]:
c(v2,v1)
def testVersionBuildout(self):
"""
Buildout has a function in it's bootstrap.py that inspected the return
value of parse_version. The new parse_version returns a Version class
which needs to support this behavior, at least for now.
"""
def buildout(parsed_version):
_final_parts = '*final-', '*final'
def _final_version(parsed_version):
for part in parsed_version:
if (part[:1] == '*') and (part not in _final_parts):
return False
return True
return _final_version(parsed_version)
assert buildout(parse_version("1.0"))
assert not buildout(parse_version("1.0a1"))
def testVersionIndexable(self):
"""
Some projects were doing things like parse_version("v")[0], so we'll
support indexing the same as we support iterating.
"""
assert parse_version("1.0")[0] == "00000001"
def testVersionTupleSort(self):
"""
Some projects expected to be able to sort tuples against the return
value of parse_version. So again we'll add a warning enabled shim to
make this possible.
"""
assert parse_version("1.0") < tuple(parse_version("2.0"))
assert parse_version("1.0") <= tuple(parse_version("2.0"))
assert parse_version("1.0") == tuple(parse_version("1.0"))
assert parse_version("3.0") > tuple(parse_version("2.0"))
assert parse_version("3.0") >= tuple(parse_version("2.0"))
assert parse_version("3.0") != tuple(parse_version("2.0"))
assert not (parse_version("3.0") != tuple(parse_version("3.0")))
def testVersionHashable(self):
"""
Ensure that our versions stay hashable even though we've subclassed
them and added some shim code to them.
"""
assert (
hash(parse_version("1.0"))
==
hash(parse_version("1.0"))
)
class TestNamespaces:
def setup_method(self, method):
self._ns_pkgs = pkg_resources._namespace_packages.copy()
self._tmpdir = tempfile.mkdtemp(prefix="tests-setuptools-")
os.makedirs(os.path.join(self._tmpdir, "site-pkgs"))
self._prev_sys_path = sys.path[:]
sys.path.append(os.path.join(self._tmpdir, "site-pkgs"))
def teardown_method(self, method):
shutil.rmtree(self._tmpdir)
pkg_resources._namespace_packages = self._ns_pkgs.copy()
sys.path = self._prev_sys_path[:]
@pytest.mark.skipif(os.path.islink(tempfile.gettempdir()),
reason="Test fails when /tmp is a symlink. See #231")
def test_two_levels_deep(self):
"""
Test nested namespace packages
Create namespace packages in the following tree :
site-packages-1/pkg1/pkg2
site-packages-2/pkg1/pkg2
Check both are in the _namespace_packages dict and that their __path__
is correct
"""
sys.path.append(os.path.join(self._tmpdir, "site-pkgs2"))
os.makedirs(os.path.join(self._tmpdir, "site-pkgs", "pkg1", "pkg2"))
os.makedirs(os.path.join(self._tmpdir, "site-pkgs2", "pkg1", "pkg2"))
ns_str = "__import__('pkg_resources').declare_namespace(__name__)\n"
for site in ["site-pkgs", "site-pkgs2"]:
pkg1_init = open(os.path.join(self._tmpdir, site,
"pkg1", "__init__.py"), "w")
pkg1_init.write(ns_str)
pkg1_init.close()
pkg2_init = open(os.path.join(self._tmpdir, site,
"pkg1", "pkg2", "__init__.py"), "w")
pkg2_init.write(ns_str)
pkg2_init.close()
import pkg1
assert "pkg1" in pkg_resources._namespace_packages
# attempt to import pkg2 from site-pkgs2
import pkg1.pkg2
# check the _namespace_packages dict
assert "pkg1.pkg2" in pkg_resources._namespace_packages
assert pkg_resources._namespace_packages["pkg1"] == ["pkg1.pkg2"]
# check the __path__ attribute contains both paths
expected = [
os.path.join(self._tmpdir, "site-pkgs", "pkg1", "pkg2"),
os.path.join(self._tmpdir, "site-pkgs2", "pkg1", "pkg2"),
]
assert pkg1.pkg2.__path__ == expected

View File

@@ -9,8 +9,8 @@ on how to use these modules.
''' '''
# The Olson database is updated several times a year. # The Olson database is updated several times a year.
OLSON_VERSION = '2014d' OLSON_VERSION = '2014j'
VERSION = '2014.4' # Switching to pip compatible version numbering. VERSION = '2014.10' # Switching to pip compatible version numbering.
__version__ = VERSION __version__ = VERSION
OLSEN_VERSION = OLSON_VERSION # Old releases had this misspelling OLSEN_VERSION = OLSON_VERSION # Old releases had this misspelling
@@ -735,6 +735,7 @@ all_timezones = \
'Asia/Bishkek', 'Asia/Bishkek',
'Asia/Brunei', 'Asia/Brunei',
'Asia/Calcutta', 'Asia/Calcutta',
'Asia/Chita',
'Asia/Choibalsan', 'Asia/Choibalsan',
'Asia/Chongqing', 'Asia/Chongqing',
'Asia/Chungking', 'Asia/Chungking',
@@ -792,6 +793,7 @@ all_timezones = \
'Asia/Seoul', 'Asia/Seoul',
'Asia/Shanghai', 'Asia/Shanghai',
'Asia/Singapore', 'Asia/Singapore',
'Asia/Srednekolymsk',
'Asia/Taipei', 'Asia/Taipei',
'Asia/Tashkent', 'Asia/Tashkent',
'Asia/Tbilisi', 'Asia/Tbilisi',
@@ -1002,6 +1004,7 @@ all_timezones = \
'PST8PDT', 'PST8PDT',
'Pacific/Apia', 'Pacific/Apia',
'Pacific/Auckland', 'Pacific/Auckland',
'Pacific/Bougainville',
'Pacific/Chatham', 'Pacific/Chatham',
'Pacific/Chuuk', 'Pacific/Chuuk',
'Pacific/Easter', 'Pacific/Easter',
@@ -1297,8 +1300,8 @@ common_timezones = \
'Asia/Beirut', 'Asia/Beirut',
'Asia/Bishkek', 'Asia/Bishkek',
'Asia/Brunei', 'Asia/Brunei',
'Asia/Chita',
'Asia/Choibalsan', 'Asia/Choibalsan',
'Asia/Chongqing',
'Asia/Colombo', 'Asia/Colombo',
'Asia/Damascus', 'Asia/Damascus',
'Asia/Dhaka', 'Asia/Dhaka',
@@ -1306,7 +1309,6 @@ common_timezones = \
'Asia/Dubai', 'Asia/Dubai',
'Asia/Dushanbe', 'Asia/Dushanbe',
'Asia/Gaza', 'Asia/Gaza',
'Asia/Harbin',
'Asia/Hebron', 'Asia/Hebron',
'Asia/Ho_Chi_Minh', 'Asia/Ho_Chi_Minh',
'Asia/Hong_Kong', 'Asia/Hong_Kong',
@@ -1318,7 +1320,6 @@ common_timezones = \
'Asia/Kabul', 'Asia/Kabul',
'Asia/Kamchatka', 'Asia/Kamchatka',
'Asia/Karachi', 'Asia/Karachi',
'Asia/Kashgar',
'Asia/Kathmandu', 'Asia/Kathmandu',
'Asia/Khandyga', 'Asia/Khandyga',
'Asia/Kolkata', 'Asia/Kolkata',
@@ -1348,6 +1349,7 @@ common_timezones = \
'Asia/Seoul', 'Asia/Seoul',
'Asia/Shanghai', 'Asia/Shanghai',
'Asia/Singapore', 'Asia/Singapore',
'Asia/Srednekolymsk',
'Asia/Taipei', 'Asia/Taipei',
'Asia/Tashkent', 'Asia/Tashkent',
'Asia/Tbilisi', 'Asia/Tbilisi',
@@ -1460,6 +1462,7 @@ common_timezones = \
'Indian/Reunion', 'Indian/Reunion',
'Pacific/Apia', 'Pacific/Apia',
'Pacific/Auckland', 'Pacific/Auckland',
'Pacific/Bougainville',
'Pacific/Chatham', 'Pacific/Chatham',
'Pacific/Chuuk', 'Pacific/Chuuk',
'Pacific/Easter', 'Pacific/Easter',

View File

@@ -369,13 +369,15 @@ class DstTzInfo(BaseTzInfo):
# hints to be passed in (such as the UTC offset or abbreviation), # hints to be passed in (such as the UTC offset or abbreviation),
# but that is just getting silly. # but that is just getting silly.
# #
# Choose the earliest (by UTC) applicable timezone. # Choose the earliest (by UTC) applicable timezone if is_dst=True
sorting_keys = {} # Choose the latest (by UTC) applicable timezone if is_dst=False
# i.e., behave like end-of-DST transition
dates = {} # utc -> local
for local_dt in filtered_possible_loc_dt: for local_dt in filtered_possible_loc_dt:
key = local_dt.replace(tzinfo=None) - local_dt.tzinfo._utcoffset utc_time = local_dt.replace(tzinfo=None) - local_dt.tzinfo._utcoffset
sorting_keys[key] = local_dt assert utc_time not in dates
first_key = sorted(sorting_keys)[0] dates[utc_time] = local_dt
return sorting_keys[first_key] return dates[[min, max][not is_dst](dates)]
def utcoffset(self, dt, is_dst=None): def utcoffset(self, dt, is_dst=None):
'''See datetime.tzinfo.utcoffset '''See datetime.tzinfo.utcoffset
@@ -560,4 +562,3 @@ def unpickler(zone, utcoffset=None, dstoffset=None, tzname=None):
inf = (utcoffset, dstoffset, tzname) inf = (utcoffset, dstoffset, tzname)
tz._tzinfos[inf] = tz.__class__(inf, tz._tzinfos) tz._tzinfos[inf] = tz.__class__(inf, tz._tzinfos)
return tz._tzinfos[inf] return tz._tzinfos[inf]

Some files were not shown because too many files have changed in this diff Show More