From 5d220e82229bd66272cbe3b2338ab929726fe2b1 Mon Sep 17 00:00:00 2001 From: adamscmRH Date: Wed, 23 May 2018 14:04:36 -0400 Subject: [PATCH] add scope validator to token endpoints --- awx/api/serializers.py | 416 +++++++++----------- awx/main/tests/functional/api/test_oauth.py | 15 +- 2 files changed, 201 insertions(+), 230 deletions(-) diff --git a/awx/api/serializers.py b/awx/api/serializers.py index 69326d265c..77b7316334 100644 --- a/awx/api/serializers.py +++ b/awx/api/serializers.py @@ -985,19 +985,12 @@ class UserSerializer(BaseSerializer): return self._validate_ldap_managed_field(value, 'is_superuser') -class UserAuthorizedTokenSerializer(BaseSerializer): - +class BaseOAuth2TokenSerializer(BaseSerializer): + refresh_token = serializers.SerializerMethodField() token = serializers.SerializerMethodField() + ALLOWED_SCOPES = ['read', 'write'] - class Meta: - model = OAuth2AccessToken - fields = ( - '*', '-name', 'description', 'user', 'token', 'refresh_token', - 'expires', 'scope', 'application' - ) - read_only_fields = ('user', 'token', 'expires') - def get_token(self, obj): request = self.context.get('request', None) try: @@ -1006,7 +999,36 @@ class UserAuthorizedTokenSerializer(BaseSerializer): else: return TOKEN_CENSOR except ObjectDoesNotExist: - return '' + return '' + + def _is_valid_scope(self, value): + if not value or (not isinstance(value, six.string_types)): + return False + words = value.split() + for word in words: + if words.count(word) > 1: + return False # do not allow duplicates + if word not in self.ALLOWED_SCOPES: + return False + return True + + def validate_scope(self, value): + if not self._is_valid_scope(value): + raise serializers.ValidationError(_( + 'Must be a simple space-separated string with allowed scopes {}.' + ).format(self.ALLOWED_SCOPES)) + return value + + +class UserAuthorizedTokenSerializer(BaseOAuth2TokenSerializer): + + class Meta: + model = OAuth2AccessToken + fields = ( + '*', '-name', 'description', 'user', 'token', 'refresh_token', + 'expires', 'scope', 'application' + ) + read_only_fields = ('user', 'token', 'expires') def get_refresh_token(self, obj): request = self.context.get('request', None) @@ -1035,7 +1057,162 @@ class UserAuthorizedTokenSerializer(BaseSerializer): access_token=obj ) return obj + + +class OAuth2TokenSerializer(BaseOAuth2TokenSerializer): + + class Meta: + model = OAuth2AccessToken + fields = ( + '*', '-name', 'description', 'user', 'token', 'refresh_token', + 'application', 'expires', 'scope', + ) + read_only_fields = ('user', 'token', 'expires') + extra_kwargs = { + 'scope': {'allow_null': False, 'required': True}, + 'user': {'allow_null': False, 'required': True} + } + + def get_modified(self, obj): + if obj is None: + return None + return obj.updated + + def get_related(self, obj): + ret = super(OAuth2TokenSerializer, self).get_related(obj) + if obj.user: + ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk}) + if obj.application: + ret['application'] = self.reverse( + 'api:o_auth2_application_detail', kwargs={'pk': obj.application.pk} + ) + ret['activity_stream'] = self.reverse( + 'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk} + ) + return ret + + def get_refresh_token(self, obj): + request = self.context.get('request', None) + try: + if request.method == 'POST': + return getattr(obj.refresh_token, 'token', '') + else: + return TOKEN_CENSOR + except ObjectDoesNotExist: + return '' + + def create(self, validated_data): + current_user = self.context['request'].user + validated_data['user'] = current_user + validated_data['token'] = generate_token() + validated_data['expires'] = now() + timedelta( + seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] + ) + obj = super(OAuth2TokenSerializer, self).create(validated_data) + if obj.application and obj.application.user: + obj.user = obj.application.user + obj.save() + if obj.application is not None: + RefreshToken.objects.create( + user=current_user, + token=generate_token(), + application=obj.application, + access_token=obj + ) + return obj + +class OAuth2TokenDetailSerializer(OAuth2TokenSerializer): + + class Meta: + read_only_fields = ('*', 'user', 'application') + + +class OAuth2AuthorizedTokenSerializer(BaseOAuth2TokenSerializer): + + class Meta: + model = OAuth2AccessToken + fields = ( + '*', '-name', 'description', '-user', 'token', 'refresh_token', + 'expires', 'scope', 'application', + ) + read_only_fields = ('user', 'token', 'expires') + extra_kwargs = { + 'scope': {'allow_null': False, 'required': True} + } + + def get_refresh_token(self, obj): + request = self.context.get('request', None) + try: + if request.method == 'POST': + return getattr(obj.refresh_token, 'token', '') + else: + return TOKEN_CENSOR + except ObjectDoesNotExist: + return '' + + def create(self, validated_data): + current_user = self.context['request'].user + validated_data['user'] = current_user + validated_data['token'] = generate_token() + validated_data['expires'] = now() + timedelta( + seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] + ) + obj = super(OAuth2AuthorizedTokenSerializer, self).create(validated_data) + if obj.application and obj.application.user: + obj.user = obj.application.user + obj.save() + if obj.application is not None: + RefreshToken.objects.create( + user=current_user, + token=generate_token(), + application=obj.application, + access_token=obj + ) + return obj + + +class OAuth2PersonalTokenSerializer(BaseOAuth2TokenSerializer): + + class Meta: + model = OAuth2AccessToken + fields = ( + '*', '-name', 'description', 'user', 'token', 'refresh_token', + 'application', 'expires', 'scope', + ) + read_only_fields = ('user', 'token', 'expires', 'application') + extra_kwargs = { + 'scope': {'allow_null': False, 'required': True} + } + + def get_modified(self, obj): + if obj is None: + return None + return obj.updated + + def get_related(self, obj): + ret = super(OAuth2PersonalTokenSerializer, self).get_related(obj) + if obj.user: + ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk}) + ret['activity_stream'] = self.reverse( + 'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk} + ) + return ret + + def get_refresh_token(self, obj): + return None + + def create(self, validated_data): + validated_data['user'] = self.context['request'].user + validated_data['token'] = generate_token() + validated_data['expires'] = now() + timedelta( + seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] + ) + validated_data['application'] = None + obj = super(OAuth2PersonalTokenSerializer, self).create(validated_data) + obj.save() + return obj + class OAuth2ApplicationSerializer(BaseSerializer): @@ -1096,223 +1273,6 @@ class OAuth2ApplicationSerializer(BaseSerializer): return ret -class OAuth2TokenSerializer(BaseSerializer): - - refresh_token = serializers.SerializerMethodField() - token = serializers.SerializerMethodField() - ALLOWED_SCOPES = ['read', 'write'] - - class Meta: - model = OAuth2AccessToken - fields = ( - '*', '-name', 'description', 'user', 'token', 'refresh_token', - 'application', 'expires', 'scope', - ) - read_only_fields = ('user', 'token', 'expires') - extra_kwargs = { - 'scope': {'allow_null': False, 'required': True}, - 'user': {'allow_null': False, 'required': True} - } - - def get_modified(self, obj): - if obj is None: - return None - return obj.updated - - def get_related(self, obj): - ret = super(OAuth2TokenSerializer, self).get_related(obj) - if obj.user: - ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk}) - if obj.application: - ret['application'] = self.reverse( - 'api:o_auth2_application_detail', kwargs={'pk': obj.application.pk} - ) - ret['activity_stream'] = self.reverse( - 'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk} - ) - return ret - - def get_token(self, obj): - request = self.context.get('request', None) - try: - if request.method == 'POST': - return obj.token - else: - return TOKEN_CENSOR - except ObjectDoesNotExist: - return '' - - def get_refresh_token(self, obj): - request = self.context.get('request', None) - try: - if request.method == 'POST': - return getattr(obj.refresh_token, 'token', '') - else: - return TOKEN_CENSOR - except ObjectDoesNotExist: - return '' - - def _is_valid_scope(self, value): - if not value or (not isinstance(value, six.string_types)): - return False - words = value.split() - for word in words: - if words.count(word) > 1: - return False # do not allow duplicates - if word not in self.ALLOWED_SCOPES: - return False - return True - - def validate_scope(self, value): - if not self._is_valid_scope(value): - raise serializers.ValidationError(_( - 'Must be a simple space-separated string with allowed scopes {}.' - ).format(self.ALLOWED_SCOPES)) - return value - - def create(self, validated_data): - current_user = self.context['request'].user - validated_data['user'] = current_user - validated_data['token'] = generate_token() - validated_data['expires'] = now() + timedelta( - seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] - ) - obj = super(OAuth2TokenSerializer, self).create(validated_data) - if obj.application and obj.application.user: - obj.user = obj.application.user - obj.save() - if obj.application is not None: - RefreshToken.objects.create( - user=current_user, - token=generate_token(), - application=obj.application, - access_token=obj - ) - return obj - - -class OAuth2TokenDetailSerializer(OAuth2TokenSerializer): - - class Meta: - read_only_fields = ('*', 'user', 'application') - - -class OAuth2AuthorizedTokenSerializer(BaseSerializer): - - refresh_token = serializers.SerializerMethodField() - token = serializers.SerializerMethodField() - - class Meta: - model = OAuth2AccessToken - fields = ( - '*', '-name', 'description', '-user', 'token', 'refresh_token', - 'expires', 'scope', 'application', - ) - read_only_fields = ('user', 'token', 'expires') - extra_kwargs = { - 'scope': {'allow_null': False, 'required': True} - } - - def get_token(self, obj): - request = self.context.get('request', None) - try: - if request.method == 'POST': - return obj.token - else: - return TOKEN_CENSOR - except ObjectDoesNotExist: - return '' - - def get_refresh_token(self, obj): - request = self.context.get('request', None) - try: - if request.method == 'POST': - return getattr(obj.refresh_token, 'token', '') - else: - return TOKEN_CENSOR - except ObjectDoesNotExist: - return '' - - def create(self, validated_data): - current_user = self.context['request'].user - validated_data['user'] = current_user - validated_data['token'] = generate_token() - validated_data['expires'] = now() + timedelta( - seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] - ) - obj = super(OAuth2AuthorizedTokenSerializer, self).create(validated_data) - if obj.application and obj.application.user: - obj.user = obj.application.user - obj.save() - if obj.application is not None: - RefreshToken.objects.create( - user=current_user, - token=generate_token(), - application=obj.application, - access_token=obj - ) - return obj - - -class OAuth2PersonalTokenSerializer(BaseSerializer): - - refresh_token = serializers.SerializerMethodField() - token = serializers.SerializerMethodField() - - class Meta: - model = OAuth2AccessToken - fields = ( - '*', '-name', 'description', 'user', 'token', 'refresh_token', - 'application', 'expires', 'scope', - ) - read_only_fields = ('user', 'token', 'expires', 'application') - extra_kwargs = { - 'scope': {'allow_null': False, 'required': True} - } - - def get_modified(self, obj): - if obj is None: - return None - return obj.updated - - def get_related(self, obj): - ret = super(OAuth2PersonalTokenSerializer, self).get_related(obj) - if obj.user: - ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk}) - if obj.application: - ret['application'] = self.reverse( - 'api:o_auth2_application_detail', kwargs={'pk': obj.application.pk} - ) - ret['activity_stream'] = self.reverse( - 'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk} - ) - return ret - - def get_token(self, obj): - request = self.context.get('request', None) - try: - if request.method == 'POST': - return obj.token - else: - return TOKEN_CENSOR - except ObjectDoesNotExist: - return '' - - def get_refresh_token(self, obj): - return None - - def create(self, validated_data): - validated_data['user'] = self.context['request'].user - validated_data['token'] = generate_token() - validated_data['expires'] = now() + timedelta( - seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] - ) - validated_data['application'] = None - obj = super(OAuth2PersonalTokenSerializer, self).create(validated_data) - obj.save() - return obj - - class OrganizationSerializer(BaseSerializer): show_capabilities = ['edit', 'delete'] diff --git a/awx/main/tests/functional/api/test_oauth.py b/awx/main/tests/functional/api/test_oauth.py index 7e745213c8..7e8b63eb08 100644 --- a/awx/main/tests/functional/api/test_oauth.py +++ b/awx/main/tests/functional/api/test_oauth.py @@ -29,7 +29,7 @@ def test_personal_access_token_creation(oauth_application, post, alice): @pytest.mark.django_db -def test_oauth_application_create(admin, organization, post): +def test_oauth2_application_create(admin, organization, post): response = post( reverse('api:o_auth2_application_list'), { 'name': 'test app', @@ -47,7 +47,18 @@ def test_oauth_application_create(admin, organization, post): assert created_app.client_type == 'confidential' assert created_app.authorization_grant_type == 'password' assert created_app.organization == organization - + + +@pytest.mark.django_db +def test_oauth2_validator(admin, oauth_application, post): + post( + reverse('api:o_auth2_application_list'), { + 'name': 'Write App Token', + 'application': oauth_application.pk, + 'scope': 'Write', + }, admin, expect=400 + ) + @pytest.mark.django_db def test_oauth_application_update(oauth_application, organization, patch, admin, alice):