diff --git a/awx/api/serializers.py b/awx/api/serializers.py index 69326d265c..9348a6c37e 100644 --- a/awx/api/serializers.py +++ b/awx/api/serializers.py @@ -948,7 +948,7 @@ class UserSerializer(BaseSerializer): access_list = self.reverse('api:user_access_list', kwargs={'pk': obj.pk}), tokens = self.reverse('api:o_auth2_token_list', kwargs={'pk': obj.pk}), authorized_tokens = self.reverse('api:user_authorized_token_list', kwargs={'pk': obj.pk}), - personal_tokens = self.reverse('api:o_auth2_personal_token_list', kwargs={'pk': obj.pk}), + personal_tokens = self.reverse('api:user_personal_token_list', kwargs={'pk': obj.pk}), )) return res @@ -985,19 +985,24 @@ class UserSerializer(BaseSerializer): return self._validate_ldap_managed_field(value, 'is_superuser') -class UserAuthorizedTokenSerializer(BaseSerializer): - +class BaseOAuth2TokenSerializer(BaseSerializer): + refresh_token = serializers.SerializerMethodField() token = serializers.SerializerMethodField() + ALLOWED_SCOPES = ['read', 'write'] class Meta: model = OAuth2AccessToken fields = ( '*', '-name', 'description', 'user', 'token', 'refresh_token', - 'expires', 'scope', 'application' + 'application', 'expires', 'scope', ) - read_only_fields = ('user', 'token', 'expires') - + read_only_fields = ('user', 'token', 'expires', 'refresh_token') + extra_kwargs = { + 'scope': {'allow_null': False, 'required': True}, + 'user': {'allow_null': False, 'required': True} + } + def get_token(self, obj): request = self.context.get('request', None) try: @@ -1006,17 +1011,60 @@ class UserAuthorizedTokenSerializer(BaseSerializer): else: return TOKEN_CENSOR except ObjectDoesNotExist: - return '' - + return '' + def get_refresh_token(self, obj): request = self.context.get('request', None) try: - if request.method == 'POST': + if not obj.refresh_token: + return None + elif request.method == 'POST': return getattr(obj.refresh_token, 'token', '') else: return TOKEN_CENSOR except ObjectDoesNotExist: - return '' + return None + + def get_related(self, obj): + ret = super(BaseOAuth2TokenSerializer, self).get_related(obj) + if obj.user: + ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk}) + if obj.application: + ret['application'] = self.reverse( + 'api:o_auth2_application_detail', kwargs={'pk': obj.application.pk} + ) + ret['activity_stream'] = self.reverse( + 'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk} + ) + return ret + + def _is_valid_scope(self, value): + if not value or (not isinstance(value, six.string_types)): + return False + words = value.split() + for word in words: + if words.count(word) > 1: + return False # do not allow duplicates + if word not in self.ALLOWED_SCOPES: + return False + return True + + def validate_scope(self, value): + if not self._is_valid_scope(value): + raise serializers.ValidationError(_( + 'Must be a simple space-separated string with allowed scopes {}.' + ).format(self.ALLOWED_SCOPES)) + return value + + +class UserAuthorizedTokenSerializer(BaseOAuth2TokenSerializer): + + class Meta: + extra_kwargs = { + 'scope': {'allow_null': False, 'required': True}, + 'user': {'allow_null': False, 'required': True}, + 'application': {'allow_null': False, 'required': True} + } def create(self, validated_data): current_user = self.context['request'].user @@ -1035,140 +1083,9 @@ class UserAuthorizedTokenSerializer(BaseSerializer): access_token=obj ) return obj - - -class OAuth2ApplicationSerializer(BaseSerializer): - - show_capabilities = ['edit', 'delete'] - - class Meta: - model = OAuth2Application - fields = ( - '*', 'description', '-user', 'client_id', 'client_secret', 'client_type', - 'redirect_uris', 'authorization_grant_type', 'skip_authorization', 'organization' - ) - read_only_fields = ('client_id', 'client_secret') - read_only_on_update_fields = ('user', 'authorization_grant_type') - extra_kwargs = { - 'user': {'allow_null': True, 'required': False}, - 'organization': {'allow_null': False}, - 'authorization_grant_type': {'allow_null': False} - } - - def to_representation(self, obj): - ret = super(OAuth2ApplicationSerializer, self).to_representation(obj) - if obj.client_type == 'public': - ret.pop('client_secret', None) - return ret - - - def get_modified(self, obj): - if obj is None: - return None - return obj.updated - - def get_related(self, obj): - ret = super(OAuth2ApplicationSerializer, self).get_related(obj) - if obj.user: - ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk}) - ret['tokens'] = self.reverse( - 'api:o_auth2_application_token_list', kwargs={'pk': obj.pk} - ) - ret['activity_stream'] = self.reverse( - 'api:o_auth2_application_activity_stream_list', kwargs={'pk': obj.pk} - ) - return ret - - def _summary_field_tokens(self, obj): - token_list = [{'id': x.pk, 'token': TOKEN_CENSOR, 'scope': x.scope} for x in obj.oauth2accesstoken_set.all()[:10]] - if has_model_field_prefetched(obj, 'oauth2accesstoken_set'): - token_count = len(obj.oauth2accesstoken_set.all()) - else: - if len(token_list) < 10: - token_count = len(token_list) - else: - token_count = obj.oauth2accesstoken_set.count() - return {'count': token_count, 'results': token_list} - - def get_summary_fields(self, obj): - ret = super(OAuth2ApplicationSerializer, self).get_summary_fields(obj) - ret['tokens'] = self._summary_field_tokens(obj) - return ret -class OAuth2TokenSerializer(BaseSerializer): - - refresh_token = serializers.SerializerMethodField() - token = serializers.SerializerMethodField() - ALLOWED_SCOPES = ['read', 'write'] - - class Meta: - model = OAuth2AccessToken - fields = ( - '*', '-name', 'description', 'user', 'token', 'refresh_token', - 'application', 'expires', 'scope', - ) - read_only_fields = ('user', 'token', 'expires') - extra_kwargs = { - 'scope': {'allow_null': False, 'required': True}, - 'user': {'allow_null': False, 'required': True} - } - - def get_modified(self, obj): - if obj is None: - return None - return obj.updated - - def get_related(self, obj): - ret = super(OAuth2TokenSerializer, self).get_related(obj) - if obj.user: - ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk}) - if obj.application: - ret['application'] = self.reverse( - 'api:o_auth2_application_detail', kwargs={'pk': obj.application.pk} - ) - ret['activity_stream'] = self.reverse( - 'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk} - ) - return ret - - def get_token(self, obj): - request = self.context.get('request', None) - try: - if request.method == 'POST': - return obj.token - else: - return TOKEN_CENSOR - except ObjectDoesNotExist: - return '' - - def get_refresh_token(self, obj): - request = self.context.get('request', None) - try: - if request.method == 'POST': - return getattr(obj.refresh_token, 'token', '') - else: - return TOKEN_CENSOR - except ObjectDoesNotExist: - return '' - - def _is_valid_scope(self, value): - if not value or (not isinstance(value, six.string_types)): - return False - words = value.split() - for word in words: - if words.count(word) > 1: - return False # do not allow duplicates - if word not in self.ALLOWED_SCOPES: - return False - return True - - def validate_scope(self, value): - if not self._is_valid_scope(value): - raise serializers.ValidationError(_( - 'Must be a simple space-separated string with allowed scopes {}.' - ).format(self.ALLOWED_SCOPES)) - return value +class OAuth2TokenSerializer(BaseOAuth2TokenSerializer): def create(self, validated_data): current_user = self.context['request'].user @@ -1197,109 +1114,10 @@ class OAuth2TokenDetailSerializer(OAuth2TokenSerializer): read_only_fields = ('*', 'user', 'application') -class OAuth2AuthorizedTokenSerializer(BaseSerializer): - - refresh_token = serializers.SerializerMethodField() - token = serializers.SerializerMethodField() +class UserPersonalTokenSerializer(BaseOAuth2TokenSerializer): class Meta: - model = OAuth2AccessToken - fields = ( - '*', '-name', 'description', '-user', 'token', 'refresh_token', - 'expires', 'scope', 'application', - ) - read_only_fields = ('user', 'token', 'expires') - extra_kwargs = { - 'scope': {'allow_null': False, 'required': True} - } - - def get_token(self, obj): - request = self.context.get('request', None) - try: - if request.method == 'POST': - return obj.token - else: - return TOKEN_CENSOR - except ObjectDoesNotExist: - return '' - - def get_refresh_token(self, obj): - request = self.context.get('request', None) - try: - if request.method == 'POST': - return getattr(obj.refresh_token, 'token', '') - else: - return TOKEN_CENSOR - except ObjectDoesNotExist: - return '' - - def create(self, validated_data): - current_user = self.context['request'].user - validated_data['user'] = current_user - validated_data['token'] = generate_token() - validated_data['expires'] = now() + timedelta( - seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] - ) - obj = super(OAuth2AuthorizedTokenSerializer, self).create(validated_data) - if obj.application and obj.application.user: - obj.user = obj.application.user - obj.save() - if obj.application is not None: - RefreshToken.objects.create( - user=current_user, - token=generate_token(), - application=obj.application, - access_token=obj - ) - return obj - - -class OAuth2PersonalTokenSerializer(BaseSerializer): - - refresh_token = serializers.SerializerMethodField() - token = serializers.SerializerMethodField() - - class Meta: - model = OAuth2AccessToken - fields = ( - '*', '-name', 'description', 'user', 'token', 'refresh_token', - 'application', 'expires', 'scope', - ) read_only_fields = ('user', 'token', 'expires', 'application') - extra_kwargs = { - 'scope': {'allow_null': False, 'required': True} - } - - def get_modified(self, obj): - if obj is None: - return None - return obj.updated - - def get_related(self, obj): - ret = super(OAuth2PersonalTokenSerializer, self).get_related(obj) - if obj.user: - ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk}) - if obj.application: - ret['application'] = self.reverse( - 'api:o_auth2_application_detail', kwargs={'pk': obj.application.pk} - ) - ret['activity_stream'] = self.reverse( - 'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk} - ) - return ret - - def get_token(self, obj): - request = self.context.get('request', None) - try: - if request.method == 'POST': - return obj.token - else: - return TOKEN_CENSOR - except ObjectDoesNotExist: - return '' - - def get_refresh_token(self, obj): - return None def create(self, validated_data): validated_data['user'] = self.context['request'].user @@ -1308,11 +1126,57 @@ class OAuth2PersonalTokenSerializer(BaseSerializer): seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] ) validated_data['application'] = None - obj = super(OAuth2PersonalTokenSerializer, self).create(validated_data) + obj = super(UserPersonalTokenSerializer, self).create(validated_data) obj.save() return obj +class OAuth2ApplicationSerializer(BaseSerializer): + + show_capabilities = ['edit', 'delete'] + + class Meta: + model = OAuth2Application + fields = ( + '*', 'description', '-user', 'client_id', 'client_secret', 'client_type', + 'redirect_uris', 'authorization_grant_type', 'skip_authorization', 'organization' + ) + read_only_fields = ('client_id', 'client_secret') + read_only_on_update_fields = ('user', 'authorization_grant_type') + extra_kwargs = { + 'user': {'allow_null': True, 'required': False}, + 'organization': {'allow_null': False}, + 'authorization_grant_type': {'allow_null': False} + } + + def to_representation(self, obj): + ret = super(OAuth2ApplicationSerializer, self).to_representation(obj) + if obj.client_type == 'public': + ret.pop('client_secret', None) + return ret + + def get_modified(self, obj): + if obj is None: + return None + return obj.updated + + def _summary_field_tokens(self, obj): + token_list = [{'id': x.pk, 'token': TOKEN_CENSOR, 'scope': x.scope} for x in obj.oauth2accesstoken_set.all()[:10]] + if has_model_field_prefetched(obj, 'oauth2accesstoken_set'): + token_count = len(obj.oauth2accesstoken_set.all()) + else: + if len(token_list) < 10: + token_count = len(token_list) + else: + token_count = obj.oauth2accesstoken_set.count() + return {'count': token_count, 'results': token_list} + + def get_summary_fields(self, obj): + ret = super(OAuth2ApplicationSerializer, self).get_summary_fields(obj) + ret['tokens'] = self._summary_field_tokens(obj) + return ret + + class OrganizationSerializer(BaseSerializer): show_capabilities = ['edit', 'delete'] diff --git a/awx/api/urls/user_oauth.py b/awx/api/urls/oauth2.py similarity index 90% rename from awx/api/urls/user_oauth.py rename to awx/api/urls/oauth2.py index bec5c4332b..6e9eea3d9f 100644 --- a/awx/api/urls/user_oauth.py +++ b/awx/api/urls/oauth2.py @@ -11,7 +11,6 @@ from awx.api.views import ( OAuth2TokenList, OAuth2TokenDetail, OAuth2TokenActivityStreamList, - OAuth2PersonalTokenList ) @@ -42,8 +41,7 @@ urls = [ r'^tokens/(?P[0-9]+)/activity_stream/$', OAuth2TokenActivityStreamList.as_view(), name='o_auth2_token_activity_stream_list' - ), - url(r'^personal_tokens/$', OAuth2PersonalTokenList.as_view(), name='o_auth2_personal_token_list'), + ), ] __all__ = ['urls'] diff --git a/awx/api/urls/oauth.py b/awx/api/urls/oauth2_root.py similarity index 100% rename from awx/api/urls/oauth.py rename to awx/api/urls/oauth2_root.py diff --git a/awx/api/urls/urls.py b/awx/api/urls/urls.py index e282a73e5f..52e9ef1cf0 100644 --- a/awx/api/urls/urls.py +++ b/awx/api/urls/urls.py @@ -67,8 +67,8 @@ from .schedule import urls as schedule_urls from .activity_stream import urls as activity_stream_urls from .instance import urls as instance_urls from .instance_group import urls as instance_group_urls -from .user_oauth import urls as user_oauth_urls -from .oauth import urls as oauth_urls +from .oauth2 import urls as oauth2_urls +from .oauth2_root import urls as oauth2_root_urls v1_urls = [ @@ -130,7 +130,7 @@ v2_urls = [ url(r'^applications/(?P[0-9]+)/$', OAuth2ApplicationDetail.as_view(), name='o_auth2_application_detail'), url(r'^applications/(?P[0-9]+)/tokens/$', ApplicationOAuth2TokenList.as_view(), name='application_o_auth2_token_list'), url(r'^tokens/$', OAuth2TokenList.as_view(), name='o_auth2_token_list'), - url(r'^', include(user_oauth_urls)), + url(r'^', include(oauth2_urls)), ] app_name = 'api' @@ -145,7 +145,7 @@ urlpatterns = [ url(r'^logout/$', LoggedLogoutView.as_view( next_page='/api/', redirect_field_name='next' ), name='logout'), - url(r'^o/', include(oauth_urls)), + url(r'^o/', include(oauth2_root_urls)), ] if settings.SETTINGS_MODULE == 'awx.settings.development': from awx.api.swagger import SwaggerSchemaView diff --git a/awx/api/urls/user.py b/awx/api/urls/user.py index 9ecebbb044..ca8d531f46 100644 --- a/awx/api/urls/user.py +++ b/awx/api/urls/user.py @@ -16,7 +16,7 @@ from awx.api.views import ( UserAccessList, OAuth2ApplicationList, OAuth2UserTokenList, - OAuth2PersonalTokenList, + UserPersonalTokenList, UserAuthorizedTokenList, ) @@ -34,7 +34,7 @@ urls = [ url(r'^(?P[0-9]+)/applications/$', OAuth2ApplicationList.as_view(), name='o_auth2_application_list'), url(r'^(?P[0-9]+)/tokens/$', OAuth2UserTokenList.as_view(), name='o_auth2_token_list'), url(r'^(?P[0-9]+)/authorized_tokens/$', UserAuthorizedTokenList.as_view(), name='user_authorized_token_list'), - url(r'^(?P[0-9]+)/personal_tokens/$', OAuth2PersonalTokenList.as_view(), name='o_auth2_personal_token_list'), + url(r'^(?P[0-9]+)/personal_tokens/$', UserPersonalTokenList.as_view(), name='user_personal_token_list'), ] diff --git a/awx/api/views.py b/awx/api/views.py index d64899df39..9365062fe0 100644 --- a/awx/api/views.py +++ b/awx/api/views.py @@ -1611,21 +1611,6 @@ class OAuth2UserTokenList(SubListCreateAPIView): relationship = 'main_oauth2accesstoken' parent_key = 'user' swagger_topic = 'Authentication' - - -class OAuth2AuthorizedTokenList(SubListCreateAPIView): - - view_name = _("OAuth2 Authorized Access Tokens") - - model = OAuth2AccessToken - serializer_class = OAuth2AuthorizedTokenSerializer - parent_model = OAuth2Application - relationship = 'oauth2accesstoken_set' - parent_key = 'application' - swagger_topic = 'Authentication' - - def get_queryset(self): - return get_access_token_model().objects.filter(application__isnull=False, user=self.request.user) class UserAuthorizedTokenList(SubListCreateAPIView): @@ -1633,7 +1618,7 @@ class UserAuthorizedTokenList(SubListCreateAPIView): view_name = _("OAuth2 User Authorized Access Tokens") model = OAuth2AccessToken - serializer_class = OAuth2AuthorizedTokenSerializer + serializer_class = UserAuthorizedTokenSerializer parent_model = User relationship = 'oauth2accesstoken_set' parent_key = 'user' @@ -1655,12 +1640,12 @@ class OrganizationApplicationList(SubListCreateAPIView): swagger_topic = 'Authentication' -class OAuth2PersonalTokenList(SubListCreateAPIView): +class UserPersonalTokenList(SubListCreateAPIView): view_name = _("OAuth2 Personal Access Tokens") model = OAuth2AccessToken - serializer_class = OAuth2PersonalTokenSerializer + serializer_class = UserPersonalTokenSerializer parent_model = User relationship = 'main_oauth2accesstoken' parent_key = 'user' diff --git a/awx/main/tests/functional/api/test_oauth.py b/awx/main/tests/functional/api/test_oauth.py index 7e745213c8..7e8b63eb08 100644 --- a/awx/main/tests/functional/api/test_oauth.py +++ b/awx/main/tests/functional/api/test_oauth.py @@ -29,7 +29,7 @@ def test_personal_access_token_creation(oauth_application, post, alice): @pytest.mark.django_db -def test_oauth_application_create(admin, organization, post): +def test_oauth2_application_create(admin, organization, post): response = post( reverse('api:o_auth2_application_list'), { 'name': 'test app', @@ -47,7 +47,18 @@ def test_oauth_application_create(admin, organization, post): assert created_app.client_type == 'confidential' assert created_app.authorization_grant_type == 'password' assert created_app.organization == organization - + + +@pytest.mark.django_db +def test_oauth2_validator(admin, oauth_application, post): + post( + reverse('api:o_auth2_application_list'), { + 'name': 'Write App Token', + 'application': oauth_application.pk, + 'scope': 'Write', + }, admin, expect=400 + ) + @pytest.mark.django_db def test_oauth_application_update(oauth_application, organization, patch, admin, alice): diff --git a/awx/main/tests/functional/test_rbac_oauth.py b/awx/main/tests/functional/test_rbac_oauth.py index 757c55e12b..f076db3689 100644 --- a/awx/main/tests/functional/test_rbac_oauth.py +++ b/awx/main/tests/functional/test_rbac_oauth.py @@ -200,7 +200,7 @@ class TestOAuth2Token: user_list = [admin, org_admin, org_member, alice] can_access_list = [True, False, True, False] response = post( - reverse('api:o_auth2_personal_token_list', kwargs={'pk': org_member.pk}), + reverse('api:user_personal_token_list', kwargs={'pk': org_member.pk}), {'scope': 'read'}, org_member, expect=201 ) token = AccessToken.objects.get(token=response.data['token']) @@ -220,7 +220,7 @@ class TestOAuth2Token: for user, can_access in zip(user_list, can_access_list): response = post( - reverse('api:o_auth2_personal_token_list', kwargs={'pk': user.pk}), + reverse('api:user_personal_token_list', kwargs={'pk': user.pk}), {'scope': 'read', 'application':None}, user, expect=201 ) token = AccessToken.objects.get(token=response.data['token'])