mirror of
https://github.com/ansible/awx.git
synced 2026-06-21 06:37:45 -02:30
Move PG version check to awx-manage check_db & migrate commands (#15463)
* Move PG version check to check_db command Move to utils, check in pre_migrate signal * Add back in environment var skip * Add tests for compliance tests Assisted-By: claude
This commit is contained in:
@@ -52,14 +52,6 @@ except ImportError: # pragma: no cover
|
||||
MODE = 'production'
|
||||
|
||||
|
||||
try:
|
||||
import django # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
from django.db import connection
|
||||
|
||||
|
||||
def prepare_env():
|
||||
# Update the default settings environment variable based on current mode.
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'awx.settings')
|
||||
@@ -79,14 +71,6 @@ def manage():
|
||||
from django.conf import settings
|
||||
from django.core.management import execute_from_command_line
|
||||
|
||||
# enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1
|
||||
# In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that.
|
||||
# The return of connection.pg_version is something like 12013
|
||||
if not os.getenv('SKIP_PG_VERSION_CHECK', False) and not MODE == 'development':
|
||||
if (connection.pg_version // 10000) < 12:
|
||||
sys.stderr.write("At a minimum, postgres version 12 is required\n")
|
||||
sys.exit(1)
|
||||
|
||||
if len(sys.argv) >= 2 and sys.argv[1] in ('version', '--version'): # pragma: no cover
|
||||
sys.stdout.write('%s\n' % __version__)
|
||||
# If running as a user without permission to read settings, display an
|
||||
|
||||
@@ -5,9 +5,13 @@ from dispatcherd.config import setup as dispatcher_setup
|
||||
from django.apps import AppConfig
|
||||
from django.db import connection
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.core.management.base import CommandError
|
||||
from django.db.models.signals import pre_migrate
|
||||
|
||||
from awx.main.utils.common import bypass_in_test, load_all_entry_points_for
|
||||
from awx.main.utils.migration import is_database_synchronized
|
||||
from awx.main.utils.named_url_graph import _customize_graph, generate_graph
|
||||
from awx.main.utils.db import db_requirement_violations
|
||||
from awx.conf import register, fields
|
||||
|
||||
from awx_plugins.interfaces._temporary_private_licensing_api import detect_server_product_name
|
||||
@@ -17,6 +21,11 @@ class MainConfig(AppConfig):
|
||||
name = 'awx.main'
|
||||
verbose_name = _('Main')
|
||||
|
||||
def check_db_requirement(self, *args, **kwargs):
|
||||
violations = db_requirement_violations()
|
||||
if violations:
|
||||
raise CommandError(violations)
|
||||
|
||||
def load_named_url_feature(self):
|
||||
models = [m for m in self.get_models() if hasattr(m, 'get_absolute_url')]
|
||||
generate_graph(models)
|
||||
@@ -110,3 +119,4 @@ class MainConfig(AppConfig):
|
||||
self.load_credential_types_feature()
|
||||
self.load_named_url_feature()
|
||||
self.load_inventory_plugins()
|
||||
pre_migrate.connect(self.check_db_requirement, sender=self)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# Copyright (c) 2015 Ansible, Inc.
|
||||
# All Rights Reserved
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from django.db import connection
|
||||
|
||||
from awx.main.utils.db import db_requirement_violations
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""Checks connection to the database, and prints out connection info if not connected"""
|
||||
@@ -13,4 +15,8 @@ class Command(BaseCommand):
|
||||
cursor.execute("SELECT version()")
|
||||
version = str(cursor.fetchone()[0])
|
||||
|
||||
violations = db_requirement_violations()
|
||||
if violations:
|
||||
raise CommandError(violations)
|
||||
|
||||
return "Database Version: {}".format(version)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from django.apps import apps
|
||||
from django.core.management.base import CommandError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -24,3 +25,23 @@ def test_load_credential_types_feature_migrations_not_ran(mocker, mock_setup_tow
|
||||
apps.get_app_config('main')._load_credential_types_feature()
|
||||
|
||||
mock_setup_tower_managed_defaults.assert_not_called()
|
||||
|
||||
|
||||
def test_check_db_requirement_no_violations(mocker):
|
||||
mocker.patch('awx.main.apps.db_requirement_violations', return_value=None)
|
||||
main_config = apps.get_app_config('main')
|
||||
|
||||
result = main_config.check_db_requirement()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_check_db_requirement_with_violations(mocker):
|
||||
violation_msg = "Database version check failed"
|
||||
mocker.patch('awx.main.apps.db_requirement_violations', return_value=violation_msg)
|
||||
main_config = apps.get_app_config('main')
|
||||
|
||||
with pytest.raises(CommandError) as exc_info:
|
||||
main_config.check_db_requirement()
|
||||
|
||||
assert str(exc_info.value) == violation_msg
|
||||
|
||||
35
awx/main/tests/unit/management/commands/test_check_db.py
Normal file
35
awx/main/tests/unit/management/commands/test_check_db.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
from django.core.management.base import CommandError
|
||||
|
||||
from awx.main.management.commands.check_db import Command
|
||||
|
||||
|
||||
def test_check_db_command_success(mocker):
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_cursor.fetchone.return_value = ['PostgreSQL 12.8 on x86_64-pc-linux-gnu, compiled by gcc (GCC) 9.3.0, 64-bit']
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mocker.patch('awx.main.management.commands.check_db.connection', mock_connection)
|
||||
mocker.patch('awx.main.management.commands.check_db.db_requirement_violations', return_value=None)
|
||||
|
||||
command = Command()
|
||||
result = command.handle()
|
||||
|
||||
assert 'Database Version:' in result
|
||||
mock_cursor.execute.assert_called_once_with('SELECT version()')
|
||||
|
||||
|
||||
def test_check_db_command_version_violations(mocker):
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_cursor.fetchone.return_value = ['PostgreSQL 11.0 on x86_64-pc-linux-gnu']
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
|
||||
mocker.patch('awx.main.management.commands.check_db.connection', mock_connection)
|
||||
violation_msg = "At a minimum, postgres version 12 is required, found 11\n"
|
||||
mocker.patch('awx.main.management.commands.check_db.db_requirement_violations', return_value=violation_msg)
|
||||
|
||||
command = Command()
|
||||
with pytest.raises(CommandError) as exc_info:
|
||||
command.handle()
|
||||
|
||||
assert str(exc_info.value) == violation_msg
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
|
||||
import awx
|
||||
from awx.main.db.profiled_pg.base import RecordedQueryLog
|
||||
from awx.main.utils.db import db_requirement_violations
|
||||
|
||||
QUERY = {'sql': 'SELECT * FROM main_job', 'time': '.01'}
|
||||
EXPLAIN = 'Seq Scan on public.main_job (cost=0.00..1.18 rows=18 width=86)'
|
||||
@@ -145,3 +146,71 @@ def test_sql_above_threshold(tmpdir):
|
||||
assert q['sql'] == QUERY['sql']
|
||||
assert EXPLAIN in q['explain']
|
||||
assert 'test_sql_above_threshold' in q['bt']
|
||||
|
||||
|
||||
def test_db_requirement_violations_skip_env_var(mocker):
|
||||
mocker.patch.dict(os.environ, {'SKIP_PG_VERSION_CHECK': 'true'})
|
||||
result = db_requirement_violations()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_db_requirement_violations_postgresql_sufficient_version(mocker):
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.vendor = 'postgresql'
|
||||
mock_connection.pg_version = 120000 # Version 12.0
|
||||
mocker.patch('awx.main.utils.db.connection', mock_connection)
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
result = db_requirement_violations()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_db_requirement_violations_postgresql_insufficient_version(mocker):
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.vendor = 'postgresql'
|
||||
mock_connection.pg_version = 110000 # Version 11.0
|
||||
mocker.patch('awx.main.utils.db.connection', mock_connection)
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
result = db_requirement_violations()
|
||||
|
||||
assert result is not None
|
||||
assert "At a minimum, postgres version 12 is required, found 11" in result
|
||||
|
||||
|
||||
def test_db_requirement_violations_non_postgresql_production(mocker):
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.vendor = 'sqlite'
|
||||
mocker.patch('awx.main.utils.db.connection', mock_connection)
|
||||
mocker.patch('awx.main.utils.db.MODE', 'production')
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
result = db_requirement_violations()
|
||||
|
||||
assert result is not None
|
||||
assert "Running server with 'sqlite' type database is not supported" in result
|
||||
|
||||
|
||||
def test_db_requirement_violations_non_postgresql_development(mocker):
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.vendor = 'sqlite'
|
||||
mocker.patch('awx.main.utils.db.connection', mock_connection)
|
||||
mocker.patch('awx.main.utils.db.MODE', 'development')
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
result = db_requirement_violations()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_db_requirement_violations_postgresql_edge_case_version(mocker):
|
||||
mock_connection = mocker.MagicMock()
|
||||
mock_connection.vendor = 'postgresql'
|
||||
mock_connection.pg_version = 129999 # Version 12.9999
|
||||
mocker.patch('awx.main.utils.db.connection', mock_connection)
|
||||
mocker.patch.dict(os.environ, {}, clear=True)
|
||||
|
||||
result = db_requirement_violations()
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
# Copyright (c) 2017 Ansible by Red Hat
|
||||
# All Rights Reserved.
|
||||
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
from awx.settings.application_name import set_application_name
|
||||
from awx import MODE
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import connection
|
||||
|
||||
|
||||
def set_connection_name(function):
|
||||
@@ -32,3 +37,25 @@ def bulk_update_sorted_by_id(model, objects, fields, batch_size=1000):
|
||||
|
||||
sorted_objects = sorted(objects, key=lambda obj: obj.id)
|
||||
return model.objects.bulk_update(sorted_objects, fields, batch_size=batch_size)
|
||||
|
||||
|
||||
MIN_PG_VERSION = 12
|
||||
|
||||
|
||||
def db_requirement_violations() -> Optional[str]:
|
||||
if os.getenv('SKIP_PG_VERSION_CHECK', False):
|
||||
return None
|
||||
if connection.vendor == 'postgresql':
|
||||
|
||||
# enforce the postgres version is a minimum of 12 (we need this for partitioning); if not, then terminate program with exit code of 1
|
||||
# In the future if we require a feature of a version of postgres > 12 this should be updated to reflect that.
|
||||
# The return of connection.pg_version is something like 12013
|
||||
major_version = connection.pg_version // 10000
|
||||
if major_version < MIN_PG_VERSION:
|
||||
return f"At a minimum, postgres version {MIN_PG_VERSION} is required, found {major_version}\n"
|
||||
|
||||
return None
|
||||
else:
|
||||
if MODE == 'production':
|
||||
return f"Running server with '{connection.vendor}' type database is not supported\n"
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user