mirror of
https://github.com/ansible/awx.git
synced 2026-01-13 02:50:02 -03:30
Add import command support for AC-332. Refactored import script to eliminate use of global variables, which breaks tests.
This commit is contained in:
parent
d774327c68
commit
290768c20d
@ -23,10 +23,6 @@ from awx.main.licenses import LicenseReader
|
||||
|
||||
LOGGER = None
|
||||
|
||||
# maps host and group names to hosts to prevent redudant additions
|
||||
group_names = {}
|
||||
host_names = {}
|
||||
|
||||
class ImportException(BaseException):
|
||||
|
||||
def __init__(self, msg):
|
||||
@ -47,6 +43,10 @@ class MemGroup(object):
|
||||
self.hosts = []
|
||||
self.variables = {}
|
||||
self.parents = []
|
||||
# Used on the "all" group in place of previous global variables.
|
||||
# maps host and group names to hosts to prevent redudant additions
|
||||
self.host_names = {}
|
||||
self.group_names = {}
|
||||
|
||||
group_vars = os.path.join(inventory_base, 'group_vars', name)
|
||||
if os.path.exists(group_vars):
|
||||
@ -133,43 +133,49 @@ class MemHost(object):
|
||||
|
||||
class BaseLoader(object):
|
||||
|
||||
def __init__(self, inventory_base=None, all_group=None):
|
||||
self.inventory_base = inventory_base
|
||||
self.all_group = all_group
|
||||
|
||||
def get_host(self, name):
|
||||
if ":" in name:
|
||||
tokens = name.split(":")
|
||||
name = tokens[0]
|
||||
global host_names
|
||||
host = None
|
||||
if not name in host_names:
|
||||
if not name in self.all_group.host_names:
|
||||
host = MemHost(name, self.inventory_base)
|
||||
host_names[name] = host
|
||||
return host_names[name]
|
||||
self.all_group.host_names[name] = host
|
||||
return self.all_group.host_names[name]
|
||||
|
||||
def get_group(self, name, all_group, child=False):
|
||||
global group_names
|
||||
def get_group(self, name, all_group=None, child=False):
|
||||
all_group = all_group or self.all_group
|
||||
if name == 'all':
|
||||
return all_group
|
||||
if not name in group_names:
|
||||
if not name in self.all_group.group_names:
|
||||
group = MemGroup(name, self.inventory_base)
|
||||
if not child:
|
||||
all_group.add_child_group(group)
|
||||
group_names[name] = group
|
||||
return group_names[name]
|
||||
self.all_group.group_names[name] = group
|
||||
return self.all_group.group_names[name]
|
||||
|
||||
def load(self, src):
|
||||
raise NotImplementedError
|
||||
|
||||
class IniLoader(BaseLoader):
|
||||
|
||||
def __init__(self, inventory_base=None):
|
||||
def __init__(self, inventory_base=None, all_group=None):
|
||||
super(IniLoader, self).__init__(inventory_base, all_group)
|
||||
LOGGER.debug("processing ini")
|
||||
self.inventory_base = inventory_base
|
||||
|
||||
def load(self, src, all_group):
|
||||
LOGGER.debug("loading: %s on %s" % (src, all_group))
|
||||
def load(self, src):
|
||||
LOGGER.debug("loading: %s on %s" % (src, self.all_group))
|
||||
|
||||
if self.inventory_base is None:
|
||||
self.inventory_base = os.path.dirname(src)
|
||||
|
||||
data = open(src).read()
|
||||
lines = data.split("\n")
|
||||
group = all_group
|
||||
group = self.all_group
|
||||
input_mode = 'host'
|
||||
|
||||
for line in lines:
|
||||
@ -183,14 +189,14 @@ class IniLoader(BaseLoader):
|
||||
if line.find(":vars") != -1:
|
||||
input_mode = 'vars'
|
||||
line = line.replace(":vars","")
|
||||
group = self.get_group(line, all_group)
|
||||
group = self.get_group(line)
|
||||
elif line.find(":children") != -1:
|
||||
input_mode = 'children'
|
||||
line = line.replace(":children","")
|
||||
group = self.get_group(line, all_group)
|
||||
group = self.get_group(line)
|
||||
else:
|
||||
input_mode = 'host'
|
||||
group = self.get_group(line, all_group)
|
||||
group = self.get_group(line)
|
||||
else:
|
||||
# add a host or variable to the existing group/host
|
||||
line = line.lstrip().rstrip()
|
||||
@ -243,10 +249,9 @@ class IniLoader(BaseLoader):
|
||||
|
||||
class ExecutableJsonLoader(BaseLoader):
|
||||
|
||||
def __init__(self, inventory_base=None):
|
||||
|
||||
def __init__(self, inventory_base=None, all_group=None):
|
||||
super(ExecutableJsonLoader, self).__init__(inventory_base, all_group)
|
||||
LOGGER.debug("processing executable JSON source")
|
||||
self.inventory_base = inventory_base
|
||||
self.child_group_names = {}
|
||||
|
||||
def command_to_json(self, cmd):
|
||||
@ -263,9 +268,9 @@ class ExecutableJsonLoader(BaseLoader):
|
||||
assert type(data) == dict
|
||||
return data
|
||||
|
||||
def load(self, src, all_group):
|
||||
def load(self, src):
|
||||
|
||||
LOGGER.debug("loading %s onto %s" % (src, all_group))
|
||||
LOGGER.debug("loading %s onto %s" % (src, self.all_group))
|
||||
|
||||
if self.inventory_base is None:
|
||||
self.inventory_base = os.path.dirname(src)
|
||||
@ -273,10 +278,11 @@ class ExecutableJsonLoader(BaseLoader):
|
||||
data = self.command_to_json([src, "--list"])
|
||||
|
||||
group = None
|
||||
_meta = data.pop('_meta', {})
|
||||
|
||||
for (k,v) in data.iteritems():
|
||||
|
||||
group = self.get_group(k, all_group)
|
||||
group = self.get_group(k)
|
||||
|
||||
if type(v) == dict:
|
||||
|
||||
@ -311,52 +317,50 @@ class ExecutableJsonLoader(BaseLoader):
|
||||
host = self.get_host(x)
|
||||
group.add_host(host)
|
||||
|
||||
all_group.add_child_group(group)
|
||||
if k != 'all':
|
||||
self.all_group.add_child_group(group)
|
||||
|
||||
|
||||
# then we invoke the executable once for each host name we've built up
|
||||
# to set their variables
|
||||
global host_names
|
||||
for (k,v) in host_names.iteritems():
|
||||
data = self.command_to_json([src, "--host", k])
|
||||
for (k,v) in self.all_group.host_names.iteritems():
|
||||
if 'hostvars' not in _meta:
|
||||
data = self.command_to_json([src, "--host", k])
|
||||
else:
|
||||
data = _meta['hostvars'].get(k, {})
|
||||
v.variables.update(data)
|
||||
|
||||
|
||||
class GenericLoader(object):
|
||||
|
||||
def __init__(self, src):
|
||||
|
||||
LOGGER.debug("preparing loaders")
|
||||
def load_generic(src):
|
||||
LOGGER.debug("preparing loaders")
|
||||
|
||||
|
||||
LOGGER.debug("analyzing type of source")
|
||||
if not os.path.exists(src):
|
||||
LOGGER.debug("source missing")
|
||||
raise CommandError("source does not exist")
|
||||
if os.path.isdir(src):
|
||||
self.memGroup = memGroup = MemGroup('all', src)
|
||||
for f in glob.glob("%s/*" % src):
|
||||
if f.endswith(".ini"):
|
||||
# config files for inventory scripts should be ignored
|
||||
continue
|
||||
if not os.path.isdir(f):
|
||||
if os.access(f, os.X_OK):
|
||||
ExecutableJsonLoader().load(f, memGroup)
|
||||
else:
|
||||
IniLoader().load(f, memGroup)
|
||||
elif os.access(src, os.X_OK):
|
||||
self.memGroup = memGroup = MemGroup('all', os.path.dirname(src))
|
||||
ExecutableJsonLoader().load(src, memGroup)
|
||||
else:
|
||||
self.memGroup = memGroup = MemGroup('all', os.path.dirname(src))
|
||||
IniLoader().load(src, memGroup)
|
||||
LOGGER.debug("analyzing type of source")
|
||||
if not os.path.exists(src):
|
||||
LOGGER.debug("source missing")
|
||||
raise CommandError("source does not exist")
|
||||
if os.path.isdir(src):
|
||||
all_group = MemGroup('all', src)
|
||||
for f in glob.glob("%s/*" % src):
|
||||
if f.endswith(".ini"):
|
||||
# config files for inventory scripts should be ignored
|
||||
continue
|
||||
if not os.path.isdir(f):
|
||||
if os.access(f, os.X_OK):
|
||||
ExecutableJsonLoader(None, all_group).load(f)
|
||||
else:
|
||||
IniLoader(None, all_group).load(f)
|
||||
elif os.access(src, os.X_OK):
|
||||
all_group = MemGroup('all', os.path.dirname(src))
|
||||
ExecutableJsonLoader(None, all_group).load(src)
|
||||
else:
|
||||
all_group = MemGroup('all', os.path.dirname(src))
|
||||
IniLoader(None, all_group).load(src)
|
||||
|
||||
LOGGER.debug("loading process complete")
|
||||
LOGGER.debug("loading process complete")
|
||||
return all_group
|
||||
|
||||
|
||||
def result(self):
|
||||
return self.memGroup
|
||||
|
||||
class Command(NoArgsCommand):
|
||||
'''
|
||||
Management command to import directory, INI, or dynamic inventory
|
||||
@ -416,11 +420,10 @@ class Command(NoArgsCommand):
|
||||
|
||||
LOGGER.debug("preparing loader")
|
||||
|
||||
loader = GenericLoader(source)
|
||||
memGroup = loader.result()
|
||||
all_group = load_generic(source)
|
||||
|
||||
LOGGER.debug("debugging loaded result")
|
||||
memGroup.debug_tree()
|
||||
all_group.debug_tree()
|
||||
|
||||
# now that memGroup is correct and supports JSON executables, INI, and trees
|
||||
# now merge and/or overwrite with the database itself!
|
||||
@ -439,15 +442,15 @@ class Command(NoArgsCommand):
|
||||
# if overwrite is set, for each host in the database but NOT in the local
|
||||
# list, delete it. Delete individually so signal handlers will run.
|
||||
if overwrite:
|
||||
LOGGER.info("deleting any hosts not in the remote source: %s" % host_names.keys())
|
||||
for host in Host.objects.exclude(name__in = host_names.keys()).filter(inventory=inventory):
|
||||
LOGGER.info("deleting any hosts not in the remote source: %s" % all_group.host_names.keys())
|
||||
for host in Host.objects.exclude(name__in = all_group.host_names.keys()).filter(inventory=inventory):
|
||||
host.delete()
|
||||
|
||||
# if overwrite is set, for each group in the database but NOT in the local
|
||||
# list, delete it. Delete individually so signal handlers will run.
|
||||
if overwrite:
|
||||
LOGGER.info("deleting any groups not in the remote source")
|
||||
for group in Group.objects.exclude(name__in = group_names.keys()).filter(inventory=inventory):
|
||||
for group in Group.objects.exclude(name__in = all_group.group_names.keys()).filter(inventory=inventory):
|
||||
group.delete()
|
||||
|
||||
# if overwrite is set, throw away all invalid child relationships for groups
|
||||
@ -457,7 +460,7 @@ class Command(NoArgsCommand):
|
||||
|
||||
for db_group in db_groups:
|
||||
db_kids = db_group.children.all()
|
||||
mem_kids = group_names[db_group.name].child_groups
|
||||
mem_kids = all_group.group_names[db_group.name].child_groups
|
||||
mem_kid_names = [ k.name for k in mem_kids ]
|
||||
removed = False
|
||||
for db_kid in db_kids:
|
||||
@ -468,6 +471,17 @@ class Command(NoArgsCommand):
|
||||
if removed:
|
||||
db_group.save()
|
||||
|
||||
# Update/overwrite inventory variables from "all" group.
|
||||
db_variables = inventory.variables_dict
|
||||
mem_variables = all_group.variables
|
||||
if overwrite_vars or overwrite:
|
||||
LOGGER.info('replacing inventory variables from "all" group')
|
||||
db_variables = mem_variables
|
||||
else:
|
||||
LOGGER.info('updating inventory variables from "all" group')
|
||||
db_variables.update(mem_variables)
|
||||
inventory.variables = json.dumps(db_variables)
|
||||
inventory.save()
|
||||
|
||||
# this will be slightly inaccurate, but attribute to first superuser.
|
||||
user = User.objects.filter(is_superuser=True)[0]
|
||||
@ -478,7 +492,7 @@ class Command(NoArgsCommand):
|
||||
db_host_names = [ h.name for h in db_hosts ]
|
||||
|
||||
# for each group not in the database but in the local list, create it
|
||||
for (k,v) in group_names.iteritems():
|
||||
for (k,v) in all_group.group_names.iteritems():
|
||||
if k not in db_group_names:
|
||||
variables = json.dumps(v.variables)
|
||||
LOGGER.info("inserting new group %s" % k)
|
||||
@ -487,7 +501,7 @@ class Command(NoArgsCommand):
|
||||
host.save()
|
||||
|
||||
# for each host not in the database but in the local list, create it
|
||||
for (k,v) in host_names.iteritems():
|
||||
for (k,v) in all_group.host_names.iteritems():
|
||||
if k not in db_host_names:
|
||||
variables = json.dumps(v.variables)
|
||||
LOGGER.info("inserting new host %s" % k)
|
||||
@ -502,7 +516,7 @@ class Command(NoArgsCommand):
|
||||
|
||||
for db_group in db_groups:
|
||||
db_hosts = db_group.hosts.all()
|
||||
mem_hosts = group_names[db_group.name].hosts
|
||||
mem_hosts = all_group.group_names[db_group.name].hosts
|
||||
mem_host_names = [ h.name for h in mem_hosts ]
|
||||
removed = False
|
||||
for db_host in db_hosts:
|
||||
@ -516,7 +530,7 @@ class Command(NoArgsCommand):
|
||||
|
||||
# for each host in a mem group, add it to the parents to which it belongs
|
||||
# FIXME: confirm Django is ok with calling add twice and not making two rows
|
||||
for (k,v) in group_names.iteritems():
|
||||
for (k,v) in all_group.group_names.iteritems():
|
||||
LOGGER.info("adding parent arrangements for %s" % k)
|
||||
db_group = Group.objects.get(name=k, inventory__pk=inventory.pk)
|
||||
mem_hosts = v.hosts
|
||||
@ -524,7 +538,7 @@ class Command(NoArgsCommand):
|
||||
db_host = Host.objects.get(name=h.name, inventory__pk=inventory.pk)
|
||||
db_group.hosts.add(db_host)
|
||||
LOGGER.debug("*** ADDING %s to %s ***" % (db_host, db_group))
|
||||
db_group.save()
|
||||
#db_group.save()
|
||||
|
||||
def variable_mangler(model, mem_hash, overwrite, overwrite_vars):
|
||||
db_collection = model.objects.filter(inventory=inventory)
|
||||
@ -541,17 +555,17 @@ class Command(NoArgsCommand):
|
||||
obj.variables = db_variables
|
||||
obj.save()
|
||||
|
||||
variable_mangler(Group, group_names, overwrite, overwrite_vars)
|
||||
variable_mangler(Host, host_names, overwrite, overwrite_vars)
|
||||
variable_mangler(Group, all_group.group_names, overwrite, overwrite_vars)
|
||||
variable_mangler(Host, all_group.host_names, overwrite, overwrite_vars)
|
||||
|
||||
# for each group, draw in child group arrangements
|
||||
# FIXME: confirm django add behavior as above
|
||||
for (k,v) in group_names.iteritems():
|
||||
for (k,v) in all_group.group_names.iteritems():
|
||||
db_group = Group.objects.get(inventory=inventory, name=k)
|
||||
for mem_child_group in v.child_groups:
|
||||
db_child = Group.objects.get(inventory=inventory, name=mem_child_group.name)
|
||||
db_group.children.add(db_child)
|
||||
db_group.save()
|
||||
#db_group.save()
|
||||
|
||||
reader = LicenseReader()
|
||||
license_info = reader.from_file()
|
||||
|
||||
@ -19,13 +19,13 @@ from django.utils.timezone import now
|
||||
# AWX
|
||||
from awx.main.licenses import LicenseWriter
|
||||
from awx.main.models import *
|
||||
from awx.main.tests.base import BaseTest
|
||||
from awx.main.tests.base import BaseTest, BaseLiveServerTest
|
||||
|
||||
__all__ = ['CleanupDeletedTest', 'InventoryImportTest']
|
||||
|
||||
TEST_INVENTORY_INI = '''\
|
||||
[webservers]
|
||||
web1.example.com
|
||||
web1.example.com ansible_ssh_host=w1.example.net
|
||||
web2.example.com
|
||||
web3.example.com
|
||||
|
||||
@ -50,19 +50,19 @@ varb=B
|
||||
vara=A
|
||||
'''
|
||||
|
||||
class BaseCommandTest(BaseTest):
|
||||
class BaseCommandMixin(object):
|
||||
'''
|
||||
Base class for tests that run management commands.
|
||||
'''
|
||||
|
||||
def setUp(self):
|
||||
super(BaseCommandTest, self).setUp()
|
||||
super(BaseCommandMixin, self).setUp()
|
||||
self._sys_path = [x for x in sys.path]
|
||||
self._environ = dict(os.environ.items())
|
||||
self._temp_files = []
|
||||
|
||||
def tearDown(self):
|
||||
super(BaseCommandTest, self).tearDown()
|
||||
super(BaseCommandMixin, self).tearDown()
|
||||
sys.path = self._sys_path
|
||||
for k,v in self._environ.items():
|
||||
if os.environ.get(k, None) != v:
|
||||
@ -152,7 +152,7 @@ class BaseCommandTest(BaseTest):
|
||||
result = CommandError(captured_stderr)
|
||||
return result, captured_stdout, captured_stderr
|
||||
|
||||
class CleanupDeletedTest(BaseCommandTest):
|
||||
class CleanupDeletedTest(BaseCommandMixin, BaseTest):
|
||||
'''
|
||||
Test cases for cleanup_deleted management command.
|
||||
'''
|
||||
@ -244,7 +244,7 @@ class CleanupDeletedTest(BaseCommandTest):
|
||||
self.assertNotEqual(counts_before, counts_after)
|
||||
self.assertFalse(counts_after[1])
|
||||
|
||||
class InventoryImportTest(BaseCommandTest):
|
||||
class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
|
||||
'''
|
||||
Test cases for inventory_import management command.
|
||||
'''
|
||||
@ -343,8 +343,82 @@ class InventoryImportTest(BaseCommandTest):
|
||||
inventory_id=new_inv.pk,
|
||||
source=self.ini_path)
|
||||
self.assertEqual(result, None)
|
||||
# FIXME
|
||||
# Check that inventory is populated as expected.
|
||||
new_inv = Inventory.objects.get(pk=new_inv.pk)
|
||||
expected_group_names = set(['servers', 'dbservers', 'webservers'])
|
||||
group_names = set(new_inv.groups.values_list('name', flat=True))
|
||||
self.assertEqual(expected_group_names, group_names)
|
||||
expected_host_names = set(['web1.example.com', 'web2.example.com',
|
||||
'web3.example.com', 'db1.example.com',
|
||||
'db2.example.com'])
|
||||
host_names = set(new_inv.hosts.values_list('name', flat=True))
|
||||
self.assertEqual(expected_host_names, host_names)
|
||||
self.assertEqual(new_inv.variables_dict, {'vara': 'A'})
|
||||
for host in new_inv.hosts.all():
|
||||
if host.name == 'web1.example.com':
|
||||
self.assertEqual(host.variables_dict,
|
||||
{'ansible_ssh_host': 'w1.example.net'})
|
||||
else:
|
||||
self.assertEqual(host.variables_dict, {})
|
||||
for group in new_inv.groups.all():
|
||||
if group.name == 'servers':
|
||||
self.assertEqual(group.variables_dict, {'varb': 'B'})
|
||||
children = set(group.children.values_list('name', flat=True))
|
||||
self.assertEqual(children, set(['dbservers', 'webservers']))
|
||||
self.assertEqual(group.hosts.count(), 0)
|
||||
elif group.name == 'dbservers':
|
||||
self.assertEqual(group.variables_dict, {'dbvar': 'ugh'})
|
||||
self.assertEqual(group.children.count(), 0)
|
||||
hosts = set(group.hosts.values_list('name', flat=True))
|
||||
host_names = set(['db1.example.com','db2.example.com'])
|
||||
self.assertEqual(hosts, host_names)
|
||||
elif group.name == 'webservers':
|
||||
self.assertEqual(group.variables_dict, {'webvar': 'blah'})
|
||||
self.assertEqual(group.children.count(), 0)
|
||||
hosts = set(group.hosts.values_list('name', flat=True))
|
||||
host_names = set(['web1.example.com','web2.example.com',
|
||||
'web3.example.com'])
|
||||
self.assertEqual(hosts, host_names)
|
||||
|
||||
def test_executable_file(self):
|
||||
pass
|
||||
# FIXME
|
||||
# New empty inventory.
|
||||
old_inv = self.inventories[1]
|
||||
new_inv = self.organizations[0].inventories.create(name='newb')
|
||||
self.assertEqual(new_inv.hosts.count(), 0)
|
||||
self.assertEqual(new_inv.groups.count(), 0)
|
||||
# Use our own inventory script as executable file.
|
||||
os.environ.setdefault('REST_API_URL', self.live_server_url)
|
||||
os.environ.setdefault('REST_API_TOKEN',
|
||||
self.super_django_user.auth_token.key)
|
||||
os.environ['INVENTORY_ID'] = str(old_inv.pk)
|
||||
source = os.path.join(os.path.dirname(__file__), '..', '..', 'scripts',
|
||||
'inventory.py')
|
||||
result, stdout, stderr = self.run_command('inventory_import',
|
||||
inventory_id=new_inv.pk,
|
||||
source=source)
|
||||
self.assertEqual(result, None)
|
||||
# Check that inventory is populated as expected.
|
||||
new_inv = Inventory.objects.get(pk=new_inv.pk)
|
||||
self.assertEqual(old_inv.variables_dict, new_inv.variables_dict)
|
||||
old_groups = set(old_inv.groups.values_list('name', flat=True))
|
||||
new_groups = set(new_inv.groups.values_list('name', flat=True))
|
||||
self.assertEqual(old_groups, new_groups)
|
||||
old_hosts = set(old_inv.hosts.values_list('name', flat=True))
|
||||
new_hosts = set(new_inv.hosts.values_list('name', flat=True))
|
||||
self.assertEqual(old_hosts, new_hosts)
|
||||
for new_host in new_inv.hosts.all():
|
||||
old_host = old_inv.hosts.get(name=new_host.name)
|
||||
self.assertEqual(old_host.variables_dict, new_host.variables_dict)
|
||||
for new_group in new_inv.groups.all():
|
||||
old_group = old_inv.groups.get(name=new_group.name)
|
||||
self.assertEqual(old_group.variables_dict, new_group.variables_dict)
|
||||
old_children = set(old_group.children.values_list('name', flat=True))
|
||||
new_children = set(new_group.children.values_list('name', flat=True))
|
||||
self.assertEqual(old_children, new_children)
|
||||
old_hosts = set(old_group.hosts.values_list('name', flat=True))
|
||||
new_hosts = set(new_group.hosts.values_list('name', flat=True))
|
||||
self.assertEqual(old_hosts, new_hosts)
|
||||
|
||||
def test_executable_file_with_meta_hostvars(self):
|
||||
os.environ['INVENTORY_HOSTVARS'] = '1'
|
||||
self.test_executable_file()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user