Add import command support for AC-332. Refactored import script to eliminate use of global variables, which breaks tests.

This commit is contained in:
Chris Church
2013-08-29 23:46:38 -04:00
parent d774327c68
commit 290768c20d
2 changed files with 175 additions and 87 deletions

View File

@@ -23,10 +23,6 @@ from awx.main.licenses import LicenseReader
LOGGER = None LOGGER = None
# maps host and group names to hosts to prevent redudant additions
group_names = {}
host_names = {}
class ImportException(BaseException): class ImportException(BaseException):
def __init__(self, msg): def __init__(self, msg):
@@ -47,6 +43,10 @@ class MemGroup(object):
self.hosts = [] self.hosts = []
self.variables = {} self.variables = {}
self.parents = [] self.parents = []
# Used on the "all" group in place of previous global variables.
# maps host and group names to hosts to prevent redudant additions
self.host_names = {}
self.group_names = {}
group_vars = os.path.join(inventory_base, 'group_vars', name) group_vars = os.path.join(inventory_base, 'group_vars', name)
if os.path.exists(group_vars): if os.path.exists(group_vars):
@@ -133,43 +133,49 @@ class MemHost(object):
class BaseLoader(object): class BaseLoader(object):
def __init__(self, inventory_base=None, all_group=None):
self.inventory_base = inventory_base
self.all_group = all_group
def get_host(self, name): def get_host(self, name):
if ":" in name: if ":" in name:
tokens = name.split(":") tokens = name.split(":")
name = tokens[0] name = tokens[0]
global host_names
host = None host = None
if not name in host_names: if not name in self.all_group.host_names:
host = MemHost(name, self.inventory_base) host = MemHost(name, self.inventory_base)
host_names[name] = host self.all_group.host_names[name] = host
return host_names[name] return self.all_group.host_names[name]
def get_group(self, name, all_group, child=False): def get_group(self, name, all_group=None, child=False):
global group_names all_group = all_group or self.all_group
if name == 'all': if name == 'all':
return all_group return all_group
if not name in group_names: if not name in self.all_group.group_names:
group = MemGroup(name, self.inventory_base) group = MemGroup(name, self.inventory_base)
if not child: if not child:
all_group.add_child_group(group) all_group.add_child_group(group)
group_names[name] = group self.all_group.group_names[name] = group
return group_names[name] return self.all_group.group_names[name]
def load(self, src):
raise NotImplementedError
class IniLoader(BaseLoader): class IniLoader(BaseLoader):
def __init__(self, inventory_base=None): def __init__(self, inventory_base=None, all_group=None):
super(IniLoader, self).__init__(inventory_base, all_group)
LOGGER.debug("processing ini") LOGGER.debug("processing ini")
self.inventory_base = inventory_base
def load(self, src, all_group): def load(self, src):
LOGGER.debug("loading: %s on %s" % (src, all_group)) LOGGER.debug("loading: %s on %s" % (src, self.all_group))
if self.inventory_base is None: if self.inventory_base is None:
self.inventory_base = os.path.dirname(src) self.inventory_base = os.path.dirname(src)
data = open(src).read() data = open(src).read()
lines = data.split("\n") lines = data.split("\n")
group = all_group group = self.all_group
input_mode = 'host' input_mode = 'host'
for line in lines: for line in lines:
@@ -183,14 +189,14 @@ class IniLoader(BaseLoader):
if line.find(":vars") != -1: if line.find(":vars") != -1:
input_mode = 'vars' input_mode = 'vars'
line = line.replace(":vars","") line = line.replace(":vars","")
group = self.get_group(line, all_group) group = self.get_group(line)
elif line.find(":children") != -1: elif line.find(":children") != -1:
input_mode = 'children' input_mode = 'children'
line = line.replace(":children","") line = line.replace(":children","")
group = self.get_group(line, all_group) group = self.get_group(line)
else: else:
input_mode = 'host' input_mode = 'host'
group = self.get_group(line, all_group) group = self.get_group(line)
else: else:
# add a host or variable to the existing group/host # add a host or variable to the existing group/host
line = line.lstrip().rstrip() line = line.lstrip().rstrip()
@@ -243,10 +249,9 @@ class IniLoader(BaseLoader):
class ExecutableJsonLoader(BaseLoader): class ExecutableJsonLoader(BaseLoader):
def __init__(self, inventory_base=None): def __init__(self, inventory_base=None, all_group=None):
super(ExecutableJsonLoader, self).__init__(inventory_base, all_group)
LOGGER.debug("processing executable JSON source") LOGGER.debug("processing executable JSON source")
self.inventory_base = inventory_base
self.child_group_names = {} self.child_group_names = {}
def command_to_json(self, cmd): def command_to_json(self, cmd):
@@ -263,9 +268,9 @@ class ExecutableJsonLoader(BaseLoader):
assert type(data) == dict assert type(data) == dict
return data return data
def load(self, src, all_group): def load(self, src):
LOGGER.debug("loading %s onto %s" % (src, all_group)) LOGGER.debug("loading %s onto %s" % (src, self.all_group))
if self.inventory_base is None: if self.inventory_base is None:
self.inventory_base = os.path.dirname(src) self.inventory_base = os.path.dirname(src)
@@ -273,10 +278,11 @@ class ExecutableJsonLoader(BaseLoader):
data = self.command_to_json([src, "--list"]) data = self.command_to_json([src, "--list"])
group = None group = None
_meta = data.pop('_meta', {})
for (k,v) in data.iteritems(): for (k,v) in data.iteritems():
group = self.get_group(k, all_group) group = self.get_group(k)
if type(v) == dict: if type(v) == dict:
@@ -311,52 +317,50 @@ class ExecutableJsonLoader(BaseLoader):
host = self.get_host(x) host = self.get_host(x)
group.add_host(host) group.add_host(host)
all_group.add_child_group(group) if k != 'all':
self.all_group.add_child_group(group)
# then we invoke the executable once for each host name we've built up # then we invoke the executable once for each host name we've built up
# to set their variables # to set their variables
global host_names for (k,v) in self.all_group.host_names.iteritems():
for (k,v) in host_names.iteritems(): if 'hostvars' not in _meta:
data = self.command_to_json([src, "--host", k]) data = self.command_to_json([src, "--host", k])
else:
data = _meta['hostvars'].get(k, {})
v.variables.update(data) v.variables.update(data)
class GenericLoader(object): def load_generic(src):
LOGGER.debug("preparing loaders")
def __init__(self, src):
LOGGER.debug("preparing loaders")
LOGGER.debug("analyzing type of source") LOGGER.debug("analyzing type of source")
if not os.path.exists(src): if not os.path.exists(src):
LOGGER.debug("source missing") LOGGER.debug("source missing")
raise CommandError("source does not exist") raise CommandError("source does not exist")
if os.path.isdir(src): if os.path.isdir(src):
self.memGroup = memGroup = MemGroup('all', src) all_group = MemGroup('all', src)
for f in glob.glob("%s/*" % src): for f in glob.glob("%s/*" % src):
if f.endswith(".ini"): if f.endswith(".ini"):
# config files for inventory scripts should be ignored # config files for inventory scripts should be ignored
continue continue
if not os.path.isdir(f): if not os.path.isdir(f):
if os.access(f, os.X_OK): if os.access(f, os.X_OK):
ExecutableJsonLoader().load(f, memGroup) ExecutableJsonLoader(None, all_group).load(f)
else: else:
IniLoader().load(f, memGroup) IniLoader(None, all_group).load(f)
elif os.access(src, os.X_OK): elif os.access(src, os.X_OK):
self.memGroup = memGroup = MemGroup('all', os.path.dirname(src)) all_group = MemGroup('all', os.path.dirname(src))
ExecutableJsonLoader().load(src, memGroup) ExecutableJsonLoader(None, all_group).load(src)
else: else:
self.memGroup = memGroup = MemGroup('all', os.path.dirname(src)) all_group = MemGroup('all', os.path.dirname(src))
IniLoader().load(src, memGroup) IniLoader(None, all_group).load(src)
LOGGER.debug("loading process complete") LOGGER.debug("loading process complete")
return all_group
def result(self):
return self.memGroup
class Command(NoArgsCommand): class Command(NoArgsCommand):
''' '''
Management command to import directory, INI, or dynamic inventory Management command to import directory, INI, or dynamic inventory
@@ -416,11 +420,10 @@ class Command(NoArgsCommand):
LOGGER.debug("preparing loader") LOGGER.debug("preparing loader")
loader = GenericLoader(source) all_group = load_generic(source)
memGroup = loader.result()
LOGGER.debug("debugging loaded result") LOGGER.debug("debugging loaded result")
memGroup.debug_tree() all_group.debug_tree()
# now that memGroup is correct and supports JSON executables, INI, and trees # now that memGroup is correct and supports JSON executables, INI, and trees
# now merge and/or overwrite with the database itself! # now merge and/or overwrite with the database itself!
@@ -439,15 +442,15 @@ class Command(NoArgsCommand):
# if overwrite is set, for each host in the database but NOT in the local # if overwrite is set, for each host in the database but NOT in the local
# list, delete it. Delete individually so signal handlers will run. # list, delete it. Delete individually so signal handlers will run.
if overwrite: if overwrite:
LOGGER.info("deleting any hosts not in the remote source: %s" % host_names.keys()) LOGGER.info("deleting any hosts not in the remote source: %s" % all_group.host_names.keys())
for host in Host.objects.exclude(name__in = host_names.keys()).filter(inventory=inventory): for host in Host.objects.exclude(name__in = all_group.host_names.keys()).filter(inventory=inventory):
host.delete() host.delete()
# if overwrite is set, for each group in the database but NOT in the local # if overwrite is set, for each group in the database but NOT in the local
# list, delete it. Delete individually so signal handlers will run. # list, delete it. Delete individually so signal handlers will run.
if overwrite: if overwrite:
LOGGER.info("deleting any groups not in the remote source") LOGGER.info("deleting any groups not in the remote source")
for group in Group.objects.exclude(name__in = group_names.keys()).filter(inventory=inventory): for group in Group.objects.exclude(name__in = all_group.group_names.keys()).filter(inventory=inventory):
group.delete() group.delete()
# if overwrite is set, throw away all invalid child relationships for groups # if overwrite is set, throw away all invalid child relationships for groups
@@ -457,7 +460,7 @@ class Command(NoArgsCommand):
for db_group in db_groups: for db_group in db_groups:
db_kids = db_group.children.all() db_kids = db_group.children.all()
mem_kids = group_names[db_group.name].child_groups mem_kids = all_group.group_names[db_group.name].child_groups
mem_kid_names = [ k.name for k in mem_kids ] mem_kid_names = [ k.name for k in mem_kids ]
removed = False removed = False
for db_kid in db_kids: for db_kid in db_kids:
@@ -468,6 +471,17 @@ class Command(NoArgsCommand):
if removed: if removed:
db_group.save() db_group.save()
# Update/overwrite inventory variables from "all" group.
db_variables = inventory.variables_dict
mem_variables = all_group.variables
if overwrite_vars or overwrite:
LOGGER.info('replacing inventory variables from "all" group')
db_variables = mem_variables
else:
LOGGER.info('updating inventory variables from "all" group')
db_variables.update(mem_variables)
inventory.variables = json.dumps(db_variables)
inventory.save()
# this will be slightly inaccurate, but attribute to first superuser. # this will be slightly inaccurate, but attribute to first superuser.
user = User.objects.filter(is_superuser=True)[0] user = User.objects.filter(is_superuser=True)[0]
@@ -478,7 +492,7 @@ class Command(NoArgsCommand):
db_host_names = [ h.name for h in db_hosts ] db_host_names = [ h.name for h in db_hosts ]
# for each group not in the database but in the local list, create it # for each group not in the database but in the local list, create it
for (k,v) in group_names.iteritems(): for (k,v) in all_group.group_names.iteritems():
if k not in db_group_names: if k not in db_group_names:
variables = json.dumps(v.variables) variables = json.dumps(v.variables)
LOGGER.info("inserting new group %s" % k) LOGGER.info("inserting new group %s" % k)
@@ -487,7 +501,7 @@ class Command(NoArgsCommand):
host.save() host.save()
# for each host not in the database but in the local list, create it # for each host not in the database but in the local list, create it
for (k,v) in host_names.iteritems(): for (k,v) in all_group.host_names.iteritems():
if k not in db_host_names: if k not in db_host_names:
variables = json.dumps(v.variables) variables = json.dumps(v.variables)
LOGGER.info("inserting new host %s" % k) LOGGER.info("inserting new host %s" % k)
@@ -502,7 +516,7 @@ class Command(NoArgsCommand):
for db_group in db_groups: for db_group in db_groups:
db_hosts = db_group.hosts.all() db_hosts = db_group.hosts.all()
mem_hosts = group_names[db_group.name].hosts mem_hosts = all_group.group_names[db_group.name].hosts
mem_host_names = [ h.name for h in mem_hosts ] mem_host_names = [ h.name for h in mem_hosts ]
removed = False removed = False
for db_host in db_hosts: for db_host in db_hosts:
@@ -516,7 +530,7 @@ class Command(NoArgsCommand):
# for each host in a mem group, add it to the parents to which it belongs # for each host in a mem group, add it to the parents to which it belongs
# FIXME: confirm Django is ok with calling add twice and not making two rows # FIXME: confirm Django is ok with calling add twice and not making two rows
for (k,v) in group_names.iteritems(): for (k,v) in all_group.group_names.iteritems():
LOGGER.info("adding parent arrangements for %s" % k) LOGGER.info("adding parent arrangements for %s" % k)
db_group = Group.objects.get(name=k, inventory__pk=inventory.pk) db_group = Group.objects.get(name=k, inventory__pk=inventory.pk)
mem_hosts = v.hosts mem_hosts = v.hosts
@@ -524,7 +538,7 @@ class Command(NoArgsCommand):
db_host = Host.objects.get(name=h.name, inventory__pk=inventory.pk) db_host = Host.objects.get(name=h.name, inventory__pk=inventory.pk)
db_group.hosts.add(db_host) db_group.hosts.add(db_host)
LOGGER.debug("*** ADDING %s to %s ***" % (db_host, db_group)) LOGGER.debug("*** ADDING %s to %s ***" % (db_host, db_group))
db_group.save() #db_group.save()
def variable_mangler(model, mem_hash, overwrite, overwrite_vars): def variable_mangler(model, mem_hash, overwrite, overwrite_vars):
db_collection = model.objects.filter(inventory=inventory) db_collection = model.objects.filter(inventory=inventory)
@@ -541,17 +555,17 @@ class Command(NoArgsCommand):
obj.variables = db_variables obj.variables = db_variables
obj.save() obj.save()
variable_mangler(Group, group_names, overwrite, overwrite_vars) variable_mangler(Group, all_group.group_names, overwrite, overwrite_vars)
variable_mangler(Host, host_names, overwrite, overwrite_vars) variable_mangler(Host, all_group.host_names, overwrite, overwrite_vars)
# for each group, draw in child group arrangements # for each group, draw in child group arrangements
# FIXME: confirm django add behavior as above # FIXME: confirm django add behavior as above
for (k,v) in group_names.iteritems(): for (k,v) in all_group.group_names.iteritems():
db_group = Group.objects.get(inventory=inventory, name=k) db_group = Group.objects.get(inventory=inventory, name=k)
for mem_child_group in v.child_groups: for mem_child_group in v.child_groups:
db_child = Group.objects.get(inventory=inventory, name=mem_child_group.name) db_child = Group.objects.get(inventory=inventory, name=mem_child_group.name)
db_group.children.add(db_child) db_group.children.add(db_child)
db_group.save() #db_group.save()
reader = LicenseReader() reader = LicenseReader()
license_info = reader.from_file() license_info = reader.from_file()

