Add import command support for AC-332. Refactored import script to eliminate use of global variables, which breaks tests.

This commit is contained in:
Chris Church 2013-08-29 23:46:38 -04:00
parent d774327c68
commit 290768c20d
2 changed files with 175 additions and 87 deletions

View File

@ -23,10 +23,6 @@ from awx.main.licenses import LicenseReader
LOGGER = None
# maps host and group names to hosts to prevent redudant additions
group_names = {}
host_names = {}
class ImportException(BaseException):
def __init__(self, msg):
@ -47,6 +43,10 @@ class MemGroup(object):
self.hosts = []
self.variables = {}
self.parents = []
# Used on the "all" group in place of previous global variables.
# maps host and group names to hosts to prevent redudant additions
self.host_names = {}
self.group_names = {}
group_vars = os.path.join(inventory_base, 'group_vars', name)
if os.path.exists(group_vars):
@ -133,43 +133,49 @@ class MemHost(object):
class BaseLoader(object):
def __init__(self, inventory_base=None, all_group=None):
self.inventory_base = inventory_base
self.all_group = all_group
def get_host(self, name):
if ":" in name:
tokens = name.split(":")
name = tokens[0]
global host_names
host = None
if not name in host_names:
if not name in self.all_group.host_names:
host = MemHost(name, self.inventory_base)
host_names[name] = host
return host_names[name]
self.all_group.host_names[name] = host
return self.all_group.host_names[name]
def get_group(self, name, all_group, child=False):
global group_names
def get_group(self, name, all_group=None, child=False):
all_group = all_group or self.all_group
if name == 'all':
return all_group
if not name in group_names:
if not name in self.all_group.group_names:
group = MemGroup(name, self.inventory_base)
if not child:
all_group.add_child_group(group)
group_names[name] = group
return group_names[name]
self.all_group.group_names[name] = group
return self.all_group.group_names[name]
def load(self, src):
raise NotImplementedError
class IniLoader(BaseLoader):
def __init__(self, inventory_base=None):
def __init__(self, inventory_base=None, all_group=None):
super(IniLoader, self).__init__(inventory_base, all_group)
LOGGER.debug("processing ini")
self.inventory_base = inventory_base
def load(self, src, all_group):
LOGGER.debug("loading: %s on %s" % (src, all_group))
def load(self, src):
LOGGER.debug("loading: %s on %s" % (src, self.all_group))
if self.inventory_base is None:
self.inventory_base = os.path.dirname(src)
data = open(src).read()
lines = data.split("\n")
group = all_group
group = self.all_group
input_mode = 'host'
for line in lines:
@ -183,14 +189,14 @@ class IniLoader(BaseLoader):
if line.find(":vars") != -1:
input_mode = 'vars'
line = line.replace(":vars","")
group = self.get_group(line, all_group)
group = self.get_group(line)
elif line.find(":children") != -1:
input_mode = 'children'
line = line.replace(":children","")
group = self.get_group(line, all_group)
group = self.get_group(line)
else:
input_mode = 'host'
group = self.get_group(line, all_group)
group = self.get_group(line)
else:
# add a host or variable to the existing group/host
line = line.lstrip().rstrip()
@ -243,10 +249,9 @@ class IniLoader(BaseLoader):
class ExecutableJsonLoader(BaseLoader):
def __init__(self, inventory_base=None):
def __init__(self, inventory_base=None, all_group=None):
super(ExecutableJsonLoader, self).__init__(inventory_base, all_group)
LOGGER.debug("processing executable JSON source")
self.inventory_base = inventory_base
self.child_group_names = {}
def command_to_json(self, cmd):
@ -263,9 +268,9 @@ class ExecutableJsonLoader(BaseLoader):
assert type(data) == dict
return data
def load(self, src, all_group):
def load(self, src):
LOGGER.debug("loading %s onto %s" % (src, all_group))
LOGGER.debug("loading %s onto %s" % (src, self.all_group))
if self.inventory_base is None:
self.inventory_base = os.path.dirname(src)
@ -273,10 +278,11 @@ class ExecutableJsonLoader(BaseLoader):
data = self.command_to_json([src, "--list"])
group = None
_meta = data.pop('_meta', {})
for (k,v) in data.iteritems():
group = self.get_group(k, all_group)
group = self.get_group(k)
if type(v) == dict:
@ -311,52 +317,50 @@ class ExecutableJsonLoader(BaseLoader):
host = self.get_host(x)
group.add_host(host)
all_group.add_child_group(group)
if k != 'all':
self.all_group.add_child_group(group)
# then we invoke the executable once for each host name we've built up
# to set their variables
global host_names
for (k,v) in host_names.iteritems():
data = self.command_to_json([src, "--host", k])
for (k,v) in self.all_group.host_names.iteritems():
if 'hostvars' not in _meta:
data = self.command_to_json([src, "--host", k])
else:
data = _meta['hostvars'].get(k, {})
v.variables.update(data)
class GenericLoader(object):
def __init__(self, src):
LOGGER.debug("preparing loaders")
def load_generic(src):
LOGGER.debug("preparing loaders")
LOGGER.debug("analyzing type of source")
if not os.path.exists(src):
LOGGER.debug("source missing")
raise CommandError("source does not exist")
if os.path.isdir(src):
self.memGroup = memGroup = MemGroup('all', src)
for f in glob.glob("%s/*" % src):
if f.endswith(".ini"):
# config files for inventory scripts should be ignored
continue
if not os.path.isdir(f):
if os.access(f, os.X_OK):
ExecutableJsonLoader().load(f, memGroup)
else:
IniLoader().load(f, memGroup)
elif os.access(src, os.X_OK):
self.memGroup = memGroup = MemGroup('all', os.path.dirname(src))
ExecutableJsonLoader().load(src, memGroup)
else:
self.memGroup = memGroup = MemGroup('all', os.path.dirname(src))
IniLoader().load(src, memGroup)
LOGGER.debug("analyzing type of source")
if not os.path.exists(src):
LOGGER.debug("source missing")
raise CommandError("source does not exist")
if os.path.isdir(src):
all_group = MemGroup('all', src)
for f in glob.glob("%s/*" % src):
if f.endswith(".ini"):
# config files for inventory scripts should be ignored
continue
if not os.path.isdir(f):
if os.access(f, os.X_OK):
ExecutableJsonLoader(None, all_group).load(f)
else:
IniLoader(None, all_group).load(f)
elif os.access(src, os.X_OK):
all_group = MemGroup('all', os.path.dirname(src))
ExecutableJsonLoader(None, all_group).load(src)
else:
all_group = MemGroup('all', os.path.dirname(src))
IniLoader(None, all_group).load(src)
LOGGER.debug("loading process complete")
LOGGER.debug("loading process complete")
return all_group
def result(self):
return self.memGroup
class Command(NoArgsCommand):
'''
Management command to import directory, INI, or dynamic inventory
@ -416,11 +420,10 @@ class Command(NoArgsCommand):
LOGGER.debug("preparing loader")
loader = GenericLoader(source)
memGroup = loader.result()
all_group = load_generic(source)
LOGGER.debug("debugging loaded result")
memGroup.debug_tree()
all_group.debug_tree()
# now that memGroup is correct and supports JSON executables, INI, and trees
# now merge and/or overwrite with the database itself!
@ -439,15 +442,15 @@ class Command(NoArgsCommand):
# if overwrite is set, for each host in the database but NOT in the local
# list, delete it. Delete individually so signal handlers will run.
if overwrite:
LOGGER.info("deleting any hosts not in the remote source: %s" % host_names.keys())
for host in Host.objects.exclude(name__in = host_names.keys()).filter(inventory=inventory):
LOGGER.info("deleting any hosts not in the remote source: %s" % all_group.host_names.keys())
for host in Host.objects.exclude(name__in = all_group.host_names.keys()).filter(inventory=inventory):
host.delete()
# if overwrite is set, for each group in the database but NOT in the local
# list, delete it. Delete individually so signal handlers will run.
if overwrite:
LOGGER.info("deleting any groups not in the remote source")
for group in Group.objects.exclude(name__in = group_names.keys()).filter(inventory=inventory):
for group in Group.objects.exclude(name__in = all_group.group_names.keys()).filter(inventory=inventory):
group.delete()
# if overwrite is set, throw away all invalid child relationships for groups
@ -457,7 +460,7 @@ class Command(NoArgsCommand):
for db_group in db_groups:
db_kids = db_group.children.all()
mem_kids = group_names[db_group.name].child_groups
mem_kids = all_group.group_names[db_group.name].child_groups
mem_kid_names = [ k.name for k in mem_kids ]
removed = False
for db_kid in db_kids:
@ -468,6 +471,17 @@ class Command(NoArgsCommand):
if removed:
db_group.save()
# Update/overwrite inventory variables from "all" group.
db_variables = inventory.variables_dict
mem_variables = all_group.variables
if overwrite_vars or overwrite:
LOGGER.info('replacing inventory variables from "all" group')
db_variables = mem_variables
else:
LOGGER.info('updating inventory variables from "all" group')
db_variables.update(mem_variables)
inventory.variables = json.dumps(db_variables)
inventory.save()
# this will be slightly inaccurate, but attribute to first superuser.
user = User.objects.filter(is_superuser=True)[0]
@ -478,7 +492,7 @@ class Command(NoArgsCommand):
db_host_names = [ h.name for h in db_hosts ]
# for each group not in the database but in the local list, create it
for (k,v) in group_names.iteritems():
for (k,v) in all_group.group_names.iteritems():
if k not in db_group_names:
variables = json.dumps(v.variables)
LOGGER.info("inserting new group %s" % k)
@ -487,7 +501,7 @@ class Command(NoArgsCommand):
host.save()
# for each host not in the database but in the local list, create it
for (k,v) in host_names.iteritems():
for (k,v) in all_group.host_names.iteritems():
if k not in db_host_names:
variables = json.dumps(v.variables)
LOGGER.info("inserting new host %s" % k)
@ -502,7 +516,7 @@ class Command(NoArgsCommand):
for db_group in db_groups:
db_hosts = db_group.hosts.all()
mem_hosts = group_names[db_group.name].hosts
mem_hosts = all_group.group_names[db_group.name].hosts
mem_host_names = [ h.name for h in mem_hosts ]
removed = False
for db_host in db_hosts:
@ -516,7 +530,7 @@ class Command(NoArgsCommand):
# for each host in a mem group, add it to the parents to which it belongs
# FIXME: confirm Django is ok with calling add twice and not making two rows
for (k,v) in group_names.iteritems():
for (k,v) in all_group.group_names.iteritems():
LOGGER.info("adding parent arrangements for %s" % k)
db_group = Group.objects.get(name=k, inventory__pk=inventory.pk)
mem_hosts = v.hosts
@ -524,7 +538,7 @@ class Command(NoArgsCommand):
db_host = Host.objects.get(name=h.name, inventory__pk=inventory.pk)
db_group.hosts.add(db_host)
LOGGER.debug("*** ADDING %s to %s ***" % (db_host, db_group))
db_group.save()
#db_group.save()
def variable_mangler(model, mem_hash, overwrite, overwrite_vars):
db_collection = model.objects.filter(inventory=inventory)
@ -541,17 +555,17 @@ class Command(NoArgsCommand):
obj.variables = db_variables
obj.save()
variable_mangler(Group, group_names, overwrite, overwrite_vars)
variable_mangler(Host, host_names, overwrite, overwrite_vars)
variable_mangler(Group, all_group.group_names, overwrite, overwrite_vars)
variable_mangler(Host, all_group.host_names, overwrite, overwrite_vars)
# for each group, draw in child group arrangements
# FIXME: confirm django add behavior as above
for (k,v) in group_names.iteritems():
for (k,v) in all_group.group_names.iteritems():
db_group = Group.objects.get(inventory=inventory, name=k)
for mem_child_group in v.child_groups:
db_child = Group.objects.get(inventory=inventory, name=mem_child_group.name)
db_group.children.add(db_child)
db_group.save()
#db_group.save()
reader = LicenseReader()
license_info = reader.from_file()

View File

@ -19,13 +19,13 @@ from django.utils.timezone import now
# AWX
from awx.main.licenses import LicenseWriter
from awx.main.models import *
from awx.main.tests.base import BaseTest
from awx.main.tests.base import BaseTest, BaseLiveServerTest
__all__ = ['CleanupDeletedTest', 'InventoryImportTest']
TEST_INVENTORY_INI = '''\
[webservers]
web1.example.com
web1.example.com ansible_ssh_host=w1.example.net
web2.example.com
web3.example.com
@ -50,19 +50,19 @@ varb=B
vara=A
'''
class BaseCommandTest(BaseTest):
class BaseCommandMixin(object):
'''
Base class for tests that run management commands.
'''
def setUp(self):
super(BaseCommandTest, self).setUp()
super(BaseCommandMixin, self).setUp()
self._sys_path = [x for x in sys.path]
self._environ = dict(os.environ.items())
self._temp_files = []
def tearDown(self):
super(BaseCommandTest, self).tearDown()
super(BaseCommandMixin, self).tearDown()
sys.path = self._sys_path
for k,v in self._environ.items():
if os.environ.get(k, None) != v:
@ -152,7 +152,7 @@ class BaseCommandTest(BaseTest):
result = CommandError(captured_stderr)
return result, captured_stdout, captured_stderr
class CleanupDeletedTest(BaseCommandTest):
class CleanupDeletedTest(BaseCommandMixin, BaseTest):
'''
Test cases for cleanup_deleted management command.
'''
@ -244,7 +244,7 @@ class CleanupDeletedTest(BaseCommandTest):
self.assertNotEqual(counts_before, counts_after)
self.assertFalse(counts_after[1])
class InventoryImportTest(BaseCommandTest):
class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
'''
Test cases for inventory_import management command.
'''
@ -343,8 +343,82 @@ class InventoryImportTest(BaseCommandTest):
inventory_id=new_inv.pk,
source=self.ini_path)
self.assertEqual(result, None)
# FIXME
# Check that inventory is populated as expected.
new_inv = Inventory.objects.get(pk=new_inv.pk)
expected_group_names = set(['servers', 'dbservers', 'webservers'])
group_names = set(new_inv.groups.values_list('name', flat=True))
self.assertEqual(expected_group_names, group_names)
expected_host_names = set(['web1.example.com', 'web2.example.com',
'web3.example.com', 'db1.example.com',
'db2.example.com'])
host_names = set(new_inv.hosts.values_list('name', flat=True))
self.assertEqual(expected_host_names, host_names)
self.assertEqual(new_inv.variables_dict, {'vara': 'A'})
for host in new_inv.hosts.all():
if host.name == 'web1.example.com':
self.assertEqual(host.variables_dict,
{'ansible_ssh_host': 'w1.example.net'})
else:
self.assertEqual(host.variables_dict, {})
for group in new_inv.groups.all():
if group.name == 'servers':
self.assertEqual(group.variables_dict, {'varb': 'B'})
children = set(group.children.values_list('name', flat=True))
self.assertEqual(children, set(['dbservers', 'webservers']))
self.assertEqual(group.hosts.count(), 0)
elif group.name == 'dbservers':
self.assertEqual(group.variables_dict, {'dbvar': 'ugh'})
self.assertEqual(group.children.count(), 0)
hosts = set(group.hosts.values_list('name', flat=True))
host_names = set(['db1.example.com','db2.example.com'])
self.assertEqual(hosts, host_names)
elif group.name == 'webservers':
self.assertEqual(group.variables_dict, {'webvar': 'blah'})
self.assertEqual(group.children.count(), 0)
hosts = set(group.hosts.values_list('name', flat=True))
host_names = set(['web1.example.com','web2.example.com',
'web3.example.com'])
self.assertEqual(hosts, host_names)
def test_executable_file(self):
pass
# FIXME
# New empty inventory.
old_inv = self.inventories[1]
new_inv = self.organizations[0].inventories.create(name='newb')
self.assertEqual(new_inv.hosts.count(), 0)
self.assertEqual(new_inv.groups.count(), 0)
# Use our own inventory script as executable file.
os.environ.setdefault('REST_API_URL', self.live_server_url)
os.environ.setdefault('REST_API_TOKEN',
self.super_django_user.auth_token.key)
os.environ['INVENTORY_ID'] = str(old_inv.pk)
source = os.path.join(os.path.dirname(__file__), '..', '..', 'scripts',
'inventory.py')
result, stdout, stderr = self.run_command('inventory_import',
inventory_id=new_inv.pk,
source=source)
self.assertEqual(result, None)
# Check that inventory is populated as expected.
new_inv = Inventory.objects.get(pk=new_inv.pk)
self.assertEqual(old_inv.variables_dict, new_inv.variables_dict)
old_groups = set(old_inv.groups.values_list('name', flat=True))
new_groups = set(new_inv.groups.values_list('name', flat=True))
self.assertEqual(old_groups, new_groups)
old_hosts = set(old_inv.hosts.values_list('name', flat=True))
new_hosts = set(new_inv.hosts.values_list('name', flat=True))
self.assertEqual(old_hosts, new_hosts)
for new_host in new_inv.hosts.all():
old_host = old_inv.hosts.get(name=new_host.name)
self.assertEqual(old_host.variables_dict, new_host.variables_dict)
for new_group in new_inv.groups.all():
old_group = old_inv.groups.get(name=new_group.name)
self.assertEqual(old_group.variables_dict, new_group.variables_dict)
old_children = set(old_group.children.values_list('name', flat=True))
new_children = set(new_group.children.values_list('name', flat=True))
self.assertEqual(old_children, new_children)
old_hosts = set(old_group.hosts.values_list('name', flat=True))
new_hosts = set(new_group.hosts.values_list('name', flat=True))
self.assertEqual(old_hosts, new_hosts)
def test_executable_file_with_meta_hostvars(self):
os.environ['INVENTORY_HOSTVARS'] = '1'
self.test_executable_file()