diff --git a/awx/main/dispatch/__init__.py b/awx/main/dispatch/__init__.py index c240f6fee9..7fa4bd06f1 100644 --- a/awx/main/dispatch/__init__.py +++ b/awx/main/dispatch/__init__.py @@ -4,6 +4,7 @@ import select from contextlib import contextmanager from django.conf import settings +from django.db import connection as pg_connection NOT_READY = ([], [], []) @@ -15,7 +16,6 @@ def get_local_queuename(): class PubSub(object): def __init__(self, conn): - assert conn.autocommit, "Connection must be in autocommit mode." self.conn = conn def listen(self, channel): @@ -31,6 +31,9 @@ class PubSub(object): cur.execute('SELECT pg_notify(%s, %s);', (channel, payload)) def events(self, select_timeout=5, yield_timeouts=False): + if not pg_connection.get_autocommit(): + raise RuntimeError('Listening for events can only be done in autocommit mode') + while True: if select.select([self.conn], [], [], select_timeout) == NOT_READY: if yield_timeouts: @@ -45,11 +48,32 @@ class PubSub(object): @contextmanager -def pg_bus_conn(): - conf = settings.DATABASES['default'] - conn = psycopg2.connect(dbname=conf['NAME'], host=conf['HOST'], user=conf['USER'], password=conf['PASSWORD'], port=conf['PORT'], **conf.get("OPTIONS", {})) - # Django connection.cursor().connection doesn't have autocommit=True on - conn.set_session(autocommit=True) +def pg_bus_conn(new_connection=False): + ''' + Any listeners probably want to establish a new database connection, + separate from the Django connection used for queries, because that will prevent + losing connection to the channel whenever a .close() happens. + + Any publishers probably want to use the existing connection + so that messages follow postgres transaction rules + https://www.postgresql.org/docs/current/sql-notify.html + ''' + + if new_connection: + conf = settings.DATABASES['default'] + conn = psycopg2.connect( + dbname=conf['NAME'], host=conf['HOST'], user=conf['USER'], password=conf['PASSWORD'], port=conf['PORT'], **conf.get("OPTIONS", {}) + ) + # Django connection.cursor().connection doesn't have autocommit=True on by default + conn.set_session(autocommit=True) + else: + if pg_connection.connection is None: + pg_connection.connect() + if pg_connection.connection is None: + raise RuntimeError('Unexpectedly could not connect to postgres for pg_notify actions') + conn = pg_connection.connection + pubsub = PubSub(conn) yield pubsub - conn.close() + if new_connection: + conn.close() diff --git a/awx/main/dispatch/worker/base.py b/awx/main/dispatch/worker/base.py index b982cb8ab4..b30f7ec17a 100644 --- a/awx/main/dispatch/worker/base.py +++ b/awx/main/dispatch/worker/base.py @@ -154,7 +154,7 @@ class AWXConsumerPG(AWXConsumerBase): while True: try: - with pg_bus_conn() as conn: + with pg_bus_conn(new_connection=True) as conn: for queue in self.queues: conn.listen(queue) if init is False: diff --git a/awx/main/models/unified_jobs.py b/awx/main/models/unified_jobs.py index b048b8ee2e..0c95d5b3e9 100644 --- a/awx/main/models/unified_jobs.py +++ b/awx/main/models/unified_jobs.py @@ -1274,7 +1274,7 @@ class UnifiedJob( def _websocket_emit_status(self, status): try: status_data = dict(unified_job_id=self.id, status=status) - if status == 'waiting': + if status == 'running': if self.instance_group: status_data['instance_group_name'] = self.instance_group.name else: diff --git a/awx/main/scheduler/task_manager.py b/awx/main/scheduler/task_manager.py index 157cbb1610..45f262ebe6 100644 --- a/awx/main/scheduler/task_manager.py +++ b/awx/main/scheduler/task_manager.py @@ -11,7 +11,7 @@ import sys import signal # Django -from django.db import transaction, connection +from django.db import transaction from django.utils.translation import gettext_lazy as _, gettext_noop from django.utils.timezone import now as tz_now from django.conf import settings @@ -39,7 +39,7 @@ from awx.main.utils import ( ScheduleTaskManager, ScheduleWorkflowManager, ) -from awx.main.utils.common import create_partition, task_manager_bulk_reschedule +from awx.main.utils.common import task_manager_bulk_reschedule from awx.main.signals import disable_activity_stream from awx.main.constants import ACTIVE_STATES from awx.main.scheduler.dependency_graph import DependencyGraph @@ -556,22 +556,23 @@ class TaskManager(TaskBase): task.save() task.log_lifecycle("waiting") - def post_commit(): - if task.status != 'failed' and type(task) is not WorkflowJob: - # Before task is dispatched, ensure that job_event partitions exist - create_partition(task.event_class._meta.db_table, start=task.created) - task_cls = task._get_task_class() - task_cls.apply_async( - [task.pk], - opts, - queue=task.get_queue_name(), - uuid=task.celery_task_id, - callbacks=[{'task': handle_work_success.name, 'kwargs': {'task_actual': task_actual}}], - errbacks=[{'task': handle_work_error.name, 'args': [task.celery_task_id], 'kwargs': {'subtasks': [task_actual] + dependencies}}], - ) + # apply_async does a NOTIFY to the channel dispatcher is listening to + # postgres will treat this as part of the transaction, which is what we want + if task.status != 'failed' and type(task) is not WorkflowJob: + task_cls = task._get_task_class() + task_cls.apply_async( + [task.pk], + opts, + queue=task.get_queue_name(), + uuid=task.celery_task_id, + callbacks=[{'task': handle_work_success.name, 'kwargs': {'task_actual': task_actual}}], + errbacks=[{'task': handle_work_error.name, 'args': [task.celery_task_id], 'kwargs': {'subtasks': [task_actual] + dependencies}}], + ) - task.websocket_emit_status(task.status) # adds to on_commit - connection.on_commit(post_commit) + # In exception cases, like a job failing pre-start checks, we send the websocket status message + # for jobs going into waiting, we omit this because of performance issues, as it should go to running quickly + if task.status != 'waiting': + task.websocket_emit_status(task.status) # adds to on_commit @timeit def process_running_tasks(self, running_tasks): diff --git a/awx/main/tasks/jobs.py b/awx/main/tasks/jobs.py index 774469c3ba..420f171e22 100644 --- a/awx/main/tasks/jobs.py +++ b/awx/main/tasks/jobs.py @@ -413,6 +413,9 @@ class BaseTask(object): """ instance.log_lifecycle("pre_run") + # Before task is started, ensure that job_event partitions exist + create_partition(instance.event_class._meta.db_table, start=instance.created) + def post_run_hook(self, instance, status): """ Hook for any steps to run before job/task is marked as complete. @@ -718,7 +721,6 @@ class SourceControlMixin(BaseTask): local_project_sync = project.create_project_update(_eager_fields=sync_metafields) local_project_sync.log_lifecycle("controller_node_chosen") local_project_sync.log_lifecycle("execution_node_chosen") - create_partition(local_project_sync.event_class._meta.db_table, start=local_project_sync.created) return local_project_sync def sync_and_copy_without_lock(self, project, private_data_dir, scm_branch=None): diff --git a/awx/main/tests/functional/task_management/test_scheduler.py b/awx/main/tests/functional/task_management/test_scheduler.py index 4081601918..b28152ff01 100644 --- a/awx/main/tests/functional/task_management/test_scheduler.py +++ b/awx/main/tests/functional/task_management/test_scheduler.py @@ -66,7 +66,7 @@ class TestJobLifeCycle: # Submits jobs # intermission - jobs will run and reschedule TM when finished self.run_tm(DependencyManager()) # flip dependencies_processed to True - self.run_tm(TaskManager(), [mock.call('waiting'), mock.call('waiting')]) + self.run_tm(TaskManager()) # I am the job runner for job in jt.jobs.all(): job.status = 'successful' diff --git a/awx/main/tests/functional/test_inventory_source_injectors.py b/awx/main/tests/functional/test_inventory_source_injectors.py index 547694c8a1..97fc8a7c17 100644 --- a/awx/main/tests/functional/test_inventory_source_injectors.py +++ b/awx/main/tests/functional/test_inventory_source_injectors.py @@ -261,5 +261,6 @@ def test_inventory_update_injected_content(this_kind, inventory, fake_credential with mock.patch.object(UnifiedJob, 'websocket_emit_status', mock.Mock()): # The point of this test is that we replace run with assertions with mock.patch('awx.main.tasks.receptor.AWXReceptorJob.run', substitute_run): - # so this sets up everything for a run and then yields control over to substitute_run - task.run(inventory_update.pk) + with mock.patch('awx.main.tasks.jobs.create_partition'): + # so this sets up everything for a run and then yields control over to substitute_run + task.run(inventory_update.pk) diff --git a/awx/main/tests/unit/test_tasks.py b/awx/main/tests/unit/test_tasks.py index 04fd42ec92..9a59e091d1 100644 --- a/awx/main/tests/unit/test_tasks.py +++ b/awx/main/tests/unit/test_tasks.py @@ -80,6 +80,12 @@ def patch_Job(): yield +@pytest.fixture +def mock_create_partition(): + with mock.patch('awx.main.tasks.jobs.create_partition') as cp_mock: + yield cp_mock + + @pytest.fixture def patch_Organization(): _credentials = [] @@ -463,7 +469,7 @@ class TestExtraVarSanitation(TestJobExecution): class TestGenericRun: - def test_generic_failure(self, patch_Job, execution_environment, mock_me): + def test_generic_failure(self, patch_Job, execution_environment, mock_me, mock_create_partition): job = Job(status='running', inventory=Inventory(), project=Project(local_path='/projects/_23_foo')) job.websocket_emit_status = mock.Mock() job.execution_environment = execution_environment @@ -483,7 +489,7 @@ class TestGenericRun: assert update_model_call['status'] == 'error' assert update_model_call['emitted_events'] == 0 - def test_cancel_flag(self, job, update_model_wrapper, execution_environment, mock_me): + def test_cancel_flag(self, job, update_model_wrapper, execution_environment, mock_me, mock_create_partition): job.status = 'running' job.cancel_flag = True job.websocket_emit_status = mock.Mock() @@ -582,7 +588,7 @@ class TestGenericRun: @pytest.mark.django_db class TestAdhocRun(TestJobExecution): - def test_options_jinja_usage(self, adhoc_job, adhoc_update_model_wrapper, mock_me): + def test_options_jinja_usage(self, adhoc_job, adhoc_update_model_wrapper, mock_me, mock_create_partition): ExecutionEnvironment.objects.create(name='Control Plane EE', managed=True) ExecutionEnvironment.objects.create(name='Default Job EE', managed=False) @@ -1936,7 +1942,7 @@ def test_managed_injector_redaction(injector_cls): assert 'very_secret_value' not in str(build_safe_env(env)) -def test_job_run_no_ee(mock_me): +def test_job_run_no_ee(mock_me, mock_create_partition): org = Organization(pk=1) proj = Project(pk=1, organization=org) job = Job(project=proj, organization=org, inventory=Inventory(pk=1))