From e87fabe6bb84691472ab67e5da737c9fe515cf3f Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Thu, 18 Aug 2022 09:43:53 -0400 Subject: [PATCH] Submit job to dispatcher as part of transaction (#12573) Make it so that submitting a task to the dispatcher happens as part of the transaction. this applies to dispatcher task "publishers" which NOTIFY the pg_notify queue if the transaction is not successful, it will not be sent, as per postgres docs This keeps current behavior for pg_notify listeners practically, this only applies for the awx-manage run_dispatcher service this requires creating a separate connection and keeping it long-lived arbitrary code will occasionally close the main connection, which would stop listening Stop sending the waiting status websocket message this is required because the ordering cannot be maintained with other changes here the instance group data is moved to the running websocket message payload Move call to create_partition from task manager to pre_run_hook mock this in relevant unit tests --- awx/main/dispatch/__init__.py | 38 +++++++++++++++---- awx/main/dispatch/worker/base.py | 2 +- awx/main/models/unified_jobs.py | 2 +- awx/main/scheduler/task_manager.py | 35 ++++++++--------- awx/main/tasks/jobs.py | 4 +- .../task_management/test_scheduler.py | 2 +- .../test_inventory_source_injectors.py | 5 ++- awx/main/tests/unit/test_tasks.py | 14 +++++-- 8 files changed, 68 insertions(+), 34 deletions(-) diff --git a/awx/main/dispatch/__init__.py b/awx/main/dispatch/__init__.py index c240f6fee9..7fa4bd06f1 100644 --- a/awx/main/dispatch/__init__.py +++ b/awx/main/dispatch/__init__.py @@ -4,6 +4,7 @@ import select from contextlib import contextmanager from django.conf import settings +from django.db import connection as pg_connection NOT_READY = ([], [], []) @@ -15,7 +16,6 @@ def get_local_queuename(): class PubSub(object): def __init__(self, conn): - assert conn.autocommit, "Connection must be in autocommit mode." self.conn = conn def listen(self, channel): @@ -31,6 +31,9 @@ class PubSub(object): cur.execute('SELECT pg_notify(%s, %s);', (channel, payload)) def events(self, select_timeout=5, yield_timeouts=False): + if not pg_connection.get_autocommit(): + raise RuntimeError('Listening for events can only be done in autocommit mode') + while True: if select.select([self.conn], [], [], select_timeout) == NOT_READY: if yield_timeouts: @@ -45,11 +48,32 @@ class PubSub(object): @contextmanager -def pg_bus_conn(): - conf = settings.DATABASES['default'] - conn = psycopg2.connect(dbname=conf['NAME'], host=conf['HOST'], user=conf['USER'], password=conf['PASSWORD'], port=conf['PORT'], **conf.get("OPTIONS", {})) - # Django connection.cursor().connection doesn't have autocommit=True on - conn.set_session(autocommit=True) +def pg_bus_conn(new_connection=False): + ''' + Any listeners probably want to establish a new database connection, + separate from the Django connection used for queries, because that will prevent + losing connection to the channel whenever a .close() happens. + + Any publishers probably want to use the existing connection + so that messages follow postgres transaction rules + https://www.postgresql.org/docs/current/sql-notify.html + ''' + + if new_connection: + conf = settings.DATABASES['default'] + conn = psycopg2.connect( + dbname=conf['NAME'], host=conf['HOST'], user=conf['USER'], password=conf['PASSWORD'], port=conf['PORT'], **conf.get("OPTIONS", {}) + ) + # Django connection.cursor().connection doesn't have autocommit=True on by default + conn.set_session(autocommit=True) + else: + if pg_connection.connection is None: + pg_connection.connect() + if pg_connection.connection is None: + raise RuntimeError('Unexpectedly could not connect to postgres for pg_notify actions') + conn = pg_connection.connection + pubsub = PubSub(conn) yield pubsub - conn.close() + if new_connection: + conn.close() diff --git a/awx/main/dispatch/worker/base.py b/awx/main/dispatch/worker/base.py index b982cb8ab4..b30f7ec17a 100644 --- a/awx/main/dispatch/worker/base.py +++ b/awx/main/dispatch/worker/base.py @@ -154,7 +154,7 @@ class AWXConsumerPG(AWXConsumerBase): while True: try: - with pg_bus_conn() as conn: + with pg_bus_conn(new_connection=True) as conn: for queue in self.queues: conn.listen(queue) if init is False: diff --git a/awx/main/models/unified_jobs.py b/awx/main/models/unified_jobs.py index b048b8ee2e..0c95d5b3e9 100644 --- a/awx/main/models/unified_jobs.py +++ b/awx/main/models/unified_jobs.py @@ -1274,7 +1274,7 @@ class UnifiedJob( def _websocket_emit_status(self, status): try: status_data = dict(unified_job_id=self.id, status=status) - if status == 'waiting': + if status == 'running': if self.instance_group: status_data['instance_group_name'] = self.instance_group.name else: diff --git a/awx/main/scheduler/task_manager.py b/awx/main/scheduler/task_manager.py index 157cbb1610..45f262ebe6 100644 --- a/awx/main/scheduler/task_manager.py +++ b/awx/main/scheduler/task_manager.py @@ -11,7 +11,7 @@ import sys import signal # Django -from django.db import transaction, connection +from django.db import transaction from django.utils.translation import gettext_lazy as _, gettext_noop from django.utils.timezone import now as tz_now from django.conf import settings @@ -39,7 +39,7 @@ from awx.main.utils import ( ScheduleTaskManager, ScheduleWorkflowManager, ) -from awx.main.utils.common import create_partition, task_manager_bulk_reschedule +from awx.main.utils.common import task_manager_bulk_reschedule from awx.main.signals import disable_activity_stream from awx.main.constants import ACTIVE_STATES from awx.main.scheduler.dependency_graph import DependencyGraph @@ -556,22 +556,23 @@ class TaskManager(TaskBase): task.save() task.log_lifecycle("waiting") - def post_commit(): - if task.status != 'failed' and type(task) is not WorkflowJob: - # Before task is dispatched, ensure that job_event partitions exist - create_partition(task.event_class._meta.db_table, start=task.created) - task_cls = task._get_task_class() - task_cls.apply_async( - [task.pk], - opts, - queue=task.get_queue_name(), - uuid=task.celery_task_id, - callbacks=[{'task': handle_work_success.name, 'kwargs': {'task_actual': task_actual}}], - errbacks=[{'task': handle_work_error.name, 'args': [task.celery_task_id], 'kwargs': {'subtasks': [task_actual] + dependencies}}], - ) + # apply_async does a NOTIFY to the channel dispatcher is listening to + # postgres will treat this as part of the transaction, which is what we want + if task.status != 'failed' and type(task) is not WorkflowJob: + task_cls = task._get_task_class() + task_cls.apply_async( + [task.pk], + opts, + queue=task.get_queue_name(), + uuid=task.celery_task_id, + callbacks=[{'task': handle_work_success.name, 'kwargs': {'task_actual': task_actual}}], + errbacks=[{'task': handle_work_error.name, 'args': [task.celery_task_id], 'kwargs': {'subtasks': [task_actual] + dependencies}}], + ) - task.websocket_emit_status(task.status) # adds to on_commit - connection.on_commit(post_commit) + # In exception cases, like a job failing pre-start checks, we send the websocket status message + # for jobs going into waiting, we omit this because of performance issues, as it should go to running quickly + if task.status != 'waiting': + task.websocket_emit_status(task.status) # adds to on_commit @timeit def process_running_tasks(self, running_tasks): diff --git a/awx/main/tasks/jobs.py b/awx/main/tasks/jobs.py index 774469c3ba..420f171e22 100644 --- a/awx/main/tasks/jobs.py +++ b/awx/main/tasks/jobs.py @@ -413,6 +413,9 @@ class BaseTask(object): """ instance.log_lifecycle("pre_run") + # Before task is started, ensure that job_event partitions exist + create_partition(instance.event_class._meta.db_table, start=instance.created) + def post_run_hook(self, instance, status): """ Hook for any steps to run before job/task is marked as complete. @@ -718,7 +721,6 @@ class SourceControlMixin(BaseTask): local_project_sync = project.create_project_update(_eager_fields=sync_metafields) local_project_sync.log_lifecycle("controller_node_chosen") local_project_sync.log_lifecycle("execution_node_chosen") - create_partition(local_project_sync.event_class._meta.db_table, start=local_project_sync.created) return local_project_sync def sync_and_copy_without_lock(self, project, private_data_dir, scm_branch=None): diff --git a/awx/main/tests/functional/task_management/test_scheduler.py b/awx/main/tests/functional/task_management/test_scheduler.py index 4081601918..b28152ff01 100644 --- a/awx/main/tests/functional/task_management/test_scheduler.py +++ b/awx/main/tests/functional/task_management/test_scheduler.py @@ -66,7 +66,7 @@ class TestJobLifeCycle: # Submits jobs # intermission - jobs will run and reschedule TM when finished self.run_tm(DependencyManager()) # flip dependencies_processed to True - self.run_tm(TaskManager(), [mock.call('waiting'), mock.call('waiting')]) + self.run_tm(TaskManager()) # I am the job runner for job in jt.jobs.all(): job.status = 'successful' diff --git a/awx/main/tests/functional/test_inventory_source_injectors.py b/awx/main/tests/functional/test_inventory_source_injectors.py index 547694c8a1..97fc8a7c17 100644 --- a/awx/main/tests/functional/test_inventory_source_injectors.py +++ b/awx/main/tests/functional/test_inventory_source_injectors.py @@ -261,5 +261,6 @@ def test_inventory_update_injected_content(this_kind, inventory, fake_credential with mock.patch.object(UnifiedJob, 'websocket_emit_status', mock.Mock()): # The point of this test is that we replace run with assertions with mock.patch('awx.main.tasks.receptor.AWXReceptorJob.run', substitute_run): - # so this sets up everything for a run and then yields control over to substitute_run - task.run(inventory_update.pk) + with mock.patch('awx.main.tasks.jobs.create_partition'): + # so this sets up everything for a run and then yields control over to substitute_run + task.run(inventory_update.pk) diff --git a/awx/main/tests/unit/test_tasks.py b/awx/main/tests/unit/test_tasks.py index 04fd42ec92..9a59e091d1 100644 --- a/awx/main/tests/unit/test_tasks.py +++ b/awx/main/tests/unit/test_tasks.py @@ -80,6 +80,12 @@ def patch_Job(): yield +@pytest.fixture +def mock_create_partition(): + with mock.patch('awx.main.tasks.jobs.create_partition') as cp_mock: + yield cp_mock + + @pytest.fixture def patch_Organization(): _credentials = [] @@ -463,7 +469,7 @@ class TestExtraVarSanitation(TestJobExecution): class TestGenericRun: - def test_generic_failure(self, patch_Job, execution_environment, mock_me): + def test_generic_failure(self, patch_Job, execution_environment, mock_me, mock_create_partition): job = Job(status='running', inventory=Inventory(), project=Project(local_path='/projects/_23_foo')) job.websocket_emit_status = mock.Mock() job.execution_environment = execution_environment @@ -483,7 +489,7 @@ class TestGenericRun: assert update_model_call['status'] == 'error' assert update_model_call['emitted_events'] == 0 - def test_cancel_flag(self, job, update_model_wrapper, execution_environment, mock_me): + def test_cancel_flag(self, job, update_model_wrapper, execution_environment, mock_me, mock_create_partition): job.status = 'running' job.cancel_flag = True job.websocket_emit_status = mock.Mock() @@ -582,7 +588,7 @@ class TestGenericRun: @pytest.mark.django_db class TestAdhocRun(TestJobExecution): - def test_options_jinja_usage(self, adhoc_job, adhoc_update_model_wrapper, mock_me): + def test_options_jinja_usage(self, adhoc_job, adhoc_update_model_wrapper, mock_me, mock_create_partition): ExecutionEnvironment.objects.create(name='Control Plane EE', managed=True) ExecutionEnvironment.objects.create(name='Default Job EE', managed=False) @@ -1936,7 +1942,7 @@ def test_managed_injector_redaction(injector_cls): assert 'very_secret_value' not in str(build_safe_env(env)) -def test_job_run_no_ee(mock_me): +def test_job_run_no_ee(mock_me, mock_create_partition): org = Organization(pk=1) proj = Project(pk=1, organization=org) job = Job(project=proj, organization=org, inventory=Inventory(pk=1))