View File

@@ -19,13 +19,13 @@ from django.utils.timezone import now
# AWX # AWX
from awx.main.licenses import LicenseWriter from awx.main.licenses import LicenseWriter
from awx.main.models import * from awx.main.models import *
from awx.main.tests.base import BaseTest from awx.main.tests.base import BaseTest, BaseLiveServerTest
__all__ = ['CleanupDeletedTest', 'InventoryImportTest'] __all__ = ['CleanupDeletedTest', 'InventoryImportTest']
TEST_INVENTORY_INI = '''\ TEST_INVENTORY_INI = '''\
[webservers] [webservers]
web1.example.com web1.example.com ansible_ssh_host=w1.example.net
web2.example.com web2.example.com
web3.example.com web3.example.com
@@ -50,19 +50,19 @@ varb=B
vara=A vara=A
''' '''
class BaseCommandTest(BaseTest): class BaseCommandMixin(object):
''' '''
Base class for tests that run management commands. Base class for tests that run management commands.
''' '''
def setUp(self): def setUp(self):
super(BaseCommandTest, self).setUp() super(BaseCommandMixin, self).setUp()
self._sys_path = [x for x in sys.path] self._sys_path = [x for x in sys.path]
self._environ = dict(os.environ.items()) self._environ = dict(os.environ.items())
self._temp_files = [] self._temp_files = []
def tearDown(self): def tearDown(self):
super(BaseCommandTest, self).tearDown() super(BaseCommandMixin, self).tearDown()
sys.path = self._sys_path sys.path = self._sys_path
for k,v in self._environ.items(): for k,v in self._environ.items():
if os.environ.get(k, None) != v: if os.environ.get(k, None) != v:
@@ -152,7 +152,7 @@ class BaseCommandTest(BaseTest):
result = CommandError(captured_stderr) result = CommandError(captured_stderr)
return result, captured_stdout, captured_stderr return result, captured_stdout, captured_stderr
class CleanupDeletedTest(BaseCommandTest): class CleanupDeletedTest(BaseCommandMixin, BaseTest):
''' '''
Test cases for cleanup_deleted management command. Test cases for cleanup_deleted management command.
''' '''
@@ -244,7 +244,7 @@ class CleanupDeletedTest(BaseCommandTest):
self.assertNotEqual(counts_before, counts_after) self.assertNotEqual(counts_before, counts_after)
self.assertFalse(counts_after[1]) self.assertFalse(counts_after[1])
class InventoryImportTest(BaseCommandTest): class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
''' '''
Test cases for inventory_import management command. Test cases for inventory_import management command.
''' '''
@@ -343,8 +343,82 @@ class InventoryImportTest(BaseCommandTest):
inventory_id=new_inv.pk, inventory_id=new_inv.pk,
source=self.ini_path) source=self.ini_path)
self.assertEqual(result, None) self.assertEqual(result, None)
# FIXME # Check that inventory is populated as expected.
new_inv = Inventory.objects.get(pk=new_inv.pk)
expected_group_names = set(['servers', 'dbservers', 'webservers'])
group_names = set(new_inv.groups.values_list('name', flat=True))
self.assertEqual(expected_group_names, group_names)
expected_host_names = set(['web1.example.com', 'web2.example.com',
'web3.example.com', 'db1.example.com',
'db2.example.com'])
host_names = set(new_inv.hosts.values_list('name', flat=True))
self.assertEqual(expected_host_names, host_names)
self.assertEqual(new_inv.variables_dict, {'vara': 'A'})
for host in new_inv.hosts.all():
if host.name == 'web1.example.com':
self.assertEqual(host.variables_dict,
{'ansible_ssh_host': 'w1.example.net'})
else:
self.assertEqual(host.variables_dict, {})
for group in new_inv.groups.all():
if group.name == 'servers':
self.assertEqual(group.variables_dict, {'varb': 'B'})
children = set(group.children.values_list('name', flat=True))
self.assertEqual(children, set(['dbservers', 'webservers']))
self.assertEqual(group.hosts.count(), 0)
elif group.name == 'dbservers':
self.assertEqual(group.variables_dict, {'dbvar': 'ugh'})
self.assertEqual(group.children.count(), 0)
hosts = set(group.hosts.values_list('name', flat=True))
host_names = set(['db1.example.com','db2.example.com'])
self.assertEqual(hosts, host_names)
elif group.name == 'webservers':
self.assertEqual(group.variables_dict, {'webvar': 'blah'})
self.assertEqual(group.children.count(), 0)
hosts = set(group.hosts.values_list('name', flat=True))
host_names = set(['web1.example.com','web2.example.com',
'web3.example.com'])
self.assertEqual(hosts, host_names)
def test_executable_file(self): def test_executable_file(self):
pass # New empty inventory.
# FIXME old_inv = self.inventories[1]
new_inv = self.organizations[0].inventories.create(name='newb')
self.assertEqual(new_inv.hosts.count(), 0)
self.assertEqual(new_inv.groups.count(), 0)
# Use our own inventory script as executable file.
os.environ.setdefault('REST_API_URL', self.live_server_url)
os.environ.setdefault('REST_API_TOKEN',
self.super_django_user.auth_token.key)
os.environ['INVENTORY_ID'] = str(old_inv.pk)
source = os.path.join(os.path.dirname(__file__), '..', '..', 'scripts',
'inventory.py')
result, stdout, stderr = self.run_command('inventory_import',
inventory_id=new_inv.pk,
source=source)
self.assertEqual(result, None)
# Check that inventory is populated as expected.
new_inv = Inventory.objects.get(pk=new_inv.pk)
self.assertEqual(old_inv.variables_dict, new_inv.variables_dict)
old_groups = set(old_inv.groups.values_list('name', flat=True))
new_groups = set(new_inv.groups.values_list('name', flat=True))
self.assertEqual(old_groups, new_groups)
old_hosts = set(old_inv.hosts.values_list('name', flat=True))
new_hosts = set(new_inv.hosts.values_list('name', flat=True))
self.assertEqual(old_hosts, new_hosts)
for new_host in new_inv.hosts.all():
old_host = old_inv.hosts.get(name=new_host.name)
self.assertEqual(old_host.variables_dict, new_host.variables_dict)
for new_group in new_inv.groups.all():
old_group = old_inv.groups.get(name=new_group.name)
self.assertEqual(old_group.variables_dict, new_group.variables_dict)
old_children = set(old_group.children.values_list('name', flat=True))
new_children = set(new_group.children.values_list('name', flat=True))
self.assertEqual(old_children, new_children)
old_hosts = set(old_group.hosts.values_list('name', flat=True))
new_hosts = set(new_group.hosts.values_list('name', flat=True))
self.assertEqual(old_hosts, new_hosts)
def test_executable_file_with_meta_hostvars(self):
os.environ['INVENTORY_HOSTVARS'] = '1'
self.test_executable_file()