diff --git a/awx/main/scheduler/task_manager.py b/awx/main/scheduler/task_manager.py index 8765c2871b..09ea5e23e8 100644 --- a/awx/main/scheduler/task_manager.py +++ b/awx/main/scheduler/task_manager.py @@ -71,6 +71,7 @@ class TaskManager: instances = Instance.objects.filter(hostname__isnull=False, enabled=True).exclude(node_type='hop') self.real_instances = {i.hostname: i for i in instances} self.controlplane_ig = None + self.dependency_graph = DependencyGraph() instances_partial = [ SimpleNamespace( @@ -90,32 +91,18 @@ class TaskManager: if rampart_group.name == settings.DEFAULT_CONTROL_PLANE_QUEUE_NAME: self.controlplane_ig = rampart_group self.graph[rampart_group.name] = dict( - graph=DependencyGraph(), - execution_capacity=0, - control_capacity=0, - consumed_capacity=0, - consumed_control_capacity=0, - consumed_execution_capacity=0, - instances=[], + instances=[ + instances_by_hostname[instance.hostname] for instance in rampart_group.instances.all() if instance.hostname in instances_by_hostname + ], ) - for instance in rampart_group.instances.all(): - if not instance.enabled: - continue - for capacity_type in ('control', 'execution'): - if instance.node_type in (capacity_type, 'hybrid'): - self.graph[rampart_group.name][f'{capacity_type}_capacity'] += instance.capacity - for instance in rampart_group.instances.filter(enabled=True).order_by('hostname'): - if instance.hostname in instances_by_hostname: - self.graph[rampart_group.name]['instances'].append(instances_by_hostname[instance.hostname]) def job_blocked_by(self, task): # TODO: I'm not happy with this, I think blocking behavior should be decided outside of the dependency graph # in the old task manager this was handled as a method on each task object outside of the graph and # probably has the side effect of cutting down *a lot* of the logic from this task manager class - for g in self.graph: - blocked_by = self.graph[g]['graph'].task_blocked_by(task) - if blocked_by: - return blocked_by + blocked_by = self.dependency_graph.task_blocked_by(task) + if blocked_by: + return blocked_by if not task.dependent_jobs_finished(): blocked_by = task.dependent_jobs.first() @@ -298,16 +285,6 @@ class TaskManager: task.save() task.log_lifecycle("waiting") - if rampart_group is not None: - self.consume_capacity(task, rampart_group.name, instance=instance) - if task.controller_node: - self.consume_capacity( - task, - settings.DEFAULT_CONTROL_PLANE_QUEUE_NAME, - instance=self.real_instances[task.controller_node], - impact=settings.AWX_CONTROL_NODE_TASK_IMPACT, - ) - def post_commit(): if task.status != 'failed' and type(task) is not WorkflowJob: # Before task is dispatched, ensure that job_event partitions exist @@ -327,8 +304,7 @@ class TaskManager: def process_running_tasks(self, running_tasks): for task in running_tasks: - if task.instance_group: - self.graph[task.instance_group.name]['graph'].add_job(task) + self.dependency_graph.add_job(task) def create_project_update(self, task): project_task = Project.objects.get(id=task.project_id).create_project_update(_eager_fields=dict(launch_type='dependency')) @@ -515,7 +491,7 @@ class TaskManager: task.execution_node = control_instance.hostname control_instance.remaining_capacity = max(0, control_instance.remaining_capacity - control_impact) control_instance.jobs_running += 1 - self.graph[settings.DEFAULT_CONTROL_PLANE_QUEUE_NAME]['graph'].add_job(task) + self.dependency_graph.add_job(task) execution_instance = self.real_instances[control_instance.hostname] self.start_task(task, self.controlplane_ig, task.get_jobs_fail_chain(), execution_instance) found_acceptable_queue = True @@ -524,7 +500,7 @@ class TaskManager: for rampart_group in preferred_instance_groups: if rampart_group.is_container_group: control_instance.jobs_running += 1 - self.graph[settings.DEFAULT_CONTROL_PLANE_QUEUE_NAME]['graph'].add_job(task) + self.dependency_graph.add_job(task) self.start_task(task, rampart_group, task.get_jobs_fail_chain(), None) found_acceptable_queue = True break @@ -559,7 +535,7 @@ class TaskManager: ) ) execution_instance = self.real_instances[execution_instance.hostname] - self.graph[rampart_group.name]['graph'].add_job(task) + self.dependency_graph.add_job(task) self.start_task(task, rampart_group, task.get_jobs_fail_chain(), execution_instance) found_acceptable_queue = True break @@ -616,29 +592,9 @@ class TaskManager: logger.error(f'{j.execution_node} is not a registered instance; reaping {j.log_format}') reap_job(j, 'failed') - def calculate_capacity_consumed(self, tasks): - self.graph = InstanceGroup.objects.capacity_values(tasks=tasks, graph=self.graph) - - def consume_capacity(self, task, instance_group, instance=None, impact=None): - impact = impact if impact else task.task_impact - logger.debug( - '{} consumed {} capacity units from {} with prior total of {}'.format( - task.log_format, impact, instance_group, self.graph[instance_group]['consumed_capacity'] - ) - ) - self.graph[instance_group]['consumed_capacity'] += impact - for capacity_type in ('control', 'execution'): - if instance is None or instance.node_type in ('hybrid', capacity_type): - self.graph[instance_group][f'consumed_{capacity_type}_capacity'] += impact - - def get_remaining_capacity(self, instance_group, capacity_type='execution'): - return self.graph[instance_group][f'{capacity_type}_capacity'] - self.graph[instance_group][f'consumed_{capacity_type}_capacity'] - def process_tasks(self, all_sorted_tasks): running_tasks = [t for t in all_sorted_tasks if t.status in ['waiting', 'running']] - self.calculate_capacity_consumed(running_tasks) - self.process_running_tasks(running_tasks) pending_tasks = [t for t in all_sorted_tasks if t.status == 'pending']