From 85b6aa2262064d9c2a300ed473fc5242dd1c9a25 Mon Sep 17 00:00:00 2001 From: Matthew Jones <“mjones@ansible.com”> Date: Mon, 10 Mar 2014 16:07:20 -0400 Subject: [PATCH] Rebasing for initial task system work. Current work towards actual task running flow --- awx/api/views.py | 4 +- .../management/commands/run_task_system.py | 233 ++++++++++++++++++ awx/main/models/base.py | 27 +- awx/main/models/inventory.py | 16 +- awx/main/models/jobs.py | 120 +++++---- awx/main/models/projects.py | 17 +- awx/main/tasks.py | 2 +- awx/main/utils.py | 9 + 8 files changed, 366 insertions(+), 62 deletions(-) create mode 100644 awx/main/management/commands/run_task_system.py diff --git a/awx/api/views.py b/awx/api/views.py index 61f6a3edf4..89090d5f7c 100644 --- a/awx/api/views.py +++ b/awx/api/views.py @@ -1122,7 +1122,7 @@ class JobTemplateCallback(GenericAPIView): return Response(data, status=status.HTTP_400_BAD_REQUEST) limit = ':'.join(filter(None, [job_template.limit, host.name])) job = job_template.create_job(limit=limit, launch_type='callback') - result = job.start() + result = job.signal_start() if not result: data = dict(msg='Error starting job!') return Response(data, status=status.HTTP_400_BAD_REQUEST) @@ -1178,7 +1178,7 @@ class JobStart(GenericAPIView): def post(self, request, *args, **kwargs): obj = self.get_object() if obj.can_start: - result = obj.start(**request.DATA) + result = obj.signal_start(**request.DATA) if not result: data = dict(passwords_needed_to_start=obj.passwords_needed_to_start) return Response(data, status=status.HTTP_400_BAD_REQUEST) diff --git a/awx/main/management/commands/run_task_system.py b/awx/main/management/commands/run_task_system.py new file mode 100644 index 0000000000..a32e9d7db2 --- /dev/null +++ b/awx/main/management/commands/run_task_system.py @@ -0,0 +1,233 @@ +#Copyright (c) 2014 Ansible, Inc. +# All Rights Reserved + +# Python +import os +import datetime +import logging +import json +import signal +import time +from optparse import make_option +from multiprocessing import Process + +# Django +from django.conf import settings +from django.core.management.base import NoArgsCommand, CommandError +from django.db import transaction, DatabaseError +from django.contrib.auth.models import User +from django.utils.dateparse import parse_datetime +from django.utils.timezone import now, is_aware, make_aware +from django.utils.tzinfo import FixedOffset + +# AWX +from awx.main.models import * +from awx.main.tasks import handle_work_error +from awx.main.utils import get_system_task_capacity, decrypt_field + +# ZeroMQ +import zmq + +# Celery +from celery.task.control import inspect + +class SimpleDAG(object): + + def __init__(self, nodes=[], edges=[]): + self.nodes = nodes + self.edges = edges + + def __contains__(self, obj): + for node in self.nodes: + if node['node_object'] == obj: + return True + return False + + def __len__(self): + return len(self.nodes) + + def __iter__(self): + return self.nodes.__iter__() + + def generate_graphviz_plot(self): + doc = """ + digraph g { + rankdir = LR + """ + for n in self.nodes: + doc += "%s [color = %s]\n" % (str(n), "red" if n.status == 'running' else "black") + for from, to in self.edges: + doc += "%s -> %s;\n" % (str(self.nodes[from]), str(self.nodes[to])) + doc += "}" + gv_file = open('/tmp/graph.gv', 'w') + gv_file.write(doc) + gv_file.close() + + def add_node(self, obj, metadata=None): + if self.find_ord(obj) is None: + self.nodes.append(dict(node_object=obj, metadata=metadata)) + + def add_edge(self, from_obj, to_obj): + from_obj_ord = self.find_ord(from_obj) + to_obj_ord = self.find_ord(from_obj) + if from_obj_ord is None or to_obj_ord is None: + raise LookupError("Object not found") + self.edges.append((from_obj_ord, to_obj_ord)) + + def add_edges(self, edgelist): + for from_obj, to_obj in edgelist: + self.add_edge(from_obj, to_obj) + + def find_ord(self, obj): + for idx in range(len(self.nodes)): + if obj == self.nodes[idx]['node_object']: + return idx + return None + + def get_node_type(self, obj): + if type(obj) == Job: + return "ansible_playbook" + elif type(obj) == InventoryUpdate: + return "inventory_update" + elif type(obj) == ProjectUpdate: + return "project_update" + return "unknown" + + def get_dependencies(self, obj): + antecedents = [] + this_ord = find_ord(self, obj) + for node, dep in self.edges: + if node == this_ord: + antecedents.append(self.nodes[dep]) + return antecedents + + def get_dependents(self, obj): + decendents = [] + this_ord = find_ord(self, obj) + for node, dep in self.edges: + if dep == this_ord: + decendents.append(self.nodes[node]) + return decendents + + def get_leaf_nodes(): + leafs = [] + for n in self.nodes: + if len(self.get_dependencies(n)) < 1: + leafs.append(n) + return n + +def get_tasks(): + # TODO: Replace this when we can grab all objects in a sane way + graph_jobs = [j for j in Job.objects.filter(status__in=('new', 'waiting', 'pending', 'running'))] + graph_inventory_updates = [iu for iu in InventoryUpdate.objects.filter(status__in=('new', 'waiting', 'pending', 'running'))] + graph_project_updates = [pu for pu in ProjectUpdate.objects.filter(status__in=('new', 'waiting', 'pending', 'running'))] + all_actions = sorted(graph_jobs + graph_inventory_updates + graph_project_updates, key=lambda task: task.created) + +def rebuild_graph(message): + inspector = inspect() + active_task_queues = inspector.active() + active_tasks = [] + for queue in active_task_queues: + active_tasks += active_task_queues[queue] + + all_sorted_tasks = get_tasks() + running_tasks = filter(lambda t: t.status == 'running', all_sorted_tasks) + waiting_tasks = filter(lambda t: t.status != 'running', all_sorted_tasks) + new_tasks = filter(lambda t: t.status == 'new', all_sorted_tasks) + + # Check running tasks and make sure they are active in celery + for task in list(running_tasks): + if task.celery_task_id not in active_tasks: + task.status = 'failed' + task.result_traceback += "Task was marked as running in Tower but was not present in Celery so it has been marked as failed" + task.save() + running_tasks.pop(task) + if settings.DEBUG: + print("Task %s appears orphaned... marking as failed" % task) + + # Create and process dependencies for new tasks + for task in new_tasks: + task_dependencies = task.generate_dependencies(running_tasks + waiting_tasks) #TODO: other 'new' tasks? Need to investigate this scenario + for dep in task_dependencies: + # We recalculate the created time for the moment to ensure the dependencies are always sorted in the right order relative to the dependent task + time_delt = len(task_dependencies) - task_dependencies.index(dep) + dep.created = task.created - datetime.timedelta(seconds=1+time_delt) + dep.save() + waiting_tasks.insert(dep, waiting_tasks.index(task)) + + # Rebuild graph + graph = SimpleDAG() + for task in running_tasks: + graph.add_node(task) + for wait_task in waiting_tasks: + node_dependencies = [] + for node in graph: + if wait_task.is_blocked_by(node['node_objects']): + node_dependencies.append(node) + graph.add_node(wait_task) + graph.add_edges([(wait_task, n) for n in node_dependencies]) + if settings.DEBUG: + graph.generate_graphviz_plot() + return graph + +def process_graph(graph, task_capacity): + leaf_nodes = graph.get_leaf_nodes() + running_nodes = filter(lambda x['node_object'].status == 'running', leaf_nodes) + running_impact = sum([t['node_object'].task_impact for t in running_nodes]) + ready_nodes = filter(lambda x['node_object'].status != 'running', leaf_nodes) + remaining_volume = task_capacity - running_impact + for task_node in ready_nodes: + node_obj = task_node['node_object'] + node_args = task_node['metadata'] + impact = node_obj.task_impact + if impact <= remaining_volume or running_impact == 0: + dependent_nodes = [{'type': graph.get_node_type(n), 'id': n.id} for n in graph.get_dependents()] + error_handler = handle_work_error.s(subtasks=dependent_nodes) + node_obj.start(error_callback=error_handler) + remaining_volume -= impact + running_impact += impact + +def run_taskmanager(command_port): + paused = False + task_capacity = get_system_task_capacity() + command_context = zmq.Context() + command_socket = command_context.socket(zmq.REP) + command_socket.bind(command_port) + last_rebuild = datetime.datetime.now() + while True: + try: + message = command_socket.recv_json(flags=zmq.NOBLOCK) + command_socket.send("1") + except zmq.core.error.ZMQError,e: + message = None + if message is not None or (datetime.datetime.now() - last_rebuild).seconds > 60: + if 'pause' in message: + paused = message['pause'] + graph = rebuild_graph(message) + if not paused: + process_graph(graph, task_capacity) + last_rebuild = datetime.datetime.now() + time.sleep(0.1) + +class Command(NoArgsCommand): + + help = 'Launch the job graph runner' + + def init_logging(self): + log_levels = dict(enumerate([logging.ERROR, logging.INFO, + logging.DEBUG, 0])) + self.logger = logging.getLogger('awx.main.commands.run_task_system') + self.logger.setLevel(log_levels.get(self.verbosity, 0)) + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter('%(message)s')) + self.logger.addHandler(handler) + self.logger.propagate = False + + def handle_noargs(self, **options): + self.verbosity = int(options.get('verbosity', 1)) + self.init_logging() + command_port = settings.TASK_COMMAND_PORT + try: + run_taskmanager(command_port) + except KeyboardInterrupt: + pass diff --git a/awx/main/models/base.py b/awx/main/models/base.py index 45cabc5a83..ba38198959 100644 --- a/awx/main/models/base.py +++ b/awx/main/models/base.py @@ -278,6 +278,11 @@ class CommonTask(PrimordialModel): default={}, editable=False, ) + start_args = models.TextField( + blank=True, + default='', + editable=False, + ) _result_stdout = models.TextField( blank=True, default='', @@ -367,12 +372,29 @@ class CommonTask(PrimordialModel): def can_start(self): return bool(self.status == 'new') + @property + def task_impact(self): + raise NotImplementedError + def _get_task_class(self): raise NotImplementedError def _get_passwords_needed_to_start(self): return [] + def is_blocked_by(self, task_object): + ''' Given another task object determine if this task would be blocked by it ''' + raise NotImplementedError + + def generate_dependencies(self, active_tasks): + ''' Generate any tasks that the current task might be dependent on given a list of active + tasks that might preclude creating one''' + return [] + + def signal_start(self): + ''' Notify the task runner system to begin work on this task ''' + raise NotImplementedError + def start_signature(self, **kwargs): from awx.main.tasks import handle_work_error @@ -383,13 +405,10 @@ class CommonTask(PrimordialModel): opts = dict([(field, kwargs.get(field, '')) for field in needed]) if not all(opts.values()): return False - self.status = 'pending' - self.save(update_fields=['status']) - transaction.commit() task_actual = task_class().si(self.pk, **opts) return task_actual - def start(self, **kwargs): + def start(self, error_callback, **kwargs): task_actual = self.start_signature(**kwargs) # TODO: Callback for status task_result = task_actual.delay() diff --git a/awx/main/models/inventory.py b/awx/main/models/inventory.py index 4a9a2c2355..a48b40e7e4 100644 --- a/awx/main/models/inventory.py +++ b/awx/main/models/inventory.py @@ -705,7 +705,7 @@ class InventorySource(PrimordialModel): def update(self, **kwargs): if self.can_update: inventory_update = self.inventory_updates.create() - inventory_update.start() + inventory_update.signal_start() return inventory_update def get_absolute_url(self): @@ -739,7 +739,7 @@ class InventoryUpdate(CommonTask): if 'license_error' not in update_fields: update_fields.append('license_error') super(InventoryUpdate, self).save(*args, **kwargs) - + def _get_parent_instance(self): return self.inventory_source @@ -749,3 +749,15 @@ class InventoryUpdate(CommonTask): def _get_task_class(self): from awx.main.tasks import RunInventoryUpdate return RunInventoryUpdate + + @property + def task_impact(self): + return 50 + + def signal_start(self, **kwargs): + signal_context = zmq.Context() + signal_socket = signal_context.socket(zmq.REQ) + signal_socket.connect(settings.TASK_COMMAND_PORT) + signal_socket.send_json(dict(task_type="inventory_update", id=self.id, metadata=kwargs)) + self.socket.recv() + return True diff --git a/awx/main/models/jobs.py b/awx/main/models/jobs.py index 36fb326588..909ce22f73 100644 --- a/awx/main/models/jobs.py +++ b/awx/main/models/jobs.py @@ -31,6 +31,7 @@ from jsonfield import JSONField # AWX from awx.main.models.base import * +from awx.main.utils import encrypt_field # Celery from celery import chain @@ -298,7 +299,7 @@ class Job(CommonTask): def _get_task_class(self): from awx.main.tasks import RunJob return RunJob - + def _get_passwords_needed_to_start(self): return self.passwords_needed_to_start @@ -307,6 +308,28 @@ class Job(CommonTask): kwargs['job_host_summaries__job__pk'] = self.pk return Host.objects.filter(**kwargs) + def is_blocked_by(self, obj): + from awx.main.models import InventoryUpdate, ProjectUpdate + if type(obj) == Job: + if obj.job_template == self.job_template: + return True + return False + if type(obj) == InventoryUpdate: + for i_s in self.inventory.inventory_sources.filter(active=True): + if i_s == obj.inventory_source: + return True + return False + if type(obj) == ProjectUpdate: + if obj.project == self.project: + return True + return False + return False + + @property + def task_impact(self): + # NOTE: We sorta have to assume the host count matches and that forks default to 5 + return min(self._get_hosts().count(), 5 if self.forks == 0 else self.forks) * 10 + @property def successful_hosts(self): return self._get_hosts(job_host_summaries__ok__gt=0) @@ -335,64 +358,57 @@ class Job(CommonTask): def processed_hosts(self): return self._get_hosts(job_host_summaries__processed__gt=0) - def start(self, **kwargs): + def generate_dependencies(self, active_tasks): + from awx.main.models import InventoryUpdate, ProjectUpdate + inventory_sources = self.inventory.inventory_sources.filter(active=True, update_on_launch=True) + project_found = False + inventory_sources_found = [] + dependencies = [] + for obj in active_tasks: + if type(obj) == ProjectUpdate: + if obj.project == self.project: + project_found = True + if type(obj) == InventoryUpdate: + if obj.inventory_source in inventory_sources: + inventory_sources_found.append(obj.inventory_source) + if not project_found and self.project.scm_update_on_launch:: + dependencies.append(self.project.project_updates.create()) + if inventory_sources.count(): # and not has_setup_failures? Probably handled as an error scenario in the task runner + for source in inventory_sources: + if not source in inventory_sources_found: + dependencies.append(source.inventory_updates.create()) + return dependencies + + def signal_start(self, **kwargs): + json_args = json.dumps(kwargs) + self.start_args = json_args + self.save() + self.start_args = encrypt_field(self, 'start_args') + self.save() + signal_context = zmq.Context() + signal_socket = signal_context.socket(zmq.REQ) + signal_socket.connect(settings.TASK_COMMAND_PORT) + signal_socket.send_json(dict(task_type="ansible_playbook", id=self.id)) + self.socket.recv() + return True + + def start(self, error_callback, **kwargs): from awx.main.tasks import handle_work_error task_class = self._get_task_class() if not self.can_start: return False needed = self._get_passwords_needed_to_start() - opts = dict([(field, kwargs.get(field, '')) for field in needed]) + try: + stored_args = json.loads(decrypt_field(self, 'start_args')) + except Exception, e: + stored_args = None + if stored_args is None or stored_args == '': + opts = dict([(field, kwargs.get(field, '')) for field in needed]) + else: + opts = stored_args if not all(opts.values()): return False - self.status = 'waiting' - self.save(update_fields=['status']) - transaction.commit() - - runnable_tasks = [] - run_tasks = [] - inventory_updates_actual = [] - project_update_actual = None - has_setup_failures = False - setup_failure_message = "" - - project = self.project - inventory = self.inventory - is_qs = inventory.inventory_sources.filter(active=True, update_on_launch=True) - if project.scm_update_on_launch: - project_update_details = project.update_signature() - if not project_update_details: - has_setup_failures = True - setup_failure_message = "Failed to check dependent project update task" - else: - runnable_tasks.append({'obj': project_update_details[0], - 'sig': project_update_details[1], - 'type': 'project_update'}) - if is_qs.count() and not has_setup_failures: - for inventory_source in is_qs: - inventory_update_details = inventory_source.update_signature() - if not inventory_update_details: - has_setup_failures = True - setup_failure_message = "Failed to check dependent inventory update task" - break - else: - runnable_tasks.append({'obj': inventory_update_details[0], - 'sig': inventory_update_details[1], - 'type': 'inventory_update'}) - if has_setup_failures: - for each_task in runnable_tasks: - obj = each_task['obj'] - obj.status = 'error' - obj.result_traceback = setup_failure_message - obj.save() - self.status = 'error' - self.result_traceback = setup_failure_message - self.save() - thisjob = {'type': 'job', 'id': self.id} - for idx in xrange(len(runnable_tasks)): - dependent_tasks = [{'type': r['type'], 'id': r['obj'].id} for r in runnable_tasks[idx:]] + [thisjob] - run_tasks.append(runnable_tasks[idx]['sig'].set(link_error=handle_work_error.s(subtasks=dependent_tasks))) - run_tasks.append(task_class().si(self.pk, **opts).set(link_error=handle_work_error.s(subtasks=[thisjob]))) - res = chain(run_tasks)() + task_class().apply_async((self.pk, **opts), link_error=error_callback) return True class JobHostSummary(BaseModel): diff --git a/awx/main/models/projects.py b/awx/main/models/projects.py index 4f36f00405..aa05a6b69a 100644 --- a/awx/main/models/projects.py +++ b/awx/main/models/projects.py @@ -16,6 +16,9 @@ import uuid # PyYAML import yaml +# ZeroMQ +import zmq + # Django from django.conf import settings from django.db import models @@ -291,7 +294,7 @@ class Project(CommonModel): def update(self, **kwargs): if self.can_update: project_update = self.project_updates.create() - project_update.start() + project_update.signal_start() return project_update def get_absolute_url(self): @@ -362,6 +365,18 @@ class ProjectUpdate(CommonTask): from awx.main.tasks import RunProjectUpdate return RunProjectUpdate + @property + def task_impact(self): + return 20 + + def signal_start(self, **kwargs): + signal_context = zmq.Context() + signal_socket = signal_context.socket(zmq.REQ) + signal_socket.connect(settings.TASK_COMMAND_PORT) + signal_socket.send_json(dict(task_type="project_update", id=self.id, metadata=kwargs)) + self.socket.recv() + return True + def _update_parent_instance(self): parent_instance = self._get_parent_instance() if parent_instance: diff --git a/awx/main/tasks.py b/awx/main/tasks.py index 7abf10af28..8915b49c7f 100644 --- a/awx/main/tasks.py +++ b/awx/main/tasks.py @@ -63,7 +63,7 @@ def handle_work_error(self, task_id, subtasks=None): elif each_task['type'] == 'inventory_update': instance = InventoryUpdate.objects.get(id=each_task['id']) instance_name = instance.inventory_source.inventory.name - elif each_task['type'] == 'job': + elif each_task['type'] == 'ansible_playbook': instance = Job.objects.get(id=each_task['id']) instance_name = instance.job_template.name else: diff --git a/awx/main/utils.py b/awx/main/utils.py index 0ff157138e..e5cdcd7de3 100644 --- a/awx/main/utils.py +++ b/awx/main/utils.py @@ -300,3 +300,12 @@ def model_to_dict(obj, serializer_mapping=None): else: attr_d[field.name] = "hidden" return attr_d + +def get_system_task_capacity(): + from django.conf import settings + if hasattr(settings, 'SYSTEM_TASK_CAPACITY'): + return settings.SYSTEM_TASK_CAPACITY + total_mem_value = subprocess.check_output(['free','-m']).split()[7] + if int(total_mem_value) <= 2048: + return 50 + return 50 + ((int(total_mem_value) / 1024) - 2) * 75