Bulk save facts, and move to before status change (#12998)

* Facts scaling fixes for large inventory, timing issue

Move save of Ansible facts to before the job status changes
  this is considered an acceptable delay with the other
  performance fixes here

Remove completely unrelated unused facts method

Scale related changes to facts saving:
  Use .iterator() on queryset when looping
  Change save to bulk_update
  Apply bulk_update in batches of 100, to reduce memory
  Only save a single file modtime, avoiding large dict

Use decorator for long func time logging
  update decorator to fill in format statement
This commit is contained in:
Alan Rominger
2022-11-15 15:18:06 -05:00
committed by GitHub
parent 0933a96d60
commit 2fdce43f9e
5 changed files with 89 additions and 65 deletions

View File

@@ -567,17 +567,6 @@ class Host(CommonModelNameNotUnique, RelatedJobsMixin):
# Use .job_host_summaries.all() to get jobs affecting this host. # Use .job_host_summaries.all() to get jobs affecting this host.
# Use .job_events.all() to get events affecting this host. # Use .job_events.all() to get events affecting this host.
'''
We don't use timestamp, but we may in the future.
'''
def update_ansible_facts(self, module, facts, timestamp=None):
if module == "ansible":
self.ansible_facts.update(facts)
else:
self.ansible_facts[module] = facts
self.save()
def get_effective_host_name(self): def get_effective_host_name(self):
""" """
Return the name of the host that will be used in actual ansible Return the name of the host that will be used in actual ansible

View File

@@ -44,7 +44,7 @@ from awx.main.models.notifications import (
NotificationTemplate, NotificationTemplate,
JobNotificationMixin, JobNotificationMixin,
) )
from awx.main.utils import parse_yaml_or_json, getattr_dne, NullablePromptPseudoField, polymorphic from awx.main.utils import parse_yaml_or_json, getattr_dne, NullablePromptPseudoField, polymorphic, log_excess_runtime
from awx.main.fields import ImplicitRoleField, AskForField, JSONBlob, OrderedManyToManyField from awx.main.fields import ImplicitRoleField, AskForField, JSONBlob, OrderedManyToManyField
from awx.main.models.mixins import ( from awx.main.models.mixins import (
ResourceMixin, ResourceMixin,
@@ -857,8 +857,11 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
return host_queryset.iterator() return host_queryset.iterator()
return host_queryset return host_queryset
def start_job_fact_cache(self, destination, modification_times, timeout=None): @log_excess_runtime(logger, debug_cutoff=0.01, msg='Job {job_id} host facts prepared for {written_ct} hosts, took {delta:.3f} s', add_log_data=True)
def start_job_fact_cache(self, destination, log_data, timeout=None):
self.log_lifecycle("start_job_fact_cache") self.log_lifecycle("start_job_fact_cache")
log_data['job_id'] = self.id
log_data['written_ct'] = 0
os.makedirs(destination, mode=0o700) os.makedirs(destination, mode=0o700)
if timeout is None: if timeout is None:
@@ -869,6 +872,8 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
hosts = self._get_inventory_hosts(ansible_facts_modified__gte=timeout) hosts = self._get_inventory_hosts(ansible_facts_modified__gte=timeout)
else: else:
hosts = self._get_inventory_hosts() hosts = self._get_inventory_hosts()
last_filepath_written = None
for host in hosts: for host in hosts:
filepath = os.sep.join(map(str, [destination, host.name])) filepath = os.sep.join(map(str, [destination, host.name]))
if not os.path.realpath(filepath).startswith(destination): if not os.path.realpath(filepath).startswith(destination):
@@ -878,23 +883,38 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
with codecs.open(filepath, 'w', encoding='utf-8') as f: with codecs.open(filepath, 'w', encoding='utf-8') as f:
os.chmod(f.name, 0o600) os.chmod(f.name, 0o600)
json.dump(host.ansible_facts, f) json.dump(host.ansible_facts, f)
log_data['written_ct'] += 1
last_filepath_written = filepath
except IOError: except IOError:
system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name))) system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name)))
continue continue
# make note of the time we wrote the file so we can check if it changed later # make note of the time we wrote the last file so we can check if any file changed later
modification_times[filepath] = os.path.getmtime(filepath) if last_filepath_written:
return os.path.getmtime(last_filepath_written)
return None
def finish_job_fact_cache(self, destination, modification_times): @log_excess_runtime(
logger,
debug_cutoff=0.01,
msg='Job {job_id} host facts: updated {updated_ct}, cleared {cleared_ct}, unchanged {unmodified_ct}, took {delta:.3f} s',
add_log_data=True,
)
def finish_job_fact_cache(self, destination, facts_write_time, log_data):
self.log_lifecycle("finish_job_fact_cache") self.log_lifecycle("finish_job_fact_cache")
log_data['job_id'] = self.id
log_data['updated_ct'] = 0
log_data['unmodified_ct'] = 0
log_data['cleared_ct'] = 0
hosts_to_update = []
for host in self._get_inventory_hosts(): for host in self._get_inventory_hosts():
filepath = os.sep.join(map(str, [destination, host.name])) filepath = os.sep.join(map(str, [destination, host.name]))
if not os.path.realpath(filepath).startswith(destination): if not os.path.realpath(filepath).startswith(destination):
system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name))) system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name)))
continue continue
if os.path.exists(filepath): if os.path.exists(filepath):
# If the file changed since we wrote it pre-playbook run... # If the file changed since we wrote the last facts file, pre-playbook run...
modified = os.path.getmtime(filepath) modified = os.path.getmtime(filepath)
if modified > modification_times.get(filepath, 0): if (not facts_write_time) or modified > facts_write_time:
with codecs.open(filepath, 'r', encoding='utf-8') as f: with codecs.open(filepath, 'r', encoding='utf-8') as f:
try: try:
ansible_facts = json.load(f) ansible_facts = json.load(f)
@@ -902,7 +922,7 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
continue continue
host.ansible_facts = ansible_facts host.ansible_facts = ansible_facts
host.ansible_facts_modified = now() host.ansible_facts_modified = now()
host.save(update_fields=['ansible_facts', 'ansible_facts_modified']) hosts_to_update.append(host)
system_tracking_logger.info( system_tracking_logger.info(
'New fact for inventory {} host {}'.format(smart_str(host.inventory.name), smart_str(host.name)), 'New fact for inventory {} host {}'.format(smart_str(host.inventory.name), smart_str(host.name)),
extra=dict( extra=dict(
@@ -913,12 +933,21 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
job_id=self.id, job_id=self.id,
), ),
) )
log_data['updated_ct'] += 1
else:
log_data['unmodified_ct'] += 1
else: else:
# if the file goes missing, ansible removed it (likely via clear_facts) # if the file goes missing, ansible removed it (likely via clear_facts)
host.ansible_facts = {} host.ansible_facts = {}
host.ansible_facts_modified = now() host.ansible_facts_modified = now()
hosts_to_update.append(host)
system_tracking_logger.info('Facts cleared for inventory {} host {}'.format(smart_str(host.inventory.name), smart_str(host.name))) system_tracking_logger.info('Facts cleared for inventory {} host {}'.format(smart_str(host.inventory.name), smart_str(host.name)))
host.save() log_data['cleared_ct'] += 1
if len(hosts_to_update) > 100:
self.inventory.hosts.bulk_update(hosts_to_update, ['ansible_facts', 'ansible_facts_modified'])
hosts_to_update = []
if hosts_to_update:
self.inventory.hosts.bulk_update(hosts_to_update, ['ansible_facts', 'ansible_facts_modified'])
class LaunchTimeConfigBase(BaseModel): class LaunchTimeConfigBase(BaseModel):

