diff --git a/awx/main/models/oauth.py b/awx/main/models/oauth.py index 3a89790e80..5f719f894e 100644 --- a/awx/main/models/oauth.py +++ b/awx/main/models/oauth.py @@ -1,4 +1,5 @@ # Python +import logging import re # Django @@ -22,6 +23,9 @@ DATA_URI_RE = re.compile(r'.*') # FIXME __all__ = ['OAuth2AccessToken', 'OAuth2Application'] +logger = logging.getLogger('awx.main.models.oauth') + + class OAuth2Application(AbstractApplication): class Meta: @@ -120,15 +124,27 @@ class OAuth2AccessToken(AbstractAccessToken): def is_valid(self, scopes=None): valid = super(OAuth2AccessToken, self).is_valid(scopes) if valid: + try: + self.validate_external_users() + except oauth2.AccessDeniedError: + logger.exception(f'Failed to authenticate {self.user.username}') + return False self.last_used = now() - connection.on_commit(lambda: self.save(update_fields=['last_used'])) + + def _update_last_used(): + if OAuth2AccessToken.objects.filter(pk=self.pk).exists(): + self.save(update_fields=['last_used']) + connection.on_commit(_update_last_used) return valid - def save(self, *args, **kwargs): + def validate_external_users(self): if self.user and settings.ALLOW_OAUTH2_FOR_EXTERNAL_USERS is False: external_account = get_external_account(self.user) if external_account is not None: raise oauth2.AccessDeniedError(_( 'OAuth2 Tokens cannot be created by users associated with an external authentication provider ({})' ).format(external_account)) + + def save(self, *args, **kwargs): + self.validate_external_users() super(OAuth2AccessToken, self).save(*args, **kwargs) diff --git a/awx/main/tests/functional/api/test_oauth.py b/awx/main/tests/functional/api/test_oauth.py index 22ae98b710..7fc0d65977 100644 --- a/awx/main/tests/functional/api/test_oauth.py +++ b/awx/main/tests/functional/api/test_oauth.py @@ -1,6 +1,8 @@ import pytest import base64 +import contextlib import json +from unittest import mock from django.db import connection from django.test.utils import override_settings @@ -14,6 +16,18 @@ from awx.sso.models import UserEnterpriseAuth from oauth2_provider.models import RefreshToken +@contextlib.contextmanager +def immediate_on_commit(): + """ + Context manager executing transaction.on_commit() hooks immediately as + if the connection was in auto-commit mode. + """ + def on_commit(func): + func() + with mock.patch('django.db.connection.on_commit', side_effect=on_commit) as patch: + yield patch + + @pytest.mark.django_db def test_personal_access_token_creation(oauth_application, post, alice): url = drf_reverse('api:oauth_authorization_root_view') + 'token/' @@ -54,6 +68,41 @@ def test_token_creation_disabled_for_external_accounts(oauth_application, post, assert AccessToken.objects.count() == 0 +@pytest.mark.django_db +def test_existing_token_disabled_for_external_accounts(oauth_application, get, post, admin): + UserEnterpriseAuth(user=admin, provider='radius').save() + url = drf_reverse('api:oauth_authorization_root_view') + 'token/' + with override_settings(RADIUS_SERVER='example.org', ALLOW_OAUTH2_FOR_EXTERNAL_USERS=True): + resp = post( + url, + data='grant_type=password&username=admin&password=admin&scope=read', + content_type='application/x-www-form-urlencoded', + HTTP_AUTHORIZATION='Basic ' + smart_str(base64.b64encode(smart_bytes(':'.join([ + oauth_application.client_id, oauth_application.client_secret + ])))), + status=201 + ) + token = json.loads(resp.content)['access_token'] + assert AccessToken.objects.count() == 1 + + with immediate_on_commit(): + resp = get( + drf_reverse('api:user_me_list', kwargs={'version': 'v2'}), + HTTP_AUTHORIZATION='Bearer ' + token, + status=200 + ) + assert json.loads(resp.content)['results'][0]['username'] == 'admin' + + with override_settings(RADIUS_SERVER='example.org', ALLOW_OAUTH2_FOR_EXTERNAL_USER=False): + with immediate_on_commit(): + resp = get( + drf_reverse('api:user_me_list', kwargs={'version': 'v2'}), + HTTP_AUTHORIZATION='Bearer ' + token, + status=401 + ) + assert b'To establish a login session' in resp.content + + @pytest.mark.django_db def test_pat_creation_no_default_scope(oauth_application, post, admin): # tests that the default scope is overriden