Shift task start signal into an asynchronous task so we don't deadlock trying to update the same record from the task runner while waiting for the signal to be received from the signaler

This commit is contained in:
Matthew Jones
2014-03-22 11:18:25 -04:00
parent 4b9059388e
commit bec301c2a8
5 changed files with 20 additions and 24 deletions

View File

@@ -142,6 +142,10 @@ def get_tasks():
def rebuild_graph(message):
''' Regenerate the task graph by refreshing known tasks from Tower, purging orphaned running tasks,
and creatingdependencies for new tasks before generating directed edge relationships between those tasks '''
all_sorted_tasks = get_tasks()
if not len(all_sorted_tasks):
return None
inspector = inspect()
if not hasattr(settings, 'IGNORE_CELERY_INSPECTOR'):
active_task_queues = inspector.active()
@@ -159,9 +163,6 @@ def rebuild_graph(message):
# TODO: Something needs to be done here to signal to the system as a whole that celery appears to be down
if not hasattr(settings, 'CELERY_UNIT_TEST'):
return None
all_sorted_tasks = get_tasks()
if not len(all_sorted_tasks):
return None
running_tasks = filter(lambda t: t.status == 'running', all_sorted_tasks)
waiting_tasks = filter(lambda t: t.status != 'running', all_sorted_tasks)
new_tasks = filter(lambda t: t.status == 'new', all_sorted_tasks)
@@ -264,7 +265,7 @@ def run_taskmanager(command_port):
command_socket.send("1")
except zmq.ZMQError,e:
message = None
if message is not None or (datetime.datetime.now() - last_rebuild).seconds > 60:
if message is not None or (datetime.datetime.now() - last_rebuild).seconds > 10:
if message is not None and 'pause' in message:
print("Pause command received: %s" % str(message))
paused = message['pause']

View File

@@ -768,6 +768,7 @@ class InventoryUpdate(CommonTask):
return 50
def signal_start(self, **kwargs):
from awx.main.tasks import notify_task_runner
if not self.can_start:
return False
needed = self._get_passwords_needed_to_start()
@@ -780,9 +781,5 @@ class InventoryUpdate(CommonTask):
self.save()
self.start_args = encrypt_field(self, 'start_args')
self.save()
signal_context = zmq.Context()
signal_socket = signal_context.socket(zmq.REQ)
signal_socket.connect(settings.TASK_COMMAND_PORT)
signal_socket.send_json(dict(task_type="inventory_update", id=self.id, metadata=kwargs))
signal_socket.recv()
notify_task_runner.delay(dict(task_type="inventory_update", id=self.id, metadata=kwargs))
return True

View File

@@ -385,6 +385,7 @@ class Job(CommonTask):
return dependencies
def signal_start(self, **kwargs):
from awx.main.tasks import notify_task_runner
if hasattr(settings, 'CELERY_UNIT_TEST'):
return self.start(None, **kwargs)
if not self.can_start:
@@ -399,11 +400,7 @@ class Job(CommonTask):
self.save()
self.start_args = encrypt_field(self, 'start_args')
self.save()
signal_context = zmq.Context()
signal_socket = signal_context.socket(zmq.REQ)
signal_socket.connect(settings.TASK_COMMAND_PORT)
signal_socket.send_json(dict(task_type="ansible_playbook", id=self.id))
signal_socket.recv()
notify_task_runner.delay(dict(task_type="ansible_playbook", id=self.id))
return True
def start(self, error_callback, **kwargs):

View File

@@ -380,6 +380,7 @@ class ProjectUpdate(CommonTask):
return 20
def signal_start(self, **kwargs):
from awx.main.tasks import notify_task_runner
if not self.can_start:
return False
needed = self._get_passwords_needed_to_start()
@@ -392,11 +393,7 @@ class ProjectUpdate(CommonTask):
self.save()
self.start_args = encrypt_field(self, 'start_args')
self.save()
signal_context = zmq.Context()
signal_socket = signal_context.socket(zmq.REQ)
signal_socket.connect(settings.TASK_COMMAND_PORT)
signal_socket.send_json(dict(task_type="project_update", id=self.id, metadata=kwargs))
signal_socket.recv()
notify_task_runner.delay(dict(task_type="project_update", id=self.id, metadata=kwargs))
return True
def _update_parent_instance(self):

View File

@@ -51,6 +51,14 @@ logger = logging.getLogger('awx.main.tasks')
# FIXME: Cleanly cancel task when celery worker is stopped.
@task()
def notify_task_runner(metadata_dict):
time.sleep(1)
signal_context = zmq.Context()
signal_socket = signal_context.socket(zmq.PUSH)
signal_socket.connect(settings.TASK_COMMAND_PORT)
signal_socket.send_json(metadata_dict)
@task(bind=True)
def handle_work_error(self, task_id, subtasks=None):
print('Executing error task id %s, subtasks: %s' % (str(self.request.id), str(subtasks)))
@@ -124,11 +132,7 @@ class BaseTask(Task):
self.model._meta.object_name, retry_count)
def signal_finished(self, pk):
signal_context = zmq.Context()
signal_socket = signal_context.socket(zmq.REQ)
signal_socket.connect(settings.TASK_COMMAND_PORT)
signal_socket.send_json(dict(complete=pk))
signal_socket.recv()
notify_task_runner(dict(complete=pk))
def get_model(self, pk):
return self.model.objects.get(pk=pk)