From cb2df43580c0ab8f968a6200b49481ff8a24c48d Mon Sep 17 00:00:00 2001 From: TVo Date: Mon, 17 Feb 2025 13:31:33 -0700 Subject: [PATCH] [4.6_Backport] Added helper method for fetching serviceaccount token (#6823) --- .../tests/unit/utils/test_analytics_proxy.py | 113 +++++++++++ awx/main/utils/analytics_proxy.py | 184 ++++++++++++++++++ 2 files changed, 297 insertions(+) create mode 100644 awx/main/tests/unit/utils/test_analytics_proxy.py create mode 100644 awx/main/utils/analytics_proxy.py diff --git a/awx/main/tests/unit/utils/test_analytics_proxy.py b/awx/main/tests/unit/utils/test_analytics_proxy.py new file mode 100644 index 0000000000..0096306e57 --- /dev/null +++ b/awx/main/tests/unit/utils/test_analytics_proxy.py @@ -0,0 +1,113 @@ +import pytest +import requests +from unittest import mock + +from awx.main.utils.analytics_proxy import OIDCClient, TokenType, TokenError + + +MOCK_TOKEN_RESPONSE = { + 'access_token': 'bob-access-token', + 'expires_in': 500, + 'refresh_expires_in': 900, + 'token_type': 'Bearer', + 'not-before-policy': 6, + 'scope': 'fake-scope1, fake-scope2', +} + + +@pytest.fixture +def oidc_client(): + ''' + oidc client instantiation fixture. + ''' + return OIDCClient( + 'fake-client-id', + 'fake-client-secret', + 'https://my-token-url.com/get/a/token/', + ['api.console'], + ) + + +@pytest.fixture +def token(): + ''' + Create Token class out of example OIDC token response. + ''' + return OIDCClient._json_response_to_token(MOCK_TOKEN_RESPONSE) + + +def test_generate_access_token(oidc_client): + with mock.patch( + 'awx.main.utils.analytics_proxy.requests.post', + return_value=mock.Mock(json=lambda: MOCK_TOKEN_RESPONSE, raise_for_status=mock.Mock(return_value=None)), # No exception raised + ): + oidc_client._generate_access_token() + + assert oidc_client.token + assert oidc_client.token.access_token == 'bob-access-token' + assert oidc_client.token.expires_in == 500 + assert oidc_client.token.refresh_expires_in == 900 + assert oidc_client.token.token_type == TokenType.BEARER + assert oidc_client.token.not_before_policy == 6 + assert oidc_client.token.scope == 'fake-scope1, fake-scope2' + + +def test_token_generation_error(oidc_client): + ''' + Check that TokenError is raised for failure in token generation process + ''' + exception_404 = requests.HTTPError('404 Client Error: Not Found for url') + with mock.patch( + 'awx.main.utils.analytics_proxy.requests.post', + return_value=mock.Mock(status_code=404, json=mock.Mock(return_value={'error': 'Not Found'}), raise_for_status=mock.Mock(side_effect=exception_404)), + ): + with pytest.raises(TokenError) as exc_info: + oidc_client._generate_access_token() + + assert exc_info.value.__cause__ == exception_404 + + +def test_make_request(oidc_client, token): + ''' + Check that make_request makes an http request with a generated token. + ''' + + def fake_generate_access_token(): + oidc_client.token = token + + with ( + mock.patch.object(oidc_client, '_generate_access_token', side_effect=fake_generate_access_token), + mock.patch('awx.main.utils.analytics_proxy.requests.request') as mock_request, + ): + oidc_client.make_request('GET', 'https://does_not_exist.com') + + mock_request.assert_called_with( + 'GET', + 'https://does_not_exist.com', + headers={ + 'Authorization': f'Bearer {token.access_token}', + 'Accept': 'application/json', + }, + ) + + +def test_make_request_existing_token(oidc_client, token): + ''' + Check that make_request does not try and generate a token. + ''' + oidc_client.token = token + + with ( + mock.patch.object(oidc_client, '_generate_access_token', side_effect=RuntimeError('expected not to be called')), + mock.patch('awx.main.utils.analytics_proxy.requests.request') as mock_request, + ): + oidc_client.make_request('GET', 'https://does_not_exist.com') + + mock_request.assert_called_with( + 'GET', + 'https://does_not_exist.com', + headers={ + 'Authorization': f'Bearer {token.access_token}', + 'Accept': 'application/json', + }, + ) diff --git a/awx/main/utils/analytics_proxy.py b/awx/main/utils/analytics_proxy.py new file mode 100644 index 0000000000..f46ed7e0ca --- /dev/null +++ b/awx/main/utils/analytics_proxy.py @@ -0,0 +1,184 @@ +''' +Proxy requests Analytics requests +''' + +import time + +from enum import Enum + +from typing import Optional, Any + +import requests + + +class TokenError(requests.RequestException): + ''' + Raised when token generation request fails. + + Useful for differentiating request failure for make_request() vs. + other requests issued to get a token i.e.: + + try: + client = OIDCClient(...) + client.make_request(...) + except TokenGenerationError as e: + print(f"Token generation failed due to {e.__cause__}") + except requests.RequestException: + print("API request failed) + ''' + + def __init__(self, message="Token generation request failed", response=None): + super().__init__(message) + self.response = response # Store the response for debugging + + +def _now(reason: str): + ''' + Wrapper for time. Helps with testing. + ''' + return int(time.time()) + + +class TokenType(Enum): + ''' + Access token type as returned by the remote API. + ''' + + BEARER = 'Bearer' + + +class Token: + ''' + Token data generated by OIDC response. + ''' + + access_token: str + expires_in: int + refresh_expires_in: int + token_type: TokenType + not_before_policy: int # not-before-policy + scope: str + + def __init__( + self, + access_token: str, + expires_in: int, + refresh_expires_in: int, + token_type: TokenType, + not_before_policy: int, + scope: str, + ): + self.access_token = access_token + self.expires_in = expires_in + self.refresh_expires_in = refresh_expires_in + self.token_type = token_type + self.not_before_policy = not_before_policy + self.scope = scope + + self._now = _now(reason='token-creation') + + @property + def expires_at(self) -> int: + ''' + Unix timestamp in seconds of when the token expires. + ''' + return self._now + self.expires_in + + def is_expired(self) -> bool: + ''' + Check if the token is expired. + ''' + return _now(reason='token-expiration-check') >= self.expires_at + + +class OIDCClient: + ''' + Wraps requests library make_request() and manages OIDC access token. + ''' + + def __init__( + self, + client_id: str, + client_secret: str, + token_url: str, + scopes: list[str], + base_url: str = '', + ) -> None: + self.client_id: str = client_id + self.client_secret: str = client_secret + self.token_url: str = token_url + self.scopes = scopes + self.base_url: str = base_url + self.token: Optional[Token] = None + + @classmethod + def _json_response_to_token(cls, json_response: Any) -> Token: + return Token( + access_token=json_response['access_token'], + expires_in=json_response['expires_in'], + refresh_expires_in=json_response['refresh_expires_in'], + token_type=TokenType(json_response['token_type']), + not_before_policy=json_response['not-before-policy'], + scope=json_response['scope'], + ) + + def _generate_access_token(self) -> None: + ''' + Fetches the initial access token using client credentials. + ''' + response = requests.post( + self.token_url, + data={ + 'grant_type': 'client_credentials', + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'scope': self.scopes, + }, + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + ) + try: + response.raise_for_status() + except requests.RequestException as e: + raise TokenError() from e + self.token = OIDCClient._json_response_to_token(response.json()) + + def _add_headers(self, headers: dict[str, str]) -> None: + ''' + Add token header + ''' + headers.update( + { + 'Authorization': f'Bearer {self.token.access_token}', + 'Accept': 'application/json', + } + ) + + def _make_request(self, method: str, url: str, headers: dict[str, str], **kwargs: Any) -> requests.Response: + ''' + Actually make an API call. + ''' + self._add_headers(headers) + return requests.request(method, url, headers=headers, **kwargs) + + def make_request(self, method: str, endpoint: str, **kwargs: Any) -> requests.Response: + ''' + Makes an authenticated request and refreshes the token if expired. + ''' + has_generated_token = False + + def generate_access_token(): + self._generate_access_token() + return True + + if not self.token or self.token.is_expired(): + has_generated_token = generate_access_token() + + url = f'{self.base_url}{endpoint}' + headers = kwargs.pop('headers', {}) + + response = self._make_request(method, url, headers, **kwargs) + if not has_generated_token and response.status_code == 401: + generate_access_token() + response = self._make_request(method, url, headers, **kwargs) + + return response