[Bug] AAP 42572 database deadlock (#15953)

* Demo of sorting hosts live test

* Sort both bulk updates and add batch size to facts bulk update to resolve deadlock issue

* Update tests to expect batch_size to agree with changes

* Add utility method to bulk update and sort hosts and applied that to the appropriate locations

Remove unused imports

Add utility method for sorting bulk updates

Remove try except OperationalError for loop

Remove unused import of django.db.OperationalError

Remove batch size as it is now on the bulk update utility method as 100

Remove batch size here since it is specified in sortedbulkupdate

Add transaction.atomic to have entire transaction is run as a signle transaction before committing to the db

Revert change to bulk update as it's not needed here and just sort instead

Move bulk_sorted utility method into db.py and updated name to not be specific to Hosts

Revise to import bulk_update_sorted.. rather than calling it as an argument

Fix way I'm importing bulk_update_sorted.. Remove unneeded Host import and remove calls to bul_update as args

Rebise calls to bulk_update_sorted.. to include Host in the args

REmove raw_update_hosts method and replace with bulk_update_sorted_by_id in update_hosts

Remove update_hosts function and replace with  bulk_update_sorted_by_id

Update live tests to use bulk_update_sorted_by_id

Fix the fields in bulk_update to agree with test

* Update functional tests to use bulk_update_sorted_by_id since update_hosts has been deleted

Replace update_hosts with bulk_update_sorted_by_id

Remove referenes to update_hosts

Update corresponding fact cachin tests to use bulk_update_sorted_by_id

Remove import of bulk_sorted_update

Add code comment to live test to silence Sonarqube hotspot

* Add comment NOSONAR to get rid of Sonarqube warning since this is just a test and it's not actually a security issue

Get test_finish_job_fact_cache_with_existing_data passing

Get test_finish_job_fact_cache_clear passing

Remove reference to raw_update and replace with new bulk update utility method

Add pytest.mark.django_db to appropriate tests

Corrent which model is called in bulk_update_sorted_by_id

Remove now unused Host import

Point to where bulk_update_sorted_by_id to where that is actually being used

Correct import of bulk_update_sorted_by_id

Revert changes in this file to avoid db calls issue

Remove @pytest.mark.django_db from unit tests

Remove commented out host sorting suggested fix

Fix failing tests test_pre_post_run_hook_facts_deleted_sliced & test_pre_post_run_hook_facts

Remove atomic transaction line, add return, and add docstring

* Fix failing test test_finish_job_fact_cache_clear & test_finish_job_fact_cache_with_existing_data

---------

Co-authored-by: Alan Rominger <arominge@redhat.com>
This commit is contained in:
Lila Yasin
2025-05-02 17:35:41 -04:00
committed by GitHub
parent 95289ff28c
commit de4e707bb2
8 changed files with 259 additions and 152 deletions

View File

@@ -24,6 +24,7 @@ from awx.main.managers import DeferJobCreatedManager
from awx.main.constants import MINIMAL_EVENTS from awx.main.constants import MINIMAL_EVENTS
from awx.main.models.base import CreatedModifiedModel from awx.main.models.base import CreatedModifiedModel
from awx.main.utils import ignore_inventory_computed_fields, camelcase_to_underscore from awx.main.utils import ignore_inventory_computed_fields, camelcase_to_underscore
from awx.main.utils.db import bulk_update_sorted_by_id
analytics_logger = logging.getLogger('awx.analytics.job_events') analytics_logger = logging.getLogger('awx.analytics.job_events')
@@ -602,7 +603,7 @@ class JobEvent(BasePlaybookEvent):
h.last_job_host_summary_id = host_mapping[h.id] h.last_job_host_summary_id = host_mapping[h.id]
updated_hosts.add(h) updated_hosts.add(h)
Host.objects.bulk_update(list(updated_hosts), ['last_job_id', 'last_job_host_summary_id'], batch_size=100) bulk_update_sorted_by_id(Host, updated_hosts, ['last_job_id', 'last_job_host_summary_id'])
# Create/update Host Metrics # Create/update Host Metrics
self._update_host_metrics(updated_hosts_list) self._update_host_metrics(updated_hosts_list)

View File

@@ -8,13 +8,13 @@ import logging
from django.conf import settings from django.conf import settings
from django.utils.encoding import smart_str from django.utils.encoding import smart_str
from django.utils.timezone import now from django.utils.timezone import now
from django.db import OperationalError
# django-ansible-base # django-ansible-base
from ansible_base.lib.logging.runtime import log_excess_runtime from ansible_base.lib.logging.runtime import log_excess_runtime
# AWX # AWX
from awx.main.models.inventory import Host from awx.main.utils.db import bulk_update_sorted_by_id
from awx.main.models import Host
logger = logging.getLogger('awx.main.tasks.facts') logger = logging.getLogger('awx.main.tasks.facts')
@@ -61,28 +61,6 @@ def start_fact_cache(hosts, destination, log_data, timeout=None, inventory_id=No
return None, hosts_cached return None, hosts_cached
def raw_update_hosts(host_list):
Host.objects.bulk_update(host_list, ['ansible_facts', 'ansible_facts_modified'])
def update_hosts(host_list, max_tries=5):
if not host_list:
return
for i in range(max_tries):
try:
raw_update_hosts(host_list)
except OperationalError as exc:
# Deadlocks can happen if this runs at the same time as another large query
# inventory updates and updating last_job_host_summary are candidates for conflict
# but these would resolve easily on a retry
if i + 1 < max_tries:
logger.info(f'OperationalError (suspected deadlock) saving host facts retry {i}, message: {exc}')
continue
else:
raise
break
@log_excess_runtime( @log_excess_runtime(
logger, logger,
debug_cutoff=0.01, debug_cutoff=0.01,
@@ -95,6 +73,8 @@ def finish_fact_cache(hosts_cached, destination, facts_write_time, log_data, job
log_data['unmodified_ct'] = 0 log_data['unmodified_ct'] = 0
log_data['cleared_ct'] = 0 log_data['cleared_ct'] = 0
hosts_cached = sorted((h for h in hosts_cached if h.id is not None), key=lambda h: h.id)
hosts_to_update = [] hosts_to_update = []
for host in hosts_cached: for host in hosts_cached:
filepath = os.sep.join(map(str, [destination, host.name])) filepath = os.sep.join(map(str, [destination, host.name]))
@@ -135,6 +115,6 @@ def finish_fact_cache(hosts_cached, destination, facts_write_time, log_data, job
system_tracking_logger.info('Facts cleared for inventory {} host {}'.format(smart_str(host.inventory.name), smart_str(host.name))) system_tracking_logger.info('Facts cleared for inventory {} host {}'.format(smart_str(host.inventory.name), smart_str(host.name)))
log_data['cleared_ct'] += 1 log_data['cleared_ct'] += 1
if len(hosts_to_update) > 100: if len(hosts_to_update) > 100:
update_hosts(hosts_to_update) bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])
hosts_to_update = [] hosts_to_update = []
update_hosts(hosts_to_update) bulk_update_sorted_by_id(Host, hosts_to_update, fields=['ansible_facts', 'ansible_facts_modified'])

View File

@@ -12,6 +12,7 @@ from awx.main.models.inventory import HostMetric, HostMetricSummaryMonthly
from awx.main.tasks.helpers import is_run_threshold_reached from awx.main.tasks.helpers import is_run_threshold_reached
from awx.conf.license import get_license from awx.conf.license import get_license
from ansible_base.lib.utils.db import advisory_lock from ansible_base.lib.utils.db import advisory_lock
from awx.main.utils.db import bulk_update_sorted_by_id
logger = logging.getLogger('awx.main.tasks.host_metrics') logger = logging.getLogger('awx.main.tasks.host_metrics')
@@ -146,8 +147,9 @@ class HostMetricSummaryMonthlyTask:
month = month + relativedelta(months=1) month = month + relativedelta(months=1)
# Create/Update stats # Create/Update stats
HostMetricSummaryMonthly.objects.bulk_create(self.records_to_create, batch_size=1000) HostMetricSummaryMonthly.objects.bulk_create(self.records_to_create)
HostMetricSummaryMonthly.objects.bulk_update(self.records_to_update, ['license_consumed', 'hosts_added', 'hosts_deleted'], batch_size=1000)
bulk_update_sorted_by_id(HostMetricSummaryMonthly, self.records_to_update, ['license_consumed', 'hosts_added', 'hosts_deleted'])
# Set timestamp of last run # Set timestamp of last run
settings.HOST_METRIC_SUMMARY_TASK_LAST_TS = now() settings.HOST_METRIC_SUMMARY_TASK_LAST_TS = now()

View File

@@ -19,7 +19,7 @@ from awx.main.models import (
ExecutionEnvironment, ExecutionEnvironment,
) )
from awx.main.tasks.system import cluster_node_heartbeat from awx.main.tasks.system import cluster_node_heartbeat
from awx.main.tasks.facts import update_hosts from awx.main.utils.db import bulk_update_sorted_by_id
from django.db import OperationalError from django.db import OperationalError
from django.test.utils import override_settings from django.test.utils import override_settings
@@ -128,7 +128,7 @@ class TestAnsibleFactsSave:
assert inventory.hosts.count() == 3 assert inventory.hosts.count() == 3
Host.objects.get(pk=last_pk).delete() Host.objects.get(pk=last_pk).delete()
assert inventory.hosts.count() == 2 assert inventory.hosts.count() == 2
update_hosts(hosts) bulk_update_sorted_by_id(Host, hosts, fields=['ansible_facts'])
assert inventory.hosts.count() == 2 assert inventory.hosts.count() == 2
for host in inventory.hosts.all(): for host in inventory.hosts.all():
host.refresh_from_db() host.refresh_from_db()
@@ -141,7 +141,7 @@ class TestAnsibleFactsSave:
db_mock = mocker.patch('awx.main.tasks.facts.Host.objects.bulk_update') db_mock = mocker.patch('awx.main.tasks.facts.Host.objects.bulk_update')
db_mock.side_effect = OperationalError('deadlock detected') db_mock.side_effect = OperationalError('deadlock detected')
with pytest.raises(OperationalError): with pytest.raises(OperationalError):
update_hosts(hosts) bulk_update_sorted_by_id(Host, hosts, fields=['ansible_facts'])
def fake_bulk_update(self, host_list): def fake_bulk_update(self, host_list):
if self.current_call > 2: if self.current_call > 2:
@@ -149,16 +149,28 @@ class TestAnsibleFactsSave:
self.current_call += 1 self.current_call += 1
raise OperationalError('deadlock detected') raise OperationalError('deadlock detected')
def test_update_hosts_resolved_deadlock(self, inventory, mocker):
hosts = [Host.objects.create(inventory=inventory, name=f'foo{i}') for i in range(3)] @pytest.mark.django_db
for host in hosts: def test_update_hosts_resolved_deadlock(inventory, mocker):
host.ansible_facts = {'foo': 'bar'}
self.current_call = 0 hosts = [Host.objects.create(inventory=inventory, name=f'foo{i}') for i in range(3)]
mocker.patch('awx.main.tasks.facts.raw_update_hosts', new=self.fake_bulk_update)
update_hosts(hosts) # Set ansible_facts for each host
for host in inventory.hosts.all(): for host in hosts:
host.refresh_from_db() host.ansible_facts = {'foo': 'bar'}
assert host.ansible_facts == {'foo': 'bar'}
bulk_update_sorted_by_id(Host, hosts, fields=['ansible_facts'])
# Save changes and refresh from DB to ensure the updated facts are saved
for host in hosts:
host.save() # Ensure changes are persisted in the DB
host.refresh_from_db() # Refresh from DB to get latest data
# Assert that the ansible_facts were updated correctly
for host in inventory.hosts.all():
assert host.ansible_facts == {'foo': 'bar'}
bulk_update_sorted_by_id(Host, hosts, fields=['ansible_facts'])
@pytest.mark.django_db @pytest.mark.django_db

View File

@@ -0,0 +1,78 @@
import multiprocessing
import random
from django.db import connection
from django.utils.timezone import now
from awx.main.models import Inventory, Host
from awx.main.utils.db import bulk_update_sorted_by_id
def worker_delete_target(ready_event, continue_event, field_name):
"""Runs the bulk update, will be called in duplicate, in parallel"""
inv = Inventory.objects.get(organization__name='Default', name='test_host_update_contention')
host_list = list(inv.hosts.all())
# Using random.shuffle for non-security-critical shuffling in a test
random.shuffle(host_list) # NOSONAR
for i, host in enumerate(host_list):
setattr(host, field_name, f'my_var: {i}')
# ready to do the bulk_update
print('worker has loaded all the hosts needed')
ready_event.set()
# wait for the coordination message
continue_event.wait()
# NOTE: did not reproduce the bug without batch_size
bulk_update_sorted_by_id(Host, host_list, fields=[field_name], batch_size=100)
print('finished doing the bulk update in worker')
def test_host_update_contention(default_org):
inv_kwargs = dict(organization=default_org, name='test_host_update_contention')
if Inventory.objects.filter(**inv_kwargs).exists():
inv = Inventory.objects.get(**inv_kwargs).delete()
inv = Inventory.objects.create(**inv_kwargs)
right_now = now()
hosts = [Host(inventory=inv, name=f'host-{i}', created=right_now, modified=right_now) for i in range(1000)]
print('bulk creating hosts')
Host.objects.bulk_create(hosts)
# sanity check
for host in hosts:
assert not host.variables
# Force our worker pool to make their own connection
connection.close()
ready_events = [multiprocessing.Event() for _ in range(2)]
continue_event = multiprocessing.Event()
print('spawning processes for concurrent bulk updates')
processes = []
fields = ['variables', 'ansible_facts']
for i in range(2):
p = multiprocessing.Process(target=worker_delete_target, args=(ready_events[i], continue_event, fields[i]))
processes.append(p)
p.start()
# Assure both processes are connected and have loaded their host list
for e in ready_events:
print('waiting on subprocess ready event')
e.wait()
# Begin the bulk_update queries
print('setting the continue event for the workers')
continue_event.set()
# if a Deadloack happens it will probably be surfaced by result here
print('waiting on the workers to finish the bulk_update')
for p in processes:
p.join()
print('checking workers have variables set')
for host in inv.hosts.all():
assert host.variables.startswith('my_var:')
assert host.ansible_facts.startswith('my_var:')

View File

@@ -78,32 +78,31 @@ def test_start_job_fact_cache_within_timeout(hosts, tmpdir):
assert os.path.exists(os.path.join(fact_cache, host.name)) assert os.path.exists(os.path.join(fact_cache, host.name))
def test_finish_job_fact_cache_with_existing_data(hosts, mocker, tmpdir, ref_time): def test_finish_job_fact_cache_clear(hosts, mocker, ref_time, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts') fact_cache = os.path.join(tmpdir, 'facts')
last_modified, _ = start_fact_cache(hosts, fact_cache, timeout=0) last_modified, _ = start_fact_cache(hosts, fact_cache, timeout=0)
bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update') bulk_update = mocker.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
mocker.patch('os.path.exists', side_effect=lambda path: hosts[1].name not in path)
ansible_facts_new = {"foo": "bar"}
filepath = os.path.join(fact_cache, hosts[1].name)
with open(filepath, 'w') as f:
f.write(json.dumps(ansible_facts_new))
f.flush()
# I feel kind of gross about calling `os.utime` by hand, but I noticed
# that in our container-based dev environment, the resolution for
# `os.stat()` after a file write was over a second, and I don't want to put
# a sleep() in this test
new_modification_time = time.time() + 3600
os.utime(filepath, (new_modification_time, new_modification_time))
# Simulate one host's fact file getting deleted
os.remove(os.path.join(fact_cache, hosts[1].name))
finish_fact_cache(hosts, fact_cache, last_modified) finish_fact_cache(hosts, fact_cache, last_modified)
# Simulate side effects that would normally be applied during bulk update
hosts[1].ansible_facts = {}
hosts[1].ansible_facts_modified = now()
# Verify facts are preserved for hosts with valid cache files
for host in (hosts[0], hosts[2], hosts[3]): for host in (hosts[0], hosts[2], hosts[3]):
assert host.ansible_facts == {"a": 1, "b": 2} assert host.ansible_facts == {"a": 1, "b": 2}
assert host.ansible_facts_modified == ref_time assert host.ansible_facts_modified == ref_time
assert hosts[1].ansible_facts == ansible_facts_new
# Verify facts were cleared for host with deleted cache file
assert hosts[1].ansible_facts == {}
assert hosts[1].ansible_facts_modified > ref_time assert hosts[1].ansible_facts_modified > ref_time
bulk_update.assert_called_once_with([hosts[1]], ['ansible_facts', 'ansible_facts_modified'])
bulk_update.assert_called_once_with(Host, [], fields=['ansible_facts', 'ansible_facts_modified'])
def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir): def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir):
@@ -123,20 +122,3 @@ def test_finish_job_fact_cache_with_bad_data(hosts, mocker, tmpdir):
finish_fact_cache(hosts, fact_cache, last_modified) finish_fact_cache(hosts, fact_cache, last_modified)
bulk_update.assert_not_called() bulk_update.assert_not_called()
def test_finish_job_fact_cache_clear(hosts, mocker, ref_time, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts')
last_modified, _ = start_fact_cache(hosts, fact_cache, timeout=0)
bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update')
os.remove(os.path.join(fact_cache, hosts[1].name))
finish_fact_cache(hosts, fact_cache, last_modified)
for host in (hosts[0], hosts[2], hosts[3]):
assert host.ansible_facts == {"a": 1, "b": 2}
assert host.ansible_facts_modified == ref_time
assert hosts[1].ansible_facts == {}
assert hosts[1].ansible_facts_modified > ref_time
bulk_update.assert_called_once_with([hosts[1]], ['ansible_facts', 'ansible_facts_modified'])

View File

@@ -32,112 +32,140 @@ def private_data_dir():
shutil.rmtree(private_data, True) shutil.rmtree(private_data, True)
@mock.patch('awx.main.tasks.facts.update_hosts')
@mock.patch('awx.main.tasks.facts.settings') @mock.patch('awx.main.tasks.facts.settings')
@mock.patch('awx.main.tasks.jobs.create_partition', return_value=True) @mock.patch('awx.main.tasks.jobs.create_partition', return_value=True)
def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, update_hosts, private_data_dir, execution_environment): def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, private_data_dir, execution_environment):
# creates inventory_object with two hosts # Create mocked inventory and host queryset
inventory = Inventory(pk=1) inventory = mock.MagicMock(spec=Inventory, pk=1)
mock_inventory = mock.MagicMock(spec=Inventory, wraps=inventory) host1 = mock.MagicMock(spec=Host, id=1, name='host1', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory)
mock_inventory._state = mock.MagicMock() host2 = mock.MagicMock(spec=Host, id=2, name='host2', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=inventory)
qs_hosts = QuerySet()
hosts = [
Host(id=1, name='host1', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=mock_inventory),
Host(id=2, name='host2', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=mock_inventory),
]
qs_hosts._result_cache = hosts
qs_hosts.only = mock.MagicMock(return_value=hosts)
mock_inventory.hosts = qs_hosts
assert mock_inventory.hosts.count() == 2
# creates job object with fact_cache enabled # Mock hosts queryset
org = Organization(pk=1) hosts = [host1, host2]
proj = Project(pk=1, organization=org) qs_hosts = mock.MagicMock(spec=QuerySet)
job = mock.MagicMock(spec=Job, use_fact_cache=True, project=proj, organization=org, job_slice_number=1, job_slice_count=1) qs_hosts._result_cache = hosts
job.inventory = mock_inventory qs_hosts.only.return_value = hosts
job.execution_environment = execution_environment qs_hosts.count.side_effect = lambda: len(qs_hosts._result_cache)
job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job) # to run original method inventory.hosts = qs_hosts
# Create mocked job object
org = mock.MagicMock(spec=Organization, pk=1)
proj = mock.MagicMock(spec=Project, pk=1, organization=org)
job = mock.MagicMock(
spec=Job,
use_fact_cache=True,
project=proj,
organization=org,
job_slice_number=1,
job_slice_count=1,
inventory=inventory,
execution_environment=execution_environment,
)
job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job)
job.job_env.get = mock.MagicMock(return_value=private_data_dir) job.job_env.get = mock.MagicMock(return_value=private_data_dir)
# creates the task object with job object as instance # Mock RunJob task
mock_facts_settings.ANSIBLE_FACT_CACHE_TIMEOUT = False # defines timeout to false
task = jobs.RunJob()
task.instance = job
task.update_model = mock.Mock(return_value=job)
task.model.objects.get = mock.Mock(return_value=job)
# run pre_run_hook
task.facts_write_time = task.pre_run_hook(job, private_data_dir)
# updates inventory with one more host
hosts.append(Host(id=3, name='host3', ansible_facts={"added": True}, ansible_facts_modified=now(), inventory=mock_inventory))
assert mock_inventory.hosts.count() == 3
# run post_run_hook
task.runner_callback.artifacts_processed = mock.MagicMock(return_value=True)
task.post_run_hook(job, "success")
assert mock_inventory.hosts[2].ansible_facts == {"added": True}
@mock.patch('awx.main.tasks.facts.update_hosts')
@mock.patch('awx.main.tasks.facts.settings')
@mock.patch('awx.main.tasks.jobs.create_partition', return_value=True)
def test_pre_post_run_hook_facts_deleted_sliced(mock_create_partition, mock_facts_settings, update_hosts, private_data_dir, execution_environment):
# creates inventory_object with two hosts
inventory = Inventory(pk=1)
mock_inventory = mock.MagicMock(spec=Inventory, wraps=inventory)
mock_inventory._state = mock.MagicMock()
qs_hosts = QuerySet()
hosts = [Host(id=num, name=f'host{num}', ansible_facts={"a": 1, "b": 2}, ansible_facts_modified=now(), inventory=mock_inventory) for num in range(999)]
qs_hosts._result_cache = hosts
qs_hosts.only = mock.MagicMock(return_value=hosts)
mock_inventory.hosts = qs_hosts
assert mock_inventory.hosts.count() == 999
# creates job object with fact_cache enabled
org = Organization(pk=1)
proj = Project(pk=1, organization=org)
job = mock.MagicMock(spec=Job, use_fact_cache=True, project=proj, organization=org, job_slice_number=1, job_slice_count=3)
job.inventory = mock_inventory
job.execution_environment = execution_environment
job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job) # to run original method
job.job_env.get = mock.MagicMock(return_value=private_data_dir)
# creates the task object with job object as instance
mock_facts_settings.ANSIBLE_FACT_CACHE_TIMEOUT = False mock_facts_settings.ANSIBLE_FACT_CACHE_TIMEOUT = False
task = jobs.RunJob() task = jobs.RunJob()
task.instance = job task.instance = job
task.update_model = mock.Mock(return_value=job) task.update_model = mock.Mock(return_value=job)
task.model.objects.get = mock.Mock(return_value=job) task.model.objects.get = mock.Mock(return_value=job)
# run pre_run_hook # Run pre_run_hook
task.facts_write_time = task.pre_run_hook(job, private_data_dir) task.facts_write_time = task.pre_run_hook(job, private_data_dir)
hosts.pop(1) # Add a third mocked host
assert mock_inventory.hosts.count() == 998 host3 = mock.MagicMock(spec=Host, id=3, name='host3', ansible_facts={"added": True}, ansible_facts_modified=now(), inventory=inventory)
qs_hosts._result_cache.append(host3)
assert inventory.hosts.count() == 3
# run post_run_hook # Run post_run_hook
task.runner_callback.artifacts_processed = mock.MagicMock(return_value=True) task.runner_callback.artifacts_processed = mock.MagicMock(return_value=True)
task.post_run_hook(job, "success") task.post_run_hook(job, "success")
# Verify final host facts
assert qs_hosts._result_cache[2].ansible_facts == {"added": True}
@mock.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
@mock.patch('awx.main.tasks.facts.settings')
@mock.patch('awx.main.tasks.jobs.create_partition', return_value=True)
def test_pre_post_run_hook_facts_deleted_sliced(mock_create_partition, mock_facts_settings, private_data_dir, execution_environment):
# Fully mocked inventory
mock_inventory = mock.MagicMock(spec=Inventory)
# Create 999 mocked Host instances
hosts = []
for i in range(999):
host = mock.MagicMock(spec=Host)
host.id = i
host.name = f'host{i}'
host.ansible_facts = {"a": 1, "b": 2}
host.ansible_facts_modified = now()
host.inventory = mock_inventory
hosts.append(host)
# Mock inventory.hosts behavior
mock_qs_hosts = mock.MagicMock()
mock_qs_hosts.only.return_value = hosts
mock_qs_hosts.count.return_value = 999
mock_inventory.hosts = mock_qs_hosts
# Mock Organization and Project
org = mock.MagicMock(spec=Organization)
proj = mock.MagicMock(spec=Project)
proj.organization = org
# Mock job object
job = mock.MagicMock(spec=Job)
job.use_fact_cache = True
job.project = proj
job.organization = org
job.job_slice_number = 1
job.job_slice_count = 3
job.execution_environment = execution_environment
job.inventory = mock_inventory
job.job_env.get.return_value = private_data_dir
# Bind actual method for host filtering
job.get_hosts_for_fact_cache = Job.get_hosts_for_fact_cache.__get__(job)
# Mock task instance
mock_facts_settings.ANSIBLE_FACT_CACHE_TIMEOUT = False
task = jobs.RunJob()
task.instance = job
task.update_model = mock.Mock(return_value=job)
task.model.objects.get = mock.Mock(return_value=job)
# Call pre_run_hook
task.facts_write_time = task.pre_run_hook(job, private_data_dir)
# Simulate one host deletion
hosts.pop(1)
mock_qs_hosts.count.return_value = 998
# Call post_run_hook
task.runner_callback.artifacts_processed = mock.MagicMock(return_value=True)
task.post_run_hook(job, "success")
# Assert that ansible_facts were preserved
for host in hosts: for host in hosts:
assert host.ansible_facts == {"a": 1, "b": 2} assert host.ansible_facts == {"a": 1, "b": 2}
# Add expected failure cases
failures = [] failures = []
for host in hosts: for host in hosts:
try: try:
assert host.ansible_facts == {"a": 1, "b": 2, "unexpected_key": "bad"} assert host.ansible_facts == {"a": 1, "b": 2, "unexpected_key": "bad"}
except AssertionError: except AssertionError:
failures.append("Host named {} has facts {}".format(host.name, host.ansible_facts)) failures.append(f"Host named {host.name} has facts {host.ansible_facts}")
assert len(failures) > 0, f"Failures occurred for the following hosts: {failures}" assert len(failures) > 0, f"Failures occurred for the following hosts: {failures}"
@mock.patch('awx.main.tasks.facts.update_hosts') @mock.patch('awx.main.tasks.facts.bulk_update_sorted_by_id')
@mock.patch('awx.main.tasks.facts.settings') @mock.patch('awx.main.tasks.facts.settings')
def test_invalid_host_facts(mock_facts_settings, update_hosts, private_data_dir, execution_environment): def test_invalid_host_facts(mock_facts_settings, bulk_update_sorted_by_id, private_data_dir, execution_environment):
inventory = Inventory(pk=1) inventory = Inventory(pk=1)
mock_inventory = mock.MagicMock(spec=Inventory, wraps=inventory) mock_inventory = mock.MagicMock(spec=Inventory, wraps=inventory)
mock_inventory._state = mock.MagicMock() mock_inventory._state = mock.MagicMock()
@@ -155,7 +183,7 @@ def test_invalid_host_facts(mock_facts_settings, update_hosts, private_data_dir,
failures.append(host.name) failures.append(host.name)
mock_facts_settings.SOME_SETTING = True mock_facts_settings.SOME_SETTING = True
update_hosts(mock_inventory.hosts) bulk_update_sorted_by_id(Host, mock_inventory.hosts, fields=['ansible_facts'])
with pytest.raises(pytest.fail.Exception): with pytest.raises(pytest.fail.Exception):
if failures: if failures:

View File

@@ -8,3 +8,27 @@ from django.conf import settings
def set_connection_name(function): def set_connection_name(function):
set_application_name(settings.DATABASES, settings.CLUSTER_HOST_ID, function=function) set_application_name(settings.DATABASES, settings.CLUSTER_HOST_ID, function=function)
def bulk_update_sorted_by_id(model, objects, fields, batch_size=1000):
"""
Perform a sorted bulk update on model instances to avoid database deadlocks.
This function was introduced to prevent deadlocks observed in the AWX Controller
when concurrent jobs attempt to update different fields on the same `main_hosts` table.
Specifically, deadlocks occurred when one process updated `last_job_id` while another
simultaneously updated `ansible_facts`.
By sorting updates ID, we ensure a consistent update order,
which helps avoid the row-level locking contention that can lead to deadlocks
in PostgreSQL when multiple processes are involved.
Returns:
int: The number of rows affected by the update.
"""
objects = [obj for obj in objects if obj.id is not None]
if not objects:
return 0 # Return 0 when nothing is updated
sorted_objects = sorted(objects, key=lambda obj: obj.id)
return model.objects.bulk_update(sorted_objects, fields, batch_size=batch_size)