mirror of
https://github.com/ansible/awx.git
synced 2026-01-11 18:09:57 -03:30
[4.6_Backport] Added helper method for fetching serviceaccount token (#6823)
This commit is contained in:
parent
ccb6360a96
commit
cb2df43580
113
awx/main/tests/unit/utils/test_analytics_proxy.py
Normal file
113
awx/main/tests/unit/utils/test_analytics_proxy.py
Normal file
@ -0,0 +1,113 @@
|
||||
import pytest
|
||||
import requests
|
||||
from unittest import mock
|
||||
|
||||
from awx.main.utils.analytics_proxy import OIDCClient, TokenType, TokenError
|
||||
|
||||
|
||||
MOCK_TOKEN_RESPONSE = {
|
||||
'access_token': 'bob-access-token',
|
||||
'expires_in': 500,
|
||||
'refresh_expires_in': 900,
|
||||
'token_type': 'Bearer',
|
||||
'not-before-policy': 6,
|
||||
'scope': 'fake-scope1, fake-scope2',
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oidc_client():
|
||||
'''
|
||||
oidc client instantiation fixture.
|
||||
'''
|
||||
return OIDCClient(
|
||||
'fake-client-id',
|
||||
'fake-client-secret',
|
||||
'https://my-token-url.com/get/a/token/',
|
||||
['api.console'],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token():
|
||||
'''
|
||||
Create Token class out of example OIDC token response.
|
||||
'''
|
||||
return OIDCClient._json_response_to_token(MOCK_TOKEN_RESPONSE)
|
||||
|
||||
|
||||
def test_generate_access_token(oidc_client):
|
||||
with mock.patch(
|
||||
'awx.main.utils.analytics_proxy.requests.post',
|
||||
return_value=mock.Mock(json=lambda: MOCK_TOKEN_RESPONSE, raise_for_status=mock.Mock(return_value=None)), # No exception raised
|
||||
):
|
||||
oidc_client._generate_access_token()
|
||||
|
||||
assert oidc_client.token
|
||||
assert oidc_client.token.access_token == 'bob-access-token'
|
||||
assert oidc_client.token.expires_in == 500
|
||||
assert oidc_client.token.refresh_expires_in == 900
|
||||
assert oidc_client.token.token_type == TokenType.BEARER
|
||||
assert oidc_client.token.not_before_policy == 6
|
||||
assert oidc_client.token.scope == 'fake-scope1, fake-scope2'
|
||||
|
||||
|
||||
def test_token_generation_error(oidc_client):
|
||||
'''
|
||||
Check that TokenError is raised for failure in token generation process
|
||||
'''
|
||||
exception_404 = requests.HTTPError('404 Client Error: Not Found for url')
|
||||
with mock.patch(
|
||||
'awx.main.utils.analytics_proxy.requests.post',
|
||||
return_value=mock.Mock(status_code=404, json=mock.Mock(return_value={'error': 'Not Found'}), raise_for_status=mock.Mock(side_effect=exception_404)),
|
||||
):
|
||||
with pytest.raises(TokenError) as exc_info:
|
||||
oidc_client._generate_access_token()
|
||||
|
||||
assert exc_info.value.__cause__ == exception_404
|
||||
|
||||
|
||||
def test_make_request(oidc_client, token):
|
||||
'''
|
||||
Check that make_request makes an http request with a generated token.
|
||||
'''
|
||||
|
||||
def fake_generate_access_token():
|
||||
oidc_client.token = token
|
||||
|
||||
with (
|
||||
mock.patch.object(oidc_client, '_generate_access_token', side_effect=fake_generate_access_token),
|
||||
mock.patch('awx.main.utils.analytics_proxy.requests.request') as mock_request,
|
||||
):
|
||||
oidc_client.make_request('GET', 'https://does_not_exist.com')
|
||||
|
||||
mock_request.assert_called_with(
|
||||
'GET',
|
||||
'https://does_not_exist.com',
|
||||
headers={
|
||||
'Authorization': f'Bearer {token.access_token}',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_make_request_existing_token(oidc_client, token):
|
||||
'''
|
||||
Check that make_request does not try and generate a token.
|
||||
'''
|
||||
oidc_client.token = token
|
||||
|
||||
with (
|
||||
mock.patch.object(oidc_client, '_generate_access_token', side_effect=RuntimeError('expected not to be called')),
|
||||
mock.patch('awx.main.utils.analytics_proxy.requests.request') as mock_request,
|
||||
):
|
||||
oidc_client.make_request('GET', 'https://does_not_exist.com')
|
||||
|
||||
mock_request.assert_called_with(
|
||||
'GET',
|
||||
'https://does_not_exist.com',
|
||||
headers={
|
||||
'Authorization': f'Bearer {token.access_token}',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
)
|
||||
184
awx/main/utils/analytics_proxy.py
Normal file
184
awx/main/utils/analytics_proxy.py
Normal file
@ -0,0 +1,184 @@
|
||||
'''
|
||||
Proxy requests Analytics requests
|
||||
'''
|
||||
|
||||
import time
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from typing import Optional, Any
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class TokenError(requests.RequestException):
|
||||
'''
|
||||
Raised when token generation request fails.
|
||||
|
||||
Useful for differentiating request failure for make_request() vs.
|
||||
other requests issued to get a token i.e.:
|
||||
|
||||
try:
|
||||
client = OIDCClient(...)
|
||||
client.make_request(...)
|
||||
except TokenGenerationError as e:
|
||||
print(f"Token generation failed due to {e.__cause__}")
|
||||
except requests.RequestException:
|
||||
print("API request failed)
|
||||
'''
|
||||
|
||||
def __init__(self, message="Token generation request failed", response=None):
|
||||
super().__init__(message)
|
||||
self.response = response # Store the response for debugging
|
||||
|
||||
|
||||
def _now(reason: str):
|
||||
'''
|
||||
Wrapper for time. Helps with testing.
|
||||
'''
|
||||
return int(time.time())
|
||||
|
||||
|
||||
class TokenType(Enum):
|
||||
'''
|
||||
Access token type as returned by the remote API.
|
||||
'''
|
||||
|
||||
BEARER = 'Bearer'
|
||||
|
||||
|
||||
class Token:
|
||||
'''
|
||||
Token data generated by OIDC response.
|
||||
'''
|
||||
|
||||
access_token: str
|
||||
expires_in: int
|
||||
refresh_expires_in: int
|
||||
token_type: TokenType
|
||||
not_before_policy: int # not-before-policy
|
||||
scope: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_token: str,
|
||||
expires_in: int,
|
||||
refresh_expires_in: int,
|
||||
token_type: TokenType,
|
||||
not_before_policy: int,
|
||||
scope: str,
|
||||
):
|
||||
self.access_token = access_token
|
||||
self.expires_in = expires_in
|
||||
self.refresh_expires_in = refresh_expires_in
|
||||
self.token_type = token_type
|
||||
self.not_before_policy = not_before_policy
|
||||
self.scope = scope
|
||||
|
||||
self._now = _now(reason='token-creation')
|
||||
|
||||
@property
|
||||
def expires_at(self) -> int:
|
||||
'''
|
||||
Unix timestamp in seconds of when the token expires.
|
||||
'''
|
||||
return self._now + self.expires_in
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
'''
|
||||
Check if the token is expired.
|
||||
'''
|
||||
return _now(reason='token-expiration-check') >= self.expires_at
|
||||
|
||||
|
||||
class OIDCClient:
|
||||
'''
|
||||
Wraps requests library make_request() and manages OIDC access token.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
token_url: str,
|
||||
scopes: list[str],
|
||||
base_url: str = '',
|
||||
) -> None:
|
||||
self.client_id: str = client_id
|
||||
self.client_secret: str = client_secret
|
||||
self.token_url: str = token_url
|
||||
self.scopes = scopes
|
||||
self.base_url: str = base_url
|
||||
self.token: Optional[Token] = None
|
||||
|
||||
@classmethod
|
||||
def _json_response_to_token(cls, json_response: Any) -> Token:
|
||||
return Token(
|
||||
access_token=json_response['access_token'],
|
||||
expires_in=json_response['expires_in'],
|
||||
refresh_expires_in=json_response['refresh_expires_in'],
|
||||
token_type=TokenType(json_response['token_type']),
|
||||
not_before_policy=json_response['not-before-policy'],
|
||||
scope=json_response['scope'],
|
||||
)
|
||||
|
||||
def _generate_access_token(self) -> None:
|
||||
'''
|
||||
Fetches the initial access token using client credentials.
|
||||
'''
|
||||
response = requests.post(
|
||||
self.token_url,
|
||||
data={
|
||||
'grant_type': 'client_credentials',
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'scope': self.scopes,
|
||||
},
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded'},
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
raise TokenError() from e
|
||||
self.token = OIDCClient._json_response_to_token(response.json())
|
||||
|
||||
def _add_headers(self, headers: dict[str, str]) -> None:
|
||||
'''
|
||||
Add token header
|
||||
'''
|
||||
headers.update(
|
||||
{
|
||||
'Authorization': f'Bearer {self.token.access_token}',
|
||||
'Accept': 'application/json',
|
||||
}
|
||||
)
|
||||
|
||||
def _make_request(self, method: str, url: str, headers: dict[str, str], **kwargs: Any) -> requests.Response:
|
||||
'''
|
||||
Actually make an API call.
|
||||
'''
|
||||
self._add_headers(headers)
|
||||
return requests.request(method, url, headers=headers, **kwargs)
|
||||
|
||||
def make_request(self, method: str, endpoint: str, **kwargs: Any) -> requests.Response:
|
||||
'''
|
||||
Makes an authenticated request and refreshes the token if expired.
|
||||
'''
|
||||
has_generated_token = False
|
||||
|
||||
def generate_access_token():
|
||||
self._generate_access_token()
|
||||
return True
|
||||
|
||||
if not self.token or self.token.is_expired():
|
||||
has_generated_token = generate_access_token()
|
||||
|
||||
url = f'{self.base_url}{endpoint}'
|
||||
headers = kwargs.pop('headers', {})
|
||||
|
||||
response = self._make_request(method, url, headers, **kwargs)
|
||||
if not has_generated_token and response.status_code == 401:
|
||||
generate_access_token()
|
||||
response = self._make_request(method, url, headers, **kwargs)
|
||||
|
||||
return response
|
||||
Loading…
x
Reference in New Issue
Block a user