Fix inventory import to look for group and host vars files with .yml, .yaml and .json extensions in addition to no extension.

This commit is contained in:
Chris Church 2014-08-12 22:57:59 -04:00
parent 88c1b05f50
commit 578f4b9c3b
2 changed files with 55 additions and 31 deletions

View File

@ -53,13 +53,20 @@ class MemObject(object):
self.name = name
self.source_dir = source_dir
def load_vars(self, path):
if os.path.exists(path) and os.path.isfile(path):
def load_vars(self, base_path):
all_vars = {}
for suffix in ('', '.yml', '.yaml', '.json'):
path = ''.join([base_path, suffix])
if not os.path.exists(path):
continue
if not os.path.isfile(path):
continue
vars_name = os.path.basename(os.path.dirname(path))
logger.debug('Loading %s from %s', vars_name, path)
try:
v = yaml.safe_load(file(path, 'r').read())
return v if hasattr(v, 'items') else {}
if hasattr(v, 'items'): # is a dict
all_vars.update(v)
except yaml.YAMLError, e:
if hasattr(e, 'problem_mark'):
logger.error('Invalid YAML in %s:%s col %s', path,
@ -68,7 +75,7 @@ class MemObject(object):
else:
logger.error('Error loading YAML from %s', path)
raise
return {}
return all_vars
class MemGroup(MemObject):

View File

@ -20,7 +20,6 @@ from django.core.management import call_command
from django.core.management.base import CommandError
from django.utils.timezone import now
from django.test.utils import override_settings
import django.db.backends.sqlite3.base
import django.db
@ -90,10 +89,6 @@ lb[01:09:2].example.us even_odd=odd
media[0:9][0:9].example.cc
'''
TEST_GROUP_VARS = '''\
test_username: test
test_email: test@example.com
'''
class BaseCommandMixin(object):
'''
@ -431,20 +426,33 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
ini_file.close()
self._temp_paths.append(self.ini_path)
def create_test_dir(self, hostnames=None):
hostnames = hostnames or []
def create_test_dir(self, host_names=None, group_names=None, suffix=''):
host_names = host_names or []
group_names = group_names or []
if 'all' not in group_names:
group_names.insert(0, 'all')
self.inv_dir = tempfile.mkdtemp()
self._temp_paths.append(self.inv_dir)
self.create_test_ini(self.inv_dir)
group_vars = os.path.join(self.inv_dir, 'group_vars')
os.makedirs(group_vars)
file(os.path.join(group_vars, 'all'), 'wb').write(TEST_GROUP_VARS)
if hostnames:
host_vars = os.path.join(self.inv_dir, 'host_vars')
os.makedirs(host_vars)
for hostname in hostnames:
test_host_vars = '''test_hostname: %s''' % hostname
file(os.path.join(host_vars, hostname), 'wb').write(test_host_vars)
group_vars_dir = os.path.join(self.inv_dir, 'group_vars')
os.makedirs(group_vars_dir)
for group_name in group_names:
if suffix == '.json':
group_vars_content = '''{"test_group_name": "%s"}\n''' % group_name
else:
group_vars_content = '''test_group_name: %s\n''' % group_name
group_vars_file = os.path.join(group_vars_dir, '%s%s' % (group_name, suffix))
file(group_vars_file, 'wb').write(group_vars_content)
if host_names:
host_vars_dir = os.path.join(self.inv_dir, 'host_vars')
os.makedirs(host_vars_dir)
for host_name in host_names:
if suffix == '.json':
host_vars_content = '''{"test_host_name": "%s"}''' % host_name
else:
host_vars_content = '''test_host_name: %s''' % host_name
host_vars_file = os.path.join(host_vars_dir, '%s%s' % (host_name, suffix))
file(host_vars_file, 'wb').write(host_vars_content)
def check_adhoc_inventory_source(self, inventory, except_host_pks=None,
except_group_pks=None):
@ -529,7 +537,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
def test_ini_file(self, source=None):
inv_src = source or self.ini_path
# New empty inventory.
new_inv = self.organizations[0].inventories.create(name='newb')
new_inv = self.organizations[0].inventories.create(name=os.path.basename(inv_src))
self.assertEqual(new_inv.hosts.count(), 0)
self.assertEqual(new_inv.groups.count(), 0)
result, stdout, stderr = self.run_command('inventory_import',
@ -546,11 +554,10 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
'db2.example.com'])
host_names = set(new_inv.hosts.values_list('name', flat=True))
self.assertEqual(expected_host_names, host_names)
if source:
if source and os.path.isdir(source):
self.assertEqual(new_inv.variables_dict, {
'vara': 'A',
'test_username': 'test',
'test_email': 'test@example.com',
'test_group_name': 'all',
})
else:
self.assertEqual(new_inv.variables_dict, {'vara': 'A'})
@ -559,7 +566,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
self.assertEqual(host.variables_dict,
{'ansible_ssh_host': 'w1.example.net'})
elif host.name in ('db1.example.com', 'db2.example.com') and source and os.path.isdir(source):
self.assertEqual(host.variables_dict, {'test_hostname': host.name})
self.assertEqual(host.variables_dict, {'test_host_name': host.name})
elif host.name == 'web3.example.com':
self.assertEqual(host.variables_dict, {'ansible_ssh_port': 1022})
else:
@ -571,7 +578,10 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
self.assertEqual(children, set(['dbservers', 'webservers']))
self.assertEqual(group.hosts.count(), 0)
elif group.name == 'dbservers':
self.assertEqual(group.variables_dict, {'dbvar': 'ugh'})
if source and os.path.isdir(source):
self.assertEqual(group.variables_dict, {'dbvar': 'ugh', 'test_group_name': 'dbservers'})
else:
self.assertEqual(group.variables_dict, {'dbvar': 'ugh'})
self.assertEqual(group.children.count(), 0)
hosts = set(group.hosts.values_list('name', flat=True))
host_names = set(['db1.example.com','db2.example.com'])
@ -586,7 +596,17 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
self.check_adhoc_inventory_source(new_inv)
def test_dir_with_ini_file(self):
self.create_test_dir(hostnames=['db1.example.com', 'db2.example.com'])
self.create_test_dir(host_names=['db1.example.com', 'db2.example.com'],
group_names=['dbservers'], suffix='')
self.test_ini_file(self.inv_dir)
self.create_test_dir(host_names=['db1.example.com', 'db2.example.com'],
group_names=['dbservers'], suffix='.yml')
self.test_ini_file(self.inv_dir)
self.create_test_dir(host_names=['db1.example.com', 'db2.example.com'],
group_names=['dbservers'], suffix='.yaml')
self.test_ini_file(self.inv_dir)
self.create_test_dir(host_names=['db1.example.com', 'db2.example.com'],
group_names=['dbservers'], suffix='.json')
self.test_ini_file(self.inv_dir)
def test_merge_from_ini_file(self, overwrite=False, overwrite_vars=False):
@ -853,12 +873,9 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
@unittest.skipIf(getattr(settings, 'LOCAL_DEVELOPMENT', False),
'Skip this test in local development environments, '
'which may vary widely on memory.')
@unittest.skipIf(django.db.backend == django.db.backend.sqlite3.base,
@unittest.skipIf(django.db.backend == django.db.backends.sqlite3.base,
'Skip this test if we are on sqlite')
def test_splunk_inventory(self):
print django.db.backend
print django.db.backend.sqlite3.base
settings.DEBUG = True
new_inv = self.organizations[0].inventories.create(name='splunk')
self.assertEqual(new_inv.hosts.count(), 0)
self.assertEqual(new_inv.groups.count(), 0)