Move PG version check to awx-manage check_db & migrate commands (#15463)

* Move PG version check to check_db command

Move to utils, check in pre_migrate signal

* Add back in environment var skip

* Add tests for compliance

tests Assisted-By: claude
This commit is contained in:
Alan Rominger
2026-06-08 10:57:55 -04:00
committed by GitHub
parent b531151931
commit 49e21d7c1c
7 changed files with 169 additions and 17 deletions

View File

@@ -52,14 +52,6 @@ except ImportError: # pragma: no cover
MODE = 'production'
try:
import django # noqa: F401
except ImportError:
pass
else:
from django.db import connection
def prepare_env():
# Update the default settings environment variable based on current mode.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'awx.settings')
@@ -79,14 +71,6 @@ def manage():
from django.conf import settings
from django.core.management import execute_from_command_line
# enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1
# In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that.
# The return of connection.pg_version is something like 12013
if not os.getenv('SKIP_PG_VERSION_CHECK', False) and not MODE == 'development':
if (connection.pg_version // 10000) < 12:
sys.stderr.write("At a minimum, postgres version 12 is required\n")
sys.exit(1)
if len(sys.argv) >= 2 and sys.argv[1] in ('version', '--version'): # pragma: no cover
sys.stdout.write('%s\n' % __version__)
# If running as a user without permission to read settings, display an

View File

@@ -5,9 +5,13 @@ from dispatcherd.config import setup as dispatcher_setup
from django.apps import AppConfig
from django.db import connection
from django.utils.translation import gettext_lazy as _
from django.core.management.base import CommandError
from django.db.models.signals import pre_migrate
from awx.main.utils.common import bypass_in_test, load_all_entry_points_for
from awx.main.utils.migration import is_database_synchronized
from awx.main.utils.named_url_graph import _customize_graph, generate_graph
from awx.main.utils.db import db_requirement_violations
from awx.conf import register, fields
from awx_plugins.interfaces._temporary_private_licensing_api import detect_server_product_name
@@ -17,6 +21,11 @@ class MainConfig(AppConfig):
name = 'awx.main'
verbose_name = _('Main')
def check_db_requirement(self, *args, **kwargs):
violations = db_requirement_violations()
if violations:
raise CommandError(violations)
def load_named_url_feature(self):
models = [m for m in self.get_models() if hasattr(m, 'get_absolute_url')]
generate_graph(models)
@@ -110,3 +119,4 @@ class MainConfig(AppConfig):
self.load_credential_types_feature()
self.load_named_url_feature()
self.load_inventory_plugins()
pre_migrate.connect(self.check_db_requirement, sender=self)

View File

@@ -1,9 +1,11 @@
# Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved
from django.core.management.base import BaseCommand
from django.core.management.base import BaseCommand, CommandError
from django.db import connection
from awx.main.utils.db import db_requirement_violations
class Command(BaseCommand):
"""Checks connection to the database, and prints out connection info if not connected"""
@@ -13,4 +15,8 @@ class Command(BaseCommand):
cursor.execute("SELECT version()")
version = str(cursor.fetchone()[0])
violations = db_requirement_violations()
if violations:
raise CommandError(violations)
return "Database Version: {}".format(version)

View File

@@ -1,6 +1,7 @@
import pytest
from django.apps import apps
from django.core.management.base import CommandError
@pytest.fixture
@@ -24,3 +25,23 @@ def test_load_credential_types_feature_migrations_not_ran(mocker, mock_setup_tow
apps.get_app_config('main')._load_credential_types_feature()
mock_setup_tower_managed_defaults.assert_not_called()
def test_check_db_requirement_no_violations(mocker):
mocker.patch('awx.main.apps.db_requirement_violations', return_value=None)
main_config = apps.get_app_config('main')
result = main_config.check_db_requirement()
assert result is None
def test_check_db_requirement_with_violations(mocker):
violation_msg = "Database version check failed"
mocker.patch('awx.main.apps.db_requirement_violations', return_value=violation_msg)
main_config = apps.get_app_config('main')
with pytest.raises(CommandError) as exc_info:
main_config.check_db_requirement()
assert str(exc_info.value) == violation_msg

View File

@@ -0,0 +1,35 @@
import pytest
from django.core.management.base import CommandError
from awx.main.management.commands.check_db import Command
def test_check_db_command_success(mocker):
mock_cursor = mocker.MagicMock()
mock_cursor.fetchone.return_value = ['PostgreSQL 12.8 on x86_64-pc-linux-gnu, compiled by gcc (GCC) 9.3.0, 64-bit']
mock_connection = mocker.MagicMock()
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
mocker.patch('awx.main.management.commands.check_db.connection', mock_connection)
mocker.patch('awx.main.management.commands.check_db.db_requirement_violations', return_value=None)
command = Command()
result = command.handle()
assert 'Database Version:' in result
mock_cursor.execute.assert_called_once_with('SELECT version()')
def test_check_db_command_version_violations(mocker):
mock_cursor = mocker.MagicMock()
mock_cursor.fetchone.return_value = ['PostgreSQL 11.0 on x86_64-pc-linux-gnu']
mock_connection = mocker.MagicMock()
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
mocker.patch('awx.main.management.commands.check_db.connection', mock_connection)
violation_msg = "At a minimum, postgres version 12 is required, found 11\n"
mocker.patch('awx.main.management.commands.check_db.db_requirement_violations', return_value=violation_msg)
command = Command()
with pytest.raises(CommandError) as exc_info:
command.handle()
assert str(exc_info.value) == violation_msg

View File

@@ -8,6 +8,7 @@ import pytest
import awx
from awx.main.db.profiled_pg.base import RecordedQueryLog
from awx.main.utils.db import db_requirement_violations
QUERY = {'sql': 'SELECT * FROM main_job', 'time': '.01'}
EXPLAIN = 'Seq Scan on public.main_job (cost=0.00..1.18 rows=18 width=86)'
@@ -145,3 +146,71 @@ def test_sql_above_threshold(tmpdir):
assert q['sql'] == QUERY['sql']
assert EXPLAIN in q['explain']
assert 'test_sql_above_threshold' in q['bt']
def test_db_requirement_violations_skip_env_var(mocker):
mocker.patch.dict(os.environ, {'SKIP_PG_VERSION_CHECK': 'true'})
result = db_requirement_violations()
assert result is None
def test_db_requirement_violations_postgresql_sufficient_version(mocker):
mock_connection = mocker.MagicMock()
mock_connection.vendor = 'postgresql'
mock_connection.pg_version = 120000 # Version 12.0
mocker.patch('awx.main.utils.db.connection', mock_connection)
mocker.patch.dict(os.environ, {}, clear=True)
result = db_requirement_violations()
assert result is None
def test_db_requirement_violations_postgresql_insufficient_version(mocker):
mock_connection = mocker.MagicMock()
mock_connection.vendor = 'postgresql'
mock_connection.pg_version = 110000 # Version 11.0
mocker.patch('awx.main.utils.db.connection', mock_connection)
mocker.patch.dict(os.environ, {}, clear=True)
result = db_requirement_violations()
assert result is not None
assert "At a minimum, postgres version 12 is required, found 11" in result
def test_db_requirement_violations_non_postgresql_production(mocker):
mock_connection = mocker.MagicMock()
mock_connection.vendor = 'sqlite'
mocker.patch('awx.main.utils.db.connection', mock_connection)
mocker.patch('awx.main.utils.db.MODE', 'production')
mocker.patch.dict(os.environ, {}, clear=True)
result = db_requirement_violations()
assert result is not None
assert "Running server with 'sqlite' type database is not supported" in result
def test_db_requirement_violations_non_postgresql_development(mocker):
mock_connection = mocker.MagicMock()
mock_connection.vendor = 'sqlite'
mocker.patch('awx.main.utils.db.connection', mock_connection)
mocker.patch('awx.main.utils.db.MODE', 'development')
mocker.patch.dict(os.environ, {}, clear=True)
result = db_requirement_violations()
assert result is None
def test_db_requirement_violations_postgresql_edge_case_version(mocker):
mock_connection = mocker.MagicMock()
mock_connection.vendor = 'postgresql'
mock_connection.pg_version = 129999 # Version 12.9999
mocker.patch('awx.main.utils.db.connection', mock_connection)
mocker.patch.dict(os.environ, {}, clear=True)
result = db_requirement_violations()
assert result is None

View File

@@ -1,9 +1,14 @@
# Copyright (c) 2017 Ansible by Red Hat
# All Rights Reserved.
from typing import Optional
import os
from awx.settings.application_name import set_application_name
from awx import MODE
from django.conf import settings
from django.db import connection
def set_connection_name(function):
@@ -32,3 +37,25 @@ def bulk_update_sorted_by_id(model, objects, fields, batch_size=1000):
sorted_objects = sorted(objects, key=lambda obj: obj.id)
return model.objects.bulk_update(sorted_objects, fields, batch_size=batch_size)
MIN_PG_VERSION = 12
def db_requirement_violations() -> Optional[str]:
if os.getenv('SKIP_PG_VERSION_CHECK', False):
return None
if connection.vendor == 'postgresql':
# enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1
# In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that.
# The return of connection.pg_version is something like 12013
major_version = connection.pg_version // 10000
if major_version < MIN_PG_VERSION:
return f"At a minimum, postgres version {MIN_PG_VERSION} is required, found {major_version}\n"
return None
else:
if MODE == 'production':
return f"Running server with '{connection.vendor}' type database is not supported\n"
return None