Replacing psycopg2.copy_expert with psycopg3.copy

This commit is contained in:
John Westcott IV
2023-04-27 08:03:27 -04:00
committed by John Westcott IV
parent e47d30974c
commit a665d96026
8 changed files with 114 additions and 84 deletions

View File

@@ -399,7 +399,10 @@ def _copy_table(table, query, path):
file_path = os.path.join(path, table + '_table.csv') file_path = os.path.join(path, table + '_table.csv')
file = FileSplitter(filespec=file_path) file = FileSplitter(filespec=file_path)
with connection.cursor() as cursor: with connection.cursor() as cursor:
cursor.copy_expert(query, file) with cursor.copy(query) as copy:
while data := copy.read():
byte_data = bytes(data)
file.write(byte_data.decode())
return file.file_list() return file.file_list()

View File

@@ -1137,11 +1137,6 @@ class UnifiedJob(
if total > max_supported: if total > max_supported:
raise StdoutMaxBytesExceeded(total, max_supported) raise StdoutMaxBytesExceeded(total, max_supported)
# psycopg2's copy_expert writes bytes, but callers of this
# function assume a str-based fd will be returned; decode
# .write() calls on the fly to maintain this interface
_write = fd.write
fd.write = lambda s: _write(smart_str(s))
tbl = self._meta.db_table + 'event' tbl = self._meta.db_table + 'event'
created_by_cond = '' created_by_cond = ''
if self.has_unpartitioned_events: if self.has_unpartitioned_events:
@@ -1150,7 +1145,12 @@ class UnifiedJob(
created_by_cond = f"job_created='{self.created.isoformat()}' AND " created_by_cond = f"job_created='{self.created.isoformat()}' AND "
sql = f"copy (select stdout from {tbl} where {created_by_cond}{self.event_parent_key}={self.id} and stdout != '' order by start_line) to stdout" # nosql sql = f"copy (select stdout from {tbl} where {created_by_cond}{self.event_parent_key}={self.id} and stdout != '' order by start_line) to stdout" # nosql
cursor.copy_expert(sql, fd) # psycopg3's copy writes bytes, but callers of this
# function assume a str-based fd will be returned; decode
# .write() calls on the fly to maintain this interface
with cursor.copy(sql) as copy:
while data := copy.read():
fd.write(smart_str(bytes(data)))
if hasattr(fd, 'name'): if hasattr(fd, 'name'):
# If we're dealing with a physical file, use `sed` to clean # If we're dealing with a physical file, use `sed` to clean

View File

@@ -2,8 +2,8 @@ import pytest
import tempfile import tempfile
import os import os
import re import re
import shutil
import csv import csv
from io import StringIO
from django.utils.timezone import now from django.utils.timezone import now
from datetime import timedelta from datetime import timedelta
@@ -20,15 +20,16 @@ from awx.main.models import (
) )
@pytest.fixture class MockCopy:
def sqlite_copy_expert(request): headers = None
# copy_expert is postgres-specific, and SQLite doesn't support it; mock its results = None
# behavior to test that it writes a file that contains stdout from events sent_data = False
path = tempfile.mkdtemp(prefix="copied_tables")
def write_stdout(self, sql, fd): def __init__(self, sql, parent_connection):
# Would be cool if we instead properly disected the SQL query and verified # Would be cool if we instead properly disected the SQL query and verified
# it that way. But instead, we just take the naive approach here. # it that way. But instead, we just take the naive approach here.
self.results = None
self.headers = None
sql = sql.strip() sql = sql.strip()
assert sql.startswith("COPY (") assert sql.startswith("COPY (")
assert sql.endswith(") TO STDOUT WITH CSV HEADER") assert sql.endswith(") TO STDOUT WITH CSV HEADER")
@@ -51,29 +52,49 @@ def sqlite_copy_expert(request):
elif not line.endswith(","): elif not line.endswith(","):
sql_new[-1] = sql_new[-1].rstrip(",") sql_new[-1] = sql_new[-1].rstrip(",")
sql = "\n".join(sql_new) sql = "\n".join(sql_new)
parent_connection.execute(sql)
self.results = parent_connection.fetchall()
self.headers = [i[0] for i in parent_connection.description]
self.execute(sql) def read(self):
results = self.fetchall() if not self.sent_data:
headers = [i[0] for i in self.description] mem_file = StringIO()
csv_handle = csv.writer(
mem_file,
delimiter=",",
quoting=csv.QUOTE_ALL,
escapechar="\\",
lineterminator="\n",
)
if self.headers:
csv_handle.writerow(self.headers)
if self.results:
csv_handle.writerows(self.results)
self.sent_data = True
return memoryview((mem_file.getvalue()).encode())
return None
csv_handle = csv.writer( def __enter__(self):
fd, return self
delimiter=",",
quoting=csv.QUOTE_ALL,
escapechar="\\",
lineterminator="\n",
)
csv_handle.writerow(headers)
csv_handle.writerows(results)
setattr(SQLiteCursorWrapper, "copy_expert", write_stdout) def __exit__(self, exc_type, exc_val, exc_tb):
request.addfinalizer(lambda: shutil.rmtree(path)) pass
request.addfinalizer(lambda: delattr(SQLiteCursorWrapper, "copy_expert"))
return path
@pytest.fixture
def sqlite_copy(request, mocker):
# copy is postgres-specific, and SQLite doesn't support it; mock its
# behavior to test that it writes a file that contains stdout from events
def write_stdout(self, sql):
mock_copy = MockCopy(sql, self)
return mock_copy
mocker.patch.object(SQLiteCursorWrapper, 'copy', write_stdout, create=True)
@pytest.mark.django_db @pytest.mark.django_db
def test_copy_tables_unified_job_query(sqlite_copy_expert, project, inventory, job_template): def test_copy_tables_unified_job_query(sqlite_copy, project, inventory, job_template):
""" """
Ensure that various unified job types are in the output of the query. Ensure that various unified job types are in the output of the query.
""" """
@@ -127,7 +148,7 @@ def workflow_job(states=["new", "new", "new", "new", "new"]):
@pytest.mark.django_db @pytest.mark.django_db
def test_copy_tables_workflow_job_node_query(sqlite_copy_expert, workflow_job): def test_copy_tables_workflow_job_node_query(sqlite_copy, workflow_job):
time_start = now() - timedelta(hours=9) time_start = now() - timedelta(hours=9)
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:

View File

@@ -224,7 +224,7 @@ class TestControllerNode:
return AdHocCommand.objects.create(inventory=inventory) return AdHocCommand.objects.create(inventory=inventory)
@pytest.mark.django_db @pytest.mark.django_db
def test_field_controller_node_exists(self, sqlite_copy_expert, admin_user, job, project_update, inventory_update, adhoc, get, system_job_factory): def test_field_controller_node_exists(self, sqlite_copy, admin_user, job, project_update, inventory_update, adhoc, get, system_job_factory):
system_job = system_job_factory() system_job = system_job_factory()
r = get(reverse('api:unified_job_list') + '?id={}'.format(job.id), admin_user, expect=200) r = get(reverse('api:unified_job_list') + '?id={}'.format(job.id), admin_user, expect=200)

View File

@@ -57,7 +57,7 @@ def _mk_inventory_update(created=None):
[_mk_inventory_update, InventoryUpdateEvent, 'inventory_update', 'api:inventory_update_stdout'], [_mk_inventory_update, InventoryUpdateEvent, 'inventory_update', 'api:inventory_update_stdout'],
], ],
) )
def test_text_stdout(sqlite_copy_expert, Parent, Child, relation, view, get, admin): def test_text_stdout(sqlite_copy, Parent, Child, relation, view, get, admin):
job = Parent() job = Parent()
job.save() job.save()
for i in range(3): for i in range(3):
@@ -79,7 +79,7 @@ def test_text_stdout(sqlite_copy_expert, Parent, Child, relation, view, get, adm
], ],
) )
@pytest.mark.parametrize('download', [True, False]) @pytest.mark.parametrize('download', [True, False])
def test_ansi_stdout_filtering(sqlite_copy_expert, Parent, Child, relation, view, download, get, admin): def test_ansi_stdout_filtering(sqlite_copy, Parent, Child, relation, view, download, get, admin):
job = Parent() job = Parent()
job.save() job.save()
for i in range(3): for i in range(3):
@@ -111,7 +111,7 @@ def test_ansi_stdout_filtering(sqlite_copy_expert, Parent, Child, relation, view
[_mk_inventory_update, InventoryUpdateEvent, 'inventory_update', 'api:inventory_update_stdout'], [_mk_inventory_update, InventoryUpdateEvent, 'inventory_update', 'api:inventory_update_stdout'],
], ],
) )
def test_colorized_html_stdout(sqlite_copy_expert, Parent, Child, relation, view, get, admin): def test_colorized_html_stdout(sqlite_copy, Parent, Child, relation, view, get, admin):
job = Parent() job = Parent()
job.save() job.save()
for i in range(3): for i in range(3):
@@ -134,7 +134,7 @@ def test_colorized_html_stdout(sqlite_copy_expert, Parent, Child, relation, view
[_mk_inventory_update, InventoryUpdateEvent, 'inventory_update', 'api:inventory_update_stdout'], [_mk_inventory_update, InventoryUpdateEvent, 'inventory_update', 'api:inventory_update_stdout'],
], ],
) )
def test_stdout_line_range(sqlite_copy_expert, Parent, Child, relation, view, get, admin): def test_stdout_line_range(sqlite_copy, Parent, Child, relation, view, get, admin):
job = Parent() job = Parent()
job.save() job.save()
for i in range(20): for i in range(20):
@@ -146,7 +146,7 @@ def test_stdout_line_range(sqlite_copy_expert, Parent, Child, relation, view, ge
@pytest.mark.django_db @pytest.mark.django_db
def test_text_stdout_from_system_job_events(sqlite_copy_expert, get, admin): def test_text_stdout_from_system_job_events(sqlite_copy, get, admin):
created = tz_now() created = tz_now()
job = SystemJob(created=created) job = SystemJob(created=created)
job.save() job.save()
@@ -158,7 +158,7 @@ def test_text_stdout_from_system_job_events(sqlite_copy_expert, get, admin):
@pytest.mark.django_db @pytest.mark.django_db
def test_text_stdout_with_max_stdout(sqlite_copy_expert, get, admin): def test_text_stdout_with_max_stdout(sqlite_copy, get, admin):
created = tz_now() created = tz_now()
job = SystemJob(created=created) job = SystemJob(created=created)
job.save() job.save()
@@ -185,7 +185,7 @@ def test_text_stdout_with_max_stdout(sqlite_copy_expert, get, admin):
) )
@pytest.mark.parametrize('fmt', ['txt', 'ansi']) @pytest.mark.parametrize('fmt', ['txt', 'ansi'])
@mock.patch('awx.main.redact.UriCleaner.SENSITIVE_URI_PATTERN', mock.Mock(**{'search.return_value': None})) # really slow for large strings @mock.patch('awx.main.redact.UriCleaner.SENSITIVE_URI_PATTERN', mock.Mock(**{'search.return_value': None})) # really slow for large strings
def test_max_bytes_display(sqlite_copy_expert, Parent, Child, relation, view, fmt, get, admin): def test_max_bytes_display(sqlite_copy, Parent, Child, relation, view, fmt, get, admin):
created = tz_now() created = tz_now()
job = Parent(created=created) job = Parent(created=created)
job.save() job.save()
@@ -255,7 +255,7 @@ def test_legacy_result_stdout_with_max_bytes(Cls, view, fmt, get, admin):
], ],
) )
@pytest.mark.parametrize('fmt', ['txt', 'ansi', 'txt_download', 'ansi_download']) @pytest.mark.parametrize('fmt', ['txt', 'ansi', 'txt_download', 'ansi_download'])
def test_text_with_unicode_stdout(sqlite_copy_expert, Parent, Child, relation, view, get, admin, fmt): def test_text_with_unicode_stdout(sqlite_copy, Parent, Child, relation, view, get, admin, fmt):
job = Parent() job = Parent()
job.save() job.save()
for i in range(3): for i in range(3):
@@ -267,7 +267,7 @@ def test_text_with_unicode_stdout(sqlite_copy_expert, Parent, Child, relation, v
@pytest.mark.django_db @pytest.mark.django_db
def test_unicode_with_base64_ansi(sqlite_copy_expert, get, admin): def test_unicode_with_base64_ansi(sqlite_copy, get, admin):
created = tz_now() created = tz_now()
job = Job(created=created) job = Job(created=created)
job.save() job.save()

View File

@@ -1,8 +1,6 @@
# Python # Python
import pytest import pytest
from unittest import mock from unittest import mock
import tempfile
import shutil
import urllib.parse import urllib.parse
from unittest.mock import PropertyMock from unittest.mock import PropertyMock
@@ -789,25 +787,43 @@ def oauth_application(admin):
return Application.objects.create(name='test app', user=admin, client_type='confidential', authorization_grant_type='password') return Application.objects.create(name='test app', user=admin, client_type='confidential', authorization_grant_type='password')
@pytest.fixture class MockCopy:
def sqlite_copy_expert(request): events = []
# copy_expert is postgres-specific, and SQLite doesn't support it; mock its index = -1
# behavior to test that it writes a file that contains stdout from events
path = tempfile.mkdtemp(prefix='job-event-stdout')
def write_stdout(self, sql, fd): def __init__(self, sql):
# simulate postgres copy_expert support with ORM code self.events = []
parts = sql.split(' ') parts = sql.split(' ')
tablename = parts[parts.index('from') + 1] tablename = parts[parts.index('from') + 1]
for cls in (JobEvent, AdHocCommandEvent, ProjectUpdateEvent, InventoryUpdateEvent, SystemJobEvent): for cls in (JobEvent, AdHocCommandEvent, ProjectUpdateEvent, InventoryUpdateEvent, SystemJobEvent):
if cls._meta.db_table == tablename: if cls._meta.db_table == tablename:
for event in cls.objects.order_by('start_line').all(): for event in cls.objects.order_by('start_line').all():
fd.write(event.stdout) self.events.append(event.stdout)
setattr(SQLiteCursorWrapper, 'copy_expert', write_stdout) def read(self):
request.addfinalizer(lambda: shutil.rmtree(path)) self.index = self.index + 1
request.addfinalizer(lambda: delattr(SQLiteCursorWrapper, 'copy_expert')) if self.index < len(self.events):
return path return memoryview(self.events[self.index].encode())
return None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
@pytest.fixture
def sqlite_copy(request, mocker):
# copy is postgres-specific, and SQLite doesn't support it; mock its
# behavior to test that it writes a file that contains stdout from events
def write_stdout(self, sql):
mock_copy = MockCopy(sql)
return mock_copy
mocker.patch.object(SQLiteCursorWrapper, 'copy', write_stdout, create=True)
@pytest.fixture @pytest.fixture

View File

@@ -98,7 +98,7 @@ class TestJobNotificationMixin(object):
@pytest.mark.django_db @pytest.mark.django_db
@pytest.mark.parametrize('JobClass', [AdHocCommand, InventoryUpdate, Job, ProjectUpdate, SystemJob, WorkflowJob]) @pytest.mark.parametrize('JobClass', [AdHocCommand, InventoryUpdate, Job, ProjectUpdate, SystemJob, WorkflowJob])
def test_context(self, JobClass, sqlite_copy_expert, project, inventory_source): def test_context(self, JobClass, sqlite_copy, project, inventory_source):
"""The Jinja context defines all of the fields that can be used by a template. Ensure that the context generated """The Jinja context defines all of the fields that can be used by a template. Ensure that the context generated
for each job type has the expected structure.""" for each job type has the expected structure."""
kwargs = {} kwargs = {}

View File

@@ -98,37 +98,27 @@ class YieldedRows(StringIO):
) )
self.rowlist.append(row) self.rowlist.append(row)
def read(self, x):
if self.rows <= 0:
self.close()
return ''
elif self.rows >= 1 and self.rows < 1000:
event_rows = self.rowlist[random.randrange(len(self.rowlist))] * self.rows
self.rows -= self.rows
return event_rows
self.rows -= 1000
return self.rowlist[random.randrange(len(self.rowlist))] * 1000
def firehose(job, count, created_stamp, modified_stamp): def firehose(job, count, created_stamp, modified_stamp):
conn = psycopg.connect(dsn) conn = psycopg.connect(dsn)
f = YieldedRows(job, count, created_stamp, modified_stamp) f = YieldedRows(job, count, created_stamp, modified_stamp)
with conn.cursor() as cursor: sql = '''
cursor.copy_expert( COPY main_jobevent(
( created, modified, job_created, event, event_data, failed, changed,
'COPY ' host_name, play, role, task, counter, host_id, job_id, uuid,
'main_jobevent(' parent_uuid, end_line, playbook, start_line, stdout, verbosity
'created, modified, job_created, event, event_data, failed, changed, ' ) FROM STDIN
'host_name, play, role, task, counter, host_id, job_id, uuid, ' '''
'parent_uuid, end_line, playbook, start_line, stdout, verbosity' try:
') ' with conn.cursor() as cursor:
'FROM STDIN' with cursor.copy(sql) as copy:
), copy.write("".join(f.rowlist))
f, except Exception as e:
size=1024 * 1000, print("Failed to import events")
) print(e)
conn.commit() finally:
conn.close() conn.commit()
conn.close()
def cleanup(sql): def cleanup(sql):