View File

@@ -426,7 +426,7 @@ class BaseTask(object):
""" """
instance.log_lifecycle("post_run") instance.log_lifecycle("post_run")
def final_run_hook(self, instance, status, private_data_dir, fact_modification_times): def final_run_hook(self, instance, status, private_data_dir):
""" """
Hook for any steps to run after job/task is marked as complete. Hook for any steps to run after job/task is marked as complete.
""" """
@@ -469,7 +469,6 @@ class BaseTask(object):
self.instance = self.update_model(pk, status='running', start_args='') # blank field to remove encrypted passwords self.instance = self.update_model(pk, status='running', start_args='') # blank field to remove encrypted passwords
self.instance.websocket_emit_status("running") self.instance.websocket_emit_status("running")
status, rc = 'error', None status, rc = 'error', None
fact_modification_times = {}
self.runner_callback.event_ct = 0 self.runner_callback.event_ct = 0
''' '''
@@ -498,14 +497,6 @@ class BaseTask(object):
if not os.path.exists(settings.AWX_ISOLATION_BASE_PATH): if not os.path.exists(settings.AWX_ISOLATION_BASE_PATH):
raise RuntimeError('AWX_ISOLATION_BASE_PATH=%s does not exist' % settings.AWX_ISOLATION_BASE_PATH) raise RuntimeError('AWX_ISOLATION_BASE_PATH=%s does not exist' % settings.AWX_ISOLATION_BASE_PATH)
# Fetch "cached" fact data from prior runs and put on the disk
# where ansible expects to find it
if getattr(self.instance, 'use_fact_cache', False):
self.instance.start_job_fact_cache(
os.path.join(private_data_dir, 'artifacts', str(self.instance.id), 'fact_cache'),
fact_modification_times,
)
# May have to serialize the value # May have to serialize the value
private_data_files, ssh_key_data = self.build_private_data_files(self.instance, private_data_dir) private_data_files, ssh_key_data = self.build_private_data_files(self.instance, private_data_dir)
passwords = self.build_passwords(self.instance, kwargs) passwords = self.build_passwords(self.instance, kwargs)
@@ -646,7 +637,7 @@ class BaseTask(object):
self.instance.send_notification_templates('succeeded' if status == 'successful' else 'failed') self.instance.send_notification_templates('succeeded' if status == 'successful' else 'failed')
try: try:
self.final_run_hook(self.instance, status, private_data_dir, fact_modification_times) self.final_run_hook(self.instance, status, private_data_dir)
except Exception: except Exception:
logger.exception('{} Final run hook errored.'.format(self.instance.log_format)) logger.exception('{} Final run hook errored.'.format(self.instance.log_format))
@@ -1066,12 +1057,19 @@ class RunJob(SourceControlMixin, BaseTask):
# ran inside of the event saving code # ran inside of the event saving code
update_smart_memberships_for_inventory(job.inventory) update_smart_memberships_for_inventory(job.inventory)
# Fetch "cached" fact data from prior runs and put on the disk
# where ansible expects to find it
if job.use_fact_cache:
self.facts_write_time = self.instance.start_job_fact_cache(os.path.join(private_data_dir, 'artifacts', str(job.id), 'fact_cache'))
def build_project_dir(self, job, private_data_dir): def build_project_dir(self, job, private_data_dir):
self.sync_and_copy(job.project, private_data_dir, scm_branch=job.scm_branch) self.sync_and_copy(job.project, private_data_dir, scm_branch=job.scm_branch)
def final_run_hook(self, job, status, private_data_dir, fact_modification_times): def post_run_hook(self, job, status):
super(RunJob, self).final_run_hook(job, status, private_data_dir, fact_modification_times) super(RunJob, self).post_run_hook(job, status)
if not private_data_dir: job.refresh_from_db(fields=['job_env'])
private_data_dir = job.job_env.get('AWX_PRIVATE_DATA_DIR')
if (not private_data_dir) or (not hasattr(self, 'facts_write_time')):
# If there's no private data dir, that means we didn't get into the # If there's no private data dir, that means we didn't get into the
# actual `run()` call; this _usually_ means something failed in # actual `run()` call; this _usually_ means something failed in
# the pre_run_hook method # the pre_run_hook method
@@ -1079,9 +1077,11 @@ class RunJob(SourceControlMixin, BaseTask):
if job.use_fact_cache: if job.use_fact_cache:
job.finish_job_fact_cache( job.finish_job_fact_cache(
os.path.join(private_data_dir, 'artifacts', str(job.id), 'fact_cache'), os.path.join(private_data_dir, 'artifacts', str(job.id), 'fact_cache'),
fact_modification_times, self.facts_write_time,
) )
def final_run_hook(self, job, status, private_data_dir):
super(RunJob, self).final_run_hook(job, status, private_data_dir)
try: try:
inventory = job.inventory inventory = job.inventory
except Inventory.DoesNotExist: except Inventory.DoesNotExist:

