From b83019bde649c8dec68ceacd4e1b286a7da1a275 Mon Sep 17 00:00:00 2001 From: Daniel Finca Date: Mon, 6 Apr 2026 21:56:11 +0200 Subject: [PATCH] feat: support for oidc credential /test endpoint (#16370) Adds support for testing external credentials that use OIDC workload identity tokens. When FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED is enabled, the /test endpoints return JWT payload details alongside test results. - Add OIDC credential test endpoints with job template selection - Return JWT payload and secret value in test response - Maintain backward compatibility (detail field for errors) - Add comprehensive unit and functional tests - Refactor shared error handling logic Co-authored-by: Daniel Finca Co-authored-by: melissalkelly --- awx/api/serializers.py | 14 +- awx/api/views/__init__.py | 242 +++++++++++++--- awx/main/tasks/jobs.py | 16 +- .../functional/api/test_credential_type.py | 99 ++++++- .../api/test_oidc_credential_test.py | 259 ++++++++++++++++++ awx/main/tests/unit/tasks/test_jobs.py | 8 +- awx/main/utils/workload_identity.py | 22 ++ 7 files changed, 613 insertions(+), 47 deletions(-) create mode 100644 awx/main/tests/functional/api/test_oidc_credential_test.py diff --git a/awx/api/serializers.py b/awx/api/serializers.py index 2a8e94d833..453c6716e7 100644 --- a/awx/api/serializers.py +++ b/awx/api/serializers.py @@ -122,7 +122,6 @@ from awx.main.scheduler.task_manager_models import TaskManagerModels from awx.main.redact import UriCleaner, REPLACE_STR from awx.main.signals import update_inventory_computed_fields - from awx.main.validators import vars_validate_or_raise from awx.api.versioning import reverse @@ -2932,6 +2931,19 @@ class CredentialTypeSerializer(BaseSerializer): field['label'] = _(field['label']) if 'help_text' in field: field['help_text'] = _(field['help_text']) + + # Deep copy inputs to avoid modifying the original model data + inputs = value.get('inputs') + if not isinstance(inputs, dict): + inputs = {} + value['inputs'] = copy.deepcopy(inputs) + fields = value['inputs'].get('fields', []) + if not isinstance(fields, list): + fields = [] + + # Normalize fields and filter out internal fields + value['inputs']['fields'] = [f for f in fields if not f.get('internal')] + return value def filter_field_metadata(self, fields, method): diff --git a/awx/api/views/__init__.py b/awx/api/views/__init__.py index 1456e8bb16..b8b582805e 100644 --- a/awx/api/views/__init__.py +++ b/awx/api/views/__init__.py @@ -14,6 +14,7 @@ import sys import time from base64 import b64encode from collections import OrderedDict +from jwt import decode as _jwt_decode from urllib3.exceptions import ConnectTimeoutError @@ -58,8 +59,13 @@ from drf_spectacular.utils import extend_schema_view, extend_schema from ansible_base.lib.utils.requests import get_remote_hosts from ansible_base.rbac.models import RoleEvaluation from ansible_base.lib.utils.schema import extend_schema_if_available +from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope + +# flags +from flags.state import flag_enabled # AWX +from awx.main.tasks.jobs import retrieve_workload_identity_jwt_with_claims from awx.main.tasks.system import send_notifications, update_inventory_computed_fields from awx.main.access import get_user_queryset from awx.api.generics import ( @@ -1595,7 +1601,177 @@ class CredentialCopy(CopyAPIView): resource_purpose = 'copy of a credential' -class CredentialExternalTest(SubDetailAPIView): +class OIDCCredentialTestMixin: + """ + Mixin to add OIDC workload identity token support to credential test endpoints. + + This mixin provides methods to handle OIDC-enabled external credentials that use + workload identity tokens for authentication. + """ + + @staticmethod + def _get_workload_identity_token(job_template: models.JobTemplate, jwt_aud: str) -> str: + """Generate a workload identity token for a job template. + + Args: + job_template: The JobTemplate instance to generate claims for + jwt_aud: The JWT audience claim value + + Returns: + str: The generated JWT token + """ + claims = { + AutomationControllerJobScope.CLAIM_ORGANIZATION_NAME: job_template.organization.name, + AutomationControllerJobScope.CLAIM_ORGANIZATION_ID: job_template.organization.id, + AutomationControllerJobScope.CLAIM_PROJECT_NAME: job_template.project.name, + AutomationControllerJobScope.CLAIM_PROJECT_ID: job_template.project.id, + AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_NAME: job_template.name, + AutomationControllerJobScope.CLAIM_JOB_TEMPLATE_ID: job_template.id, + AutomationControllerJobScope.CLAIM_PLAYBOOK_NAME: job_template.playbook, + } + return retrieve_workload_identity_jwt_with_claims( + claims=claims, + audience=jwt_aud, + scope=AutomationControllerJobScope.name, + ) + + @staticmethod + def _decode_jwt_payload_for_display(jwt_token): + """Decode JWT payload for display purposes only (signature not verified). + + This is safe because the JWT was just created by AWX and is only decoded + to show the user what claims are being sent to the external system. + The external system will perform proper signature verification. + + Args: + jwt_token: The JWT token to decode + + Returns: + dict: The decoded JWT payload + """ + return _jwt_decode(jwt_token, algorithms=["RS256"], options={"verify_signature": False}) # NOSONAR python:S5659 + + def _has_workload_identity_token(self, credential_type_inputs): + """Check if credential type has an internal workload_identity_token field. + + Args: + credential_type_inputs: The inputs dict from a credential type + + Returns: + bool: True if the credential type has a workload_identity_token field marked as internal + """ + fields = credential_type_inputs.get('fields', []) if isinstance(credential_type_inputs, dict) else [] + return any(field.get('internal') and field.get('id') == 'workload_identity_token' for field in fields) + + def _validate_and_get_job_template(self, job_template_id): + """Validate job template ID and return the JobTemplate instance. + + Args: + job_template_id: The job template ID from metadata + + Returns: + JobTemplate instance + + Raises: + ParseError: If job_template_id is invalid or not found + """ + if job_template_id is None: + raise ParseError(_('Job template ID is required.')) + + try: + return models.JobTemplate.objects.get(id=int(job_template_id)) + except ValueError: + raise ParseError(_('Job template ID must be an integer.')) + except models.JobTemplate.DoesNotExist: + raise ParseError(_('Job template with ID %(id)s does not exist.') % {'id': job_template_id}) + + def _handle_oidc_credential_test(self, backend_kwargs): + """ + Handle OIDC workload identity token generation for external credential test endpoints. + + This method should only be called when FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED is enabled + and the credential type has a workload_identity_token field. + + Args: + backend_kwargs: The kwargs dict to pass to the backend (will be modified in place) + + Returns: + dict: Response body containing details with the sent JWT payload + + Raises: + PermissionDenied: If user lacks access to the job template (re-raised for 403 response) + + All other exceptions are caught and converted to 400 responses with error details. + + Modifies backend_kwargs in place to add workload_identity_token. + """ + # Validate job template + job_template_id = backend_kwargs.pop('job_template_id', None) + job_template = self._validate_and_get_job_template(job_template_id) + + # Check user access + if not self.request.user.can_access(models.JobTemplate, 'start', job_template): + raise PermissionDenied(_('You do not have access to job template with id: %(id)s.') % {'id': job_template.id}) + + # Generate workload identity token + jwt_token = self._get_workload_identity_token(job_template, backend_kwargs.pop('jwt_aud', None)) + backend_kwargs['workload_identity_token'] = jwt_token + + return {'details': {'sent_jwt_payload': self._decode_jwt_payload_for_display(jwt_token)}} + + def _call_backend_with_error_handling(self, plugin, backend_kwargs, response_body): + """Call credential backend and handle errors, adding secret_value to response if OIDC details present.""" + try: + with set_environ(**settings.AWX_TASK_ENV): + secret_value = plugin.backend(**backend_kwargs) + if 'details' in response_body: + response_body['details']['secret_value'] = secret_value + return Response(response_body, status=status.HTTP_202_ACCEPTED) + except requests.exceptions.HTTPError as exc: + message = self._extract_http_error_message(exc) + self._add_error_to_response(response_body, message) + return Response(response_body, status=status.HTTP_400_BAD_REQUEST) + except Exception as exc: + message = self._extract_generic_error_message(exc) + self._add_error_to_response(response_body, message) + return Response(response_body, status=status.HTTP_400_BAD_REQUEST) + + @staticmethod + def _extract_http_error_message(exc): + """Extract error message from HTTPError, checking response JSON and text.""" + message = str(exc) + if not hasattr(exc, 'response') or exc.response is None: + return message + + try: + error_data = exc.response.json() + if 'errors' in error_data and error_data['errors']: + return ', '.join(error_data['errors']) + if 'error' in error_data: + return error_data['error'] + except (ValueError, KeyError): + if exc.response.text: + return exc.response.text + return message + + @staticmethod + def _extract_generic_error_message(exc): + """Extract error message from exception, handling ConnectTimeoutError specially.""" + message = str(exc) if str(exc) else exc.__class__.__name__ + for arg in getattr(exc, 'args', []): + if isinstance(getattr(arg, 'reason', None), ConnectTimeoutError): + return str(arg.reason) + return message + + @staticmethod + def _add_error_to_response(response_body, message): + """Add error message to both 'detail' and 'details.error_message' fields.""" + response_body['detail'] = message + if 'details' in response_body: + response_body['details']['error_message'] = message + + +class CredentialExternalTest(OIDCCredentialTestMixin, SubDetailAPIView): """ Test updates to the input values and metadata of an external credential before saving them. @@ -1622,23 +1798,22 @@ class CredentialExternalTest(SubDetailAPIView): if value != '$encrypted$': backend_kwargs[field_name] = value backend_kwargs.update(request.data.get('metadata', {})) - try: - with set_environ(**settings.AWX_TASK_ENV): - obj.credential_type.plugin.backend(**backend_kwargs) - return Response({}, status=status.HTTP_202_ACCEPTED) - except requests.exceptions.HTTPError: - message = """Test operation is not supported for credential type {}. - This endpoint only supports credentials that connect to - external secret management systems such as CyberArk, HashiCorp - Vault, or cloud-based secret managers.""".format(obj.credential_type.kind) - return Response({'detail': message}, status=status.HTTP_400_BAD_REQUEST) - except Exception as exc: - message = exc.__class__.__name__ - exc_args = getattr(exc, 'args', []) - for a in exc_args: - if isinstance(getattr(a, 'reason', None), ConnectTimeoutError): - message = str(a.reason) - return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) + + # Handle OIDC workload identity token generation if enabled + response_body = {} + if flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED') and self._has_workload_identity_token(obj.credential_type.inputs): + try: + oidc_response_body = self._handle_oidc_credential_test(backend_kwargs) + response_body.update(oidc_response_body) + except PermissionDenied: + raise + except Exception as exc: + error_message = str(exc.detail) if hasattr(exc, 'detail') else str(exc) + response_body['detail'] = error_message + response_body['details'] = {'error_message': error_message} + return Response(response_body, status=status.HTTP_400_BAD_REQUEST) + + return self._call_backend_with_error_handling(obj.credential_type.plugin, backend_kwargs, response_body) class CredentialInputSourceDetail(RetrieveUpdateDestroyAPIView): @@ -1668,7 +1843,7 @@ class CredentialInputSourceSubList(SubListCreateAPIView): parent_key = 'target_credential' -class CredentialTypeExternalTest(SubDetailAPIView): +class CredentialTypeExternalTest(OIDCCredentialTestMixin, SubDetailAPIView): """ Test a complete set of input values for an external credential before saving it. @@ -1685,19 +1860,22 @@ class CredentialTypeExternalTest(SubDetailAPIView): obj = self.get_object() backend_kwargs = request.data.get('inputs', {}) backend_kwargs.update(request.data.get('metadata', {})) - try: - obj.plugin.backend(**backend_kwargs) - return Response({}, status=status.HTTP_202_ACCEPTED) - except requests.exceptions.HTTPError as exc: - message = 'HTTP {}'.format(exc.response.status_code) - return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) - except Exception as exc: - message = exc.__class__.__name__ - args_exc = getattr(exc, 'args', []) - for a in args_exc: - if isinstance(getattr(a, 'reason', None), ConnectTimeoutError): - message = str(a.reason) - return Response({'inputs': message}, status=status.HTTP_400_BAD_REQUEST) + + # Handle OIDC workload identity token generation if enabled + response_body = {} + if flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED') and self._has_workload_identity_token(obj.inputs): + try: + oidc_response_body = self._handle_oidc_credential_test(backend_kwargs) + response_body.update(oidc_response_body) + except PermissionDenied: + raise + except Exception as exc: + error_message = str(exc.detail) if hasattr(exc, 'detail') else str(exc) + response_body['detail'] = error_message + response_body['details'] = {'error_message': error_message} + return Response(response_body, status=status.HTTP_400_BAD_REQUEST) + + return self._call_backend_with_error_handling(obj.plugin, backend_kwargs, response_body) class HostRelatedSearchMixin(object): diff --git a/awx/main/tasks/jobs.py b/awx/main/tasks/jobs.py index e6f6a893f1..dc93a4dd83 100644 --- a/awx/main/tasks/jobs.py +++ b/awx/main/tasks/jobs.py @@ -94,7 +94,7 @@ from flags.state import flag_enabled # Workload Identity from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope -from ansible_base.resource_registry.workload_identity_client import get_workload_identity_client +from awx.main.utils.workload_identity import retrieve_workload_identity_jwt_with_claims logger = logging.getLogger('awx.main.tasks.jobs') @@ -168,14 +168,12 @@ def retrieve_workload_identity_jwt( Raises: RuntimeError: if the workload identity client is not configured. """ - client = get_workload_identity_client() - if client is None: - raise RuntimeError("Workload identity client is not configured") - claims = populate_claims_for_workload(unified_job) - kwargs = {"claims": claims, "scope": scope, "audience": audience} - if workload_ttl_seconds: - kwargs["workload_ttl_seconds"] = workload_ttl_seconds - return client.request_workload_jwt(**kwargs).jwt + return retrieve_workload_identity_jwt_with_claims( + populate_claims_for_workload(unified_job), + audience, + scope, + workload_ttl_seconds, + ) def with_path_cleanup(f): diff --git a/awx/main/tests/functional/api/test_credential_type.py b/awx/main/tests/functional/api/test_credential_type.py index ed0f1e9f28..bc1637079f 100644 --- a/awx/main/tests/functional/api/test_credential_type.py +++ b/awx/main/tests/functional/api/test_credential_type.py @@ -2,6 +2,7 @@ import json import pytest +from ansible_base.lib.testing.util import feature_flag_enabled from awx.main.models.credential import CredentialType, Credential from awx.api.versioning import reverse @@ -159,7 +160,8 @@ def test_create_as_admin(get, post, admin): response = get(reverse('api:credential_type_list'), admin) assert response.data['count'] == 1 assert response.data['results'][0]['name'] == 'Custom Credential Type' - assert response.data['results'][0]['inputs'] == {} + # Serializer normalizes empty inputs to {'fields': []} + assert response.data['results'][0]['inputs'] == {'fields': []} assert response.data['results'][0]['injectors'] == {} assert response.data['results'][0]['managed'] is False @@ -474,3 +476,98 @@ def test_credential_type_rbac_external_test(post, alice, admin, credentialtype_e data = {'inputs': {}, 'metadata': {}} assert post(url, data, admin).status_code == 202 assert post(url, data, alice).status_code == 403 + + +# --- Tests for internal field filtering with None/invalid inputs --- + + +@pytest.mark.django_db +def test_credential_type_with_none_inputs(get, admin): + """Test that credential type with empty inputs dict works correctly.""" + # Create a credential type with empty dict + ct = CredentialType.objects.create( + kind='cloud', + name='Test Type', + managed=False, + inputs={}, # Empty dict, not None (DB has NOT NULL constraint) + ) + + url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk}) + response = get(url, admin) + assert response.status_code == 200 + # Should have normalized inputs to empty dict + assert 'inputs' in response.data + assert isinstance(response.data['inputs'], dict) + assert response.data['inputs']['fields'] == [] + + +@pytest.mark.django_db +def test_credential_type_with_invalid_inputs_type(get, admin): + """Test that credential type with non-dict inputs doesn't cause errors.""" + # Create a credential type with invalid inputs type + ct = CredentialType.objects.create(kind='cloud', name='Test Type', managed=False, inputs={'fields': 'not-a-list'}) + + url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk}) + response = get(url, admin) + assert response.status_code == 200 + # Should gracefully handle invalid fields type + assert 'inputs' in response.data + assert response.data['inputs']['fields'] == [] + + +@pytest.mark.django_db +def test_credential_type_filters_internal_fields(get, admin): + """Test that internal fields are filtered from API responses.""" + ct = CredentialType.objects.create( + kind='cloud', + name='Test OIDC Type', + managed=False, + inputs={ + 'fields': [ + {'id': 'url', 'label': 'URL', 'type': 'string'}, + {'id': 'token', 'label': 'Token', 'type': 'string', 'secret': True, 'internal': True}, + {'id': 'public_field', 'label': 'Public', 'type': 'string'}, + ] + }, + ) + + url = reverse('api:credential_type_detail', kwargs={'pk': ct.pk}) + with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'): + response = get(url, admin) + assert response.status_code == 200 + + field_ids = [f['id'] for f in response.data['inputs']['fields']] + # Internal field should be filtered out + assert 'token' not in field_ids + assert 'url' in field_ids + assert 'public_field' in field_ids + + +@pytest.mark.django_db +def test_credential_type_list_filters_internal_fields(get, admin): + """Test that internal fields are filtered in list view.""" + CredentialType.objects.create( + kind='cloud', + name='Test OIDC Type', + managed=False, + inputs={ + 'fields': [ + {'id': 'url', 'label': 'URL', 'type': 'string'}, + {'id': 'workload_identity_token', 'label': 'Token', 'type': 'string', 'secret': True, 'internal': True}, + ] + }, + ) + + url = reverse('api:credential_type_list') + with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'): + response = get(url, admin) + assert response.status_code == 200 + + # Find our credential type in the results + test_ct = next((ct for ct in response.data['results'] if ct['name'] == 'Test OIDC Type'), None) + assert test_ct is not None + + field_ids = [f['id'] for f in test_ct['inputs']['fields']] + # Internal field should be filtered out + assert 'workload_identity_token' not in field_ids + assert 'url' in field_ids diff --git a/awx/main/tests/functional/api/test_oidc_credential_test.py b/awx/main/tests/functional/api/test_oidc_credential_test.py new file mode 100644 index 0000000000..6837d1ca42 --- /dev/null +++ b/awx/main/tests/functional/api/test_oidc_credential_test.py @@ -0,0 +1,259 @@ +""" +Tests for OIDC workload identity credential test endpoints. + +Tests the /api/v2/credentials//test/ and /api/v2/credential_types//test/ +endpoints when used with OIDC-enabled credential types. +""" + +import pytest +from unittest import mock + +from django.test import override_settings + +from awx.main.models import Credential, CredentialType, JobTemplate +from awx.api.versioning import reverse + + +@pytest.fixture +def job_template(organization, project): + """Job template with organization and project for OIDC JWT generation.""" + return JobTemplate.objects.create(name='test-jt', organization=organization, project=project, playbook='helloworld.yml') + + +@pytest.fixture +def oidc_credentialtype(): + """Create a credential type with workload_identity_token internal field.""" + oidc_type_inputs = { + 'fields': [ + {'id': 'url', 'label': 'Vault URL', 'type': 'string', 'help_text': 'The Vault server URL.'}, + {'id': 'auth_path', 'label': 'Auth Path', 'type': 'string', 'help_text': 'JWT auth mount path.'}, + {'id': 'role_id', 'label': 'Role ID', 'type': 'string', 'help_text': 'Vault role.'}, + {'id': 'jwt_aud', 'label': 'JWT Audience', 'type': 'string', 'help_text': 'Expected audience.'}, + {'id': 'workload_identity_token', 'label': 'Workload Identity Token', 'type': 'string', 'secret': True, 'internal': True}, + ], + 'metadata': [ + {'id': 'secret_path', 'label': 'Secret Path', 'type': 'string'}, + {'id': 'job_template_id', 'label': 'Job Template ID', 'type': 'string'}, + ], + 'required': ['url', 'auth_path', 'role_id'], + } + + class MockPlugin(object): + def backend(self, **kwargs): + # Simulate successful backend call + return 'secret' + + with mock.patch('awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock) as mock_plugin: + mock_plugin.return_value = MockPlugin() + oidc_type = CredentialType(kind='external', managed=True, namespace='hashivault-kv-oidc', name='HashiCorp Vault KV (OIDC)', inputs=oidc_type_inputs) + oidc_type.save() + yield oidc_type + + +@pytest.fixture +def oidc_credential(oidc_credentialtype): + """Create a credential using the OIDC credential type.""" + return Credential.objects.create( + credential_type=oidc_credentialtype, + name='oidc-vault-cred', + inputs={'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role', 'jwt_aud': 'vault'}, + ) + + +@pytest.fixture +def mock_oidc_backend(): + """Fixture that mocks OIDC JWT generation and credential backend.""" + with mock.patch('awx.api.views.retrieve_workload_identity_jwt_with_claims') as mock_jwt, mock.patch('awx.api.views._jwt_decode') as mock_decode, mock.patch( + 'awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock + ) as mock_plugin: + + # Set default return values + mock_jwt.return_value = 'fake.jwt.token' + mock_decode.return_value = {'iss': 'http://gateway/o', 'aud': 'vault'} + + # Create mock backend + mock_backend = mock.MagicMock() + mock_backend.backend.return_value = 'secret' + mock_plugin.return_value = mock_backend + + # Yield all mocks for test customization + yield { + 'jwt': mock_jwt, + 'decode': mock_decode, + 'plugin': mock_plugin, + 'backend': mock_backend, + } + + +# --- Tests for CredentialExternalTest endpoint --- + + +@pytest.mark.django_db +@override_settings(FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED=False) +def test_credential_test_without_oidc_feature_flag(post, admin, oidc_credential): + """Test that credential test works without OIDC feature flag enabled.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': '1'}} + + with mock.patch('awx.main.models.credential.CredentialType.plugin', new_callable=mock.PropertyMock) as mock_plugin: + mock_backend = mock.MagicMock() + mock_backend.backend.return_value = 'secret' + mock_plugin.return_value = mock_backend + + response = post(url, data, admin) + assert response.status_code == 202 + # Should not contain JWT payload when feature flag is disabled + assert 'details' not in response.data or 'sent_jwt_payload' not in response.data.get('details', {}) + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +@pytest.mark.parametrize( + 'job_template_id, expected_error', + [ + (None, 'Job template ID is required'), + ('not-an-integer', 'must be an integer'), + ('99999', 'does not exist'), + ], + ids=['missing_job_template_id', 'invalid_job_template_id_type', 'nonexistent_job_template_id'], +) +def test_credential_test_job_template_validation(mock_flag, post, admin, oidc_credential, job_template_id, expected_error): + """Test that invalid job_template_id values return 400 with appropriate error messages.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret'}} + if job_template_id is not None: + data['metadata']['job_template_id'] = job_template_id + + response = post(url, data, admin) + assert response.status_code == 400 + assert 'details' in response.data + assert 'error_message' in response.data['details'] + assert expected_error in response.data['details']['error_message'] + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_test_no_access_to_job_template(mock_flag, post, alice, oidc_credential, job_template): + """Test that user without access to job template gets 403.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}} + + # Give alice use permission on credential but not on job template + oidc_credential.use_role.members.add(alice) + + response = post(url, data, alice) + assert response.status_code == 403 + assert 'You do not have access to job template' in str(response.data) + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_test_success_returns_jwt_payload(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend): + """Test that successful test returns JWT payload in response.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}} + + # Customize mock for this test + mock_oidc_backend['decode'].return_value = { + 'iss': 'http://gateway/o', + 'sub': 'system:serviceaccount:default:awx-operator', + 'aud': 'vault', + 'job_template_id': job_template.id, + } + + response = post(url, data, admin) + assert response.status_code == 202 + assert 'details' in response.data + assert 'sent_jwt_payload' in response.data['details'] + assert response.data['details']['sent_jwt_payload']['job_template_id'] == job_template.id + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_test_backend_failure_returns_jwt_and_error(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend): + """Test that backend failure still returns JWT payload along with error message.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}} + + # Make backend fail + mock_oidc_backend['backend'].backend.side_effect = RuntimeError('Connection failed') + + response = post(url, data, admin) + assert response.status_code == 400 + assert 'details' in response.data + # Both JWT payload and error message should be present + assert 'sent_jwt_payload' in response.data['details'] + assert 'error_message' in response.data['details'] + assert 'Connection failed' in response.data['details']['error_message'] + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_test_jwt_generation_failure(mock_flag, post, admin, oidc_credential, job_template): + """Test that JWT generation failure returns error without JWT payload.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}} + + with mock.patch('awx.api.views.OIDCCredentialTestMixin._get_workload_identity_token') as mock_jwt: + mock_jwt.side_effect = RuntimeError('Failed to generate JWT') + + response = post(url, data, admin) + assert response.status_code == 400 + assert 'details' in response.data + assert 'error_message' in response.data['details'] + assert 'Failed to generate JWT' in response.data['details']['error_message'] + # No JWT payload when generation fails + assert 'sent_jwt_payload' not in response.data['details'] + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_test_job_template_id_not_passed_to_backend(mock_flag, post, admin, oidc_credential, job_template, mock_oidc_backend): + """Test that job_template_id and jwt_aud are removed from backend_kwargs.""" + url = reverse('api:credential_external_test', kwargs={'pk': oidc_credential.pk}) + data = {'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}} + + response = post(url, data, admin) + assert response.status_code == 202 + + # Check that backend was called without job_template_id or jwt_aud + call_kwargs = mock_oidc_backend['backend'].backend.call_args[1] + assert 'job_template_id' not in call_kwargs + assert 'jwt_aud' not in call_kwargs + assert 'workload_identity_token' in call_kwargs + + +# --- Tests for CredentialTypeExternalTest endpoint --- + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_type_test_missing_job_template_id(mock_flag, post, admin, oidc_credentialtype): + """Test that missing job_template_id returns 400 for credential type test endpoint.""" + url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk}) + data = { + 'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role', 'jwt_aud': 'vault'}, + 'metadata': {'secret_path': 'test/secret'}, + } + + response = post(url, data, admin) + assert response.status_code == 400 + assert 'details' in response.data + assert 'error_message' in response.data['details'] + assert 'Job template ID is required' in response.data['details']['error_message'] + + +@pytest.mark.django_db +@mock.patch('awx.api.views.flag_enabled', return_value=True) +def test_credential_type_test_success_returns_jwt_payload(mock_flag, post, admin, oidc_credentialtype, job_template, mock_oidc_backend): + """Test that successful credential type test returns JWT payload.""" + url = reverse('api:credential_type_external_test', kwargs={'pk': oidc_credentialtype.pk}) + data = { + 'inputs': {'url': 'http://vault.example.com:8200', 'auth_path': 'jwt', 'role_id': 'test-role', 'jwt_aud': 'vault'}, + 'metadata': {'secret_path': 'test/secret', 'job_template_id': str(job_template.id)}, + } + + response = post(url, data, admin) + assert response.status_code == 202 + assert 'details' in response.data + assert 'sent_jwt_payload' in response.data['details'] diff --git a/awx/main/tests/unit/tasks/test_jobs.py b/awx/main/tests/unit/tasks/test_jobs.py index bcc6f4d0fd..e4df52b63f 100644 --- a/awx/main/tests/unit/tasks/test_jobs.py +++ b/awx/main/tests/unit/tasks/test_jobs.py @@ -473,7 +473,7 @@ def test_populate_claims_for_adhoc_command(workload_attrs, expected_claims): assert claims == expected_claims -@mock.patch('awx.main.tasks.jobs.get_workload_identity_client') +@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client') def test_retrieve_workload_identity_jwt_returns_jwt_from_client(mock_get_client): """retrieve_workload_identity_jwt returns the JWT string from the client.""" mock_client = mock.MagicMock() @@ -502,7 +502,7 @@ def test_retrieve_workload_identity_jwt_returns_jwt_from_client(mock_get_client) assert call_kwargs['claims'][AutomationControllerJobScope.CLAIM_JOB_NAME] == 'Test Job' -@mock.patch('awx.main.tasks.jobs.get_workload_identity_client') +@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client') def test_retrieve_workload_identity_jwt_passes_audience_and_scope(mock_get_client): """retrieve_workload_identity_jwt passes audience and scope to the client.""" mock_client = mock.MagicMock() @@ -518,7 +518,7 @@ def test_retrieve_workload_identity_jwt_passes_audience_and_scope(mock_get_clien mock_client.request_workload_jwt.assert_called_once_with(claims={'job_id': 1}, scope=scope, audience=audience) -@mock.patch('awx.main.tasks.jobs.get_workload_identity_client') +@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client') def test_retrieve_workload_identity_jwt_passes_workload_ttl(mock_get_client): """retrieve_workload_identity_jwt passes workload_ttl_seconds when provided.""" mock_client = mock.Mock() @@ -542,7 +542,7 @@ def test_retrieve_workload_identity_jwt_passes_workload_ttl(mock_get_client): ) -@mock.patch('awx.main.tasks.jobs.get_workload_identity_client') +@mock.patch('awx.main.utils.workload_identity.get_workload_identity_client') def test_retrieve_workload_identity_jwt_raises_when_client_not_configured(mock_get_client): """retrieve_workload_identity_jwt raises RuntimeError when client is None.""" mock_get_client.return_value = None diff --git a/awx/main/utils/workload_identity.py b/awx/main/utils/workload_identity.py index e69de29bb2..50582e2245 100644 --- a/awx/main/utils/workload_identity.py +++ b/awx/main/utils/workload_identity.py @@ -0,0 +1,22 @@ +from ansible_base.resource_registry.workload_identity_client import get_workload_identity_client + +__all__ = ['retrieve_workload_identity_jwt_with_claims'] + + +def retrieve_workload_identity_jwt_with_claims( + claims: dict, + audience: str, + scope: str, + workload_ttl_seconds: int | None = None, +) -> str: + """Retrieve JWT token from workload claims. + Raises: + RuntimeError: if the workload identity client is not configured. + """ + client = get_workload_identity_client() + if client is None: + raise RuntimeError("Workload identity client is not configured") + kwargs = {"claims": claims, "scope": scope, "audience": audience} + if workload_ttl_seconds: + kwargs["workload_ttl_seconds"] = workload_ttl_seconds + return client.request_workload_jwt(**kwargs).jwt