[4.6_Backport] Added helper method for fetching serviceaccount token (#6823)

This commit is contained in:
TVo 2025-02-17 13:31:33 -07:00 committed by GitHub
parent ccb6360a96
commit cb2df43580
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 297 additions and 0 deletions

View File

@ -0,0 +1,113 @@
import pytest
import requests
from unittest import mock
from awx.main.utils.analytics_proxy import OIDCClient, TokenType, TokenError
MOCK_TOKEN_RESPONSE = {
'access_token': 'bob-access-token',
'expires_in': 500,
'refresh_expires_in': 900,
'token_type': 'Bearer',
'not-before-policy': 6,
'scope': 'fake-scope1, fake-scope2',
}
@pytest.fixture
def oidc_client():
'''
oidc client instantiation fixture.
'''
return OIDCClient(
'fake-client-id',
'fake-client-secret',
'https://my-token-url.com/get/a/token/',
['api.console'],
)
@pytest.fixture
def token():
'''
Create Token class out of example OIDC token response.
'''
return OIDCClient._json_response_to_token(MOCK_TOKEN_RESPONSE)
def test_generate_access_token(oidc_client):
with mock.patch(
'awx.main.utils.analytics_proxy.requests.post',
return_value=mock.Mock(json=lambda: MOCK_TOKEN_RESPONSE, raise_for_status=mock.Mock(return_value=None)), # No exception raised
):
oidc_client._generate_access_token()
assert oidc_client.token
assert oidc_client.token.access_token == 'bob-access-token'
assert oidc_client.token.expires_in == 500
assert oidc_client.token.refresh_expires_in == 900
assert oidc_client.token.token_type == TokenType.BEARER
assert oidc_client.token.not_before_policy == 6
assert oidc_client.token.scope == 'fake-scope1, fake-scope2'
def test_token_generation_error(oidc_client):
'''
Check that TokenError is raised for failure in token generation process
'''
exception_404 = requests.HTTPError('404 Client Error: Not Found for url')
with mock.patch(
'awx.main.utils.analytics_proxy.requests.post',
return_value=mock.Mock(status_code=404, json=mock.Mock(return_value={'error': 'Not Found'}), raise_for_status=mock.Mock(side_effect=exception_404)),
):
with pytest.raises(TokenError) as exc_info:
oidc_client._generate_access_token()
assert exc_info.value.__cause__ == exception_404
def test_make_request(oidc_client, token):
'''
Check that make_request makes an http request with a generated token.
'''
def fake_generate_access_token():
oidc_client.token = token
with (
mock.patch.object(oidc_client, '_generate_access_token', side_effect=fake_generate_access_token),
mock.patch('awx.main.utils.analytics_proxy.requests.request') as mock_request,
):
oidc_client.make_request('GET', 'https://does_not_exist.com')
mock_request.assert_called_with(
'GET',
'https://does_not_exist.com',
headers={
'Authorization': f'Bearer {token.access_token}',
'Accept': 'application/json',
},
)
def test_make_request_existing_token(oidc_client, token):
'''
Check that make_request does not try and generate a token.
'''
oidc_client.token = token
with (
mock.patch.object(oidc_client, '_generate_access_token', side_effect=RuntimeError('expected not to be called')),
mock.patch('awx.main.utils.analytics_proxy.requests.request') as mock_request,
):
oidc_client.make_request('GET', 'https://does_not_exist.com')
mock_request.assert_called_with(
'GET',
'https://does_not_exist.com',
headers={
'Authorization': f'Bearer {token.access_token}',
'Accept': 'application/json',
},
)

View File

@ -0,0 +1,184 @@
'''
Proxy requests Analytics requests
'''
import time
from enum import Enum
from typing import Optional, Any
import requests
class TokenError(requests.RequestException):
'''
Raised when token generation request fails.
Useful for differentiating request failure for make_request() vs.
other requests issued to get a token i.e.:
try:
client = OIDCClient(...)
client.make_request(...)
except TokenGenerationError as e:
print(f"Token generation failed due to {e.__cause__}")
except requests.RequestException:
print("API request failed)
'''
def __init__(self, message="Token generation request failed", response=None):
super().__init__(message)
self.response = response # Store the response for debugging
def _now(reason: str):
'''
Wrapper for time. Helps with testing.
'''
return int(time.time())
class TokenType(Enum):
'''
Access token type as returned by the remote API.
'''
BEARER = 'Bearer'
class Token:
'''
Token data generated by OIDC response.
'''
access_token: str
expires_in: int
refresh_expires_in: int
token_type: TokenType
not_before_policy: int # not-before-policy
scope: str
def __init__(
self,
access_token: str,
expires_in: int,
refresh_expires_in: int,
token_type: TokenType,
not_before_policy: int,
scope: str,
):
self.access_token = access_token
self.expires_in = expires_in
self.refresh_expires_in = refresh_expires_in
self.token_type = token_type
self.not_before_policy = not_before_policy
self.scope = scope
self._now = _now(reason='token-creation')
@property
def expires_at(self) -> int:
'''
Unix timestamp in seconds of when the token expires.
'''
return self._now + self.expires_in
def is_expired(self) -> bool:
'''
Check if the token is expired.
'''
return _now(reason='token-expiration-check') >= self.expires_at
class OIDCClient:
'''
Wraps requests library make_request() and manages OIDC access token.
'''
def __init__(
self,
client_id: str,
client_secret: str,
token_url: str,
scopes: list[str],
base_url: str = '',
) -> None:
self.client_id: str = client_id
self.client_secret: str = client_secret
self.token_url: str = token_url
self.scopes = scopes
self.base_url: str = base_url
self.token: Optional[Token] = None
@classmethod
def _json_response_to_token(cls, json_response: Any) -> Token:
return Token(
access_token=json_response['access_token'],
expires_in=json_response['expires_in'],
refresh_expires_in=json_response['refresh_expires_in'],
token_type=TokenType(json_response['token_type']),
not_before_policy=json_response['not-before-policy'],
scope=json_response['scope'],
)
def _generate_access_token(self) -> None:
'''
Fetches the initial access token using client credentials.
'''
response = requests.post(
self.token_url,
data={
'grant_type': 'client_credentials',
'client_id': self.client_id,
'client_secret': self.client_secret,
'scope': self.scopes,
},
headers={'Content-Type': 'application/x-www-form-urlencoded'},
)
try:
response.raise_for_status()
except requests.RequestException as e:
raise TokenError() from e
self.token = OIDCClient._json_response_to_token(response.json())
def _add_headers(self, headers: dict[str, str]) -> None:
'''
Add token header
'''
headers.update(
{
'Authorization': f'Bearer {self.token.access_token}',
'Accept': 'application/json',
}
)
def _make_request(self, method: str, url: str, headers: dict[str, str], **kwargs: Any) -> requests.Response:
'''
Actually make an API call.
'''
self._add_headers(headers)
return requests.request(method, url, headers=headers, **kwargs)
def make_request(self, method: str, endpoint: str, **kwargs: Any) -> requests.Response:
'''
Makes an authenticated request and refreshes the token if expired.
'''
has_generated_token = False
def generate_access_token():
self._generate_access_token()
return True
if not self.token or self.token.is_expired():
has_generated_token = generate_access_token()
url = f'{self.base_url}{endpoint}'
headers = kwargs.pop('headers', {})
response = self._make_request(method, url, headers, **kwargs)
if not has_generated_token and response.status_code == 401:
generate_access_token()
response = self._make_request(method, url, headers, **kwargs)
return response