View File

@@ -36,15 +36,14 @@ def job(mocker, hosts, inventory):
def test_start_job_fact_cache(hosts, job, inventory, tmpdir): def test_start_job_fact_cache(hosts, job, inventory, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts') fact_cache = os.path.join(tmpdir, 'facts')
modified_times = {} last_modified = job.start_job_fact_cache(fact_cache, timeout=0)
job.start_job_fact_cache(fact_cache, modified_times, 0)
for host in hosts: for host in hosts:
filepath = os.path.join(fact_cache, host.name) filepath = os.path.join(fact_cache, host.name)
assert os.path.exists(filepath) assert os.path.exists(filepath)
with open(filepath, 'r') as f: with open(filepath, 'r') as f:
assert f.read() == json.dumps(host.ansible_facts) assert f.read() == json.dumps(host.ansible_facts)
assert filepath in modified_times assert os.path.getmtime(filepath) <= last_modified
def test_fact_cache_with_invalid_path_traversal(job, inventory, tmpdir, mocker): def test_fact_cache_with_invalid_path_traversal(job, inventory, tmpdir, mocker):
@@ -58,18 +57,16 @@ def test_fact_cache_with_invalid_path_traversal(job, inventory, tmpdir, mocker):
) )
fact_cache = os.path.join(tmpdir, 'facts') fact_cache = os.path.join(tmpdir, 'facts')
job.start_job_fact_cache(fact_cache, {}, 0) job.start_job_fact_cache(fact_cache, timeout=0)
# a file called "foo" should _not_ be written outside the facts dir # a file called "foo" should _not_ be written outside the facts dir
assert os.listdir(os.path.join(fact_cache, '..')) == ['facts'] assert os.listdir(os.path.join(fact_cache, '..')) == ['facts']
def test_finish_job_fact_cache_with_existing_data(job, hosts, inventory, mocker, tmpdir): def test_finish_job_fact_cache_with_existing_data(job, hosts, inventory, mocker, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts') fact_cache = os.path.join(tmpdir, 'facts')
modified_times = {} last_modified = job.start_job_fact_cache(fact_cache, timeout=0)
job.start_job_fact_cache(fact_cache, modified_times, 0)
for h in hosts: bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update')
h.save = mocker.Mock()
ansible_facts_new = {"foo": "bar"} ansible_facts_new = {"foo": "bar"}
filepath = os.path.join(fact_cache, hosts[1].name) filepath = os.path.join(fact_cache, hosts[1].name)
@@ -83,23 +80,20 @@ def test_finish_job_fact_cache_with_existing_data(job, hosts, inventory, mocker,
new_modification_time = time.time() + 3600 new_modification_time = time.time() + 3600
os.utime(filepath, (new_modification_time, new_modification_time)) os.utime(filepath, (new_modification_time, new_modification_time))
job.finish_job_fact_cache(fact_cache, modified_times) job.finish_job_fact_cache(fact_cache, last_modified)
for host in (hosts[0], hosts[2], hosts[3]): for host in (hosts[0], hosts[2], hosts[3]):
host.save.assert_not_called()
assert host.ansible_facts == {"a": 1, "b": 2} assert host.ansible_facts == {"a": 1, "b": 2}
assert host.ansible_facts_modified is None assert host.ansible_facts_modified is None
assert hosts[1].ansible_facts == ansible_facts_new assert hosts[1].ansible_facts == ansible_facts_new
hosts[1].save.assert_called_once_with(update_fields=['ansible_facts', 'ansible_facts_modified']) bulk_update.assert_called_once_with([hosts[1]], ['ansible_facts', 'ansible_facts_modified'])
def test_finish_job_fact_cache_with_bad_data(job, hosts, inventory, mocker, tmpdir): def test_finish_job_fact_cache_with_bad_data(job, hosts, inventory, mocker, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts') fact_cache = os.path.join(tmpdir, 'facts')
modified_times = {} last_modified = job.start_job_fact_cache(fact_cache, timeout=0)
job.start_job_fact_cache(fact_cache, modified_times, 0)
for h in hosts: bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update')
h.save = mocker.Mock()
for h in hosts: for h in hosts:
filepath = os.path.join(fact_cache, h.name) filepath = os.path.join(fact_cache, h.name)
@@ -109,26 +103,22 @@ def test_finish_job_fact_cache_with_bad_data(job, hosts, inventory, mocker, tmpd
new_modification_time = time.time() + 3600 new_modification_time = time.time() + 3600
os.utime(filepath, (new_modification_time, new_modification_time)) os.utime(filepath, (new_modification_time, new_modification_time))
job.finish_job_fact_cache(fact_cache, modified_times) job.finish_job_fact_cache(fact_cache, last_modified)
for h in hosts: bulk_update.assert_not_called()
h.save.assert_not_called()
def test_finish_job_fact_cache_clear(job, hosts, inventory, mocker, tmpdir): def test_finish_job_fact_cache_clear(job, hosts, inventory, mocker, tmpdir):
fact_cache = os.path.join(tmpdir, 'facts') fact_cache = os.path.join(tmpdir, 'facts')
modified_times = {} last_modified = job.start_job_fact_cache(fact_cache, timeout=0)
job.start_job_fact_cache(fact_cache, modified_times, 0)
for h in hosts: bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update')
h.save = mocker.Mock()
os.remove(os.path.join(fact_cache, hosts[1].name)) os.remove(os.path.join(fact_cache, hosts[1].name))
job.finish_job_fact_cache(fact_cache, modified_times) job.finish_job_fact_cache(fact_cache, last_modified)
for host in (hosts[0], hosts[2], hosts[3]): for host in (hosts[0], hosts[2], hosts[3]):
host.save.assert_not_called()
assert host.ansible_facts == {"a": 1, "b": 2} assert host.ansible_facts == {"a": 1, "b": 2}
assert host.ansible_facts_modified is None assert host.ansible_facts_modified is None
assert hosts[1].ansible_facts == {} assert hosts[1].ansible_facts == {}
hosts[1].save.assert_called_once_with() bulk_update.assert_called_once_with([hosts[1]], ['ansible_facts', 'ansible_facts_modified'])

View File

@@ -90,6 +90,7 @@ __all__ = [
'deepmerge', 'deepmerge',
'get_event_partition_epoch', 'get_event_partition_epoch',
'cleanup_new_process', 'cleanup_new_process',
'log_excess_runtime',
] ]
@@ -1215,15 +1216,30 @@ def cleanup_new_process(func):
return wrapper_cleanup_new_process return wrapper_cleanup_new_process
def log_excess_runtime(func_logger, cutoff=5.0): def log_excess_runtime(func_logger, cutoff=5.0, debug_cutoff=5.0, msg=None, add_log_data=False):
def log_excess_runtime_decorator(func): def log_excess_runtime_decorator(func):
@functools.wraps(func) @functools.wraps(func)
def _new_func(*args, **kwargs): def _new_func(*args, **kwargs):
start_time = time.time() start_time = time.time()
return_value = func(*args, **kwargs) log_data = {'name': repr(func.__name__)}
delta = time.time() - start_time
if delta > cutoff: if add_log_data:
logger.info(f'Running {func.__name__!r} took {delta:.2f}s') return_value = func(*args, log_data=log_data, **kwargs)
else:
return_value = func(*args, **kwargs)
log_data['delta'] = time.time() - start_time
if isinstance(return_value, dict):
log_data.update(return_value)
if msg is None:
record_msg = 'Running {name} took {delta:.2f}s'
else:
record_msg = msg
if log_data['delta'] > cutoff:
func_logger.info(record_msg.format(**log_data))
elif log_data['delta'] > debug_cutoff:
func_logger.debug(record_msg.format(**log_data))
return return_value return return_value
return _new_func return _new_func