diff --git a/awx/main/management/commands/inventory_import.py b/awx/main/management/commands/inventory_import.py index dc1b58efc2..8f524c1f05 100644 --- a/awx/main/management/commands/inventory_import.py +++ b/awx/main/management/commands/inventory_import.py @@ -10,6 +10,7 @@ import subprocess import sys import time import traceback +from collections import OrderedDict # Django from django.conf import settings @@ -269,12 +270,13 @@ class Command(BaseCommand): self.db_instance_id_map = {} if self.instance_id_var: host_qs = self.inventory_source.hosts.all() - host_qs = host_qs.filter(instance_id='', variables__contains=self.instance_id_var.split('.')[0]) - for host in host_qs: - instance_id = self._get_instance_id(host.variables_dict) - if not instance_id: - continue - self.db_instance_id_map[instance_id] = host.pk + for instance_id_part in reversed(self.instance_id_var.split(',')): + host_qs = host_qs.filter(instance_id='', variables__contains=instance_id_part.split('.')[0]) + for host in host_qs: + instance_id = self._get_instance_id(host.variables_dict) + if not instance_id: + continue + self.db_instance_id_map[instance_id] = host.pk def _build_mem_instance_id_map(self): """ @@ -300,7 +302,7 @@ class Command(BaseCommand): self._cached_host_pk_set = frozenset(self.inventory_source.hosts.values_list('pk', flat=True)) return self._cached_host_pk_set - def _delete_hosts(self): + def _delete_hosts(self, pk_mem_host_map): """ For each host in the database that is NOT in the local list, delete it. When importing from a cloud inventory source attached to a @@ -309,25 +311,10 @@ class Command(BaseCommand): """ if settings.SQL_DEBUG: queries_before = len(connection.queries) + hosts_qs = self.inventory_source.hosts - # Build list of all host pks, remove all that should not be deleted. - del_host_pks = set(self._existing_host_pks()) # makes mutable copy - if self.instance_id_var: - all_instance_ids = list(self.mem_instance_id_map.keys()) - instance_ids = [] - for offset in range(0, len(all_instance_ids), self._batch_size): - instance_ids = all_instance_ids[offset : (offset + self._batch_size)] - for host_pk in hosts_qs.filter(instance_id__in=instance_ids).values_list('pk', flat=True): - del_host_pks.discard(host_pk) - for host_pk in set([v for k, v in self.db_instance_id_map.items() if k in instance_ids]): - del_host_pks.discard(host_pk) - all_host_names = list(set(self.mem_instance_id_map.values()) - set(self.all_group.all_hosts.keys())) - else: - all_host_names = list(self.all_group.all_hosts.keys()) - for offset in range(0, len(all_host_names), self._batch_size): - host_names = all_host_names[offset : (offset + self._batch_size)] - for host_pk in hosts_qs.filter(name__in=host_names).values_list('pk', flat=True): - del_host_pks.discard(host_pk) + del_host_pks = hosts_qs.exclude(pk__in=pk_mem_host_map.keys()).values_list('pk', flat=True) + # Now delete all remaining hosts in batches. all_del_pks = sorted(list(del_host_pks)) for offset in range(0, len(all_del_pks), self._batch_size): @@ -568,7 +555,63 @@ class Command(BaseCommand): logger.debug('Host "%s" is now disabled', mem_host.name) self._batch_add_m2m(self.inventory_source.hosts, db_host) - def _create_update_hosts(self): + def _build_pk_mem_host_map(self): + """ + Creates and returns a data structure that maps DB hosts to in-memory host that + they correspond to - meaning that those hosts will be updated to in-memory host values + """ + mem_host_pk_map = OrderedDict() # keys are mem_host name, values are matching DB host pk + host_pks_updated = set() # same as items of mem_host_pk_map but used for efficiency + mem_host_pk_map_by_id = {} # incomplete mapping by new instance_id to be sorted and pushed to mem_host_pk_map + mem_host_instance_id_map = {} + for k, v in self.all_group.all_hosts.items(): + instance_id = self._get_instance_id(v.variables) + if instance_id in self.db_instance_id_map: + mem_host_pk_map_by_id[self.db_instance_id_map[instance_id]] = v + elif instance_id: + mem_host_instance_id_map[instance_id] = v + + # Update all existing hosts where we know the PK based on instance_id. + all_host_pks = sorted(mem_host_pk_map_by_id.keys()) + for offset in range(0, len(all_host_pks), self._batch_size): + host_pks = all_host_pks[offset : (offset + self._batch_size)] + for db_host in self.inventory.hosts.only('pk').filter(pk__in=host_pks): + if db_host.pk in host_pks_updated: + continue + mem_host = mem_host_pk_map_by_id[db_host.pk] + mem_host_pk_map[mem_host.name] = db_host.pk + host_pks_updated.add(db_host.pk) + + # Update all existing hosts where we know the DB (the prior) instance_id. + all_instance_ids = sorted(mem_host_instance_id_map.keys()) + for offset in range(0, len(all_instance_ids), self._batch_size): + instance_ids = all_instance_ids[offset : (offset + self._batch_size)] + for db_host in self.inventory.hosts.only('pk', 'instance_id').filter(instance_id__in=instance_ids): + if db_host.pk in host_pks_updated: + continue + mem_host = mem_host_instance_id_map[db_host.instance_id] + mem_host_pk_map[mem_host.name] = db_host.pk + host_pks_updated.add(db_host.pk) + + # Update all existing hosts by name. + all_host_names = sorted(self.all_group.all_hosts.keys()) + for offset in range(0, len(all_host_names), self._batch_size): + host_names = all_host_names[offset : (offset + self._batch_size)] + for db_host in self.inventory.hosts.only('pk', 'name').filter(name__in=host_names): + if db_host.pk in host_pks_updated: + continue + mem_host = self.all_group.all_hosts[db_host.name] + mem_host_pk_map[mem_host.name] = db_host.pk + host_pks_updated.add(db_host.pk) + + # Rotate the dictionary so that lookups are done by the host pk + pk_mem_host_map = OrderedDict() + for name, host_pk in mem_host_pk_map.items(): + pk_mem_host_map[host_pk] = name + + return pk_mem_host_map # keys are DB host pk, keys are matching mem host name + + def _create_update_hosts(self, pk_mem_host_map): """ For each host in the local list, create it if it doesn't exist in the database. Otherwise, update/replace database variables from the @@ -577,57 +620,22 @@ class Command(BaseCommand): """ if settings.SQL_DEBUG: queries_before = len(connection.queries) - host_pks_updated = set() - mem_host_pk_map = {} - mem_host_instance_id_map = {} - mem_host_name_map = {} - mem_host_names_to_update = set(self.all_group.all_hosts.keys()) - for k, v in self.all_group.all_hosts.items(): - mem_host_name_map[k] = v - instance_id = self._get_instance_id(v.variables) - if instance_id in self.db_instance_id_map: - mem_host_pk_map[self.db_instance_id_map[instance_id]] = v - elif instance_id: - mem_host_instance_id_map[instance_id] = v - # Update all existing hosts where we know the PK based on instance_id. - all_host_pks = sorted(mem_host_pk_map.keys()) + updated_mem_host_names = set() + + all_host_pks = sorted(pk_mem_host_map.keys()) for offset in range(0, len(all_host_pks), self._batch_size): host_pks = all_host_pks[offset : (offset + self._batch_size)] for db_host in self.inventory.hosts.filter(pk__in=host_pks): - if db_host.pk in host_pks_updated: - continue - mem_host = mem_host_pk_map[db_host.pk] + mem_host_name = pk_mem_host_map[db_host.pk] + mem_host = self.all_group.all_hosts[mem_host_name] self._update_db_host_from_mem_host(db_host, mem_host) - host_pks_updated.add(db_host.pk) - mem_host_names_to_update.discard(mem_host.name) + updated_mem_host_names.add(mem_host.name) - # Update all existing hosts where we know the instance_id. - all_instance_ids = sorted(mem_host_instance_id_map.keys()) - for offset in range(0, len(all_instance_ids), self._batch_size): - instance_ids = all_instance_ids[offset : (offset + self._batch_size)] - for db_host in self.inventory.hosts.filter(instance_id__in=instance_ids): - if db_host.pk in host_pks_updated: - continue - mem_host = mem_host_instance_id_map[db_host.instance_id] - self._update_db_host_from_mem_host(db_host, mem_host) - host_pks_updated.add(db_host.pk) - mem_host_names_to_update.discard(mem_host.name) - - # Update all existing hosts by name. - all_host_names = sorted(mem_host_name_map.keys()) - for offset in range(0, len(all_host_names), self._batch_size): - host_names = all_host_names[offset : (offset + self._batch_size)] - for db_host in self.inventory.hosts.filter(name__in=host_names): - if db_host.pk in host_pks_updated: - continue - mem_host = mem_host_name_map[db_host.name] - self._update_db_host_from_mem_host(db_host, mem_host) - host_pks_updated.add(db_host.pk) - mem_host_names_to_update.discard(mem_host.name) + mem_host_names_to_create = set(self.all_group.all_hosts.keys()) - updated_mem_host_names # Create any new hosts. - for mem_host_name in sorted(mem_host_names_to_update): + for mem_host_name in sorted(mem_host_names_to_create): mem_host = self.all_group.all_hosts[mem_host_name] import_vars = mem_host.variables host_desc = import_vars.pop('_awx_description', 'imported') @@ -726,13 +734,14 @@ class Command(BaseCommand): self._batch_size = 500 self._build_db_instance_id_map() self._build_mem_instance_id_map() + pk_mem_host_map = self._build_pk_mem_host_map() if self.overwrite: - self._delete_hosts() + self._delete_hosts(pk_mem_host_map) self._delete_groups() self._delete_group_children_and_hosts() self._update_inventory() self._create_update_groups() - self._create_update_hosts() + self._create_update_hosts(pk_mem_host_map) self._create_update_group_children() self._create_update_group_hosts() diff --git a/awx/main/tests/functional/commands/test_inventory_import.py b/awx/main/tests/functional/commands/test_inventory_import.py index c53630bcb5..75a09fc476 100644 --- a/awx/main/tests/functional/commands/test_inventory_import.py +++ b/awx/main/tests/functional/commands/test_inventory_import.py @@ -5,6 +5,7 @@ import pytest from unittest import mock import os +import yaml # Django from django.core.management.base import CommandError @@ -52,6 +53,110 @@ def mock_logging(self, level): pass +@pytest.mark.django_db +@mock.patch.object(inventory_import.Command, 'set_logging_level', mock_logging) +class TestMigrationCases: + """In the case that we have any bugs with the declared instance ID variables + then it is inevitable that we will, at some point, import a host with a blank ID + and then later import it with the correct id. + """ + + @pytest.mark.parametrize('id_var', ('', 'foo.id', 'foo.id,other', 'other,foo.id'), ids=['none', 'simple', 'complex', 'backward']) + @pytest.mark.parametrize('host_name', ('host-1', 'fooval'), ids=['arbitrary', 'id']) + @pytest.mark.parametrize('has_var', (True, False)) + def test_single_host_not_recreated(self, inventory, id_var, host_name, has_var): + inv_src = InventorySource.objects.create(inventory=inventory, source='gce') + + options = dict(overwrite=True, instance_id_var=id_var) + + vars = {'foo': {'id': 'fooval'}} + data = { + '_meta': {'hostvars': {host_name: vars if has_var else {'unrelated': 'value'}}}, + "ungrouped": {"hosts": [host_name]}, + } + old_id = None + + for i in range(3): + inventory_import.Command().perform_update(options.copy(), data.copy(), inv_src.create_unified_job()) + + assert inventory.hosts.count() == inv_src.hosts.count() == 1 + host = inventory.hosts.first() + assert host.name == host_name + assert host.instance_id in ('fooval', '') + if has_var: + assert yaml.safe_load(host.variables) == vars + else: + assert yaml.safe_load(host.variables) == {'unrelated': 'value'} + + if old_id is not None: + assert host.id == old_id + old_id = host.id + + @pytest.mark.parametrize('id_var_seq', [('', 'foo.id,other'), ('foo.id,other', '')], ids=['gained', 'lost']) # second is problem case + @pytest.mark.parametrize('host_name', ('host-1', 'fooval'), ids=['arbitrary', 'id']) + def test_host_gains_or_loses_instance_id(self, inventory, id_var_seq, host_name): + inv_src = InventorySource.objects.create(inventory=inventory, source='gce') + + options = dict(overwrite=True) + + vars = {'foo': {'id': 'fooval'}} + old_id = None + + for id_var in id_var_seq: + options['instance_id_var'] = id_var + data = { + '_meta': {'hostvars': {host_name: vars}}, + "ungrouped": {"hosts": [host_name]}, + } + inventory_import.Command().perform_update(options.copy(), data.copy(), inv_src.create_unified_job()) + + assert inventory.hosts.count() == inv_src.hosts.count() == 1 + host = inventory.hosts.first() + assert host.name == host_name + assert host.instance_id == ('fooval' if id_var else '') + assert yaml.safe_load(host.variables) == vars + + if old_id is not None: + assert host.id == old_id + old_id = host.id + + @pytest.mark.parametrize('second_list', [('host-1', 'fooval'), ('host-1',), ('fooval',)]) + def test_name_and_id_confusion(self, inventory, second_list): + inv_src = InventorySource.objects.create(inventory=inventory, source='gce') + + CASES = [('', ['host-1', 'fooval']), ('foo.id', second_list)] + + options = dict(overwrite=True) + + vars = {'foo': {'id': 'fooval'}} + data = { + '_meta': {'hostvars': {}}, + "ungrouped": {"hosts": []}, + } + id_set = None + + for id_var, hosts in CASES: + options['instance_id_var'] = id_var + + data['_meta']['hostvars'] = {} + for host_name in hosts: + data['_meta']['hostvars'][host_name] = vars if id_var else {} + data['ungrouped']['hosts'] = hosts + + inventory_import.Command().perform_update(options.copy(), data.copy(), inv_src.create_unified_job()) + + new_ids = set(inventory.hosts.values_list('id', flat=True)) + if id_set is not None: + assert not (new_ids - id_set) + id_set = new_ids + + assert inventory.hosts.count() == len(hosts), [(host.name, host.instance_id) for host in inventory.hosts.all()] + assert inv_src.hosts.count() == len(hosts), [(host.name, host.instance_id) for host in inventory.hosts.all()] + for host_name in hosts: + host = inventory.hosts.get(name=host_name) + assert host.instance_id == ('fooval' if id_var else '') + + @pytest.mark.django_db @pytest.mark.inventory_import @mock.patch.object(inventory_import.Command, 'check_license', mock.MagicMock()) @@ -89,7 +194,7 @@ class TestINIImports: def test_inventory_single_ini_import(self, inventory, capsys): inventory_import.AnsibleInventoryLoader._data = TEST_INVENTORY_CONTENT cmd = inventory_import.Command() - r = cmd.handle(inventory_id=inventory.pk, source=__file__, method='backport') + r = cmd.handle(inventory_id=inventory.pk, source=__file__) out, err = capsys.readouterr() assert r is None assert out == ''