diff --git a/awx/main/management/commands/inventory_import.py b/awx/main/management/commands/inventory_import.py index 1301f577af..f9d7bdfdc3 100644 --- a/awx/main/management/commands/inventory_import.py +++ b/awx/main/management/commands/inventory_import.py @@ -23,10 +23,6 @@ from awx.main.licenses import LicenseReader LOGGER = None -# maps host and group names to hosts to prevent redudant additions -group_names = {} -host_names = {} - class ImportException(BaseException): def __init__(self, msg): @@ -47,6 +43,10 @@ class MemGroup(object): self.hosts = [] self.variables = {} self.parents = [] + # Used on the "all" group in place of previous global variables. + # maps host and group names to hosts to prevent redudant additions + self.host_names = {} + self.group_names = {} group_vars = os.path.join(inventory_base, 'group_vars', name) if os.path.exists(group_vars): @@ -133,43 +133,49 @@ class MemHost(object): class BaseLoader(object): + def __init__(self, inventory_base=None, all_group=None): + self.inventory_base = inventory_base + self.all_group = all_group + def get_host(self, name): if ":" in name: tokens = name.split(":") name = tokens[0] - global host_names host = None - if not name in host_names: + if not name in self.all_group.host_names: host = MemHost(name, self.inventory_base) - host_names[name] = host - return host_names[name] + self.all_group.host_names[name] = host + return self.all_group.host_names[name] - def get_group(self, name, all_group, child=False): - global group_names + def get_group(self, name, all_group=None, child=False): + all_group = all_group or self.all_group if name == 'all': return all_group - if not name in group_names: + if not name in self.all_group.group_names: group = MemGroup(name, self.inventory_base) if not child: all_group.add_child_group(group) - group_names[name] = group - return group_names[name] + self.all_group.group_names[name] = group + return self.all_group.group_names[name] + + def load(self, src): + raise NotImplementedError class IniLoader(BaseLoader): - def __init__(self, inventory_base=None): + def __init__(self, inventory_base=None, all_group=None): + super(IniLoader, self).__init__(inventory_base, all_group) LOGGER.debug("processing ini") - self.inventory_base = inventory_base - def load(self, src, all_group): - LOGGER.debug("loading: %s on %s" % (src, all_group)) + def load(self, src): + LOGGER.debug("loading: %s on %s" % (src, self.all_group)) if self.inventory_base is None: self.inventory_base = os.path.dirname(src) data = open(src).read() lines = data.split("\n") - group = all_group + group = self.all_group input_mode = 'host' for line in lines: @@ -183,14 +189,14 @@ class IniLoader(BaseLoader): if line.find(":vars") != -1: input_mode = 'vars' line = line.replace(":vars","") - group = self.get_group(line, all_group) + group = self.get_group(line) elif line.find(":children") != -1: input_mode = 'children' line = line.replace(":children","") - group = self.get_group(line, all_group) + group = self.get_group(line) else: input_mode = 'host' - group = self.get_group(line, all_group) + group = self.get_group(line) else: # add a host or variable to the existing group/host line = line.lstrip().rstrip() @@ -243,10 +249,9 @@ class IniLoader(BaseLoader): class ExecutableJsonLoader(BaseLoader): - def __init__(self, inventory_base=None): - + def __init__(self, inventory_base=None, all_group=None): + super(ExecutableJsonLoader, self).__init__(inventory_base, all_group) LOGGER.debug("processing executable JSON source") - self.inventory_base = inventory_base self.child_group_names = {} def command_to_json(self, cmd): @@ -263,9 +268,9 @@ class ExecutableJsonLoader(BaseLoader): assert type(data) == dict return data - def load(self, src, all_group): + def load(self, src): - LOGGER.debug("loading %s onto %s" % (src, all_group)) + LOGGER.debug("loading %s onto %s" % (src, self.all_group)) if self.inventory_base is None: self.inventory_base = os.path.dirname(src) @@ -273,10 +278,11 @@ class ExecutableJsonLoader(BaseLoader): data = self.command_to_json([src, "--list"]) group = None + _meta = data.pop('_meta', {}) for (k,v) in data.iteritems(): - group = self.get_group(k, all_group) + group = self.get_group(k) if type(v) == dict: @@ -311,52 +317,50 @@ class ExecutableJsonLoader(BaseLoader): host = self.get_host(x) group.add_host(host) - all_group.add_child_group(group) + if k != 'all': + self.all_group.add_child_group(group) # then we invoke the executable once for each host name we've built up # to set their variables - global host_names - for (k,v) in host_names.iteritems(): - data = self.command_to_json([src, "--host", k]) + for (k,v) in self.all_group.host_names.iteritems(): + if 'hostvars' not in _meta: + data = self.command_to_json([src, "--host", k]) + else: + data = _meta['hostvars'].get(k, {}) v.variables.update(data) -class GenericLoader(object): - - def __init__(self, src): - - LOGGER.debug("preparing loaders") +def load_generic(src): + LOGGER.debug("preparing loaders") - LOGGER.debug("analyzing type of source") - if not os.path.exists(src): - LOGGER.debug("source missing") - raise CommandError("source does not exist") - if os.path.isdir(src): - self.memGroup = memGroup = MemGroup('all', src) - for f in glob.glob("%s/*" % src): - if f.endswith(".ini"): - # config files for inventory scripts should be ignored - continue - if not os.path.isdir(f): - if os.access(f, os.X_OK): - ExecutableJsonLoader().load(f, memGroup) - else: - IniLoader().load(f, memGroup) - elif os.access(src, os.X_OK): - self.memGroup = memGroup = MemGroup('all', os.path.dirname(src)) - ExecutableJsonLoader().load(src, memGroup) - else: - self.memGroup = memGroup = MemGroup('all', os.path.dirname(src)) - IniLoader().load(src, memGroup) + LOGGER.debug("analyzing type of source") + if not os.path.exists(src): + LOGGER.debug("source missing") + raise CommandError("source does not exist") + if os.path.isdir(src): + all_group = MemGroup('all', src) + for f in glob.glob("%s/*" % src): + if f.endswith(".ini"): + # config files for inventory scripts should be ignored + continue + if not os.path.isdir(f): + if os.access(f, os.X_OK): + ExecutableJsonLoader(None, all_group).load(f) + else: + IniLoader(None, all_group).load(f) + elif os.access(src, os.X_OK): + all_group = MemGroup('all', os.path.dirname(src)) + ExecutableJsonLoader(None, all_group).load(src) + else: + all_group = MemGroup('all', os.path.dirname(src)) + IniLoader(None, all_group).load(src) - LOGGER.debug("loading process complete") + LOGGER.debug("loading process complete") + return all_group - def result(self): - return self.memGroup - class Command(NoArgsCommand): ''' Management command to import directory, INI, or dynamic inventory @@ -416,11 +420,10 @@ class Command(NoArgsCommand): LOGGER.debug("preparing loader") - loader = GenericLoader(source) - memGroup = loader.result() + all_group = load_generic(source) LOGGER.debug("debugging loaded result") - memGroup.debug_tree() + all_group.debug_tree() # now that memGroup is correct and supports JSON executables, INI, and trees # now merge and/or overwrite with the database itself! @@ -439,15 +442,15 @@ class Command(NoArgsCommand): # if overwrite is set, for each host in the database but NOT in the local # list, delete it. Delete individually so signal handlers will run. if overwrite: - LOGGER.info("deleting any hosts not in the remote source: %s" % host_names.keys()) - for host in Host.objects.exclude(name__in = host_names.keys()).filter(inventory=inventory): + LOGGER.info("deleting any hosts not in the remote source: %s" % all_group.host_names.keys()) + for host in Host.objects.exclude(name__in = all_group.host_names.keys()).filter(inventory=inventory): host.delete() # if overwrite is set, for each group in the database but NOT in the local # list, delete it. Delete individually so signal handlers will run. if overwrite: LOGGER.info("deleting any groups not in the remote source") - for group in Group.objects.exclude(name__in = group_names.keys()).filter(inventory=inventory): + for group in Group.objects.exclude(name__in = all_group.group_names.keys()).filter(inventory=inventory): group.delete() # if overwrite is set, throw away all invalid child relationships for groups @@ -457,7 +460,7 @@ class Command(NoArgsCommand): for db_group in db_groups: db_kids = db_group.children.all() - mem_kids = group_names[db_group.name].child_groups + mem_kids = all_group.group_names[db_group.name].child_groups mem_kid_names = [ k.name for k in mem_kids ] removed = False for db_kid in db_kids: @@ -468,6 +471,17 @@ class Command(NoArgsCommand): if removed: db_group.save() + # Update/overwrite inventory variables from "all" group. + db_variables = inventory.variables_dict + mem_variables = all_group.variables + if overwrite_vars or overwrite: + LOGGER.info('replacing inventory variables from "all" group') + db_variables = mem_variables + else: + LOGGER.info('updating inventory variables from "all" group') + db_variables.update(mem_variables) + inventory.variables = json.dumps(db_variables) + inventory.save() # this will be slightly inaccurate, but attribute to first superuser. user = User.objects.filter(is_superuser=True)[0] @@ -478,7 +492,7 @@ class Command(NoArgsCommand): db_host_names = [ h.name for h in db_hosts ] # for each group not in the database but in the local list, create it - for (k,v) in group_names.iteritems(): + for (k,v) in all_group.group_names.iteritems(): if k not in db_group_names: variables = json.dumps(v.variables) LOGGER.info("inserting new group %s" % k) @@ -487,7 +501,7 @@ class Command(NoArgsCommand): host.save() # for each host not in the database but in the local list, create it - for (k,v) in host_names.iteritems(): + for (k,v) in all_group.host_names.iteritems(): if k not in db_host_names: variables = json.dumps(v.variables) LOGGER.info("inserting new host %s" % k) @@ -502,7 +516,7 @@ class Command(NoArgsCommand): for db_group in db_groups: db_hosts = db_group.hosts.all() - mem_hosts = group_names[db_group.name].hosts + mem_hosts = all_group.group_names[db_group.name].hosts mem_host_names = [ h.name for h in mem_hosts ] removed = False for db_host in db_hosts: @@ -516,7 +530,7 @@ class Command(NoArgsCommand): # for each host in a mem group, add it to the parents to which it belongs # FIXME: confirm Django is ok with calling add twice and not making two rows - for (k,v) in group_names.iteritems(): + for (k,v) in all_group.group_names.iteritems(): LOGGER.info("adding parent arrangements for %s" % k) db_group = Group.objects.get(name=k, inventory__pk=inventory.pk) mem_hosts = v.hosts @@ -524,7 +538,7 @@ class Command(NoArgsCommand): db_host = Host.objects.get(name=h.name, inventory__pk=inventory.pk) db_group.hosts.add(db_host) LOGGER.debug("*** ADDING %s to %s ***" % (db_host, db_group)) - db_group.save() + #db_group.save() def variable_mangler(model, mem_hash, overwrite, overwrite_vars): db_collection = model.objects.filter(inventory=inventory) @@ -541,17 +555,17 @@ class Command(NoArgsCommand): obj.variables = db_variables obj.save() - variable_mangler(Group, group_names, overwrite, overwrite_vars) - variable_mangler(Host, host_names, overwrite, overwrite_vars) + variable_mangler(Group, all_group.group_names, overwrite, overwrite_vars) + variable_mangler(Host, all_group.host_names, overwrite, overwrite_vars) # for each group, draw in child group arrangements # FIXME: confirm django add behavior as above - for (k,v) in group_names.iteritems(): + for (k,v) in all_group.group_names.iteritems(): db_group = Group.objects.get(inventory=inventory, name=k) for mem_child_group in v.child_groups: db_child = Group.objects.get(inventory=inventory, name=mem_child_group.name) db_group.children.add(db_child) - db_group.save() + #db_group.save() reader = LicenseReader() license_info = reader.from_file() diff --git a/awx/main/tests/commands.py b/awx/main/tests/commands.py index f731fc9fba..65be68aa37 100644 --- a/awx/main/tests/commands.py +++ b/awx/main/tests/commands.py @@ -19,13 +19,13 @@ from django.utils.timezone import now # AWX from awx.main.licenses import LicenseWriter from awx.main.models import * -from awx.main.tests.base import BaseTest +from awx.main.tests.base import BaseTest, BaseLiveServerTest __all__ = ['CleanupDeletedTest', 'InventoryImportTest'] TEST_INVENTORY_INI = '''\ [webservers] -web1.example.com +web1.example.com ansible_ssh_host=w1.example.net web2.example.com web3.example.com @@ -50,19 +50,19 @@ varb=B vara=A ''' -class BaseCommandTest(BaseTest): +class BaseCommandMixin(object): ''' Base class for tests that run management commands. ''' def setUp(self): - super(BaseCommandTest, self).setUp() + super(BaseCommandMixin, self).setUp() self._sys_path = [x for x in sys.path] self._environ = dict(os.environ.items()) self._temp_files = [] def tearDown(self): - super(BaseCommandTest, self).tearDown() + super(BaseCommandMixin, self).tearDown() sys.path = self._sys_path for k,v in self._environ.items(): if os.environ.get(k, None) != v: @@ -152,7 +152,7 @@ class BaseCommandTest(BaseTest): result = CommandError(captured_stderr) return result, captured_stdout, captured_stderr -class CleanupDeletedTest(BaseCommandTest): +class CleanupDeletedTest(BaseCommandMixin, BaseTest): ''' Test cases for cleanup_deleted management command. ''' @@ -244,7 +244,7 @@ class CleanupDeletedTest(BaseCommandTest): self.assertNotEqual(counts_before, counts_after) self.assertFalse(counts_after[1]) -class InventoryImportTest(BaseCommandTest): +class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest): ''' Test cases for inventory_import management command. ''' @@ -343,8 +343,82 @@ class InventoryImportTest(BaseCommandTest): inventory_id=new_inv.pk, source=self.ini_path) self.assertEqual(result, None) - # FIXME + # Check that inventory is populated as expected. + new_inv = Inventory.objects.get(pk=new_inv.pk) + expected_group_names = set(['servers', 'dbservers', 'webservers']) + group_names = set(new_inv.groups.values_list('name', flat=True)) + self.assertEqual(expected_group_names, group_names) + expected_host_names = set(['web1.example.com', 'web2.example.com', + 'web3.example.com', 'db1.example.com', + 'db2.example.com']) + host_names = set(new_inv.hosts.values_list('name', flat=True)) + self.assertEqual(expected_host_names, host_names) + self.assertEqual(new_inv.variables_dict, {'vara': 'A'}) + for host in new_inv.hosts.all(): + if host.name == 'web1.example.com': + self.assertEqual(host.variables_dict, + {'ansible_ssh_host': 'w1.example.net'}) + else: + self.assertEqual(host.variables_dict, {}) + for group in new_inv.groups.all(): + if group.name == 'servers': + self.assertEqual(group.variables_dict, {'varb': 'B'}) + children = set(group.children.values_list('name', flat=True)) + self.assertEqual(children, set(['dbservers', 'webservers'])) + self.assertEqual(group.hosts.count(), 0) + elif group.name == 'dbservers': + self.assertEqual(group.variables_dict, {'dbvar': 'ugh'}) + self.assertEqual(group.children.count(), 0) + hosts = set(group.hosts.values_list('name', flat=True)) + host_names = set(['db1.example.com','db2.example.com']) + self.assertEqual(hosts, host_names) + elif group.name == 'webservers': + self.assertEqual(group.variables_dict, {'webvar': 'blah'}) + self.assertEqual(group.children.count(), 0) + hosts = set(group.hosts.values_list('name', flat=True)) + host_names = set(['web1.example.com','web2.example.com', + 'web3.example.com']) + self.assertEqual(hosts, host_names) def test_executable_file(self): - pass - # FIXME + # New empty inventory. + old_inv = self.inventories[1] + new_inv = self.organizations[0].inventories.create(name='newb') + self.assertEqual(new_inv.hosts.count(), 0) + self.assertEqual(new_inv.groups.count(), 0) + # Use our own inventory script as executable file. + os.environ.setdefault('REST_API_URL', self.live_server_url) + os.environ.setdefault('REST_API_TOKEN', + self.super_django_user.auth_token.key) + os.environ['INVENTORY_ID'] = str(old_inv.pk) + source = os.path.join(os.path.dirname(__file__), '..', '..', 'scripts', + 'inventory.py') + result, stdout, stderr = self.run_command('inventory_import', + inventory_id=new_inv.pk, + source=source) + self.assertEqual(result, None) + # Check that inventory is populated as expected. + new_inv = Inventory.objects.get(pk=new_inv.pk) + self.assertEqual(old_inv.variables_dict, new_inv.variables_dict) + old_groups = set(old_inv.groups.values_list('name', flat=True)) + new_groups = set(new_inv.groups.values_list('name', flat=True)) + self.assertEqual(old_groups, new_groups) + old_hosts = set(old_inv.hosts.values_list('name', flat=True)) + new_hosts = set(new_inv.hosts.values_list('name', flat=True)) + self.assertEqual(old_hosts, new_hosts) + for new_host in new_inv.hosts.all(): + old_host = old_inv.hosts.get(name=new_host.name) + self.assertEqual(old_host.variables_dict, new_host.variables_dict) + for new_group in new_inv.groups.all(): + old_group = old_inv.groups.get(name=new_group.name) + self.assertEqual(old_group.variables_dict, new_group.variables_dict) + old_children = set(old_group.children.values_list('name', flat=True)) + new_children = set(new_group.children.values_list('name', flat=True)) + self.assertEqual(old_children, new_children) + old_hosts = set(old_group.hosts.values_list('name', flat=True)) + new_hosts = set(new_group.hosts.values_list('name', flat=True)) + self.assertEqual(old_hosts, new_hosts) + + def test_executable_file_with_meta_hostvars(self): + os.environ['INVENTORY_HOSTVARS'] = '1' + self.test_executable_file()