From a665d960267e56710a631c23054184ffd19078ff Mon Sep 17 00:00:00 2001 From: John Westcott IV Date: Thu, 27 Apr 2023 08:03:27 -0400 Subject: [PATCH] Replacing psycopg2.copy_expert with psycopg3.copy --- awx/main/analytics/collectors.py | 5 +- awx/main/models/unified_jobs.py | 12 ++-- .../functional/analytics/test_collectors.py | 71 ++++++++++++------- awx/main/tests/functional/api/test_job.py | 2 +- .../api/test_unified_jobs_stdout.py | 18 ++--- awx/main/tests/functional/conftest.py | 44 ++++++++---- .../functional/models/test_notifications.py | 2 +- tools/scripts/firehose.py | 44 +++++------- 8 files changed, 114 insertions(+), 84 deletions(-) diff --git a/awx/main/analytics/collectors.py b/awx/main/analytics/collectors.py index 15577c9696..1279c4596e 100644 --- a/awx/main/analytics/collectors.py +++ b/awx/main/analytics/collectors.py @@ -399,7 +399,10 @@ def _copy_table(table, query, path): file_path = os.path.join(path, table + '_table.csv') file = FileSplitter(filespec=file_path) with connection.cursor() as cursor: - cursor.copy_expert(query, file) + with cursor.copy(query) as copy: + while data := copy.read(): + byte_data = bytes(data) + file.write(byte_data.decode()) return file.file_list() diff --git a/awx/main/models/unified_jobs.py b/awx/main/models/unified_jobs.py index d5885a2b0b..1e987fa982 100644 --- a/awx/main/models/unified_jobs.py +++ b/awx/main/models/unified_jobs.py @@ -1137,11 +1137,6 @@ class UnifiedJob( if total > max_supported: raise StdoutMaxBytesExceeded(total, max_supported) - # psycopg2's copy_expert writes bytes, but callers of this - # function assume a str-based fd will be returned; decode - # .write() calls on the fly to maintain this interface - _write = fd.write - fd.write = lambda s: _write(smart_str(s)) tbl = self._meta.db_table + 'event' created_by_cond = '' if self.has_unpartitioned_events: @@ -1150,7 +1145,12 @@ class UnifiedJob( created_by_cond = f"job_created='{self.created.isoformat()}' AND " sql = f"copy (select stdout from {tbl} where {created_by_cond}{self.event_parent_key}={self.id} and stdout != '' order by start_line) to stdout" # nosql - cursor.copy_expert(sql, fd) + # psycopg3's copy writes bytes, but callers of this + # function assume a str-based fd will be returned; decode + # .write() calls on the fly to maintain this interface + with cursor.copy(sql) as copy: + while data := copy.read(): + fd.write(smart_str(bytes(data))) if hasattr(fd, 'name'): # If we're dealing with a physical file, use `sed` to clean diff --git a/awx/main/tests/functional/analytics/test_collectors.py b/awx/main/tests/functional/analytics/test_collectors.py index 0fed6e9c15..4dcb9cd3c3 100644 --- a/awx/main/tests/functional/analytics/test_collectors.py +++ b/awx/main/tests/functional/analytics/test_collectors.py @@ -2,8 +2,8 @@ import pytest import tempfile import os import re -import shutil import csv +from io import StringIO from django.utils.timezone import now from datetime import timedelta @@ -20,15 +20,16 @@ from awx.main.models import ( ) -@pytest.fixture -def sqlite_copy_expert(request): - # copy_expert is postgres-specific, and SQLite doesn't support it; mock its - # behavior to test that it writes a file that contains stdout from events - path = tempfile.mkdtemp(prefix="copied_tables") +class MockCopy: + headers = None + results = None + sent_data = False - def write_stdout(self, sql, fd): + def __init__(self, sql, parent_connection): # Would be cool if we instead properly disected the SQL query and verified # it that way. But instead, we just take the naive approach here. + self.results = None + self.headers = None sql = sql.strip() assert sql.startswith("COPY (") assert sql.endswith(") TO STDOUT WITH CSV HEADER") @@ -51,29 +52,49 @@ def sqlite_copy_expert(request): elif not line.endswith(","): sql_new[-1] = sql_new[-1].rstrip(",") sql = "\n".join(sql_new) + parent_connection.execute(sql) + self.results = parent_connection.fetchall() + self.headers = [i[0] for i in parent_connection.description] - self.execute(sql) - results = self.fetchall() - headers = [i[0] for i in self.description] + def read(self): + if not self.sent_data: + mem_file = StringIO() + csv_handle = csv.writer( + mem_file, + delimiter=",", + quoting=csv.QUOTE_ALL, + escapechar="\\", + lineterminator="\n", + ) + if self.headers: + csv_handle.writerow(self.headers) + if self.results: + csv_handle.writerows(self.results) + self.sent_data = True + return memoryview((mem_file.getvalue()).encode()) + return None - csv_handle = csv.writer( - fd, - delimiter=",", - quoting=csv.QUOTE_ALL, - escapechar="\\", - lineterminator="\n", - ) - csv_handle.writerow(headers) - csv_handle.writerows(results) + def __enter__(self): + return self - setattr(SQLiteCursorWrapper, "copy_expert", write_stdout) - request.addfinalizer(lambda: shutil.rmtree(path)) - request.addfinalizer(lambda: delattr(SQLiteCursorWrapper, "copy_expert")) - return path + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +@pytest.fixture +def sqlite_copy(request, mocker): + # copy is postgres-specific, and SQLite doesn't support it; mock its + # behavior to test that it writes a file that contains stdout from events + + def write_stdout(self, sql): + mock_copy = MockCopy(sql, self) + return mock_copy + + mocker.patch.object(SQLiteCursorWrapper, 'copy', write_stdout, create=True) @pytest.mark.django_db -def test_copy_tables_unified_job_query(sqlite_copy_expert, project, inventory, job_template): +def test_copy_tables_unified_job_query(sqlite_copy, project, inventory, job_template): """ Ensure that various unified job types are in the output of the query. """ @@ -127,7 +148,7 @@ def workflow_job(states=["new", "new", "new", "new", "new"]): @pytest.mark.django_db -def test_copy_tables_workflow_job_node_query(sqlite_copy_expert, workflow_job): +def test_copy_tables_workflow_job_node_query(sqlite_copy, workflow_job): time_start = now() - timedelta(hours=9) with tempfile.TemporaryDirectory() as tmpdir: diff --git a/awx/main/tests/functional/api/test_job.py b/awx/main/tests/functional/api/test_job.py index 836f94da0c..ea87fe40a0 100644 --- a/awx/main/tests/functional/api/test_job.py +++ b/awx/main/tests/functional/api/test_job.py @@ -224,7 +224,7 @@ class TestControllerNode: return AdHocCommand.objects.create(inventory=inventory) @pytest.mark.django_db - def test_field_controller_node_exists(self, sqlite_copy_expert, admin_user, job, project_update, inventory_update, adhoc, get, system_job_factory): + def test_field_controller_node_exists(self, sqlite_copy, admin_user, job, project_update, inventory_update, adhoc, get, system_job_factory): system_job = system_job_factory() r = get(reverse('api:unified_job_list') + '?id={}'.format(job.id), admin_user, expect=200) diff --git a/awx/main/tests/functional/api/test_unified_jobs_stdout.py b/awx/main/tests/functional/api/test_unified_jobs_stdout.py index dad55c5ba0..3dcef8f0e7 100644 --- a/awx/main/tests/functional/api/test_unified_jobs_stdout.py +++ b/awx/main/tests/functional/api/test_unified_jobs_stdout.py @@ -57,7 +57,7 @@ def _mk_inventory_update(created=None): [_mk_inventory_update, InventoryUpdateEvent, 'inventory_update', 'api:inventory_update_stdout'], ], ) -def test_text_stdout(sqlite_copy_expert, Parent, Child, relation, view, get, admin): +def test_text_stdout(sqlite_copy, Parent, Child, relation, view, get, admin): job = Parent() job.save() for i in range(3): @@ -79,7 +79,7 @@ def test_text_stdout(sqlite_copy_expert, Parent, Child, relation, view, get, adm ], ) @pytest.mark.parametrize('download', [True, False]) -def test_ansi_stdout_filtering(sqlite_copy_expert, Parent, Child, relation, view, download, get, admin): +def test_ansi_stdout_filtering(sqlite_copy, Parent, Child, relation, view, download, get, admin): job = Parent() job.save() for i in range(3): @@ -111,7 +111,7 @@ def test_ansi_stdout_filtering(sqlite_copy_expert, Parent, Child, relation, view [_mk_inventory_update, InventoryUpdateEvent, 'inventory_update', 'api:inventory_update_stdout'], ], ) -def test_colorized_html_stdout(sqlite_copy_expert, Parent, Child, relation, view, get, admin): +def test_colorized_html_stdout(sqlite_copy, Parent, Child, relation, view, get, admin): job = Parent() job.save() for i in range(3): @@ -134,7 +134,7 @@ def test_colorized_html_stdout(sqlite_copy_expert, Parent, Child, relation, view [_mk_inventory_update, InventoryUpdateEvent, 'inventory_update', 'api:inventory_update_stdout'], ], ) -def test_stdout_line_range(sqlite_copy_expert, Parent, Child, relation, view, get, admin): +def test_stdout_line_range(sqlite_copy, Parent, Child, relation, view, get, admin): job = Parent() job.save() for i in range(20): @@ -146,7 +146,7 @@ def test_stdout_line_range(sqlite_copy_expert, Parent, Child, relation, view, ge @pytest.mark.django_db -def test_text_stdout_from_system_job_events(sqlite_copy_expert, get, admin): +def test_text_stdout_from_system_job_events(sqlite_copy, get, admin): created = tz_now() job = SystemJob(created=created) job.save() @@ -158,7 +158,7 @@ def test_text_stdout_from_system_job_events(sqlite_copy_expert, get, admin): @pytest.mark.django_db -def test_text_stdout_with_max_stdout(sqlite_copy_expert, get, admin): +def test_text_stdout_with_max_stdout(sqlite_copy, get, admin): created = tz_now() job = SystemJob(created=created) job.save() @@ -185,7 +185,7 @@ def test_text_stdout_with_max_stdout(sqlite_copy_expert, get, admin): ) @pytest.mark.parametrize('fmt', ['txt', 'ansi']) @mock.patch('awx.main.redact.UriCleaner.SENSITIVE_URI_PATTERN', mock.Mock(**{'search.return_value': None})) # really slow for large strings -def test_max_bytes_display(sqlite_copy_expert, Parent, Child, relation, view, fmt, get, admin): +def test_max_bytes_display(sqlite_copy, Parent, Child, relation, view, fmt, get, admin): created = tz_now() job = Parent(created=created) job.save() @@ -255,7 +255,7 @@ def test_legacy_result_stdout_with_max_bytes(Cls, view, fmt, get, admin): ], ) @pytest.mark.parametrize('fmt', ['txt', 'ansi', 'txt_download', 'ansi_download']) -def test_text_with_unicode_stdout(sqlite_copy_expert, Parent, Child, relation, view, get, admin, fmt): +def test_text_with_unicode_stdout(sqlite_copy, Parent, Child, relation, view, get, admin, fmt): job = Parent() job.save() for i in range(3): @@ -267,7 +267,7 @@ def test_text_with_unicode_stdout(sqlite_copy_expert, Parent, Child, relation, v @pytest.mark.django_db -def test_unicode_with_base64_ansi(sqlite_copy_expert, get, admin): +def test_unicode_with_base64_ansi(sqlite_copy, get, admin): created = tz_now() job = Job(created=created) job.save() diff --git a/awx/main/tests/functional/conftest.py b/awx/main/tests/functional/conftest.py index c87f0a6c1a..d65c80e96c 100644 --- a/awx/main/tests/functional/conftest.py +++ b/awx/main/tests/functional/conftest.py @@ -1,8 +1,6 @@ # Python import pytest from unittest import mock -import tempfile -import shutil import urllib.parse from unittest.mock import PropertyMock @@ -789,25 +787,43 @@ def oauth_application(admin): return Application.objects.create(name='test app', user=admin, client_type='confidential', authorization_grant_type='password') -@pytest.fixture -def sqlite_copy_expert(request): - # copy_expert is postgres-specific, and SQLite doesn't support it; mock its - # behavior to test that it writes a file that contains stdout from events - path = tempfile.mkdtemp(prefix='job-event-stdout') +class MockCopy: + events = [] + index = -1 - def write_stdout(self, sql, fd): - # simulate postgres copy_expert support with ORM code + def __init__(self, sql): + self.events = [] parts = sql.split(' ') tablename = parts[parts.index('from') + 1] for cls in (JobEvent, AdHocCommandEvent, ProjectUpdateEvent, InventoryUpdateEvent, SystemJobEvent): if cls._meta.db_table == tablename: for event in cls.objects.order_by('start_line').all(): - fd.write(event.stdout) + self.events.append(event.stdout) - setattr(SQLiteCursorWrapper, 'copy_expert', write_stdout) - request.addfinalizer(lambda: shutil.rmtree(path)) - request.addfinalizer(lambda: delattr(SQLiteCursorWrapper, 'copy_expert')) - return path + def read(self): + self.index = self.index + 1 + if self.index < len(self.events): + return memoryview(self.events[self.index].encode()) + + return None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +@pytest.fixture +def sqlite_copy(request, mocker): + # copy is postgres-specific, and SQLite doesn't support it; mock its + # behavior to test that it writes a file that contains stdout from events + + def write_stdout(self, sql): + mock_copy = MockCopy(sql) + return mock_copy + + mocker.patch.object(SQLiteCursorWrapper, 'copy', write_stdout, create=True) @pytest.fixture diff --git a/awx/main/tests/functional/models/test_notifications.py b/awx/main/tests/functional/models/test_notifications.py index 2d1d5e0f17..2c1d6022de 100644 --- a/awx/main/tests/functional/models/test_notifications.py +++ b/awx/main/tests/functional/models/test_notifications.py @@ -98,7 +98,7 @@ class TestJobNotificationMixin(object): @pytest.mark.django_db @pytest.mark.parametrize('JobClass', [AdHocCommand, InventoryUpdate, Job, ProjectUpdate, SystemJob, WorkflowJob]) - def test_context(self, JobClass, sqlite_copy_expert, project, inventory_source): + def test_context(self, JobClass, sqlite_copy, project, inventory_source): """The Jinja context defines all of the fields that can be used by a template. Ensure that the context generated for each job type has the expected structure.""" kwargs = {} diff --git a/tools/scripts/firehose.py b/tools/scripts/firehose.py index e23287a5b8..d11edd47eb 100755 --- a/tools/scripts/firehose.py +++ b/tools/scripts/firehose.py @@ -98,37 +98,27 @@ class YieldedRows(StringIO): ) self.rowlist.append(row) - def read(self, x): - if self.rows <= 0: - self.close() - return '' - elif self.rows >= 1 and self.rows < 1000: - event_rows = self.rowlist[random.randrange(len(self.rowlist))] * self.rows - self.rows -= self.rows - return event_rows - self.rows -= 1000 - return self.rowlist[random.randrange(len(self.rowlist))] * 1000 - def firehose(job, count, created_stamp, modified_stamp): conn = psycopg.connect(dsn) f = YieldedRows(job, count, created_stamp, modified_stamp) - with conn.cursor() as cursor: - cursor.copy_expert( - ( - 'COPY ' - 'main_jobevent(' - 'created, modified, job_created, event, event_data, failed, changed, ' - 'host_name, play, role, task, counter, host_id, job_id, uuid, ' - 'parent_uuid, end_line, playbook, start_line, stdout, verbosity' - ') ' - 'FROM STDIN' - ), - f, - size=1024 * 1000, - ) - conn.commit() - conn.close() + sql = ''' + COPY main_jobevent( + created, modified, job_created, event, event_data, failed, changed, + host_name, play, role, task, counter, host_id, job_id, uuid, + parent_uuid, end_line, playbook, start_line, stdout, verbosity + ) FROM STDIN + ''' + try: + with conn.cursor() as cursor: + with cursor.copy(sql) as copy: + copy.write("".join(f.rowlist)) + except Exception as e: + print("Failed to import events") + print(e) + finally: + conn.commit() + conn.close() def cleanup(sql):