diff --git a/awx/main/tasks/jobs.py b/awx/main/tasks/jobs.py index dc93a4dd83..4d8f1229e1 100644 --- a/awx/main/tasks/jobs.py +++ b/awx/main/tasks/jobs.py @@ -228,16 +228,19 @@ class BaseTask(object): # Convert to list to prevent re-evaluation of QuerySet return list(credentials_list) - def populate_workload_identity_tokens(self): + def populate_workload_identity_tokens(self, additional_credentials=None): """ Populate credentials with workload identity tokens. Sets the context on Credential objects that have input sources using compatible external credential types. """ + credentials = list(self._credentials) + if additional_credentials: + credentials.extend(additional_credentials) credential_input_sources = ( (credential.context, src) - for credential in self._credentials + for credential in credentials for src in credential.input_sources.all() if any( field.get('id') == 'workload_identity_token' and field.get('internal') @@ -1863,6 +1866,24 @@ class RunInventoryUpdate(SourceControlMixin, BaseTask): # All credentials not used by inventory source injector return inventory_update.get_extra_credentials() + def populate_workload_identity_tokens(self, additional_credentials=None): + """Also generate OIDC tokens for the cloud credential. + + The cloud credential is not in _credentials (it is handled by the + inventory source injector), but it may still need a workload identity + token generated for it. + """ + cloud_cred = self.instance.get_cloud_credential() + creds = list(additional_credentials or []) + if cloud_cred: + creds.append(cloud_cred) + super().populate_workload_identity_tokens(additional_credentials=creds or None) + # Override get_cloud_credential on this instance so the injector + # uses the credential with OIDC context instead of doing a fresh + # DB fetch that would lose it. + if cloud_cred and cloud_cred.context: + self.instance.get_cloud_credential = lambda: cloud_cred + def build_project_dir(self, inventory_update, private_data_dir): source_project = None if inventory_update.inventory_source: diff --git a/awx/main/tests/unit/tasks/test_jobs.py b/awx/main/tests/unit/tasks/test_jobs.py index e4df52b63f..0f4f6d3031 100644 --- a/awx/main/tests/unit/tasks/test_jobs.py +++ b/awx/main/tests/unit/tasks/test_jobs.py @@ -590,3 +590,67 @@ def test_populate_workload_identity_tokens_passes_get_instance_timeout_to_client scope=AutomationControllerJobScope.name, workload_ttl_seconds=expected_ttl, ) + + +class TestRunInventoryUpdatePopulateWorkloadIdentityTokens: + """Tests for RunInventoryUpdate.populate_workload_identity_tokens.""" + + def test_cloud_credential_passed_as_additional_credential(self): + """The cloud credential is forwarded to super().populate_workload_identity_tokens via additional_credentials.""" + cloud_cred = mock.MagicMock(name='cloud_cred') + cloud_cred.context = {} + + task = jobs.RunInventoryUpdate() + task.instance = mock.MagicMock() + task.instance.get_cloud_credential.return_value = cloud_cred + task._credentials = [] + + with mock.patch.object(jobs.BaseTask, 'populate_workload_identity_tokens') as mock_super: + task.populate_workload_identity_tokens() + + mock_super.assert_called_once_with(additional_credentials=[cloud_cred]) + + def test_no_cloud_credential_calls_super_with_none(self): + """When there is no cloud credential, super() is called with additional_credentials=None.""" + task = jobs.RunInventoryUpdate() + task.instance = mock.MagicMock() + task.instance.get_cloud_credential.return_value = None + task._credentials = [] + + with mock.patch.object(jobs.BaseTask, 'populate_workload_identity_tokens') as mock_super: + task.populate_workload_identity_tokens() + + mock_super.assert_called_once_with(additional_credentials=None) + + def test_additional_credentials_combined_with_cloud_credential(self): + """Caller-supplied additional_credentials are combined with the cloud credential.""" + cloud_cred = mock.MagicMock(name='cloud_cred') + cloud_cred.context = {} + extra_cred = mock.MagicMock(name='extra_cred') + + task = jobs.RunInventoryUpdate() + task.instance = mock.MagicMock() + task.instance.get_cloud_credential.return_value = cloud_cred + task._credentials = [] + + with mock.patch.object(jobs.BaseTask, 'populate_workload_identity_tokens') as mock_super: + task.populate_workload_identity_tokens(additional_credentials=[extra_cred]) + + mock_super.assert_called_once_with(additional_credentials=[extra_cred, cloud_cred]) + + def test_cloud_credential_override_after_context_set(self): + """After OIDC processing, get_cloud_credential is overridden on the instance when context is populated.""" + cloud_cred = mock.MagicMock(name='cloud_cred') + # Simulate that super().populate_workload_identity_tokens populates context + cloud_cred.context = {'workload_identity_token': 'eyJ.test.jwt'} + + task = jobs.RunInventoryUpdate() + task.instance = mock.MagicMock() + task.instance.get_cloud_credential.return_value = cloud_cred + task._credentials = [] + + with mock.patch.object(jobs.BaseTask, 'populate_workload_identity_tokens'): + task.populate_workload_identity_tokens() + + # The instance's get_cloud_credential should now return the same object with context + assert task.instance.get_cloud_credential() is cloud_cred