mirror of
https://github.com/ansible/awx.git
synced 2026-05-14 12:57:40 -02:30
Fix inventory import to look for group and host vars files with .yml, .yaml and .json extensions in addition to no extension.
This commit is contained in:
@@ -53,13 +53,20 @@ class MemObject(object):
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.source_dir = source_dir
|
self.source_dir = source_dir
|
||||||
|
|
||||||
def load_vars(self, path):
|
def load_vars(self, base_path):
|
||||||
if os.path.exists(path) and os.path.isfile(path):
|
all_vars = {}
|
||||||
|
for suffix in ('', '.yml', '.yaml', '.json'):
|
||||||
|
path = ''.join([base_path, suffix])
|
||||||
|
if not os.path.exists(path):
|
||||||
|
continue
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
continue
|
||||||
vars_name = os.path.basename(os.path.dirname(path))
|
vars_name = os.path.basename(os.path.dirname(path))
|
||||||
logger.debug('Loading %s from %s', vars_name, path)
|
logger.debug('Loading %s from %s', vars_name, path)
|
||||||
try:
|
try:
|
||||||
v = yaml.safe_load(file(path, 'r').read())
|
v = yaml.safe_load(file(path, 'r').read())
|
||||||
return v if hasattr(v, 'items') else {}
|
if hasattr(v, 'items'): # is a dict
|
||||||
|
all_vars.update(v)
|
||||||
except yaml.YAMLError, e:
|
except yaml.YAMLError, e:
|
||||||
if hasattr(e, 'problem_mark'):
|
if hasattr(e, 'problem_mark'):
|
||||||
logger.error('Invalid YAML in %s:%s col %s', path,
|
logger.error('Invalid YAML in %s:%s col %s', path,
|
||||||
@@ -68,7 +75,7 @@ class MemObject(object):
|
|||||||
else:
|
else:
|
||||||
logger.error('Error loading YAML from %s', path)
|
logger.error('Error loading YAML from %s', path)
|
||||||
raise
|
raise
|
||||||
return {}
|
return all_vars
|
||||||
|
|
||||||
|
|
||||||
class MemGroup(MemObject):
|
class MemGroup(MemObject):
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from django.core.management import call_command
|
|||||||
from django.core.management.base import CommandError
|
from django.core.management.base import CommandError
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
from django.test.utils import override_settings
|
from django.test.utils import override_settings
|
||||||
|
|
||||||
import django.db.backends.sqlite3.base
|
import django.db.backends.sqlite3.base
|
||||||
import django.db
|
import django.db
|
||||||
|
|
||||||
@@ -90,10 +89,6 @@ lb[01:09:2].example.us even_odd=odd
|
|||||||
media[0:9][0:9].example.cc
|
media[0:9][0:9].example.cc
|
||||||
'''
|
'''
|
||||||
|
|
||||||
TEST_GROUP_VARS = '''\
|
|
||||||
test_username: test
|
|
||||||
test_email: test@example.com
|
|
||||||
'''
|
|
||||||
|
|
||||||
class BaseCommandMixin(object):
|
class BaseCommandMixin(object):
|
||||||
'''
|
'''
|
||||||
@@ -431,20 +426,33 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
|
|||||||
ini_file.close()
|
ini_file.close()
|
||||||
self._temp_paths.append(self.ini_path)
|
self._temp_paths.append(self.ini_path)
|
||||||
|
|
||||||
def create_test_dir(self, hostnames=None):
|
def create_test_dir(self, host_names=None, group_names=None, suffix=''):
|
||||||
hostnames = hostnames or []
|
host_names = host_names or []
|
||||||
|
group_names = group_names or []
|
||||||
|
if 'all' not in group_names:
|
||||||
|
group_names.insert(0, 'all')
|
||||||
self.inv_dir = tempfile.mkdtemp()
|
self.inv_dir = tempfile.mkdtemp()
|
||||||
self._temp_paths.append(self.inv_dir)
|
self._temp_paths.append(self.inv_dir)
|
||||||
self.create_test_ini(self.inv_dir)
|
self.create_test_ini(self.inv_dir)
|
||||||
group_vars = os.path.join(self.inv_dir, 'group_vars')
|
group_vars_dir = os.path.join(self.inv_dir, 'group_vars')
|
||||||
os.makedirs(group_vars)
|
os.makedirs(group_vars_dir)
|
||||||
file(os.path.join(group_vars, 'all'), 'wb').write(TEST_GROUP_VARS)
|
for group_name in group_names:
|
||||||
if hostnames:
|
if suffix == '.json':
|
||||||
host_vars = os.path.join(self.inv_dir, 'host_vars')
|
group_vars_content = '''{"test_group_name": "%s"}\n''' % group_name
|
||||||
os.makedirs(host_vars)
|
else:
|
||||||
for hostname in hostnames:
|
group_vars_content = '''test_group_name: %s\n''' % group_name
|
||||||
test_host_vars = '''test_hostname: %s''' % hostname
|
group_vars_file = os.path.join(group_vars_dir, '%s%s' % (group_name, suffix))
|
||||||
file(os.path.join(host_vars, hostname), 'wb').write(test_host_vars)
|
file(group_vars_file, 'wb').write(group_vars_content)
|
||||||
|
if host_names:
|
||||||
|
host_vars_dir = os.path.join(self.inv_dir, 'host_vars')
|
||||||
|
os.makedirs(host_vars_dir)
|
||||||
|
for host_name in host_names:
|
||||||
|
if suffix == '.json':
|
||||||
|
host_vars_content = '''{"test_host_name": "%s"}''' % host_name
|
||||||
|
else:
|
||||||
|
host_vars_content = '''test_host_name: %s''' % host_name
|
||||||
|
host_vars_file = os.path.join(host_vars_dir, '%s%s' % (host_name, suffix))
|
||||||
|
file(host_vars_file, 'wb').write(host_vars_content)
|
||||||
|
|
||||||
def check_adhoc_inventory_source(self, inventory, except_host_pks=None,
|
def check_adhoc_inventory_source(self, inventory, except_host_pks=None,
|
||||||
except_group_pks=None):
|
except_group_pks=None):
|
||||||
@@ -529,7 +537,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
|
|||||||
def test_ini_file(self, source=None):
|
def test_ini_file(self, source=None):
|
||||||
inv_src = source or self.ini_path
|
inv_src = source or self.ini_path
|
||||||
# New empty inventory.
|
# New empty inventory.
|
||||||
new_inv = self.organizations[0].inventories.create(name='newb')
|
new_inv = self.organizations[0].inventories.create(name=os.path.basename(inv_src))
|
||||||
self.assertEqual(new_inv.hosts.count(), 0)
|
self.assertEqual(new_inv.hosts.count(), 0)
|
||||||
self.assertEqual(new_inv.groups.count(), 0)
|
self.assertEqual(new_inv.groups.count(), 0)
|
||||||
result, stdout, stderr = self.run_command('inventory_import',
|
result, stdout, stderr = self.run_command('inventory_import',
|
||||||
@@ -546,11 +554,10 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
|
|||||||
'db2.example.com'])
|
'db2.example.com'])
|
||||||
host_names = set(new_inv.hosts.values_list('name', flat=True))
|
host_names = set(new_inv.hosts.values_list('name', flat=True))
|
||||||
self.assertEqual(expected_host_names, host_names)
|
self.assertEqual(expected_host_names, host_names)
|
||||||
if source:
|
if source and os.path.isdir(source):
|
||||||
self.assertEqual(new_inv.variables_dict, {
|
self.assertEqual(new_inv.variables_dict, {
|
||||||
'vara': 'A',
|
'vara': 'A',
|
||||||
'test_username': 'test',
|
'test_group_name': 'all',
|
||||||
'test_email': 'test@example.com',
|
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
self.assertEqual(new_inv.variables_dict, {'vara': 'A'})
|
self.assertEqual(new_inv.variables_dict, {'vara': 'A'})
|
||||||
@@ -559,7 +566,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
|
|||||||
self.assertEqual(host.variables_dict,
|
self.assertEqual(host.variables_dict,
|
||||||
{'ansible_ssh_host': 'w1.example.net'})
|
{'ansible_ssh_host': 'w1.example.net'})
|
||||||
elif host.name in ('db1.example.com', 'db2.example.com') and source and os.path.isdir(source):
|
elif host.name in ('db1.example.com', 'db2.example.com') and source and os.path.isdir(source):
|
||||||
self.assertEqual(host.variables_dict, {'test_hostname': host.name})
|
self.assertEqual(host.variables_dict, {'test_host_name': host.name})
|
||||||
elif host.name == 'web3.example.com':
|
elif host.name == 'web3.example.com':
|
||||||
self.assertEqual(host.variables_dict, {'ansible_ssh_port': 1022})
|
self.assertEqual(host.variables_dict, {'ansible_ssh_port': 1022})
|
||||||
else:
|
else:
|
||||||
@@ -571,7 +578,10 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
|
|||||||
self.assertEqual(children, set(['dbservers', 'webservers']))
|
self.assertEqual(children, set(['dbservers', 'webservers']))
|
||||||
self.assertEqual(group.hosts.count(), 0)
|
self.assertEqual(group.hosts.count(), 0)
|
||||||
elif group.name == 'dbservers':
|
elif group.name == 'dbservers':
|
||||||
self.assertEqual(group.variables_dict, {'dbvar': 'ugh'})
|
if source and os.path.isdir(source):
|
||||||
|
self.assertEqual(group.variables_dict, {'dbvar': 'ugh', 'test_group_name': 'dbservers'})
|
||||||
|
else:
|
||||||
|
self.assertEqual(group.variables_dict, {'dbvar': 'ugh'})
|
||||||
self.assertEqual(group.children.count(), 0)
|
self.assertEqual(group.children.count(), 0)
|
||||||
hosts = set(group.hosts.values_list('name', flat=True))
|
hosts = set(group.hosts.values_list('name', flat=True))
|
||||||
host_names = set(['db1.example.com','db2.example.com'])
|
host_names = set(['db1.example.com','db2.example.com'])
|
||||||
@@ -586,7 +596,17 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
|
|||||||
self.check_adhoc_inventory_source(new_inv)
|
self.check_adhoc_inventory_source(new_inv)
|
||||||
|
|
||||||
def test_dir_with_ini_file(self):
|
def test_dir_with_ini_file(self):
|
||||||
self.create_test_dir(hostnames=['db1.example.com', 'db2.example.com'])
|
self.create_test_dir(host_names=['db1.example.com', 'db2.example.com'],
|
||||||
|
group_names=['dbservers'], suffix='')
|
||||||
|
self.test_ini_file(self.inv_dir)
|
||||||
|
self.create_test_dir(host_names=['db1.example.com', 'db2.example.com'],
|
||||||
|
group_names=['dbservers'], suffix='.yml')
|
||||||
|
self.test_ini_file(self.inv_dir)
|
||||||
|
self.create_test_dir(host_names=['db1.example.com', 'db2.example.com'],
|
||||||
|
group_names=['dbservers'], suffix='.yaml')
|
||||||
|
self.test_ini_file(self.inv_dir)
|
||||||
|
self.create_test_dir(host_names=['db1.example.com', 'db2.example.com'],
|
||||||
|
group_names=['dbservers'], suffix='.json')
|
||||||
self.test_ini_file(self.inv_dir)
|
self.test_ini_file(self.inv_dir)
|
||||||
|
|
||||||
def test_merge_from_ini_file(self, overwrite=False, overwrite_vars=False):
|
def test_merge_from_ini_file(self, overwrite=False, overwrite_vars=False):
|
||||||
@@ -853,12 +873,9 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
|
|||||||
@unittest.skipIf(getattr(settings, 'LOCAL_DEVELOPMENT', False),
|
@unittest.skipIf(getattr(settings, 'LOCAL_DEVELOPMENT', False),
|
||||||
'Skip this test in local development environments, '
|
'Skip this test in local development environments, '
|
||||||
'which may vary widely on memory.')
|
'which may vary widely on memory.')
|
||||||
@unittest.skipIf(django.db.backend == django.db.backend.sqlite3.base,
|
@unittest.skipIf(django.db.backend == django.db.backends.sqlite3.base,
|
||||||
'Skip this test if we are on sqlite')
|
'Skip this test if we are on sqlite')
|
||||||
def test_splunk_inventory(self):
|
def test_splunk_inventory(self):
|
||||||
print django.db.backend
|
|
||||||
print django.db.backend.sqlite3.base
|
|
||||||
settings.DEBUG = True
|
|
||||||
new_inv = self.organizations[0].inventories.create(name='splunk')
|
new_inv = self.organizations[0].inventories.create(name='splunk')
|
||||||
self.assertEqual(new_inv.hosts.count(), 0)
|
self.assertEqual(new_inv.hosts.count(), 0)
|
||||||
self.assertEqual(new_inv.groups.count(), 0)
|
self.assertEqual(new_inv.groups.count(), 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user