Only fetch fields we need in task manager

By using .only we select fewer columns, avoiding potentially large
fields that we never reference.

Also, small tweak to eliminate what was a duplicate dictionary of
hostname:instance, because we don't need build and carry two copies of
the same data.
This commit is contained in:
Elijah DeLee
2022-04-12 21:13:44 -04:00
parent 81cda0ba74
commit 868e811b3f

View File

@@ -68,8 +68,7 @@ class TaskManager:
""" """
Init AFTER we know this instance of the task manager will run because the lock is acquired. Init AFTER we know this instance of the task manager will run because the lock is acquired.
""" """
instances = Instance.objects.filter(hostname__isnull=False, enabled=True).exclude(node_type='hop') instances = Instance.objects.filter(hostname__isnull=False, enabled=True).exclude(node_type='hop').only('node_type', 'capacity', 'hostname', 'enabled')
self.real_instances = {i.hostname: i for i in instances}
self.controlplane_ig = None self.controlplane_ig = None
self.dependency_graph = DependencyGraph() self.dependency_graph = DependencyGraph()
@@ -85,11 +84,12 @@ class TaskManager:
] ]
instances_by_hostname = {i.hostname: i for i in instances_partial} instances_by_hostname = {i.hostname: i for i in instances_partial}
self.instances_by_hostname = instances_by_hostname
# updates remaining capacity value based on currently running and waiting tasks # updates remaining capacity value based on currently running and waiting tasks
Instance.update_remaining_capacity(instances_by_hostname, all_sorted_tasks) Instance.update_remaining_capacity(instances_by_hostname, all_sorted_tasks)
for rampart_group in InstanceGroup.objects.prefetch_related('instances'): for rampart_group in InstanceGroup.objects.prefetch_related('instances').only('name', 'instances'):
if rampart_group.name == settings.DEFAULT_CONTROL_PLANE_QUEUE_NAME: if rampart_group.name == settings.DEFAULT_CONTROL_PLANE_QUEUE_NAME:
self.controlplane_ig = rampart_group self.controlplane_ig = rampart_group
self.graph[rampart_group.name] = dict( self.graph[rampart_group.name] = dict(
@@ -114,16 +114,23 @@ class TaskManager:
return None return None
def get_tasks(self, status_list=('pending', 'waiting', 'running')): def get_tasks(self, status_list=('pending', 'waiting', 'running')):
jobs = [j for j in Job.objects.filter(status__in=status_list).prefetch_related('instance_group')] common_needed_values = ['celery_task_id', 'controller_node', 'created', 'execution_node', 'instance_group', 'job_explanation', 'name', 'pk', 'status']
inv_needed_values = ['inventory_source']
jobs = [j for j in Job.objects.filter(status__in=status_list).prefetch_related('instance_group').only(*common_needed_values)]
inventory_updates_qs = ( inventory_updates_qs = (
InventoryUpdate.objects.filter(status__in=status_list).exclude(source='file').prefetch_related('inventory_source', 'instance_group') InventoryUpdate.objects.filter(status__in=status_list)
.exclude(source='file')
.prefetch_related('inventory_source', 'instance_group')
.only(*(common_needed_values + inv_needed_values))
) )
inventory_updates = [i for i in inventory_updates_qs] inventory_updates = [i for i in inventory_updates_qs]
# Notice the job_type='check': we want to prevent implicit project updates from blocking our jobs. # Notice the job_type='check': we want to prevent implicit project updates from blocking our jobs.
project_updates = [p for p in ProjectUpdate.objects.filter(status__in=status_list, job_type='check').prefetch_related('instance_group')] project_updates = [
system_jobs = [s for s in SystemJob.objects.filter(status__in=status_list).prefetch_related('instance_group')] p for p in ProjectUpdate.objects.filter(status__in=status_list, job_type='check').prefetch_related('instance_group').only(*common_needed_values)
ad_hoc_commands = [a for a in AdHocCommand.objects.filter(status__in=status_list).prefetch_related('instance_group')] ]
workflow_jobs = [w for w in WorkflowJob.objects.filter(status__in=status_list)] system_jobs = [s for s in SystemJob.objects.filter(status__in=status_list).prefetch_related('instance_group').only(*common_needed_values)]
ad_hoc_commands = [a for a in AdHocCommand.objects.filter(status__in=status_list).prefetch_related('instance_group').only(*common_needed_values)]
workflow_jobs = [w for w in WorkflowJob.objects.filter(status__in=status_list).only(*common_needed_values)]
all_tasks = sorted(jobs + project_updates + inventory_updates + system_jobs + ad_hoc_commands + workflow_jobs, key=lambda task: task.created) all_tasks = sorted(jobs + project_updates + inventory_updates + system_jobs + ad_hoc_commands + workflow_jobs, key=lambda task: task.created)
return all_tasks return all_tasks
@@ -493,7 +500,7 @@ class TaskManager:
task.execution_node = control_instance.hostname task.execution_node = control_instance.hostname
control_instance.remaining_capacity = max(0, control_instance.remaining_capacity - control_impact) control_instance.remaining_capacity = max(0, control_instance.remaining_capacity - control_impact)
self.dependency_graph.add_job(task) self.dependency_graph.add_job(task)
execution_instance = self.real_instances[control_instance.hostname] execution_instance = self.instances_by_hostname[control_instance.hostname].obj
task.log_lifecycle("controller_node_chosen") task.log_lifecycle("controller_node_chosen")
task.log_lifecycle("execution_node_chosen") task.log_lifecycle("execution_node_chosen")
self.start_task(task, self.controlplane_ig, task.get_jobs_fail_chain(), execution_instance) self.start_task(task, self.controlplane_ig, task.get_jobs_fail_chain(), execution_instance)
@@ -533,7 +540,7 @@ class TaskManager:
task.log_format, rampart_group.name, execution_instance.hostname, execution_instance.remaining_capacity task.log_format, rampart_group.name, execution_instance.hostname, execution_instance.remaining_capacity
) )
) )
execution_instance = self.real_instances[execution_instance.hostname] execution_instance = self.instances_by_hostname[execution_instance.hostname].obj
self.dependency_graph.add_job(task) self.dependency_graph.add_job(task)
self.start_task(task, rampart_group, task.get_jobs_fail_chain(), execution_instance) self.start_task(task, rampart_group, task.get_jobs_fail_chain(), execution_instance)
found_acceptable_queue = True found_acceptable_queue = True