From 49e21d7c1ca8eb60c11d616001a948f64b5c3464 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Mon, 8 Jun 2026 10:57:55 -0400 Subject: [PATCH] Move PG version check to `awx-manage` `check_db` & `migrate` commands (#15463) * Move PG version check to check_db command Move to utils, check in pre_migrate signal * Add back in environment var skip * Add tests for compliance tests Assisted-By: claude --- awx/__init__.py | 16 ----- awx/main/apps.py | 10 +++ awx/main/management/commands/check_db.py | 8 ++- awx/main/tests/functional/test_apps.py | 21 ++++++ .../unit/management/commands/test_check_db.py | 35 ++++++++++ awx/main/tests/unit/test_db.py | 69 +++++++++++++++++++ awx/main/utils/db.py | 27 ++++++++ 7 files changed, 169 insertions(+), 17 deletions(-) create mode 100644 awx/main/tests/unit/management/commands/test_check_db.py diff --git a/awx/__init__.py b/awx/__init__.py index 59cccd5d8b..de1b3cf9c9 100644 --- a/awx/__init__.py +++ b/awx/__init__.py @@ -52,14 +52,6 @@ except ImportError: # pragma: no cover MODE = 'production' -try: - import django # noqa: F401 -except ImportError: - pass -else: - from django.db import connection - - def prepare_env(): # Update the default settings environment variable based on current mode. os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'awx.settings') @@ -79,14 +71,6 @@ def manage(): from django.conf import settings from django.core.management import execute_from_command_line - # enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1 - # In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that. - # The return of connection.pg_version is something like 12013 - if not os.getenv('SKIP_PG_VERSION_CHECK', False) and not MODE == 'development': - if (connection.pg_version // 10000) < 12: - sys.stderr.write("At a minimum, postgres version 12 is required\n") - sys.exit(1) - if len(sys.argv) >= 2 and sys.argv[1] in ('version', '--version'): # pragma: no cover sys.stdout.write('%s\n' % __version__) # If running as a user without permission to read settings, display an diff --git a/awx/main/apps.py b/awx/main/apps.py index acb6c8ea93..2b67de1cf9 100644 --- a/awx/main/apps.py +++ b/awx/main/apps.py @@ -5,9 +5,13 @@ from dispatcherd.config import setup as dispatcher_setup from django.apps import AppConfig from django.db import connection from django.utils.translation import gettext_lazy as _ +from django.core.management.base import CommandError +from django.db.models.signals import pre_migrate + from awx.main.utils.common import bypass_in_test, load_all_entry_points_for from awx.main.utils.migration import is_database_synchronized from awx.main.utils.named_url_graph import _customize_graph, generate_graph +from awx.main.utils.db import db_requirement_violations from awx.conf import register, fields from awx_plugins.interfaces._temporary_private_licensing_api import detect_server_product_name @@ -17,6 +21,11 @@ class MainConfig(AppConfig): name = 'awx.main' verbose_name = _('Main') + def check_db_requirement(self, *args, **kwargs): + violations = db_requirement_violations() + if violations: + raise CommandError(violations) + def load_named_url_feature(self): models = [m for m in self.get_models() if hasattr(m, 'get_absolute_url')] generate_graph(models) @@ -110,3 +119,4 @@ class MainConfig(AppConfig): self.load_credential_types_feature() self.load_named_url_feature() self.load_inventory_plugins() + pre_migrate.connect(self.check_db_requirement, sender=self) diff --git a/awx/main/management/commands/check_db.py b/awx/main/management/commands/check_db.py index e490e7a0e1..0d34340f3d 100644 --- a/awx/main/management/commands/check_db.py +++ b/awx/main/management/commands/check_db.py @@ -1,9 +1,11 @@ # Copyright (c) 2015 Ansible, Inc. # All Rights Reserved -from django.core.management.base import BaseCommand +from django.core.management.base import BaseCommand, CommandError from django.db import connection +from awx.main.utils.db import db_requirement_violations + class Command(BaseCommand): """Checks connection to the database, and prints out connection info if not connected""" @@ -13,4 +15,8 @@ class Command(BaseCommand): cursor.execute("SELECT version()") version = str(cursor.fetchone()[0]) + violations = db_requirement_violations() + if violations: + raise CommandError(violations) + return "Database Version: {}".format(version) diff --git a/awx/main/tests/functional/test_apps.py b/awx/main/tests/functional/test_apps.py index a52d4aa723..fbe9a4a370 100644 --- a/awx/main/tests/functional/test_apps.py +++ b/awx/main/tests/functional/test_apps.py @@ -1,6 +1,7 @@ import pytest from django.apps import apps +from django.core.management.base import CommandError @pytest.fixture @@ -24,3 +25,23 @@ def test_load_credential_types_feature_migrations_not_ran(mocker, mock_setup_tow apps.get_app_config('main')._load_credential_types_feature() mock_setup_tower_managed_defaults.assert_not_called() + + +def test_check_db_requirement_no_violations(mocker): + mocker.patch('awx.main.apps.db_requirement_violations', return_value=None) + main_config = apps.get_app_config('main') + + result = main_config.check_db_requirement() + + assert result is None + + +def test_check_db_requirement_with_violations(mocker): + violation_msg = "Database version check failed" + mocker.patch('awx.main.apps.db_requirement_violations', return_value=violation_msg) + main_config = apps.get_app_config('main') + + with pytest.raises(CommandError) as exc_info: + main_config.check_db_requirement() + + assert str(exc_info.value) == violation_msg diff --git a/awx/main/tests/unit/management/commands/test_check_db.py b/awx/main/tests/unit/management/commands/test_check_db.py new file mode 100644 index 0000000000..e1bb9efbd9 --- /dev/null +++ b/awx/main/tests/unit/management/commands/test_check_db.py @@ -0,0 +1,35 @@ +import pytest +from django.core.management.base import CommandError + +from awx.main.management.commands.check_db import Command + + +def test_check_db_command_success(mocker): + mock_cursor = mocker.MagicMock() + mock_cursor.fetchone.return_value = ['PostgreSQL 12.8 on x86_64-pc-linux-gnu, compiled by gcc (GCC) 9.3.0, 64-bit'] + mock_connection = mocker.MagicMock() + mock_connection.cursor.return_value.__enter__.return_value = mock_cursor + mocker.patch('awx.main.management.commands.check_db.connection', mock_connection) + mocker.patch('awx.main.management.commands.check_db.db_requirement_violations', return_value=None) + + command = Command() + result = command.handle() + + assert 'Database Version:' in result + mock_cursor.execute.assert_called_once_with('SELECT version()') + + +def test_check_db_command_version_violations(mocker): + mock_cursor = mocker.MagicMock() + mock_cursor.fetchone.return_value = ['PostgreSQL 11.0 on x86_64-pc-linux-gnu'] + mock_connection = mocker.MagicMock() + mock_connection.cursor.return_value.__enter__.return_value = mock_cursor + mocker.patch('awx.main.management.commands.check_db.connection', mock_connection) + violation_msg = "At a minimum, postgres version 12 is required, found 11\n" + mocker.patch('awx.main.management.commands.check_db.db_requirement_violations', return_value=violation_msg) + + command = Command() + with pytest.raises(CommandError) as exc_info: + command.handle() + + assert str(exc_info.value) == violation_msg diff --git a/awx/main/tests/unit/test_db.py b/awx/main/tests/unit/test_db.py index b1ffbfc0d8..dc40ea77f3 100644 --- a/awx/main/tests/unit/test_db.py +++ b/awx/main/tests/unit/test_db.py @@ -8,6 +8,7 @@ import pytest import awx from awx.main.db.profiled_pg.base import RecordedQueryLog +from awx.main.utils.db import db_requirement_violations QUERY = {'sql': 'SELECT * FROM main_job', 'time': '.01'} EXPLAIN = 'Seq Scan on public.main_job (cost=0.00..1.18 rows=18 width=86)' @@ -145,3 +146,71 @@ def test_sql_above_threshold(tmpdir): assert q['sql'] == QUERY['sql'] assert EXPLAIN in q['explain'] assert 'test_sql_above_threshold' in q['bt'] + + +def test_db_requirement_violations_skip_env_var(mocker): + mocker.patch.dict(os.environ, {'SKIP_PG_VERSION_CHECK': 'true'}) + result = db_requirement_violations() + assert result is None + + +def test_db_requirement_violations_postgresql_sufficient_version(mocker): + mock_connection = mocker.MagicMock() + mock_connection.vendor = 'postgresql' + mock_connection.pg_version = 120000 # Version 12.0 + mocker.patch('awx.main.utils.db.connection', mock_connection) + mocker.patch.dict(os.environ, {}, clear=True) + + result = db_requirement_violations() + + assert result is None + + +def test_db_requirement_violations_postgresql_insufficient_version(mocker): + mock_connection = mocker.MagicMock() + mock_connection.vendor = 'postgresql' + mock_connection.pg_version = 110000 # Version 11.0 + mocker.patch('awx.main.utils.db.connection', mock_connection) + mocker.patch.dict(os.environ, {}, clear=True) + + result = db_requirement_violations() + + assert result is not None + assert "At a minimum, postgres version 12 is required, found 11" in result + + +def test_db_requirement_violations_non_postgresql_production(mocker): + mock_connection = mocker.MagicMock() + mock_connection.vendor = 'sqlite' + mocker.patch('awx.main.utils.db.connection', mock_connection) + mocker.patch('awx.main.utils.db.MODE', 'production') + mocker.patch.dict(os.environ, {}, clear=True) + + result = db_requirement_violations() + + assert result is not None + assert "Running server with 'sqlite' type database is not supported" in result + + +def test_db_requirement_violations_non_postgresql_development(mocker): + mock_connection = mocker.MagicMock() + mock_connection.vendor = 'sqlite' + mocker.patch('awx.main.utils.db.connection', mock_connection) + mocker.patch('awx.main.utils.db.MODE', 'development') + mocker.patch.dict(os.environ, {}, clear=True) + + result = db_requirement_violations() + + assert result is None + + +def test_db_requirement_violations_postgresql_edge_case_version(mocker): + mock_connection = mocker.MagicMock() + mock_connection.vendor = 'postgresql' + mock_connection.pg_version = 129999 # Version 12.9999 + mocker.patch('awx.main.utils.db.connection', mock_connection) + mocker.patch.dict(os.environ, {}, clear=True) + + result = db_requirement_violations() + + assert result is None diff --git a/awx/main/utils/db.py b/awx/main/utils/db.py index 2078b28d49..9b1b887ebe 100644 --- a/awx/main/utils/db.py +++ b/awx/main/utils/db.py @@ -1,9 +1,14 @@ # Copyright (c) 2017 Ansible by Red Hat # All Rights Reserved. +from typing import Optional +import os + from awx.settings.application_name import set_application_name +from awx import MODE from django.conf import settings +from django.db import connection def set_connection_name(function): @@ -32,3 +37,25 @@ def bulk_update_sorted_by_id(model, objects, fields, batch_size=1000): sorted_objects = sorted(objects, key=lambda obj: obj.id) return model.objects.bulk_update(sorted_objects, fields, batch_size=batch_size) + + +MIN_PG_VERSION = 12 + + +def db_requirement_violations() -> Optional[str]: + if os.getenv('SKIP_PG_VERSION_CHECK', False): + return None + if connection.vendor == 'postgresql': + + # enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1 + # In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that. + # The return of connection.pg_version is something like 12013 + major_version = connection.pg_version // 10000 + if major_version < MIN_PG_VERSION: + return f"At a minimum, postgres version {MIN_PG_VERSION} is required, found {major_version}\n" + + return None + else: + if MODE == 'production': + return f"Running server with '{connection.vendor}' type database is not supported\n" + return None