From 2fdce43f9e6c305f71071224326aeb787c70bb88 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Tue, 15 Nov 2022 15:18:06 -0500 Subject: [PATCH] Bulk save facts, and move to before status change (#12998) * Facts scaling fixes for large inventory, timing issue Move save of Ansible facts to before the job status changes this is considered an acceptable delay with the other performance fixes here Remove completely unrelated unused facts method Scale related changes to facts saving: Use .iterator() on queryset when looping Change save to bulk_update Apply bulk_update in batches of 100, to reduce memory Only save a single file modtime, avoiding large dict Use decorator for long func time logging update decorator to fill in format statement --- awx/main/models/inventory.py | 11 ------ awx/main/models/jobs.py | 47 ++++++++++++++++++++----- awx/main/tasks/jobs.py | 30 ++++++++-------- awx/main/tests/unit/models/test_jobs.py | 40 ++++++++------------- awx/main/utils/common.py | 26 +++++++++++--- 5 files changed, 89 insertions(+), 65 deletions(-) diff --git a/awx/main/models/inventory.py b/awx/main/models/inventory.py index d685ddb4e2..81af2379a0 100644 --- a/awx/main/models/inventory.py +++ b/awx/main/models/inventory.py @@ -567,17 +567,6 @@ class Host(CommonModelNameNotUnique, RelatedJobsMixin): # Use .job_host_summaries.all() to get jobs affecting this host. # Use .job_events.all() to get events affecting this host. - ''' - We don't use timestamp, but we may in the future. - ''' - - def update_ansible_facts(self, module, facts, timestamp=None): - if module == "ansible": - self.ansible_facts.update(facts) - else: - self.ansible_facts[module] = facts - self.save() - def get_effective_host_name(self): """ Return the name of the host that will be used in actual ansible diff --git a/awx/main/models/jobs.py b/awx/main/models/jobs.py index cc1a477899..d4e473f7b7 100644 --- a/awx/main/models/jobs.py +++ b/awx/main/models/jobs.py @@ -44,7 +44,7 @@ from awx.main.models.notifications import ( NotificationTemplate, JobNotificationMixin, ) -from awx.main.utils import parse_yaml_or_json, getattr_dne, NullablePromptPseudoField, polymorphic +from awx.main.utils import parse_yaml_or_json, getattr_dne, NullablePromptPseudoField, polymorphic, log_excess_runtime from awx.main.fields import ImplicitRoleField, AskForField, JSONBlob, OrderedManyToManyField from awx.main.models.mixins import ( ResourceMixin, @@ -857,8 +857,11 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana return host_queryset.iterator() return host_queryset - def start_job_fact_cache(self, destination, modification_times, timeout=None): + @log_excess_runtime(logger, debug_cutoff=0.01, msg='Job {job_id} host facts prepared for {written_ct} hosts, took {delta:.3f} s', add_log_data=True) + def start_job_fact_cache(self, destination, log_data, timeout=None): self.log_lifecycle("start_job_fact_cache") + log_data['job_id'] = self.id + log_data['written_ct'] = 0 os.makedirs(destination, mode=0o700) if timeout is None: @@ -869,6 +872,8 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana hosts = self._get_inventory_hosts(ansible_facts_modified__gte=timeout) else: hosts = self._get_inventory_hosts() + + last_filepath_written = None for host in hosts: filepath = os.sep.join(map(str, [destination, host.name])) if not os.path.realpath(filepath).startswith(destination): @@ -878,23 +883,38 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana with codecs.open(filepath, 'w', encoding='utf-8') as f: os.chmod(f.name, 0o600) json.dump(host.ansible_facts, f) + log_data['written_ct'] += 1 + last_filepath_written = filepath except IOError: system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name))) continue - # make note of the time we wrote the file so we can check if it changed later - modification_times[filepath] = os.path.getmtime(filepath) + # make note of the time we wrote the last file so we can check if any file changed later + if last_filepath_written: + return os.path.getmtime(last_filepath_written) + return None - def finish_job_fact_cache(self, destination, modification_times): + @log_excess_runtime( + logger, + debug_cutoff=0.01, + msg='Job {job_id} host facts: updated {updated_ct}, cleared {cleared_ct}, unchanged {unmodified_ct}, took {delta:.3f} s', + add_log_data=True, + ) + def finish_job_fact_cache(self, destination, facts_write_time, log_data): self.log_lifecycle("finish_job_fact_cache") + log_data['job_id'] = self.id + log_data['updated_ct'] = 0 + log_data['unmodified_ct'] = 0 + log_data['cleared_ct'] = 0 + hosts_to_update = [] for host in self._get_inventory_hosts(): filepath = os.sep.join(map(str, [destination, host.name])) if not os.path.realpath(filepath).startswith(destination): system_tracking_logger.error('facts for host {} could not be cached'.format(smart_str(host.name))) continue if os.path.exists(filepath): - # If the file changed since we wrote it pre-playbook run... + # If the file changed since we wrote the last facts file, pre-playbook run... modified = os.path.getmtime(filepath) - if modified > modification_times.get(filepath, 0): + if (not facts_write_time) or modified > facts_write_time: with codecs.open(filepath, 'r', encoding='utf-8') as f: try: ansible_facts = json.load(f) @@ -902,7 +922,7 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana continue host.ansible_facts = ansible_facts host.ansible_facts_modified = now() - host.save(update_fields=['ansible_facts', 'ansible_facts_modified']) + hosts_to_update.append(host) system_tracking_logger.info( 'New fact for inventory {} host {}'.format(smart_str(host.inventory.name), smart_str(host.name)), extra=dict( @@ -913,12 +933,21 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana job_id=self.id, ), ) + log_data['updated_ct'] += 1 + else: + log_data['unmodified_ct'] += 1 else: # if the file goes missing, ansible removed it (likely via clear_facts) host.ansible_facts = {} host.ansible_facts_modified = now() + hosts_to_update.append(host) system_tracking_logger.info('Facts cleared for inventory {} host {}'.format(smart_str(host.inventory.name), smart_str(host.name))) - host.save() + log_data['cleared_ct'] += 1 + if len(hosts_to_update) > 100: + self.inventory.hosts.bulk_update(hosts_to_update, ['ansible_facts', 'ansible_facts_modified']) + hosts_to_update = [] + if hosts_to_update: + self.inventory.hosts.bulk_update(hosts_to_update, ['ansible_facts', 'ansible_facts_modified']) class LaunchTimeConfigBase(BaseModel): diff --git a/awx/main/tasks/jobs.py b/awx/main/tasks/jobs.py index 3557c4110c..c6eaa36fec 100644 --- a/awx/main/tasks/jobs.py +++ b/awx/main/tasks/jobs.py @@ -426,7 +426,7 @@ class BaseTask(object): """ instance.log_lifecycle("post_run") - def final_run_hook(self, instance, status, private_data_dir, fact_modification_times): + def final_run_hook(self, instance, status, private_data_dir): """ Hook for any steps to run after job/task is marked as complete. """ @@ -469,7 +469,6 @@ class BaseTask(object): self.instance = self.update_model(pk, status='running', start_args='') # blank field to remove encrypted passwords self.instance.websocket_emit_status("running") status, rc = 'error', None - fact_modification_times = {} self.runner_callback.event_ct = 0 ''' @@ -498,14 +497,6 @@ class BaseTask(object): if not os.path.exists(settings.AWX_ISOLATION_BASE_PATH): raise RuntimeError('AWX_ISOLATION_BASE_PATH=%s does not exist' % settings.AWX_ISOLATION_BASE_PATH) - # Fetch "cached" fact data from prior runs and put on the disk - # where ansible expects to find it - if getattr(self.instance, 'use_fact_cache', False): - self.instance.start_job_fact_cache( - os.path.join(private_data_dir, 'artifacts', str(self.instance.id), 'fact_cache'), - fact_modification_times, - ) - # May have to serialize the value private_data_files, ssh_key_data = self.build_private_data_files(self.instance, private_data_dir) passwords = self.build_passwords(self.instance, kwargs) @@ -646,7 +637,7 @@ class BaseTask(object): self.instance.send_notification_templates('succeeded' if status == 'successful' else 'failed') try: - self.final_run_hook(self.instance, status, private_data_dir, fact_modification_times) + self.final_run_hook(self.instance, status, private_data_dir) except Exception: logger.exception('{} Final run hook errored.'.format(self.instance.log_format)) @@ -1066,12 +1057,19 @@ class RunJob(SourceControlMixin, BaseTask): # ran inside of the event saving code update_smart_memberships_for_inventory(job.inventory) + # Fetch "cached" fact data from prior runs and put on the disk + # where ansible expects to find it + if job.use_fact_cache: + self.facts_write_time = self.instance.start_job_fact_cache(os.path.join(private_data_dir, 'artifacts', str(job.id), 'fact_cache')) + def build_project_dir(self, job, private_data_dir): self.sync_and_copy(job.project, private_data_dir, scm_branch=job.scm_branch) - def final_run_hook(self, job, status, private_data_dir, fact_modification_times): - super(RunJob, self).final_run_hook(job, status, private_data_dir, fact_modification_times) - if not private_data_dir: + def post_run_hook(self, job, status): + super(RunJob, self).post_run_hook(job, status) + job.refresh_from_db(fields=['job_env']) + private_data_dir = job.job_env.get('AWX_PRIVATE_DATA_DIR') + if (not private_data_dir) or (not hasattr(self, 'facts_write_time')): # If there's no private data dir, that means we didn't get into the # actual `run()` call; this _usually_ means something failed in # the pre_run_hook method @@ -1079,9 +1077,11 @@ class RunJob(SourceControlMixin, BaseTask): if job.use_fact_cache: job.finish_job_fact_cache( os.path.join(private_data_dir, 'artifacts', str(job.id), 'fact_cache'), - fact_modification_times, + self.facts_write_time, ) + def final_run_hook(self, job, status, private_data_dir): + super(RunJob, self).final_run_hook(job, status, private_data_dir) try: inventory = job.inventory except Inventory.DoesNotExist: diff --git a/awx/main/tests/unit/models/test_jobs.py b/awx/main/tests/unit/models/test_jobs.py index 98ac4e21d6..2f030a57c3 100644 --- a/awx/main/tests/unit/models/test_jobs.py +++ b/awx/main/tests/unit/models/test_jobs.py @@ -36,15 +36,14 @@ def job(mocker, hosts, inventory): def test_start_job_fact_cache(hosts, job, inventory, tmpdir): fact_cache = os.path.join(tmpdir, 'facts') - modified_times = {} - job.start_job_fact_cache(fact_cache, modified_times, 0) + last_modified = job.start_job_fact_cache(fact_cache, timeout=0) for host in hosts: filepath = os.path.join(fact_cache, host.name) assert os.path.exists(filepath) with open(filepath, 'r') as f: assert f.read() == json.dumps(host.ansible_facts) - assert filepath in modified_times + assert os.path.getmtime(filepath) <= last_modified def test_fact_cache_with_invalid_path_traversal(job, inventory, tmpdir, mocker): @@ -58,18 +57,16 @@ def test_fact_cache_with_invalid_path_traversal(job, inventory, tmpdir, mocker): ) fact_cache = os.path.join(tmpdir, 'facts') - job.start_job_fact_cache(fact_cache, {}, 0) + job.start_job_fact_cache(fact_cache, timeout=0) # a file called "foo" should _not_ be written outside the facts dir assert os.listdir(os.path.join(fact_cache, '..')) == ['facts'] def test_finish_job_fact_cache_with_existing_data(job, hosts, inventory, mocker, tmpdir): fact_cache = os.path.join(tmpdir, 'facts') - modified_times = {} - job.start_job_fact_cache(fact_cache, modified_times, 0) + last_modified = job.start_job_fact_cache(fact_cache, timeout=0) - for h in hosts: - h.save = mocker.Mock() + bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update') ansible_facts_new = {"foo": "bar"} filepath = os.path.join(fact_cache, hosts[1].name) @@ -83,23 +80,20 @@ def test_finish_job_fact_cache_with_existing_data(job, hosts, inventory, mocker, new_modification_time = time.time() + 3600 os.utime(filepath, (new_modification_time, new_modification_time)) - job.finish_job_fact_cache(fact_cache, modified_times) + job.finish_job_fact_cache(fact_cache, last_modified) for host in (hosts[0], hosts[2], hosts[3]): - host.save.assert_not_called() assert host.ansible_facts == {"a": 1, "b": 2} assert host.ansible_facts_modified is None assert hosts[1].ansible_facts == ansible_facts_new - hosts[1].save.assert_called_once_with(update_fields=['ansible_facts', 'ansible_facts_modified']) + bulk_update.assert_called_once_with([hosts[1]], ['ansible_facts', 'ansible_facts_modified']) def test_finish_job_fact_cache_with_bad_data(job, hosts, inventory, mocker, tmpdir): fact_cache = os.path.join(tmpdir, 'facts') - modified_times = {} - job.start_job_fact_cache(fact_cache, modified_times, 0) + last_modified = job.start_job_fact_cache(fact_cache, timeout=0) - for h in hosts: - h.save = mocker.Mock() + bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update') for h in hosts: filepath = os.path.join(fact_cache, h.name) @@ -109,26 +103,22 @@ def test_finish_job_fact_cache_with_bad_data(job, hosts, inventory, mocker, tmpd new_modification_time = time.time() + 3600 os.utime(filepath, (new_modification_time, new_modification_time)) - job.finish_job_fact_cache(fact_cache, modified_times) + job.finish_job_fact_cache(fact_cache, last_modified) - for h in hosts: - h.save.assert_not_called() + bulk_update.assert_not_called() def test_finish_job_fact_cache_clear(job, hosts, inventory, mocker, tmpdir): fact_cache = os.path.join(tmpdir, 'facts') - modified_times = {} - job.start_job_fact_cache(fact_cache, modified_times, 0) + last_modified = job.start_job_fact_cache(fact_cache, timeout=0) - for h in hosts: - h.save = mocker.Mock() + bulk_update = mocker.patch('django.db.models.query.QuerySet.bulk_update') os.remove(os.path.join(fact_cache, hosts[1].name)) - job.finish_job_fact_cache(fact_cache, modified_times) + job.finish_job_fact_cache(fact_cache, last_modified) for host in (hosts[0], hosts[2], hosts[3]): - host.save.assert_not_called() assert host.ansible_facts == {"a": 1, "b": 2} assert host.ansible_facts_modified is None assert hosts[1].ansible_facts == {} - hosts[1].save.assert_called_once_with() + bulk_update.assert_called_once_with([hosts[1]], ['ansible_facts', 'ansible_facts_modified']) diff --git a/awx/main/utils/common.py b/awx/main/utils/common.py index e724c1bc3f..af9cdc7a18 100644 --- a/awx/main/utils/common.py +++ b/awx/main/utils/common.py @@ -90,6 +90,7 @@ __all__ = [ 'deepmerge', 'get_event_partition_epoch', 'cleanup_new_process', + 'log_excess_runtime', ] @@ -1215,15 +1216,30 @@ def cleanup_new_process(func): return wrapper_cleanup_new_process -def log_excess_runtime(func_logger, cutoff=5.0): +def log_excess_runtime(func_logger, cutoff=5.0, debug_cutoff=5.0, msg=None, add_log_data=False): def log_excess_runtime_decorator(func): @functools.wraps(func) def _new_func(*args, **kwargs): start_time = time.time() - return_value = func(*args, **kwargs) - delta = time.time() - start_time - if delta > cutoff: - logger.info(f'Running {func.__name__!r} took {delta:.2f}s') + log_data = {'name': repr(func.__name__)} + + if add_log_data: + return_value = func(*args, log_data=log_data, **kwargs) + else: + return_value = func(*args, **kwargs) + + log_data['delta'] = time.time() - start_time + if isinstance(return_value, dict): + log_data.update(return_value) + + if msg is None: + record_msg = 'Running {name} took {delta:.2f}s' + else: + record_msg = msg + if log_data['delta'] > cutoff: + func_logger.info(record_msg.format(**log_data)) + elif log_data['delta'] > debug_cutoff: + func_logger.debug(record_msg.format(**log_data)) return return_value return _new_func