Fixes https://trello.com/c/ZBHrkuLb - Add support for IPv6 addresses in inventory import.

This commit is contained in:
Chris Church 2014-09-12 02:34:07 -04:00
parent 1128c55cc3
commit 5fe3ee3bf4
2 changed files with 43 additions and 16 deletions

View File

@ -93,8 +93,7 @@ class MemGroup(MemObject):
# maps host and group names to hosts to prevent redudant additions
self.all_hosts = {}
self.all_groups = {}
group_vars = os.path.join(self.source_dir, 'group_vars', self.name)
group_vars = os.path.join(source_dir, 'group_vars', self.name)
self.variables = self.load_vars(group_vars)
logger.debug('Loaded group: %s', self.name)
@ -149,14 +148,13 @@ class MemHost(MemObject):
In-memory representation of an inventory host.
'''
def __init__(self, name, source_dir):
def __init__(self, name, source_dir, port=None):
super(MemHost, self).__init__(name, source_dir)
self.variables = {}
self.instance_id = None
if ':' in name:
tokens = name.split(':')
self.name = tokens[0]
self.variables['ansible_ssh_port'] = int(tokens[1])
self.name = name
if port:
self.variables['ansible_ssh_port'] = port
host_vars = os.path.join(source_dir, 'host_vars', name)
self.variables.update(self.load_vars(host_vars))
logger.debug('Loaded host: %s', self.name)
@ -173,19 +171,29 @@ class BaseLoader(object):
self.all_group = all_group or MemGroup('all', self.source_dir)
self.group_filter_re = group_filter_re
self.host_filter_re = host_filter_re
self.ipv6_port_re = re.compile(r'^\[([A-Fa-f0-9:]{3,})\]:(\d+?)$')
def get_host(self, name):
'''
Return a MemHost instance from host name, creating if needed. If name
contains brackets, they will not be interpreted as a host pattern.
contains brackets, they will NOT be interpreted as a host pattern.
'''
host_name = name.split(':')[0]
m = self.ipv6_port_re.match(name)
if m:
host_name = m.groups()[0]
port = int(m.groups()[1])
elif name.count(':') == 1:
host_name = name.split(':')[0]
port = int(name.split(':')[1])
else:
host_name = name
port = None
if self.host_filter_re and not self.host_filter_re.match(host_name):
logger.debug('Filtering host %s', host_name)
return None
host = None
if not host_name in self.all_group.all_hosts:
host = MemHost(name, self.source_dir)
host = MemHost(host_name, self.source_dir, port)
self.all_group.all_hosts[host_name] = host
return self.all_group.all_hosts[host_name]
@ -201,6 +209,9 @@ class BaseLoader(object):
yield ''.join([str(i), j])
else:
yield ''
if self.ipv6_port_re.match(name):
yield self.get_host(name)
return
pattern_re = re.compile(r'(\[(?:(?:\d+\:\d+)|(?:[A-Za-z]\:[A-Za-z]))(?:\:\d+)??\])')
iters = []
for s in re.split(pattern_re, name):

View File

@ -65,6 +65,13 @@ varb=B
[all:vars]
vara=A
[others]
10.11.12.13
10.12.14.16:8022
fe80::1610:9fff:fedd:654b
[fe80::1610:9fff:fedd:b654]:1022
::1
'''
TEST_INVENTORY_INI_WITH_HOST_PATTERNS = '''\
@ -539,12 +546,14 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
self.assertEqual(result, None, stdout + stderr)
# Check that inventory is populated as expected.
new_inv = Inventory.objects.get(pk=new_inv.pk)
expected_group_names = set(['servers', 'dbservers', 'webservers'])
expected_group_names = set(['servers', 'dbservers', 'webservers', 'others'])
group_names = set(new_inv.groups.values_list('name', flat=True))
self.assertEqual(expected_group_names, group_names)
expected_host_names = set(['web1.example.com', 'web2.example.com',
'web3.example.com', 'db1.example.com',
'db2.example.com'])
'db2.example.com', '10.11.12.13',
'10.12.14.16', 'fe80::1610:9fff:fedd:654b',
'fe80::1610:9fff:fedd:b654', '::1'])
host_names = set(new_inv.hosts.values_list('name', flat=True))
self.assertEqual(expected_host_names, host_names)
if source and os.path.isdir(source):
@ -560,8 +569,10 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
{'ansible_ssh_host': 'w1.example.net'})
elif host.name in ('db1.example.com', 'db2.example.com') and source and os.path.isdir(source):
self.assertEqual(host.variables_dict, {'test_host_name': host.name})
elif host.name == 'web3.example.com':
elif host.name in ('web3.example.com', 'fe80::1610:9fff:fedd:b654'):
self.assertEqual(host.variables_dict, {'ansible_ssh_port': 1022})
elif host.name == '10.12.14.16':
self.assertEqual(host.variables_dict, {'ansible_ssh_port': 8022})
else:
self.assertEqual(host.variables_dict, {})
for group in new_inv.groups.all():
@ -624,14 +635,17 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
# Check that inventory is populated as expected.
new_inv = Inventory.objects.get(pk=new_inv.pk)
expected_group_names = set(['servers', 'dbservers', 'webservers',
'lbservers'])
'lbservers', 'others'])
if overwrite:
expected_group_names.remove('lbservers')
group_names = set(new_inv.groups.filter(active=True).values_list('name', flat=True))
self.assertEqual(expected_group_names, group_names)
expected_host_names = set(['web1.example.com', 'web2.example.com',
'web3.example.com', 'db1.example.com',
'db2.example.com', 'lb.example.com'])
'db2.example.com', 'lb.example.com',
'10.11.12.13', '10.12.14.16',
'fe80::1610:9fff:fedd:654b',
'fe80::1610:9fff:fedd:b654', '::1'])
if overwrite:
expected_host_names.remove('lb.example.com')
host_names = set(new_inv.hosts.filter(active=True).values_list('name', flat=True))
@ -644,8 +658,10 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
if host.name == 'web1.example.com':
self.assertEqual(host.variables_dict,
{'ansible_ssh_host': 'w1.example.net'})
elif host.name == 'web3.example.com':
elif host.name in ('web3.example.com', 'fe80::1610:9fff:fedd:b654'):
self.assertEqual(host.variables_dict, {'ansible_ssh_port': 1022})
elif host.name == '10.12.14.16':
self.assertEqual(host.variables_dict, {'ansible_ssh_port': 8022})
elif host.name == 'lb.example.com':
self.assertEqual(host.variables_dict, {'lbvar': 'ni!'})
else: