diff --git a/awx/main/management/commands/inventory_import.py b/awx/main/management/commands/inventory_import.py index 4394509eb5..67907ad8e2 100644 --- a/awx/main/management/commands/inventory_import.py +++ b/awx/main/management/commands/inventory_import.py @@ -53,13 +53,20 @@ class MemObject(object): self.name = name self.source_dir = source_dir - def load_vars(self, path): - if os.path.exists(path) and os.path.isfile(path): + def load_vars(self, base_path): + all_vars = {} + for suffix in ('', '.yml', '.yaml', '.json'): + path = ''.join([base_path, suffix]) + if not os.path.exists(path): + continue + if not os.path.isfile(path): + continue vars_name = os.path.basename(os.path.dirname(path)) logger.debug('Loading %s from %s', vars_name, path) try: v = yaml.safe_load(file(path, 'r').read()) - return v if hasattr(v, 'items') else {} + if hasattr(v, 'items'): # is a dict + all_vars.update(v) except yaml.YAMLError, e: if hasattr(e, 'problem_mark'): logger.error('Invalid YAML in %s:%s col %s', path, @@ -68,7 +75,7 @@ class MemObject(object): else: logger.error('Error loading YAML from %s', path) raise - return {} + return all_vars class MemGroup(MemObject): diff --git a/awx/main/tests/commands.py b/awx/main/tests/commands.py index 1811211aeb..a600e52f68 100644 --- a/awx/main/tests/commands.py +++ b/awx/main/tests/commands.py @@ -20,7 +20,6 @@ from django.core.management import call_command from django.core.management.base import CommandError from django.utils.timezone import now from django.test.utils import override_settings - import django.db.backends.sqlite3.base import django.db @@ -90,10 +89,6 @@ lb[01:09:2].example.us even_odd=odd media[0:9][0:9].example.cc ''' -TEST_GROUP_VARS = '''\ -test_username: test -test_email: test@example.com -''' class BaseCommandMixin(object): ''' @@ -431,20 +426,33 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): ini_file.close() self._temp_paths.append(self.ini_path) - def create_test_dir(self, hostnames=None): - hostnames = hostnames or [] + def create_test_dir(self, host_names=None, group_names=None, suffix=''): + host_names = host_names or [] + group_names = group_names or [] + if 'all' not in group_names: + group_names.insert(0, 'all') self.inv_dir = tempfile.mkdtemp() self._temp_paths.append(self.inv_dir) self.create_test_ini(self.inv_dir) - group_vars = os.path.join(self.inv_dir, 'group_vars') - os.makedirs(group_vars) - file(os.path.join(group_vars, 'all'), 'wb').write(TEST_GROUP_VARS) - if hostnames: - host_vars = os.path.join(self.inv_dir, 'host_vars') - os.makedirs(host_vars) - for hostname in hostnames: - test_host_vars = '''test_hostname: %s''' % hostname - file(os.path.join(host_vars, hostname), 'wb').write(test_host_vars) + group_vars_dir = os.path.join(self.inv_dir, 'group_vars') + os.makedirs(group_vars_dir) + for group_name in group_names: + if suffix == '.json': + group_vars_content = '''{"test_group_name": "%s"}\n''' % group_name + else: + group_vars_content = '''test_group_name: %s\n''' % group_name + group_vars_file = os.path.join(group_vars_dir, '%s%s' % (group_name, suffix)) + file(group_vars_file, 'wb').write(group_vars_content) + if host_names: + host_vars_dir = os.path.join(self.inv_dir, 'host_vars') + os.makedirs(host_vars_dir) + for host_name in host_names: + if suffix == '.json': + host_vars_content = '''{"test_host_name": "%s"}''' % host_name + else: + host_vars_content = '''test_host_name: %s''' % host_name + host_vars_file = os.path.join(host_vars_dir, '%s%s' % (host_name, suffix)) + file(host_vars_file, 'wb').write(host_vars_content) def check_adhoc_inventory_source(self, inventory, except_host_pks=None, except_group_pks=None): @@ -529,7 +537,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): def test_ini_file(self, source=None): inv_src = source or self.ini_path # New empty inventory. - new_inv = self.organizations[0].inventories.create(name='newb') + new_inv = self.organizations[0].inventories.create(name=os.path.basename(inv_src)) self.assertEqual(new_inv.hosts.count(), 0) self.assertEqual(new_inv.groups.count(), 0) result, stdout, stderr = self.run_command('inventory_import', @@ -546,11 +554,10 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): 'db2.example.com']) host_names = set(new_inv.hosts.values_list('name', flat=True)) self.assertEqual(expected_host_names, host_names) - if source: + if source and os.path.isdir(source): self.assertEqual(new_inv.variables_dict, { 'vara': 'A', - 'test_username': 'test', - 'test_email': 'test@example.com', + 'test_group_name': 'all', }) else: self.assertEqual(new_inv.variables_dict, {'vara': 'A'}) @@ -559,7 +566,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): self.assertEqual(host.variables_dict, {'ansible_ssh_host': 'w1.example.net'}) elif host.name in ('db1.example.com', 'db2.example.com') and source and os.path.isdir(source): - self.assertEqual(host.variables_dict, {'test_hostname': host.name}) + self.assertEqual(host.variables_dict, {'test_host_name': host.name}) elif host.name == 'web3.example.com': self.assertEqual(host.variables_dict, {'ansible_ssh_port': 1022}) else: @@ -571,7 +578,10 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): self.assertEqual(children, set(['dbservers', 'webservers'])) self.assertEqual(group.hosts.count(), 0) elif group.name == 'dbservers': - self.assertEqual(group.variables_dict, {'dbvar': 'ugh'}) + if source and os.path.isdir(source): + self.assertEqual(group.variables_dict, {'dbvar': 'ugh', 'test_group_name': 'dbservers'}) + else: + self.assertEqual(group.variables_dict, {'dbvar': 'ugh'}) self.assertEqual(group.children.count(), 0) hosts = set(group.hosts.values_list('name', flat=True)) host_names = set(['db1.example.com','db2.example.com']) @@ -586,7 +596,17 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): self.check_adhoc_inventory_source(new_inv) def test_dir_with_ini_file(self): - self.create_test_dir(hostnames=['db1.example.com', 'db2.example.com']) + self.create_test_dir(host_names=['db1.example.com', 'db2.example.com'], + group_names=['dbservers'], suffix='') + self.test_ini_file(self.inv_dir) + self.create_test_dir(host_names=['db1.example.com', 'db2.example.com'], + group_names=['dbservers'], suffix='.yml') + self.test_ini_file(self.inv_dir) + self.create_test_dir(host_names=['db1.example.com', 'db2.example.com'], + group_names=['dbservers'], suffix='.yaml') + self.test_ini_file(self.inv_dir) + self.create_test_dir(host_names=['db1.example.com', 'db2.example.com'], + group_names=['dbservers'], suffix='.json') self.test_ini_file(self.inv_dir) def test_merge_from_ini_file(self, overwrite=False, overwrite_vars=False): @@ -853,12 +873,9 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): @unittest.skipIf(getattr(settings, 'LOCAL_DEVELOPMENT', False), 'Skip this test in local development environments, ' 'which may vary widely on memory.') - @unittest.skipIf(django.db.backend == django.db.backend.sqlite3.base, + @unittest.skipIf(django.db.backend == django.db.backends.sqlite3.base, 'Skip this test if we are on sqlite') def test_splunk_inventory(self): - print django.db.backend - print django.db.backend.sqlite3.base - settings.DEBUG = True new_inv = self.organizations[0].inventories.create(name='splunk') self.assertEqual(new_inv.hosts.count(), 0) self.assertEqual(new_inv.groups.count(), 0)