From 5e0870a7ec4afd2a7ae3f76f5f21ce16a85241ff Mon Sep 17 00:00:00 2001 From: Andrew Potozniak Date: Fri, 18 Jul 2025 12:41:21 -0400 Subject: [PATCH] AAP-48510 Enable Service Tokens with the Authentication Migration Management Command (#7017) * Enabled Service Token Auth for Management Command import_auth_config_to_gateway Co-authored-by: Peter Braun Co-authored-by: Zack Kayyali Assisted-by: Cursor --- .../commands/import_auth_config_to_gateway.py | 58 ++- .../test_import_auth_config_to_gateway.py | 493 ++++++++++++++++++ awx/main/utils/gateway_client.py | 51 +- awx/main/utils/gateway_client_svc_token.py | 77 +++ 4 files changed, 654 insertions(+), 25 deletions(-) create mode 100644 awx/main/tests/unit/commands/test_import_auth_config_to_gateway.py create mode 100644 awx/main/utils/gateway_client_svc_token.py diff --git a/awx/main/management/commands/import_auth_config_to_gateway.py b/awx/main/management/commands/import_auth_config_to_gateway.py index ca93a98534..31f232d567 100644 --- a/awx/main/management/commands/import_auth_config_to_gateway.py +++ b/awx/main/management/commands/import_auth_config_to_gateway.py @@ -2,6 +2,7 @@ import sys import os from django.core.management.base import BaseCommand +from urllib.parse import urlparse, urlunparse from awx.sso.utils.azure_ad_migrator import AzureADMigrator from awx.sso.utils.github_migrator import GitHubMigrator from awx.sso.utils.ldap_migrator import LDAPMigrator @@ -11,12 +12,15 @@ from awx.sso.utils.radius_migrator import RADIUSMigrator from awx.sso.utils.tacacs_migrator import TACACSMigrator from awx.sso.utils.google_oauth2_migrator import GoogleOAuth2Migrator from awx.main.utils.gateway_client import GatewayClient, GatewayAPIError +from awx.main.utils.gateway_client_svc_token import GatewayClientSVCToken +from ansible_base.resource_registry.tasks.sync import create_api_client class Command(BaseCommand): help = 'Import existing auth provider configurations to AAP Gateway via API requests' def add_arguments(self, parser): + parser.add_argument('--basic-auth', action='store_true', help='Use HTTP Basic Authentication between Controller and Gateway') parser.add_argument('--skip-oidc', action='store_true', help='Skip importing GitHub and generic OIDC authenticators') parser.add_argument('--skip-ldap', action='store_true', help='Skip importing LDAP authenticators') parser.add_argument('--skip-ad', action='store_true', help='Skip importing Azure AD authenticator') @@ -41,28 +45,64 @@ class Command(BaseCommand): skip_tacacs = options['skip_tacacs'] skip_google = options['skip_google'] force = options['force'] + basic_auth = options['basic_auth'] + + management_command_validation_errors = [] # If the management command isn't called with all parameters needed to talk to Gateway, consider # it a dry-run and exit cleanly - if not gateway_base_url or not gateway_user or not gateway_password: + if not gateway_base_url and basic_auth: + management_command_validation_errors.append('- GATEWAY_BASE_URL: Base URL of the AAP Gateway instance') + if (not gateway_user or not gateway_password) and basic_auth: + management_command_validation_errors.append('- GATEWAY_USER: Username for AAP Gateway authentication') + management_command_validation_errors.append('- GATEWAY_PASSWORD: Password for AAP Gateway authentication') + + if len(management_command_validation_errors) > 0: self.stdout.write(self.style.WARNING('Missing required environment variables:')) - self.stdout.write(self.style.WARNING('- GATEWAY_BASE_URL: Base URL of the AAP Gateway instance')) - self.stdout.write(self.style.WARNING('- GATEWAY_USER: Username for AAP Gateway authentication')) - self.stdout.write(self.style.WARNING('- GATEWAY_PASSWORD: Password for AAP Gateway authentication')) + for validation_error in management_command_validation_errors: + self.stdout.write(self.style.WARNING(f"{validation_error}")) self.stdout.write(self.style.WARNING('- GATEWAY_SKIP_VERIFY: Skip SSL certificate verification (optional)')) sys.exit(0) - self.stdout.write(self.style.SUCCESS(f'Gateway Base URL: {gateway_base_url}')) - self.stdout.write(self.style.SUCCESS(f'Gateway User: {gateway_user}')) - self.stdout.write(self.style.SUCCESS(f'Gateway Password: {"*" * len(gateway_password)}')) - self.stdout.write(self.style.SUCCESS(f'Skip SSL Verification: {gateway_skip_verify}')) + resource_api_client = None + response = None + + if basic_auth: + self.stdout.write(self.style.SUCCESS('HTTP Basic Auth: true')) + self.stdout.write(self.style.SUCCESS(f'Gateway Base URL: {gateway_base_url}')) + self.stdout.write(self.style.SUCCESS(f'Gateway User: {gateway_user}')) + self.stdout.write(self.style.SUCCESS('Gateway Password: *******************')) + self.stdout.write(self.style.SUCCESS(f'Skip SSL Verification: {gateway_skip_verify}')) + + else: + resource_api_client = create_api_client() + resource_api_client.verify_https = not gateway_skip_verify + response = resource_api_client.get_service_metadata() + parsed_url = urlparse(resource_api_client.base_url) + resource_api_client.base_url = urlunparse((parsed_url.scheme, parsed_url.netloc, '/', '', '', '')) + + self.stdout.write(self.style.SUCCESS('Gateway Service Token: true')) + self.stdout.write(self.style.SUCCESS(f'Gateway Base URL: {resource_api_client.base_url}')) + self.stdout.write(self.style.SUCCESS(f'Gateway JWT User: {resource_api_client.jwt_user_id}')) + self.stdout.write(self.style.SUCCESS(f'Gateway JWT Expiration: {resource_api_client.jwt_expiration}')) + self.stdout.write(self.style.SUCCESS(f'Skip SSL Verification: {not resource_api_client.verify_https}')) + self.stdout.write(self.style.SUCCESS(f'Connection Validated: {response.status_code == 200}')) # Create Gateway client and run migrations try: self.stdout.write(self.style.SUCCESS('\n=== Connecting to Gateway ===')) + pre_gateway_client = None + if basic_auth: + self.stdout.write(self.style.SUCCESS('\n=== With Basic HTTP Auth ===')) + pre_gateway_client = GatewayClient( + base_url=gateway_base_url, username=gateway_user, password=gateway_password, skip_verify=gateway_skip_verify, command=self + ) - with GatewayClient(base_url=gateway_base_url, username=gateway_user, password=gateway_password, skip_verify=gateway_skip_verify) as gateway_client: + else: + self.stdout.write(self.style.SUCCESS('\n=== With Service Token ===')) + pre_gateway_client = GatewayClientSVCToken(resource_api_client=resource_api_client, command=self) + with pre_gateway_client as gateway_client: self.stdout.write(self.style.SUCCESS('Successfully connected to Gateway')) # Initialize migrators diff --git a/awx/main/tests/unit/commands/test_import_auth_config_to_gateway.py b/awx/main/tests/unit/commands/test_import_auth_config_to_gateway.py new file mode 100644 index 0000000000..1b06d00dfa --- /dev/null +++ b/awx/main/tests/unit/commands/test_import_auth_config_to_gateway.py @@ -0,0 +1,493 @@ +import os +from unittest.mock import patch, Mock, call +from io import StringIO + +from django.test import TestCase + +from awx.main.management.commands.import_auth_config_to_gateway import Command +from awx.main.utils.gateway_client import GatewayAPIError + + +class TestImportAuthConfigToGatewayCommand(TestCase): + def setUp(self): + self.command = Command() + + def test_add_arguments(self): + """Test that all expected arguments are properly added to the parser.""" + parser = Mock() + self.command.add_arguments(parser) + + expected_calls = [ + call('--basic-auth', action='store_true', help='Use HTTP Basic Authentication between Controller and Gateway'), + call('--skip-oidc', action='store_true', help='Skip importing GitHub and generic OIDC authenticators'), + call('--skip-ldap', action='store_true', help='Skip importing LDAP authenticators'), + call('--skip-ad', action='store_true', help='Skip importing Azure AD authenticator'), + call('--skip-saml', action='store_true', help='Skip importing SAML authenticator'), + call('--skip-radius', action='store_true', help='Skip importing RADIUS authenticator'), + call('--skip-tacacs', action='store_true', help='Skip importing TACACS+ authenticator'), + call('--skip-google', action='store_true', help='Skip importing Google OAuth2 authenticator'), + call('--force', action='store_true', help='Force migration even if configurations already exist'), + ] + + parser.add_argument.assert_has_calls(expected_calls, any_order=True) + + @patch.dict(os.environ, {}, clear=True) + @patch('sys.stdout', new_callable=StringIO) + def test_handle_missing_env_vars_basic_auth(self, mock_stdout): + """Test that missing environment variables cause clean exit when using basic auth.""" + options = { + 'basic_auth': True, + 'skip_oidc': False, + 'skip_ldap': False, + 'skip_ad': False, + 'skip_saml': False, + 'skip_radius': False, + 'skip_tacacs': False, + 'skip_google': False, + 'force': False, + } + + with patch('sys.exit') as mock_exit: + with patch.object(self.command, 'stdout', mock_stdout): + self.command.handle(**options) + mock_exit.assert_called_once_with(0) + + output = mock_stdout.getvalue() + self.assertIn('Missing required environment variables:', output) + self.assertIn('GATEWAY_BASE_URL', output) + self.assertIn('GATEWAY_USER', output) + self.assertIn('GATEWAY_PASSWORD', output) + + @patch.dict( + os.environ, + {'GATEWAY_BASE_URL': 'https://gateway.example.com', 'GATEWAY_USER': 'testuser', 'GATEWAY_PASSWORD': 'testpass', 'GATEWAY_SKIP_VERIFY': 'true'}, + ) + @patch('awx.main.management.commands.import_auth_config_to_gateway.GatewayClient') + @patch('awx.main.management.commands.import_auth_config_to_gateway.GitHubMigrator') + @patch('awx.main.management.commands.import_auth_config_to_gateway.OIDCMigrator') + @patch('awx.main.management.commands.import_auth_config_to_gateway.SAMLMigrator') + @patch('awx.main.management.commands.import_auth_config_to_gateway.AzureADMigrator') + @patch('awx.main.management.commands.import_auth_config_to_gateway.LDAPMigrator') + @patch('awx.main.management.commands.import_auth_config_to_gateway.RADIUSMigrator') + @patch('awx.main.management.commands.import_auth_config_to_gateway.TACACSMigrator') + @patch('sys.stdout', new_callable=StringIO) + def test_handle_basic_auth_success( + self, mock_stdout, mock_tacacs, mock_radius, mock_ldap, mock_azure, mock_saml, mock_oidc, mock_github, mock_gateway_client + ): + """Test successful execution with basic auth.""" + # Mock gateway client context manager + mock_client_instance = Mock() + mock_gateway_client.return_value.__enter__.return_value = mock_client_instance + mock_gateway_client.return_value.__exit__.return_value = None + + # Mock migrators + mock_migration_result = { + 'created': 1, + 'updated': 0, + 'unchanged': 0, + 'failed': 0, + 'mappers_created': 2, + 'mappers_updated': 0, + 'mappers_failed': 0, + } + + for mock_migrator_class in [mock_github, mock_oidc, mock_saml, mock_azure, mock_ldap, mock_radius, mock_tacacs]: + mock_migrator = Mock() + mock_migrator.get_authenticator_type.return_value = 'TestAuth' + mock_migrator.migrate.return_value = mock_migration_result + mock_migrator_class.return_value = mock_migrator + + options = { + 'basic_auth': True, + 'skip_oidc': False, + 'skip_ldap': False, + 'skip_ad': False, + 'skip_saml': False, + 'skip_radius': False, + 'skip_tacacs': False, + 'skip_google': False, + 'force': False, + } + + with patch.object(self.command, 'stdout', mock_stdout): + self.command.handle(**options) + + # Verify gateway client was created with correct parameters + mock_gateway_client.assert_called_once_with( + base_url='https://gateway.example.com', username='testuser', password='testpass', skip_verify=True, command=self.command + ) + + # Verify all migrators were created + mock_github.assert_called_once_with(mock_client_instance, self.command, force=False) + mock_oidc.assert_called_once_with(mock_client_instance, self.command, force=False) + mock_saml.assert_called_once_with(mock_client_instance, self.command, force=False) + mock_azure.assert_called_once_with(mock_client_instance, self.command, force=False) + mock_ldap.assert_called_once_with(mock_client_instance, self.command, force=False) + mock_radius.assert_called_once_with(mock_client_instance, self.command, force=False) + mock_tacacs.assert_called_once_with(mock_client_instance, self.command, force=False) + + # Verify output contains success messages + output = mock_stdout.getvalue() + self.assertIn('HTTP Basic Auth: true', output) + self.assertIn('Successfully connected to Gateway', output) + self.assertIn('Migration Summary', output) + + @patch.dict(os.environ, {'GATEWAY_SKIP_VERIFY': 'false'}) + @patch('awx.main.management.commands.import_auth_config_to_gateway.create_api_client') + @patch('awx.main.management.commands.import_auth_config_to_gateway.GatewayClientSVCToken') + @patch('awx.main.management.commands.import_auth_config_to_gateway.urlparse') + @patch('awx.main.management.commands.import_auth_config_to_gateway.urlunparse') + @patch('sys.stdout', new_callable=StringIO) + def test_handle_service_token_success(self, mock_stdout, mock_urlunparse, mock_urlparse, mock_gateway_client_svc, mock_create_api_client): + """Test successful execution with service token.""" + # Mock resource API client + mock_resource_client = Mock() + mock_resource_client.base_url = 'https://gateway.example.com/api/v1' + mock_resource_client.jwt_user_id = 'test-user' + mock_resource_client.jwt_expiration = '2024-12-31' + mock_resource_client.verify_https = True + mock_response = Mock() + mock_response.status_code = 200 + mock_resource_client.get_service_metadata.return_value = mock_response + mock_create_api_client.return_value = mock_resource_client + + # Mock URL parsing + mock_parsed = Mock() + mock_parsed.scheme = 'https' + mock_parsed.netloc = 'gateway.example.com' + mock_urlparse.return_value = mock_parsed + mock_urlunparse.return_value = 'https://gateway.example.com/' + + # Mock gateway client context manager + mock_client_instance = Mock() + mock_gateway_client_svc.return_value.__enter__.return_value = mock_client_instance + mock_gateway_client_svc.return_value.__exit__.return_value = None + + options = { + 'basic_auth': False, + 'skip_oidc': True, + 'skip_ldap': True, + 'skip_ad': True, + 'skip_saml': True, + 'skip_radius': True, + 'skip_tacacs': True, + 'skip_google': True, + 'force': False, + } + + with patch.object(self.command, 'stdout', mock_stdout): + self.command.handle(**options) + + # Verify resource API client was created and configured + mock_create_api_client.assert_called_once() + self.assertTrue(mock_resource_client.verify_https) # Should be True when GATEWAY_SKIP_VERIFY='false' + mock_resource_client.get_service_metadata.assert_called_once() + + # Verify service token client was created + mock_gateway_client_svc.assert_called_once_with(resource_api_client=mock_resource_client, command=self.command) + + # Verify output contains service token messages + output = mock_stdout.getvalue() + self.assertIn('Gateway Service Token: true', output) + self.assertIn('No authentication configurations found to migrate.', output) + + @patch('sys.stdout', new_callable=StringIO) + def test_skip_flags_prevent_migrator_creation(self, mock_stdout): + """Test that skip flags prevent corresponding migrators from being created.""" + with patch.dict(os.environ, {'GATEWAY_BASE_URL': 'https://gateway.example.com', 'GATEWAY_USER': 'testuser', 'GATEWAY_PASSWORD': 'testpass'}): + with patch('awx.main.management.commands.import_auth_config_to_gateway.GatewayClient') as mock_gateway_client: + with patch('awx.main.management.commands.import_auth_config_to_gateway.GitHubMigrator') as mock_github: + with patch('awx.main.management.commands.import_auth_config_to_gateway.OIDCMigrator') as mock_oidc: + # Mock gateway client context manager + mock_client_instance = Mock() + mock_gateway_client.return_value.__enter__.return_value = mock_client_instance + mock_gateway_client.return_value.__exit__.return_value = None + + options = { + 'basic_auth': True, + 'skip_oidc': True, + 'skip_ldap': True, + 'skip_ad': True, + 'skip_saml': True, + 'skip_radius': True, + 'skip_tacacs': True, + 'skip_google': True, + 'force': False, + } + + with patch.object(self.command, 'stdout', mock_stdout): + self.command.handle(**options) + + # Verify no migrators were created + mock_github.assert_not_called() + mock_oidc.assert_not_called() + + # Verify warning message about no configurations + output = mock_stdout.getvalue() + self.assertIn('No authentication configurations found to migrate.', output) + + @patch.dict(os.environ, {'GATEWAY_BASE_URL': 'https://gateway.example.com', 'GATEWAY_USER': 'testuser', 'GATEWAY_PASSWORD': 'testpass'}) + @patch('awx.main.management.commands.import_auth_config_to_gateway.GatewayClient') + @patch('sys.stdout', new_callable=StringIO) + def test_handle_gateway_api_error(self, mock_stdout, mock_gateway_client): + """Test handling of GatewayAPIError exceptions.""" + # Mock gateway client to raise GatewayAPIError + mock_gateway_client.side_effect = GatewayAPIError('Test error message', status_code=400, response_data={'error': 'Bad request'}) + + options = { + 'basic_auth': True, + 'skip_oidc': False, + 'skip_ldap': False, + 'skip_ad': False, + 'skip_saml': False, + 'skip_radius': False, + 'skip_tacacs': False, + 'skip_google': False, + 'force': False, + } + + with patch.object(self.command, 'stdout', mock_stdout): + self.command.handle(**options) + + # Verify error message output + output = mock_stdout.getvalue() + self.assertIn('Gateway API Error: Test error message', output) + self.assertIn('Status Code: 400', output) + self.assertIn("Response: {'error': 'Bad request'}", output) + + @patch.dict(os.environ, {'GATEWAY_BASE_URL': 'https://gateway.example.com', 'GATEWAY_USER': 'testuser', 'GATEWAY_PASSWORD': 'testpass'}) + @patch('awx.main.management.commands.import_auth_config_to_gateway.GatewayClient') + @patch('sys.stdout', new_callable=StringIO) + def test_handle_unexpected_error(self, mock_stdout, mock_gateway_client): + """Test handling of unexpected exceptions.""" + # Mock gateway client to raise unexpected error + mock_gateway_client.side_effect = ValueError('Unexpected error') + + options = { + 'basic_auth': True, + 'skip_oidc': False, + 'skip_ldap': False, + 'skip_ad': False, + 'skip_saml': False, + 'skip_radius': False, + 'skip_tacacs': False, + 'skip_google': False, + 'force': False, + } + + with patch.object(self.command, 'stdout', mock_stdout): + self.command.handle(**options) + + # Verify error message output + output = mock_stdout.getvalue() + self.assertIn('Unexpected error during migration: Unexpected error', output) + + @patch.dict(os.environ, {'GATEWAY_BASE_URL': 'https://gateway.example.com', 'GATEWAY_USER': 'testuser', 'GATEWAY_PASSWORD': 'testpass'}) + @patch('awx.main.management.commands.import_auth_config_to_gateway.GatewayClient') + @patch('awx.main.management.commands.import_auth_config_to_gateway.GitHubMigrator') + @patch('sys.stdout', new_callable=StringIO) + def test_force_flag_passed_to_migrators(self, mock_stdout, mock_github, mock_gateway_client): + """Test that force flag is properly passed to migrators.""" + # Mock gateway client context manager + mock_client_instance = Mock() + mock_gateway_client.return_value.__enter__.return_value = mock_client_instance + mock_gateway_client.return_value.__exit__.return_value = None + + # Mock migrator + mock_migrator = Mock() + mock_migrator.get_authenticator_type.return_value = 'GitHub' + mock_migrator.migrate.return_value = { + 'created': 0, + 'updated': 0, + 'unchanged': 0, + 'failed': 0, + 'mappers_created': 0, + 'mappers_updated': 0, + 'mappers_failed': 0, + } + mock_github.return_value = mock_migrator + + options = { + 'basic_auth': True, + 'skip_oidc': False, + 'skip_ldap': True, + 'skip_ad': True, + 'skip_saml': True, + 'skip_radius': True, + 'skip_tacacs': True, + 'skip_google': True, + 'force': True, + } + + with patch.object(self.command, 'stdout', mock_stdout): + self.command.handle(**options) + + # Verify migrator was created with force=True + mock_github.assert_called_once_with(mock_client_instance, self.command, force=True) + + @patch('sys.stdout', new_callable=StringIO) + def test_print_export_summary(self, mock_stdout): + """Test the _print_export_summary method.""" + result = { + 'created': 2, + 'updated': 1, + 'unchanged': 3, + 'failed': 0, + 'mappers_created': 5, + 'mappers_updated': 2, + 'mappers_failed': 1, + } + + with patch.object(self.command, 'stdout', mock_stdout): + self.command._print_export_summary('SAML', result) + + output = mock_stdout.getvalue() + self.assertIn('--- SAML Export Summary ---', output) + self.assertIn('Authenticators created: 2', output) + self.assertIn('Authenticators updated: 1', output) + self.assertIn('Authenticators unchanged: 3', output) + self.assertIn('Authenticators failed: 0', output) + self.assertIn('Mappers created: 5', output) + self.assertIn('Mappers updated: 2', output) + self.assertIn('Mappers failed: 1', output) + + @patch('sys.stdout', new_callable=StringIO) + def test_print_export_summary_missing_keys(self, mock_stdout): + """Test _print_export_summary handles missing keys gracefully.""" + result = { + 'created': 1, + 'updated': 2, + # Missing other keys + } + + with patch.object(self.command, 'stdout', mock_stdout): + self.command._print_export_summary('LDAP', result) + + output = mock_stdout.getvalue() + self.assertIn('--- LDAP Export Summary ---', output) + self.assertIn('Authenticators created: 1', output) + self.assertIn('Authenticators updated: 2', output) + self.assertIn('Authenticators unchanged: 0', output) # Default value + self.assertIn('Mappers created: 0', output) # Default value + + @patch.dict(os.environ, {'GATEWAY_BASE_URL': 'https://gateway.example.com', 'GATEWAY_USER': 'testuser', 'GATEWAY_PASSWORD': 'testpass'}) + @patch('awx.main.management.commands.import_auth_config_to_gateway.GatewayClient') + @patch('awx.main.management.commands.import_auth_config_to_gateway.GitHubMigrator') + @patch('awx.main.management.commands.import_auth_config_to_gateway.OIDCMigrator') + @patch('sys.stdout', new_callable=StringIO) + def test_total_results_accumulation(self, mock_stdout, mock_oidc, mock_github, mock_gateway_client): + """Test that results from multiple migrators are properly accumulated.""" + # Mock gateway client context manager + mock_client_instance = Mock() + mock_gateway_client.return_value.__enter__.return_value = mock_client_instance + mock_gateway_client.return_value.__exit__.return_value = None + + # Mock migrators with different results + mock_github_migrator = Mock() + mock_github_migrator.get_authenticator_type.return_value = 'GitHub' + mock_github_migrator.migrate.return_value = { + 'created': 1, + 'updated': 0, + 'unchanged': 0, + 'failed': 0, + 'mappers_created': 2, + 'mappers_updated': 0, + 'mappers_failed': 0, + } + mock_github.return_value = mock_github_migrator + + mock_oidc_migrator = Mock() + mock_oidc_migrator.get_authenticator_type.return_value = 'OIDC' + mock_oidc_migrator.migrate.return_value = { + 'created': 0, + 'updated': 1, + 'unchanged': 1, + 'failed': 0, + 'mappers_created': 1, + 'mappers_updated': 1, + 'mappers_failed': 0, + } + mock_oidc.return_value = mock_oidc_migrator + + options = { + 'basic_auth': True, + 'skip_oidc': False, + 'skip_ldap': True, + 'skip_ad': True, + 'skip_saml': True, + 'skip_radius': True, + 'skip_tacacs': True, + 'skip_google': True, + 'force': False, + } + + with patch.object(self.command, 'stdout', mock_stdout): + self.command.handle(**options) + + # Verify total results are accumulated correctly + output = mock_stdout.getvalue() + self.assertIn('Total authenticators created: 1', output) # 1 + 0 + self.assertIn('Total authenticators updated: 1', output) # 0 + 1 + self.assertIn('Total authenticators unchanged: 1', output) # 0 + 1 + self.assertIn('Total authenticators failed: 0', output) # 0 + 0 + self.assertIn('Total mappers created: 3', output) # 2 + 1 + self.assertIn('Total mappers updated: 1', output) # 0 + 1 + self.assertIn('Total mappers failed: 0', output) # 0 + 0 + + @patch('sys.stdout', new_callable=StringIO) + def test_environment_variable_parsing(self, mock_stdout): + """Test that environment variables are parsed correctly.""" + test_cases = [ + ('true', True), + ('1', True), + ('yes', True), + ('on', True), + ('TRUE', True), + ('false', False), + ('0', False), + ('no', False), + ('off', False), + ('', False), + ('random', False), + ] + + for env_value, expected in test_cases: + with patch.dict( + os.environ, + { + 'GATEWAY_BASE_URL': 'https://gateway.example.com', + 'GATEWAY_USER': 'testuser', + 'GATEWAY_PASSWORD': 'testpass', + 'GATEWAY_SKIP_VERIFY': env_value, + }, + ): + with patch('awx.main.management.commands.import_auth_config_to_gateway.GatewayClient') as mock_gateway_client: + # Mock gateway client context manager + mock_client_instance = Mock() + mock_gateway_client.return_value.__enter__.return_value = mock_client_instance + mock_gateway_client.return_value.__exit__.return_value = None + + options = { + 'basic_auth': True, + 'skip_oidc': True, + 'skip_ldap': True, + 'skip_ad': True, + 'skip_saml': True, + 'skip_radius': True, + 'skip_tacacs': True, + 'skip_google': True, + 'force': False, + } + + with patch.object(self.command, 'stdout', mock_stdout): + self.command.handle(**options) + + # Verify gateway client was called with correct skip_verify value + mock_gateway_client.assert_called_once_with( + base_url='https://gateway.example.com', username='testuser', password='testpass', skip_verify=expected, command=self.command + ) + + # Reset for next iteration + mock_gateway_client.reset_mock() + mock_stdout.seek(0) + mock_stdout.truncate(0) diff --git a/awx/main/utils/gateway_client.py b/awx/main/utils/gateway_client.py index b7ceaca97f..748423c21a 100644 --- a/awx/main/utils/gateway_client.py +++ b/awx/main/utils/gateway_client.py @@ -27,7 +27,7 @@ class GatewayAPIError(Exception): class GatewayClient: """Client for AAP Gateway REST API interactions.""" - def __init__(self, base_url: str, username: str, password: str, skip_verify: bool = False): + def __init__(self, base_url: str, username: str, password: str, skip_verify: bool = False, skip_session_init: bool = False, command=None): """Initialize Gateway client. Args: @@ -35,31 +35,38 @@ class GatewayClient: username: Username for authentication password: Password for authentication skip_verify: Skip SSL certificate verification + skip_session_init: Skip initializing the session. Only set to True if you are using a base class that doesn't need the initialization of the session. + command: The command object. This is used to write output to the console. """ self.base_url = base_url.rstrip('/') self.username = username self.password = password self.skip_verify = skip_verify + self.command = command + self.session_was_not_initialized = skip_session_init # Initialize session - self.session = requests.Session() + if not skip_session_init: + self.session = requests.Session() - # Configure SSL verification - if skip_verify: - self.session.verify = False - # Disable SSL warnings when verification is disabled - import urllib3 + # Configure SSL verification + if skip_verify: + self.session.verify = False + # Disable SSL warnings when verification is disabled + import urllib3 - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - # Set default headers - self.session.headers.update( - { - 'User-Agent': 'AWX-Gateway-Migration-Client/1.0', - 'Accept': 'application/json', - 'Content-Type': 'application/json', - } - ) + # Set default headers + self.session.headers.update( + { + 'User-Agent': 'AWX-Gateway-Migration-Client/1.0', + 'Accept': 'application/json', + 'Content-Type': 'application/json', + } + ) + else: + self.session = None # Authentication state self._authenticated = False @@ -402,3 +409,15 @@ class GatewayClient: def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit.""" self.close() + + def _write_output(self, message, style=None): + """Write output message if command is available.""" + if self.command: + if style == 'success': + self.command.stdout.write(self.command.style.SUCCESS(message)) + elif style == 'warning': + self.command.stdout.write(self.command.style.WARNING(message)) + elif style == 'error': + self.command.stdout.write(self.command.style.ERROR(message)) + else: + self.command.stdout.write(message) diff --git a/awx/main/utils/gateway_client_svc_token.py b/awx/main/utils/gateway_client_svc_token.py new file mode 100644 index 0000000000..1a7cd8fcac --- /dev/null +++ b/awx/main/utils/gateway_client_svc_token.py @@ -0,0 +1,77 @@ +""" +Gateway API client for AAP Gateway interactions with Service Tokens. + +This module provides a client class to interact with the AAP Gateway REST API, +specifically for creating authenticators and mapping configurations. +""" + +import requests +import logging +from typing import Dict, Optional +from awx.main.utils.gateway_client import GatewayClient, GatewayAPIError + + +logger = logging.getLogger(__name__) + + +class GatewayClientSVCToken(GatewayClient): + """Client for AAP Gateway REST API interactions.""" + + def __init__(self, resource_api_client=None, command=None): + """Initialize Gateway client. + + Args: + resource_api_client: Resource API Client for Gateway leveraging service tokens + """ + super().__init__( + base_url=resource_api_client.base_url, + username=resource_api_client.jwt_user_id, + password="required-in-GatewayClient-authenticate()-but-unused-by-GatewayClientSVCToken", + skip_verify=(not resource_api_client.verify_https), + skip_session_init=True, + command=command, + ) + self.resource_api_client = resource_api_client + # Authentication state + self._authenticated = True + + def authenticate(self) -> bool: + """Overload the base class method to always return True. + + Returns: + bool: True always + """ + + return True + + def _ensure_authenticated(self): + """Refresh JWT service token""" + self.resource_api_client.refresh_jwt() + + def _make_request(self, method: str, endpoint: str, data: Optional[Dict] = None, params: Optional[Dict] = None) -> requests.Response: + """Make a service token authenticated request to the Gateway API. + + Args: + method: HTTP method (GET, POST, PUT, DELETE, etc.) + endpoint: API endpoint (without base URL) + data: JSON data to send in request body + params: Query parameters + + Returns: + requests.Response: The response object + + Raises: + GatewayAPIError: If request fails + """ + self._ensure_authenticated() + + try: + response = self.resource_api_client._make_request(method=method, path=endpoint, data=data, params=params) + + # Log request details + logger.debug(f"{method.upper()} {self.base_url}{endpoint} - Status: {response.status_code}") + + return response + + except requests.RequestException as e: + raise GatewayAPIError(f"Request failed: {str(e)}")