diff --git a/awx/main/tasks/jobs.py b/awx/main/tasks/jobs.py index 74aab14c85..ce750489fa 100644 --- a/awx/main/tasks/jobs.py +++ b/awx/main/tasks/jobs.py @@ -96,6 +96,10 @@ from flags.state import flag_enabled # Workload Identity from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope +from ansible_base.resource_registry.workload_identity_client import ( + get_workload_identity_client, +) + logger = logging.getLogger('awx.main.tasks.jobs') @@ -163,6 +167,18 @@ def populate_claims_for_workload(unified_job) -> dict: return claims +def retrieve_workload_identity_jwt(unified_job: UnifiedJob, audience: str, scope: str) -> str: + """Retrieve JWT token from workload claims. + Raises: + RuntimeError: if the workload identity client is not configured. + """ + client = get_workload_identity_client() + if client is None: + raise RuntimeError("Workload identity client is not configured") + claims = populate_claims_for_workload(unified_job) + return client.request_workload_jwt(claims=claims, scope=scope, audience=audience).jwt + + def with_path_cleanup(f): @functools.wraps(f) def _wrapped(self, *args, **kwargs): diff --git a/awx/main/tests/unit/tasks/test_jobs.py b/awx/main/tests/unit/tasks/test_jobs.py index 6995776b0d..b678add12d 100644 --- a/awx/main/tests/unit/tasks/test_jobs.py +++ b/awx/main/tests/unit/tasks/test_jobs.py @@ -427,3 +427,59 @@ def test_populate_claims_for_adhoc_command(workload_attrs, expected_claims): claims = jobs.populate_claims_for_workload(adhoc_command) assert claims == expected_claims + + +@mock.patch('awx.main.tasks.jobs.get_workload_identity_client') +def test_retrieve_workload_identity_jwt_returns_jwt_from_client(mock_get_client): + """retrieve_workload_identity_jwt returns the JWT string from the client.""" + mock_client = mock.MagicMock() + mock_response = mock.MagicMock() + mock_response.jwt = 'eyJ.test.jwt' + mock_client.request_workload_jwt.return_value = mock_response + mock_get_client.return_value = mock_client + + unified_job = Job() + unified_job.id = 42 + unified_job.name = 'Test Job' + unified_job.launch_type = 'manual' + unified_job.organization = Organization(id=1, name='Test Org') + unified_job.unified_job_template = None + unified_job.instance_group = None + + result = jobs.retrieve_workload_identity_jwt(unified_job, audience='https://api.example.com', scope='aap_controller_automation_job') + + assert result == 'eyJ.test.jwt' + mock_client.request_workload_jwt.assert_called_once() + call_kwargs = mock_client.request_workload_jwt.call_args[1] + assert call_kwargs['audience'] == 'https://api.example.com' + assert call_kwargs['scope'] == 'aap_controller_automation_job' + assert 'claims' in call_kwargs + assert call_kwargs['claims'][AutomationControllerJobScope.CLAIM_JOB_ID] == 42 + assert call_kwargs['claims'][AutomationControllerJobScope.CLAIM_JOB_NAME] == 'Test Job' + + +@mock.patch('awx.main.tasks.jobs.get_workload_identity_client') +def test_retrieve_workload_identity_jwt_passes_audience_and_scope(mock_get_client): + """retrieve_workload_identity_jwt passes audience and scope to the client.""" + mock_client = mock.MagicMock() + mock_client.request_workload_jwt.return_value = mock.MagicMock(jwt='token') + mock_get_client.return_value = mock_client + + unified_job = mock.MagicMock() + audience = 'custom_audience' + scope = 'custom_scope' + with mock.patch('awx.main.tasks.jobs.populate_claims_for_workload', return_value={'job_id': 1}): + jobs.retrieve_workload_identity_jwt(unified_job, audience=audience, scope=scope) + + mock_client.request_workload_jwt.assert_called_once_with(claims={'job_id': 1}, scope=scope, audience=audience) + + +@mock.patch('awx.main.tasks.jobs.get_workload_identity_client') +def test_retrieve_workload_identity_jwt_raises_when_client_not_configured(mock_get_client): + """retrieve_workload_identity_jwt raises RuntimeError when client is None.""" + mock_get_client.return_value = None + + unified_job = mock.MagicMock() + + with pytest.raises(RuntimeError, match="Workload identity client is not configured"): + jobs.retrieve_workload_identity_jwt(unified_job, audience='test_audience', scope='test_scope')