Fix overwrite bug where hosts with no instance ID var are re-created (#10910)

* Write tests to assure air-tightness of overwrite

* Candidate fix for group overwrite air-tightness

* Another proposed fix for the id mapping

* Further double down on tracking old instance_id

* Separate unchanging data case and fix some test issues

* parametrize final confusion test

* cut down some more on test cases and fix bug with prior fix

* Rewrite of _delete_host code sharing with update method

This is a start-to-finish rewrite of the host overwrite bug fix
this method is much more conservative,
it does this by keeping the overall code structure where hosts
are deleted before host updates are made

To fix the bug, we share code between the method that deletes hosts
and the method that updates the hosts
A data structure is created and passed to both methods

By having both methods use the same data structure which maps
the in-memory hosts to DB hosts, we assure that the deletions
will ONLY delete hosts that will not be updated
This commit is contained in:
Alan Rominger 2021-09-16 15:29:57 -04:00 committed by GitHub
parent 181bda51ce
commit 1319fadc60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 186 additions and 72 deletions

View File

@ -10,6 +10,7 @@ import subprocess
import sys
import time
import traceback
from collections import OrderedDict
# Django
from django.conf import settings
@ -269,12 +270,13 @@ class Command(BaseCommand):
self.db_instance_id_map = {}
if self.instance_id_var:
host_qs = self.inventory_source.hosts.all()
host_qs = host_qs.filter(instance_id='', variables__contains=self.instance_id_var.split('.')[0])
for host in host_qs:
instance_id = self._get_instance_id(host.variables_dict)
if not instance_id:
continue
self.db_instance_id_map[instance_id] = host.pk
for instance_id_part in reversed(self.instance_id_var.split(',')):
host_qs = host_qs.filter(instance_id='', variables__contains=instance_id_part.split('.')[0])
for host in host_qs:
instance_id = self._get_instance_id(host.variables_dict)
if not instance_id:
continue
self.db_instance_id_map[instance_id] = host.pk
def _build_mem_instance_id_map(self):
"""
@ -300,7 +302,7 @@ class Command(BaseCommand):
self._cached_host_pk_set = frozenset(self.inventory_source.hosts.values_list('pk', flat=True))
return self._cached_host_pk_set
def _delete_hosts(self):
def _delete_hosts(self, pk_mem_host_map):
"""
For each host in the database that is NOT in the local list, delete
it. When importing from a cloud inventory source attached to a
@ -309,25 +311,10 @@ class Command(BaseCommand):
"""
if settings.SQL_DEBUG:
queries_before = len(connection.queries)
hosts_qs = self.inventory_source.hosts
# Build list of all host pks, remove all that should not be deleted.
del_host_pks = set(self._existing_host_pks()) # makes mutable copy
if self.instance_id_var:
all_instance_ids = list(self.mem_instance_id_map.keys())
instance_ids = []
for offset in range(0, len(all_instance_ids), self._batch_size):
instance_ids = all_instance_ids[offset : (offset + self._batch_size)]
for host_pk in hosts_qs.filter(instance_id__in=instance_ids).values_list('pk', flat=True):
del_host_pks.discard(host_pk)
for host_pk in set([v for k, v in self.db_instance_id_map.items() if k in instance_ids]):
del_host_pks.discard(host_pk)
all_host_names = list(set(self.mem_instance_id_map.values()) - set(self.all_group.all_hosts.keys()))
else:
all_host_names = list(self.all_group.all_hosts.keys())
for offset in range(0, len(all_host_names), self._batch_size):
host_names = all_host_names[offset : (offset + self._batch_size)]
for host_pk in hosts_qs.filter(name__in=host_names).values_list('pk', flat=True):
del_host_pks.discard(host_pk)
del_host_pks = hosts_qs.exclude(pk__in=pk_mem_host_map.keys()).values_list('pk', flat=True)
# Now delete all remaining hosts in batches.
all_del_pks = sorted(list(del_host_pks))
for offset in range(0, len(all_del_pks), self._batch_size):
@ -568,7 +555,63 @@ class Command(BaseCommand):
logger.debug('Host "%s" is now disabled', mem_host.name)
self._batch_add_m2m(self.inventory_source.hosts, db_host)
def _create_update_hosts(self):
def _build_pk_mem_host_map(self):
"""
Creates and returns a data structure that maps DB hosts to in-memory host that
they correspond to - meaning that those hosts will be updated to in-memory host values
"""
mem_host_pk_map = OrderedDict() # keys are mem_host name, values are matching DB host pk
host_pks_updated = set() # same as items of mem_host_pk_map but used for efficiency
mem_host_pk_map_by_id = {} # incomplete mapping by new instance_id to be sorted and pushed to mem_host_pk_map
mem_host_instance_id_map = {}
for k, v in self.all_group.all_hosts.items():
instance_id = self._get_instance_id(v.variables)
if instance_id in self.db_instance_id_map:
mem_host_pk_map_by_id[self.db_instance_id_map[instance_id]] = v
elif instance_id:
mem_host_instance_id_map[instance_id] = v
# Update all existing hosts where we know the PK based on instance_id.
all_host_pks = sorted(mem_host_pk_map_by_id.keys())
for offset in range(0, len(all_host_pks), self._batch_size):
host_pks = all_host_pks[offset : (offset + self._batch_size)]
for db_host in self.inventory.hosts.only('pk').filter(pk__in=host_pks):
if db_host.pk in host_pks_updated:
continue
mem_host = mem_host_pk_map_by_id[db_host.pk]
mem_host_pk_map[mem_host.name] = db_host.pk
host_pks_updated.add(db_host.pk)
# Update all existing hosts where we know the DB (the prior) instance_id.
all_instance_ids = sorted(mem_host_instance_id_map.keys())
for offset in range(0, len(all_instance_ids), self._batch_size):
instance_ids = all_instance_ids[offset : (offset + self._batch_size)]
for db_host in self.inventory.hosts.only('pk', 'instance_id').filter(instance_id__in=instance_ids):
if db_host.pk in host_pks_updated:
continue
mem_host = mem_host_instance_id_map[db_host.instance_id]
mem_host_pk_map[mem_host.name] = db_host.pk
host_pks_updated.add(db_host.pk)
# Update all existing hosts by name.
all_host_names = sorted(self.all_group.all_hosts.keys())
for offset in range(0, len(all_host_names), self._batch_size):
host_names = all_host_names[offset : (offset + self._batch_size)]
for db_host in self.inventory.hosts.only('pk', 'name').filter(name__in=host_names):
if db_host.pk in host_pks_updated:
continue
mem_host = self.all_group.all_hosts[db_host.name]
mem_host_pk_map[mem_host.name] = db_host.pk
host_pks_updated.add(db_host.pk)
# Rotate the dictionary so that lookups are done by the host pk
pk_mem_host_map = OrderedDict()
for name, host_pk in mem_host_pk_map.items():
pk_mem_host_map[host_pk] = name
return pk_mem_host_map # keys are DB host pk, keys are matching mem host name
def _create_update_hosts(self, pk_mem_host_map):
"""
For each host in the local list, create it if it doesn't exist in the
database. Otherwise, update/replace database variables from the
@ -577,57 +620,22 @@ class Command(BaseCommand):
"""
if settings.SQL_DEBUG:
queries_before = len(connection.queries)
host_pks_updated = set()
mem_host_pk_map = {}
mem_host_instance_id_map = {}
mem_host_name_map = {}
mem_host_names_to_update = set(self.all_group.all_hosts.keys())
for k, v in self.all_group.all_hosts.items():
mem_host_name_map[k] = v
instance_id = self._get_instance_id(v.variables)
if instance_id in self.db_instance_id_map:
mem_host_pk_map[self.db_instance_id_map[instance_id]] = v
elif instance_id:
mem_host_instance_id_map[instance_id] = v
# Update all existing hosts where we know the PK based on instance_id.
all_host_pks = sorted(mem_host_pk_map.keys())
updated_mem_host_names = set()
all_host_pks = sorted(pk_mem_host_map.keys())
for offset in range(0, len(all_host_pks), self._batch_size):
host_pks = all_host_pks[offset : (offset + self._batch_size)]
for db_host in self.inventory.hosts.filter(pk__in=host_pks):
if db_host.pk in host_pks_updated:
continue
mem_host = mem_host_pk_map[db_host.pk]
mem_host_name = pk_mem_host_map[db_host.pk]
mem_host = self.all_group.all_hosts[mem_host_name]
self._update_db_host_from_mem_host(db_host, mem_host)
host_pks_updated.add(db_host.pk)
mem_host_names_to_update.discard(mem_host.name)
updated_mem_host_names.add(mem_host.name)
# Update all existing hosts where we know the instance_id.
all_instance_ids = sorted(mem_host_instance_id_map.keys())
for offset in range(0, len(all_instance_ids), self._batch_size):
instance_ids = all_instance_ids[offset : (offset + self._batch_size)]
for db_host in self.inventory.hosts.filter(instance_id__in=instance_ids):
if db_host.pk in host_pks_updated:
continue
mem_host = mem_host_instance_id_map[db_host.instance_id]
self._update_db_host_from_mem_host(db_host, mem_host)
host_pks_updated.add(db_host.pk)
mem_host_names_to_update.discard(mem_host.name)
# Update all existing hosts by name.
all_host_names = sorted(mem_host_name_map.keys())
for offset in range(0, len(all_host_names), self._batch_size):
host_names = all_host_names[offset : (offset + self._batch_size)]
for db_host in self.inventory.hosts.filter(name__in=host_names):
if db_host.pk in host_pks_updated:
continue
mem_host = mem_host_name_map[db_host.name]
self._update_db_host_from_mem_host(db_host, mem_host)
host_pks_updated.add(db_host.pk)
mem_host_names_to_update.discard(mem_host.name)
mem_host_names_to_create = set(self.all_group.all_hosts.keys()) - updated_mem_host_names
# Create any new hosts.
for mem_host_name in sorted(mem_host_names_to_update):
for mem_host_name in sorted(mem_host_names_to_create):
mem_host = self.all_group.all_hosts[mem_host_name]
import_vars = mem_host.variables
host_desc = import_vars.pop('_awx_description', 'imported')
@ -726,13 +734,14 @@ class Command(BaseCommand):
self._batch_size = 500
self._build_db_instance_id_map()
self._build_mem_instance_id_map()
pk_mem_host_map = self._build_pk_mem_host_map()
if self.overwrite:
self._delete_hosts()
self._delete_hosts(pk_mem_host_map)
self._delete_groups()
self._delete_group_children_and_hosts()
self._update_inventory()
self._create_update_groups()
self._create_update_hosts()
self._create_update_hosts(pk_mem_host_map)
self._create_update_group_children()
self._create_update_group_hosts()

View File

@ -5,6 +5,7 @@
import pytest
from unittest import mock
import os
import yaml
# Django
from django.core.management.base import CommandError
@ -52,6 +53,110 @@ def mock_logging(self, level):
pass
@pytest.mark.django_db
@mock.patch.object(inventory_import.Command, 'set_logging_level', mock_logging)
class TestMigrationCases:
"""In the case that we have any bugs with the declared instance ID variables
then it is inevitable that we will, at some point, import a host with a blank ID
and then later import it with the correct id.
"""
@pytest.mark.parametrize('id_var', ('', 'foo.id', 'foo.id,other', 'other,foo.id'), ids=['none', 'simple', 'complex', 'backward'])
@pytest.mark.parametrize('host_name', ('host-1', 'fooval'), ids=['arbitrary', 'id'])
@pytest.mark.parametrize('has_var', (True, False))
def test_single_host_not_recreated(self, inventory, id_var, host_name, has_var):
inv_src = InventorySource.objects.create(inventory=inventory, source='gce')
options = dict(overwrite=True, instance_id_var=id_var)
vars = {'foo': {'id': 'fooval'}}
data = {
'_meta': {'hostvars': {host_name: vars if has_var else {'unrelated': 'value'}}},
"ungrouped": {"hosts": [host_name]},
}
old_id = None
for i in range(3):
inventory_import.Command().perform_update(options.copy(), data.copy(), inv_src.create_unified_job())
assert inventory.hosts.count() == inv_src.hosts.count() == 1
host = inventory.hosts.first()
assert host.name == host_name
assert host.instance_id in ('fooval', '')
if has_var:
assert yaml.safe_load(host.variables) == vars
else:
assert yaml.safe_load(host.variables) == {'unrelated': 'value'}
if old_id is not None:
assert host.id == old_id
old_id = host.id
@pytest.mark.parametrize('id_var_seq', [('', 'foo.id,other'), ('foo.id,other', '')], ids=['gained', 'lost']) # second is problem case
@pytest.mark.parametrize('host_name', ('host-1', 'fooval'), ids=['arbitrary', 'id'])
def test_host_gains_or_loses_instance_id(self, inventory, id_var_seq, host_name):
inv_src = InventorySource.objects.create(inventory=inventory, source='gce')
options = dict(overwrite=True)
vars = {'foo': {'id': 'fooval'}}
old_id = None
for id_var in id_var_seq:
options['instance_id_var'] = id_var
data = {
'_meta': {'hostvars': {host_name: vars}},
"ungrouped": {"hosts": [host_name]},
}
inventory_import.Command().perform_update(options.copy(), data.copy(), inv_src.create_unified_job())
assert inventory.hosts.count() == inv_src.hosts.count() == 1
host = inventory.hosts.first()
assert host.name == host_name
assert host.instance_id == ('fooval' if id_var else '')
assert yaml.safe_load(host.variables) == vars
if old_id is not None:
assert host.id == old_id
old_id = host.id
@pytest.mark.parametrize('second_list', [('host-1', 'fooval'), ('host-1',), ('fooval',)])
def test_name_and_id_confusion(self, inventory, second_list):
inv_src = InventorySource.objects.create(inventory=inventory, source='gce')
CASES = [('', ['host-1', 'fooval']), ('foo.id', second_list)]
options = dict(overwrite=True)
vars = {'foo': {'id': 'fooval'}}
data = {
'_meta': {'hostvars': {}},
"ungrouped": {"hosts": []},
}
id_set = None
for id_var, hosts in CASES:
options['instance_id_var'] = id_var
data['_meta']['hostvars'] = {}
for host_name in hosts:
data['_meta']['hostvars'][host_name] = vars if id_var else {}
data['ungrouped']['hosts'] = hosts
inventory_import.Command().perform_update(options.copy(), data.copy(), inv_src.create_unified_job())
new_ids = set(inventory.hosts.values_list('id', flat=True))
if id_set is not None:
assert not (new_ids - id_set)
id_set = new_ids
assert inventory.hosts.count() == len(hosts), [(host.name, host.instance_id) for host in inventory.hosts.all()]
assert inv_src.hosts.count() == len(hosts), [(host.name, host.instance_id) for host in inventory.hosts.all()]
for host_name in hosts:
host = inventory.hosts.get(name=host_name)
assert host.instance_id == ('fooval' if id_var else '')
@pytest.mark.django_db
@pytest.mark.inventory_import
@mock.patch.object(inventory_import.Command, 'check_license', mock.MagicMock())
@ -89,7 +194,7 @@ class TestINIImports:
def test_inventory_single_ini_import(self, inventory, capsys):
inventory_import.AnsibleInventoryLoader._data = TEST_INVENTORY_CONTENT
cmd = inventory_import.Command()
r = cmd.handle(inventory_id=inventory.pk, source=__file__, method='backport')
r = cmd.handle(inventory_id=inventory.pk, source=__file__)
out, err = capsys.readouterr()
assert r is None
assert out == ''