Submit job to dispatcher as part of transaction (#12573)

Make it so that submitting a task to the dispatcher happens as part of the transaction.
  this applies to dispatcher task "publishers" which NOTIFY the pg_notify queue
  if the transaction is not successful, it will not be sent, as per postgres docs

This keeps current behavior for pg_notify listeners
  practically, this only applies for the awx-manage run_dispatcher service
  this requires creating a separate connection and keeping it long-lived
  arbitrary code will occasionally close the main connection, which would stop listening

Stop sending the waiting status websocket message
  this is required because the ordering cannot be maintained with other changes here
  the instance group data is moved to the running websocket message payload

Move call to create_partition from task manager to pre_run_hook
  mock this in relevant unit tests
This commit is contained in:
Alan Rominger
2022-08-18 09:43:53 -04:00
committed by GitHub
parent 532aa83555
commit e87fabe6bb
8 changed files with 68 additions and 34 deletions

View File

@@ -4,6 +4,7 @@ import select
from contextlib import contextmanager from contextlib import contextmanager
from django.conf import settings from django.conf import settings
from django.db import connection as pg_connection
NOT_READY = ([], [], []) NOT_READY = ([], [], [])
@@ -15,7 +16,6 @@ def get_local_queuename():
class PubSub(object): class PubSub(object):
def __init__(self, conn): def __init__(self, conn):
assert conn.autocommit, "Connection must be in autocommit mode."
self.conn = conn self.conn = conn
def listen(self, channel): def listen(self, channel):
@@ -31,6 +31,9 @@ class PubSub(object):
cur.execute('SELECT pg_notify(%s, %s);', (channel, payload)) cur.execute('SELECT pg_notify(%s, %s);', (channel, payload))
def events(self, select_timeout=5, yield_timeouts=False): def events(self, select_timeout=5, yield_timeouts=False):
if not pg_connection.get_autocommit():
raise RuntimeError('Listening for events can only be done in autocommit mode')
while True: while True:
if select.select([self.conn], [], [], select_timeout) == NOT_READY: if select.select([self.conn], [], [], select_timeout) == NOT_READY:
if yield_timeouts: if yield_timeouts:
@@ -45,11 +48,32 @@ class PubSub(object):
@contextmanager @contextmanager
def pg_bus_conn(): def pg_bus_conn(new_connection=False):
conf = settings.DATABASES['default'] '''
conn = psycopg2.connect(dbname=conf['NAME'], host=conf['HOST'], user=conf['USER'], password=conf['PASSWORD'], port=conf['PORT'], **conf.get("OPTIONS", {})) Any listeners probably want to establish a new database connection,
# Django connection.cursor().connection doesn't have autocommit=True on separate from the Django connection used for queries, because that will prevent
conn.set_session(autocommit=True) losing connection to the channel whenever a .close() happens.
Any publishers probably want to use the existing connection
so that messages follow postgres transaction rules
https://www.postgresql.org/docs/current/sql-notify.html
'''
if new_connection:
conf = settings.DATABASES['default']
conn = psycopg2.connect(
dbname=conf['NAME'], host=conf['HOST'], user=conf['USER'], password=conf['PASSWORD'], port=conf['PORT'], **conf.get("OPTIONS", {})
)
# Django connection.cursor().connection doesn't have autocommit=True on by default
conn.set_session(autocommit=True)
else:
if pg_connection.connection is None:
pg_connection.connect()
if pg_connection.connection is None:
raise RuntimeError('Unexpectedly could not connect to postgres for pg_notify actions')
conn = pg_connection.connection
pubsub = PubSub(conn) pubsub = PubSub(conn)
yield pubsub yield pubsub
conn.close() if new_connection:
conn.close()

View File

@@ -154,7 +154,7 @@ class AWXConsumerPG(AWXConsumerBase):
while True: while True:
try: try:
with pg_bus_conn() as conn: with pg_bus_conn(new_connection=True) as conn:
for queue in self.queues: for queue in self.queues:
conn.listen(queue) conn.listen(queue)
if init is False: if init is False:

View File

@@ -1274,7 +1274,7 @@ class UnifiedJob(
def _websocket_emit_status(self, status): def _websocket_emit_status(self, status):
try: try:
status_data = dict(unified_job_id=self.id, status=status) status_data = dict(unified_job_id=self.id, status=status)
if status == 'waiting': if status == 'running':
if self.instance_group: if self.instance_group:
status_data['instance_group_name'] = self.instance_group.name status_data['instance_group_name'] = self.instance_group.name
else: else:

View File

@@ -11,7 +11,7 @@ import sys
import signal import signal
# Django # Django
from django.db import transaction, connection from django.db import transaction
from django.utils.translation import gettext_lazy as _, gettext_noop from django.utils.translation import gettext_lazy as _, gettext_noop
from django.utils.timezone import now as tz_now from django.utils.timezone import now as tz_now
from django.conf import settings from django.conf import settings
@@ -39,7 +39,7 @@ from awx.main.utils import (
ScheduleTaskManager, ScheduleTaskManager,
ScheduleWorkflowManager, ScheduleWorkflowManager,
) )
from awx.main.utils.common import create_partition, task_manager_bulk_reschedule from awx.main.utils.common import task_manager_bulk_reschedule
from awx.main.signals import disable_activity_stream from awx.main.signals import disable_activity_stream
from awx.main.constants import ACTIVE_STATES from awx.main.constants import ACTIVE_STATES
from awx.main.scheduler.dependency_graph import DependencyGraph from awx.main.scheduler.dependency_graph import DependencyGraph
@@ -556,22 +556,23 @@ class TaskManager(TaskBase):
task.save() task.save()
task.log_lifecycle("waiting") task.log_lifecycle("waiting")
def post_commit(): # apply_async does a NOTIFY to the channel dispatcher is listening to
if task.status != 'failed' and type(task) is not WorkflowJob: # postgres will treat this as part of the transaction, which is what we want
# Before task is dispatched, ensure that job_event partitions exist if task.status != 'failed' and type(task) is not WorkflowJob:
create_partition(task.event_class._meta.db_table, start=task.created) task_cls = task._get_task_class()
task_cls = task._get_task_class() task_cls.apply_async(
task_cls.apply_async( [task.pk],
[task.pk], opts,
opts, queue=task.get_queue_name(),
queue=task.get_queue_name(), uuid=task.celery_task_id,
uuid=task.celery_task_id, callbacks=[{'task': handle_work_success.name, 'kwargs': {'task_actual': task_actual}}],
callbacks=[{'task': handle_work_success.name, 'kwargs': {'task_actual': task_actual}}], errbacks=[{'task': handle_work_error.name, 'args': [task.celery_task_id], 'kwargs': {'subtasks': [task_actual] + dependencies}}],
errbacks=[{'task': handle_work_error.name, 'args': [task.celery_task_id], 'kwargs': {'subtasks': [task_actual] + dependencies}}], )
)
task.websocket_emit_status(task.status) # adds to on_commit # In exception cases, like a job failing pre-start checks, we send the websocket status message
connection.on_commit(post_commit) # for jobs going into waiting, we omit this because of performance issues, as it should go to running quickly
if task.status != 'waiting':
task.websocket_emit_status(task.status) # adds to on_commit
@timeit @timeit
def process_running_tasks(self, running_tasks): def process_running_tasks(self, running_tasks):

View File

@@ -413,6 +413,9 @@ class BaseTask(object):
""" """
instance.log_lifecycle("pre_run") instance.log_lifecycle("pre_run")
# Before task is started, ensure that job_event partitions exist
create_partition(instance.event_class._meta.db_table, start=instance.created)
def post_run_hook(self, instance, status): def post_run_hook(self, instance, status):
""" """
Hook for any steps to run before job/task is marked as complete. Hook for any steps to run before job/task is marked as complete.
@@ -718,7 +721,6 @@ class SourceControlMixin(BaseTask):
local_project_sync = project.create_project_update(_eager_fields=sync_metafields) local_project_sync = project.create_project_update(_eager_fields=sync_metafields)
local_project_sync.log_lifecycle("controller_node_chosen") local_project_sync.log_lifecycle("controller_node_chosen")
local_project_sync.log_lifecycle("execution_node_chosen") local_project_sync.log_lifecycle("execution_node_chosen")
create_partition(local_project_sync.event_class._meta.db_table, start=local_project_sync.created)
return local_project_sync return local_project_sync
def sync_and_copy_without_lock(self, project, private_data_dir, scm_branch=None): def sync_and_copy_without_lock(self, project, private_data_dir, scm_branch=None):

View File

@@ -66,7 +66,7 @@ class TestJobLifeCycle:
# Submits jobs # Submits jobs
# intermission - jobs will run and reschedule TM when finished # intermission - jobs will run and reschedule TM when finished
self.run_tm(DependencyManager()) # flip dependencies_processed to True self.run_tm(DependencyManager()) # flip dependencies_processed to True
self.run_tm(TaskManager(), [mock.call('waiting'), mock.call('waiting')]) self.run_tm(TaskManager())
# I am the job runner # I am the job runner
for job in jt.jobs.all(): for job in jt.jobs.all():
job.status = 'successful' job.status = 'successful'

View File

@@ -261,5 +261,6 @@ def test_inventory_update_injected_content(this_kind, inventory, fake_credential
with mock.patch.object(UnifiedJob, 'websocket_emit_status', mock.Mock()): with mock.patch.object(UnifiedJob, 'websocket_emit_status', mock.Mock()):
# The point of this test is that we replace run with assertions # The point of this test is that we replace run with assertions
with mock.patch('awx.main.tasks.receptor.AWXReceptorJob.run', substitute_run): with mock.patch('awx.main.tasks.receptor.AWXReceptorJob.run', substitute_run):
# so this sets up everything for a run and then yields control over to substitute_run with mock.patch('awx.main.tasks.jobs.create_partition'):
task.run(inventory_update.pk) # so this sets up everything for a run and then yields control over to substitute_run
task.run(inventory_update.pk)

View File

@@ -80,6 +80,12 @@ def patch_Job():
yield yield
@pytest.fixture
def mock_create_partition():
with mock.patch('awx.main.tasks.jobs.create_partition') as cp_mock:
yield cp_mock
@pytest.fixture @pytest.fixture
def patch_Organization(): def patch_Organization():
_credentials = [] _credentials = []
@@ -463,7 +469,7 @@ class TestExtraVarSanitation(TestJobExecution):
class TestGenericRun: class TestGenericRun:
def test_generic_failure(self, patch_Job, execution_environment, mock_me): def test_generic_failure(self, patch_Job, execution_environment, mock_me, mock_create_partition):
job = Job(status='running', inventory=Inventory(), project=Project(local_path='/projects/_23_foo')) job = Job(status='running', inventory=Inventory(), project=Project(local_path='/projects/_23_foo'))
job.websocket_emit_status = mock.Mock() job.websocket_emit_status = mock.Mock()
job.execution_environment = execution_environment job.execution_environment = execution_environment
@@ -483,7 +489,7 @@ class TestGenericRun:
assert update_model_call['status'] == 'error' assert update_model_call['status'] == 'error'
assert update_model_call['emitted_events'] == 0 assert update_model_call['emitted_events'] == 0
def test_cancel_flag(self, job, update_model_wrapper, execution_environment, mock_me): def test_cancel_flag(self, job, update_model_wrapper, execution_environment, mock_me, mock_create_partition):
job.status = 'running' job.status = 'running'
job.cancel_flag = True job.cancel_flag = True
job.websocket_emit_status = mock.Mock() job.websocket_emit_status = mock.Mock()
@@ -582,7 +588,7 @@ class TestGenericRun:
@pytest.mark.django_db @pytest.mark.django_db
class TestAdhocRun(TestJobExecution): class TestAdhocRun(TestJobExecution):
def test_options_jinja_usage(self, adhoc_job, adhoc_update_model_wrapper, mock_me): def test_options_jinja_usage(self, adhoc_job, adhoc_update_model_wrapper, mock_me, mock_create_partition):
ExecutionEnvironment.objects.create(name='Control Plane EE', managed=True) ExecutionEnvironment.objects.create(name='Control Plane EE', managed=True)
ExecutionEnvironment.objects.create(name='Default Job EE', managed=False) ExecutionEnvironment.objects.create(name='Default Job EE', managed=False)
@@ -1936,7 +1942,7 @@ def test_managed_injector_redaction(injector_cls):
assert 'very_secret_value' not in str(build_safe_env(env)) assert 'very_secret_value' not in str(build_safe_env(env))
def test_job_run_no_ee(mock_me): def test_job_run_no_ee(mock_me, mock_create_partition):
org = Organization(pk=1) org = Organization(pk=1)
proj = Project(pk=1, organization=org) proj = Project(pk=1, organization=org)
job = Job(project=proj, organization=org, inventory=Inventory(pk=1)) job = Job(project=proj, organization=org, inventory=Inventory(pk=1))