From 5fe3ee3bf416361cb2b1fc46c4f7dadcfdf3520c Mon Sep 17 00:00:00 2001 From: Chris Church Date: Fri, 12 Sep 2014 02:34:07 -0400 Subject: [PATCH] Fixes https://trello.com/c/ZBHrkuLb - Add support for IPv6 addresses in inventory import. --- .../management/commands/inventory_import.py | 31 +++++++++++++------ awx/main/tests/commands.py | 28 +++++++++++++---- 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/awx/main/management/commands/inventory_import.py b/awx/main/management/commands/inventory_import.py index d78fe93358..d393846986 100644 --- a/awx/main/management/commands/inventory_import.py +++ b/awx/main/management/commands/inventory_import.py @@ -93,8 +93,7 @@ class MemGroup(MemObject): # maps host and group names to hosts to prevent redudant additions self.all_hosts = {} self.all_groups = {} - - group_vars = os.path.join(self.source_dir, 'group_vars', self.name) + group_vars = os.path.join(source_dir, 'group_vars', self.name) self.variables = self.load_vars(group_vars) logger.debug('Loaded group: %s', self.name) @@ -149,14 +148,13 @@ class MemHost(MemObject): In-memory representation of an inventory host. ''' - def __init__(self, name, source_dir): + def __init__(self, name, source_dir, port=None): super(MemHost, self).__init__(name, source_dir) self.variables = {} self.instance_id = None - if ':' in name: - tokens = name.split(':') - self.name = tokens[0] - self.variables['ansible_ssh_port'] = int(tokens[1]) + self.name = name + if port: + self.variables['ansible_ssh_port'] = port host_vars = os.path.join(source_dir, 'host_vars', name) self.variables.update(self.load_vars(host_vars)) logger.debug('Loaded host: %s', self.name) @@ -173,19 +171,29 @@ class BaseLoader(object): self.all_group = all_group or MemGroup('all', self.source_dir) self.group_filter_re = group_filter_re self.host_filter_re = host_filter_re + self.ipv6_port_re = re.compile(r'^\[([A-Fa-f0-9:]{3,})\]:(\d+?)$') def get_host(self, name): ''' Return a MemHost instance from host name, creating if needed. If name - contains brackets, they will not be interpreted as a host pattern. + contains brackets, they will NOT be interpreted as a host pattern. ''' - host_name = name.split(':')[0] + m = self.ipv6_port_re.match(name) + if m: + host_name = m.groups()[0] + port = int(m.groups()[1]) + elif name.count(':') == 1: + host_name = name.split(':')[0] + port = int(name.split(':')[1]) + else: + host_name = name + port = None if self.host_filter_re and not self.host_filter_re.match(host_name): logger.debug('Filtering host %s', host_name) return None host = None if not host_name in self.all_group.all_hosts: - host = MemHost(name, self.source_dir) + host = MemHost(host_name, self.source_dir, port) self.all_group.all_hosts[host_name] = host return self.all_group.all_hosts[host_name] @@ -201,6 +209,9 @@ class BaseLoader(object): yield ''.join([str(i), j]) else: yield '' + if self.ipv6_port_re.match(name): + yield self.get_host(name) + return pattern_re = re.compile(r'(\[(?:(?:\d+\:\d+)|(?:[A-Za-z]\:[A-Za-z]))(?:\:\d+)??\])') iters = [] for s in re.split(pattern_re, name): diff --git a/awx/main/tests/commands.py b/awx/main/tests/commands.py index 32efaf4028..0841812bbd 100644 --- a/awx/main/tests/commands.py +++ b/awx/main/tests/commands.py @@ -65,6 +65,13 @@ varb=B [all:vars] vara=A + +[others] +10.11.12.13 +10.12.14.16:8022 +fe80::1610:9fff:fedd:654b +[fe80::1610:9fff:fedd:b654]:1022 +::1 ''' TEST_INVENTORY_INI_WITH_HOST_PATTERNS = '''\ @@ -539,12 +546,14 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): self.assertEqual(result, None, stdout + stderr) # Check that inventory is populated as expected. new_inv = Inventory.objects.get(pk=new_inv.pk) - expected_group_names = set(['servers', 'dbservers', 'webservers']) + expected_group_names = set(['servers', 'dbservers', 'webservers', 'others']) group_names = set(new_inv.groups.values_list('name', flat=True)) self.assertEqual(expected_group_names, group_names) expected_host_names = set(['web1.example.com', 'web2.example.com', 'web3.example.com', 'db1.example.com', - 'db2.example.com']) + 'db2.example.com', '10.11.12.13', + '10.12.14.16', 'fe80::1610:9fff:fedd:654b', + 'fe80::1610:9fff:fedd:b654', '::1']) host_names = set(new_inv.hosts.values_list('name', flat=True)) self.assertEqual(expected_host_names, host_names) if source and os.path.isdir(source): @@ -560,8 +569,10 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): {'ansible_ssh_host': 'w1.example.net'}) elif host.name in ('db1.example.com', 'db2.example.com') and source and os.path.isdir(source): self.assertEqual(host.variables_dict, {'test_host_name': host.name}) - elif host.name == 'web3.example.com': + elif host.name in ('web3.example.com', 'fe80::1610:9fff:fedd:b654'): self.assertEqual(host.variables_dict, {'ansible_ssh_port': 1022}) + elif host.name == '10.12.14.16': + self.assertEqual(host.variables_dict, {'ansible_ssh_port': 8022}) else: self.assertEqual(host.variables_dict, {}) for group in new_inv.groups.all(): @@ -624,14 +635,17 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): # Check that inventory is populated as expected. new_inv = Inventory.objects.get(pk=new_inv.pk) expected_group_names = set(['servers', 'dbservers', 'webservers', - 'lbservers']) + 'lbservers', 'others']) if overwrite: expected_group_names.remove('lbservers') group_names = set(new_inv.groups.filter(active=True).values_list('name', flat=True)) self.assertEqual(expected_group_names, group_names) expected_host_names = set(['web1.example.com', 'web2.example.com', 'web3.example.com', 'db1.example.com', - 'db2.example.com', 'lb.example.com']) + 'db2.example.com', 'lb.example.com', + '10.11.12.13', '10.12.14.16', + 'fe80::1610:9fff:fedd:654b', + 'fe80::1610:9fff:fedd:b654', '::1']) if overwrite: expected_host_names.remove('lb.example.com') host_names = set(new_inv.hosts.filter(active=True).values_list('name', flat=True)) @@ -644,8 +658,10 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): if host.name == 'web1.example.com': self.assertEqual(host.variables_dict, {'ansible_ssh_host': 'w1.example.net'}) - elif host.name == 'web3.example.com': + elif host.name in ('web3.example.com', 'fe80::1610:9fff:fedd:b654'): self.assertEqual(host.variables_dict, {'ansible_ssh_port': 1022}) + elif host.name == '10.12.14.16': + self.assertEqual(host.variables_dict, {'ansible_ssh_port': 8022}) elif host.name == 'lb.example.com': self.assertEqual(host.variables_dict, {'lbvar': 'ni!'}) else: