diff --git a/awx/main/tasks/jobs.py b/awx/main/tasks/jobs.py index 4bda1b0768..fa4be381c4 100644 --- a/awx/main/tasks/jobs.py +++ b/awx/main/tasks/jobs.py @@ -94,10 +94,7 @@ from flags.state import flag_enabled # Workload Identity from ansible_base.lib.workload_identity.controller import AutomationControllerJobScope - -from ansible_base.resource_registry.workload_identity_client import ( - get_workload_identity_client, -) +from ansible_base.resource_registry.workload_identity_client import get_workload_identity_client logger = logging.getLogger('awx.main.tasks.jobs') @@ -161,7 +158,12 @@ def populate_claims_for_workload(unified_job) -> dict: return claims -def retrieve_workload_identity_jwt(unified_job: UnifiedJob, audience: str, scope: str) -> str: +def retrieve_workload_identity_jwt( + unified_job: UnifiedJob, + audience: str, + scope: str, + workload_ttl_seconds: int | None = None, +) -> str: """Retrieve JWT token from workload claims. Raises: RuntimeError: if the workload identity client is not configured. @@ -170,7 +172,10 @@ def retrieve_workload_identity_jwt(unified_job: UnifiedJob, audience: str, scope if client is None: raise RuntimeError("Workload identity client is not configured") claims = populate_claims_for_workload(unified_job) - return client.request_workload_jwt(claims=claims, scope=scope, audience=audience).jwt + kwargs = {"claims": claims, "scope": scope, "audience": audience} + if workload_ttl_seconds: + kwargs["workload_ttl_seconds"] = workload_ttl_seconds + return client.request_workload_jwt(**kwargs).jwt def with_path_cleanup(f): @@ -243,9 +248,14 @@ class BaseTask(object): ) for credential_ctx, input_src in credential_input_sources: if flag_enabled("FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED"): + effective_timeout = self.get_instance_timeout(self.instance) + workload_ttl = effective_timeout if effective_timeout else None try: jwt = retrieve_workload_identity_jwt( - self.instance, audience=input_src.source_credential.get_input('jwt_aud'), scope=AutomationControllerJobScope.name + self.instance, + audience=input_src.source_credential.get_input('jwt_aud'), + scope=AutomationControllerJobScope.name, + workload_ttl_seconds=workload_ttl, ) # Store token keyed by input source PK, since a credential can have # multiple input sources (one per field), each potentially with a different audience @@ -500,6 +510,7 @@ class BaseTask(object): return [] def get_instance_timeout(self, instance): + """Return the effective job timeout in seconds.""" global_timeout_setting_name = instance._global_timeout_setting() if global_timeout_setting_name: global_timeout = getattr(settings, global_timeout_setting_name, 0) diff --git a/awx/main/tests/functional/test_jobs.py b/awx/main/tests/functional/test_jobs.py index 754c9823a6..8a550449a9 100644 --- a/awx/main/tests/functional/test_jobs.py +++ b/awx/main/tests/functional/test_jobs.py @@ -529,6 +529,55 @@ def test_populate_workload_identity_tokens_with_flag_enabled(job_template_with_c assert target_cred.context[input_source.pk]['workload_identity_token'] == 'eyJ.test.jwt' +@pytest.mark.django_db +@override_settings(RESOURCE_SERVER={'URL': 'https://gateway.example.com', 'SECRET_KEY': 'test-secret-key', 'VALIDATE_HTTPS': False}) +def test_populate_workload_identity_tokens_passes_workload_ttl_from_job_timeout(job_template_with_credentials, mocker): + """Test populate_workload_identity_tokens passes workload_ttl_seconds from get_instance_timeout to the client.""" + with feature_flag_enabled('FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED'): + task = jobs.RunJob() + + ssh_type = CredentialType.defaults['ssh']() + ssh_type.save() + + hashivault_type = CredentialType( + name='HashiCorp Vault Secret Lookup (OIDC)', + kind='cloud', + managed=False, + inputs={ + 'fields': [ + {'id': 'jwt_aud', 'type': 'string', 'label': 'JWT Audience'}, + {'id': 'workload_identity_token', 'type': 'string', 'label': 'Workload Identity Token', 'secret': True, 'internal': True}, + ] + }, + ) + hashivault_type.save() + + ssh_cred = Credential.objects.create(credential_type=ssh_type, name='ssh-cred') + source_cred = Credential.objects.create(credential_type=hashivault_type, name='vault-source', inputs={'jwt_aud': 'https://vault.example.com'}) + target_cred = Credential.objects.create(credential_type=ssh_type, name='target-cred', inputs={'username': 'testuser'}) + + CredentialInputSource.objects.create( + target_credential=target_cred, source_credential=source_cred, input_field_name='password', metadata={'path': 'secret/data/password'} + ) + + job = job_template_with_credentials(target_cred, ssh_cred) + job.timeout = 3600 + job.save() + task.instance = job + task._credentials = [target_cred, ssh_cred] + + mock_response = mocker.Mock(status_code=200) + mock_response.json.return_value = {'jwt': 'eyJ.test.jwt'} + mock_request = mocker.patch('requests.request', return_value=mock_response, autospec=True) + + task.populate_workload_identity_tokens() + + call_kwargs = mock_request.call_args.kwargs + assert call_kwargs['method'] == 'POST' + json_body = call_kwargs.get('json', {}) + assert json_body.get('workload_ttl_seconds') == 3600 + + @pytest.mark.django_db def test_populate_workload_identity_tokens_with_flag_disabled(job_template_with_credentials): """Test populate_workload_identity_tokens sets error status when flag is disabled.""" diff --git a/awx/main/tests/unit/tasks/test_jobs.py b/awx/main/tests/unit/tasks/test_jobs.py index c57403e1c6..288dac7206 100644 --- a/awx/main/tests/unit/tasks/test_jobs.py +++ b/awx/main/tests/unit/tasks/test_jobs.py @@ -140,7 +140,9 @@ def test_pre_post_run_hook_facts(mock_create_partition, mock_facts_settings, pri @mock.patch('awx.main.tasks.facts.bulk_update_sorted_by_id') @mock.patch('awx.main.tasks.facts.settings') @mock.patch('awx.main.tasks.jobs.create_partition', return_value=True) -def test_pre_post_run_hook_facts_deleted_sliced(mock_create_partition, mock_facts_settings, private_data_dir, execution_environment): +def test_pre_post_run_hook_facts_deleted_sliced( + mock_create_partition, mock_facts_settings, mock_bulk_update_sorted_by_id, private_data_dir, execution_environment +): # Fully mocked inventory mock_inventory = mock.MagicMock(spec=Inventory, pk=1, kind='') @@ -517,6 +519,30 @@ def test_retrieve_workload_identity_jwt_passes_audience_and_scope(mock_get_clien mock_client.request_workload_jwt.assert_called_once_with(claims={'job_id': 1}, scope=scope, audience=audience) +@mock.patch('awx.main.tasks.jobs.get_workload_identity_client') +def test_retrieve_workload_identity_jwt_passes_workload_ttl(mock_get_client): + """retrieve_workload_identity_jwt passes workload_ttl_seconds when provided.""" + mock_client = mock.Mock() + mock_client.request_workload_jwt.return_value = mock.Mock(jwt='token') + mock_get_client.return_value = mock_client + + unified_job = mock.MagicMock() + with mock.patch('awx.main.tasks.jobs.populate_claims_for_workload', return_value={'job_id': 1}): + jobs.retrieve_workload_identity_jwt( + unified_job, + audience='https://vault.example.com', + scope='aap_controller_automation_job', + workload_ttl_seconds=3600, + ) + + mock_client.request_workload_jwt.assert_called_once_with( + claims={'job_id': 1}, + scope='aap_controller_automation_job', + audience='https://vault.example.com', + workload_ttl_seconds=3600, + ) + + @mock.patch('awx.main.tasks.jobs.get_workload_identity_client') def test_retrieve_workload_identity_jwt_raises_when_client_not_configured(mock_get_client): """retrieve_workload_identity_jwt raises RuntimeError when client is None.""" @@ -526,3 +552,42 @@ def test_retrieve_workload_identity_jwt_raises_when_client_not_configured(mock_g with pytest.raises(RuntimeError, match="Workload identity client is not configured"): jobs.retrieve_workload_identity_jwt(unified_job, audience='test_audience', scope='test_scope') + + +@pytest.mark.parametrize('effective_timeout,expected_ttl', [(3600, 3600), (0, None)]) +@mock.patch('awx.main.tasks.jobs.retrieve_workload_identity_jwt') +@mock.patch('awx.main.tasks.jobs.flag_enabled', return_value=True) +def test_populate_workload_identity_tokens_passes_get_instance_timeout_to_client(mock_flag_enabled, mock_retrieve_jwt, effective_timeout, expected_ttl): + """populate_workload_identity_tokens passes get_instance_timeout() value as workload_ttl_seconds to retrieve_workload_identity_jwt.""" + mock_retrieve_jwt.return_value = 'eyJ.test.jwt' + + task = jobs.RunJob() + task.instance = mock.MagicMock() + + # Minimal credential with workload identity input source + credential_ctx = {} + input_src = mock.MagicMock() + input_src.pk = 1 + input_src.source_credential = mock.MagicMock() + input_src.source_credential.get_input.return_value = 'https://vault.example.com' + input_src.source_credential.name = 'vault-cred' + input_src.source_credential.credential_type = mock.MagicMock() + input_src.source_credential.credential_type.inputs = {'fields': [{'id': 'workload_identity_token', 'internal': True}]} + + credential = mock.MagicMock() + credential.context = credential_ctx + credential.input_sources = mock.MagicMock() + credential.input_sources.all.return_value = [input_src] + + task._credentials = [credential] + + with mock.patch.object(task, 'get_instance_timeout', return_value=effective_timeout): + task.populate_workload_identity_tokens() + + mock_flag_enabled.assert_called_once_with("FEATURE_OIDC_WORKLOAD_IDENTITY_ENABLED") + mock_retrieve_jwt.assert_called_once_with( + task.instance, + audience='https://vault.example.com', + scope=AutomationControllerJobScope.name, + workload_ttl_seconds=expected_ttl, + ) diff --git a/awx/main/tests/unit/test_tasks.py b/awx/main/tests/unit/test_tasks.py index ba00b1792b..e1ff5bc34b 100644 --- a/awx/main/tests/unit/test_tasks.py +++ b/awx/main/tests/unit/test_tasks.py @@ -1587,7 +1587,7 @@ def test_managed_injector_redaction(injector_cls): assert 'very_secret_value' not in str(build_safe_env(env)) -def test_job_run_no_ee(mock_me, mock_create_partition): +def test_job_run_no_ee(mock_me, mock_create_partition, private_data_dir): org = Organization(pk=1) proj = Project(pk=1, organization=org) job = Job(project=proj, organization=org, inventory=Inventory(pk=1))