diff --git a/awx/main/constants.py b/awx/main/constants.py index 3b45153721..44262a1838 100644 --- a/awx/main/constants.py +++ b/awx/main/constants.py @@ -11,6 +11,7 @@ __all__ = [ 'CAN_CANCEL', 'ACTIVE_STATES', 'STANDARD_INVENTORY_UPDATE_ENV', + 'OIDC_CREDENTIAL_TYPE_NAMESPACES', ] PRIVILEGE_ESCALATION_METHODS = [ @@ -140,3 +141,6 @@ org_role_to_permission = { 'execution_environment_admin_role': 'add_executionenvironment', 'auditor_role': 'view_project', # TODO: also doesnt really work } + +# OIDC credential type namespaces for feature flag filtering +OIDC_CREDENTIAL_TYPE_NAMESPACES = ['hashivault-kv-oidc', 'hashivault-ssh-oidc'] diff --git a/awx/main/models/credential.py b/awx/main/models/credential.py index 63edca398b..a035718202 100644 --- a/awx/main/models/credential.py +++ b/awx/main/models/credential.py @@ -28,6 +28,7 @@ from rest_framework.serializers import ValidationError as DRFValidationError from ansible_base.lib.utils.db import advisory_lock # AWX +from awx.main.constants import OIDC_CREDENTIAL_TYPE_NAMESPACES from awx.api.versioning import reverse from awx.main.fields import ( ImplicitRoleField, @@ -458,13 +459,15 @@ class CredentialType(CommonModelNameNotUnique): def from_db(cls, db, field_names, values): instance = super(CredentialType, cls).from_db(db, field_names, values) if instance.managed and instance.namespace and instance.kind != "external": - native = ManagedCredentialType.registry[instance.namespace] - instance.inputs = native.inputs - instance.injectors = native.injectors - instance.custom_injectors = getattr(native, 'custom_injectors', None) + native = ManagedCredentialType.registry.get(instance.namespace) + if native: + instance.inputs = native.inputs + instance.injectors = native.injectors + instance.custom_injectors = getattr(native, 'custom_injectors', None) elif instance.namespace and instance.kind == "external": - native = ManagedCredentialType.registry[instance.namespace] - instance.inputs = native.inputs + native = ManagedCredentialType.registry.get(instance.namespace) + if native: + instance.inputs = native.inputs return instance @@ -683,13 +686,20 @@ class CredentialInputSource(PrimordialModel): return reverse(view_name, kwargs={'pk': self.pk}, request=request) -def load_credentials(): +def _is_oidc_namespace_disabled(ns): + """Check if a credential namespace should be skipped based on the OIDC feature flag.""" + return ns in OIDC_CREDENTIAL_TYPE_NAMESPACES and not getattr(settings, 'FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED', False) + +def load_credentials(): awx_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials')} supported_entry_points = {ep.name: ep for ep in entry_points(group='awx_plugins.managed_credentials.supported')} plugin_entry_points = awx_entry_points if detect_server_product_name() == 'AWX' else {**awx_entry_points, **supported_entry_points} for ns, ep in plugin_entry_points.items(): + if _is_oidc_namespace_disabled(ns): + continue + cred_plugin = ep.load() if not hasattr(cred_plugin, 'inputs'): setattr(cred_plugin, 'inputs', {}) @@ -708,5 +718,8 @@ def load_credentials(): credential_plugins = {} for ns, ep in credential_plugins.items(): + if _is_oidc_namespace_disabled(ns): + continue + plugin = ep.load() CredentialType.load_plugin(ns, plugin) diff --git a/awx/main/tests/functional/api/test_oidc_credential_feature_flag.py b/awx/main/tests/functional/api/test_oidc_credential_feature_flag.py new file mode 100644 index 0000000000..e76e615289 --- /dev/null +++ b/awx/main/tests/functional/api/test_oidc_credential_feature_flag.py @@ -0,0 +1,163 @@ +""" +Tests for OIDC workload identity credential type feature flag. + +The FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED flag is an install-time flag that +controls whether OIDC credential types are loaded into the registry at startup. +When disabled, OIDC credential types are not loaded and do not exist in the database. +""" + +import pytest +from unittest import mock + +from django.test import override_settings + +from awx.main.constants import OIDC_CREDENTIAL_TYPE_NAMESPACES +from awx.main.models.credential import CredentialType, ManagedCredentialType, load_credentials +from awx.api.versioning import reverse + + +@pytest.fixture +def reload_credentials_with_flag(django_db_setup, django_db_blocker): + """ + Fixture that reloads credentials with a specific flag state. + This simulates what happens at application startup. + """ + # Save original registry state + original_registry = ManagedCredentialType.registry.copy() + + def _reload(flag_enabled): + with django_db_blocker.unblock(): + # Clear the entire registry before reloading + ManagedCredentialType.registry.clear() + + # Reload credentials with the specified flag state + with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=flag_enabled): + with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'): + load_credentials() + + # Sync to database + CredentialType.setup_tower_managed_defaults(lock=False) + + # In tests, the session fixture pre-loads all credential types into the DB. + # Remove OIDC types when testing the disabled state so the API test is accurate. + if not flag_enabled: + CredentialType.objects.filter(namespace__in=OIDC_CREDENTIAL_TYPE_NAMESPACES).delete() + + yield _reload + + # Restore original registry state after tests + ManagedCredentialType.registry.clear() + ManagedCredentialType.registry.update(original_registry) + + +@pytest.fixture +def isolated_registry(): + """Save and restore the ManagedCredentialType registry, with full isolation via mocked entry_points.""" + original_registry = ManagedCredentialType.registry.copy() + ManagedCredentialType.registry.clear() + yield + ManagedCredentialType.registry.clear() + ManagedCredentialType.registry.update(original_registry) + + +def _make_mock_entry_point(name): + """Create a mock entry point that mimics a credential plugin.""" + ep = mock.MagicMock() + ep.name = name + ep.value = f'test_plugin:{name}' + plugin = mock.MagicMock(spec=[]) + ep.load.return_value = plugin + return ep + + +def _mock_entry_points_factory(managed_names, supported_names): + """Return a side_effect function for mocking entry_points() with controlled plugins.""" + managed = [_make_mock_entry_point(n) for n in managed_names] + supported = [_make_mock_entry_point(n) for n in supported_names] + + def _entry_points(group): + if group == 'awx_plugins.managed_credentials': + return managed + elif group == 'awx_plugins.managed_credentials.supported': + return supported + return [] + + return _entry_points + + +# --- Unit tests for load_credentials() registry behavior --- + + +def test_oidc_types_in_registry_when_flag_enabled(isolated_registry): + """Test that OIDC credential types are added to the registry when flag is enabled.""" + mock_eps = _mock_entry_points_factory( + managed_names=['ssh', 'vault'], + supported_names=['hashivault-kv-oidc', 'hashivault-ssh-oidc'], + ) + with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=True): + with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'): + with mock.patch('awx.main.models.credential.entry_points', side_effect=mock_eps): + load_credentials() + + for ns in OIDC_CREDENTIAL_TYPE_NAMESPACES: + assert ns in ManagedCredentialType.registry, f"{ns} should be in registry when flag is enabled" + assert 'ssh' in ManagedCredentialType.registry + assert 'vault' in ManagedCredentialType.registry + + +def test_oidc_types_not_in_registry_when_flag_disabled(isolated_registry): + """Test that OIDC credential types are excluded from the registry when flag is disabled.""" + mock_eps = _mock_entry_points_factory( + managed_names=['ssh', 'vault'], + supported_names=['hashivault-kv-oidc', 'hashivault-ssh-oidc'], + ) + with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=False): + with mock.patch('awx.main.models.credential.detect_server_product_name', return_value='NOT_AWX'): + with mock.patch('awx.main.models.credential.entry_points', side_effect=mock_eps): + load_credentials() + + for ns in OIDC_CREDENTIAL_TYPE_NAMESPACES: + assert ns not in ManagedCredentialType.registry, f"{ns} should not be in registry when flag is disabled" + # Non-OIDC types should still be loaded + assert 'ssh' in ManagedCredentialType.registry + assert 'vault' in ManagedCredentialType.registry + + +def test_oidc_namespaces_constant(): + """Test that OIDC_CREDENTIAL_TYPE_NAMESPACES contains the expected namespaces.""" + assert 'hashivault-kv-oidc' in OIDC_CREDENTIAL_TYPE_NAMESPACES + assert 'hashivault-ssh-oidc' in OIDC_CREDENTIAL_TYPE_NAMESPACES + assert len(OIDC_CREDENTIAL_TYPE_NAMESPACES) == 2 + + +# --- Functional API tests --- + + +@pytest.mark.django_db +def test_oidc_types_loaded_when_flag_enabled(get, admin, reload_credentials_with_flag): + """Test that OIDC credential types are visible in the API when flag is enabled.""" + reload_credentials_with_flag(flag_enabled=True) + + response = get(reverse('api:credential_type_list'), admin) + assert response.status_code == 200 + + namespaces = [ct['namespace'] for ct in response.data['results']] + assert 'hashivault-kv-oidc' in namespaces + assert 'hashivault-ssh-oidc' in namespaces + + +@pytest.mark.django_db +def test_oidc_types_not_loaded_when_flag_disabled(get, admin, reload_credentials_with_flag): + """Test that OIDC credential types are not visible in the API when flag is disabled.""" + reload_credentials_with_flag(flag_enabled=False) + + response = get(reverse('api:credential_type_list'), admin) + assert response.status_code == 200 + + namespaces = [ct['namespace'] for ct in response.data['results']] + assert 'hashivault-kv-oidc' not in namespaces + assert 'hashivault-ssh-oidc' not in namespaces + + # Verify they're also not in the database + assert not CredentialType.objects.filter(namespace='hashivault-kv-oidc').exists() + assert not CredentialType.objects.filter(namespace='hashivault-ssh-oidc').exists() diff --git a/awx/main/tests/functional/test_credential.py b/awx/main/tests/functional/test_credential.py index cfaea5df90..5d2106dbab 100644 --- a/awx/main/tests/functional/test_credential.py +++ b/awx/main/tests/functional/test_credential.py @@ -74,49 +74,64 @@ GLqbpJyX2r3p/Rmo6mLY71SqpA== @pytest.mark.django_db def test_default_cred_types(): - assert sorted(CredentialType.defaults.keys()) == sorted( - [ - 'aim', - 'aws', - 'aws_secretsmanager_credential', - 'azure_kv', - 'azure_rm', - 'bitbucket_dc_token', - 'centrify_vault_kv', - 'conjur', - 'controller', - 'galaxy_api_token', - 'gce', - 'github_token', - 'github_app_lookup', - 'gitlab_token', - 'gpg_public_key', - 'hashivault_kv', - 'hashivault_ssh', - 'hashivault-kv-oidc', - 'hashivault-ssh-oidc', - 'hcp_terraform', - 'insights', - 'kubernetes_bearer_token', - 'net', - 'openstack', - 'registry', - 'rhv', - 'satellite6', - 'scm', - 'ssh', - 'terraform', - 'thycotic_dsv', - 'thycotic_tss', - 'vault', - 'vmware', - ] - ) + expected = [ + 'aim', + 'aws', + 'aws_secretsmanager_credential', + 'azure_kv', + 'azure_rm', + 'bitbucket_dc_token', + 'centrify_vault_kv', + 'conjur', + 'controller', + 'galaxy_api_token', + 'gce', + 'github_token', + 'github_app_lookup', + 'gitlab_token', + 'gpg_public_key', + 'hashivault_kv', + 'hashivault_ssh', + 'hcp_terraform', + 'insights', + 'kubernetes_bearer_token', + 'net', + 'openstack', + 'registry', + 'rhv', + 'satellite6', + 'scm', + 'ssh', + 'terraform', + 'thycotic_dsv', + 'thycotic_tss', + 'vault', + 'vmware', + ] + assert sorted(CredentialType.defaults.keys()) == sorted(expected) + assert 'hashivault-kv-oidc' not in CredentialType.defaults + assert 'hashivault-ssh-oidc' not in CredentialType.defaults for type_ in CredentialType.defaults.values(): assert type_().managed is True +@pytest.mark.django_db +def test_default_cred_types_with_oidc_enabled(): + from django.test import override_settings + from awx.main.models.credential import load_credentials, ManagedCredentialType + + original_registry = ManagedCredentialType.registry.copy() + try: + with override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=True): + ManagedCredentialType.registry.clear() + load_credentials() + assert 'hashivault-kv-oidc' in CredentialType.defaults + assert 'hashivault-ssh-oidc' in CredentialType.defaults + finally: + ManagedCredentialType.registry = original_registry + + @pytest.mark.django_db def test_credential_creation(organization_factory): org = organization_factory('test').organization diff --git a/awx/settings/defaults.py b/awx/settings/defaults.py index 1f6b7ae2b2..30874b6d61 100644 --- a/awx/settings/defaults.py +++ b/awx/settings/defaults.py @@ -1134,6 +1134,7 @@ OPA_REQUEST_RETRIES = 2 # The number of retry attempts for connecting to the OP # feature flags FEATURE_INDIRECT_NODE_COUNTING_ENABLED = False +FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED = False # Dispatcher worker lifetime. If set to None, workers will never be retired # based on age. Note workers will finish their last task before retiring if