add scope validator to token endpoints

This commit is contained in:
adamscmRH 2018-05-23 14:04:36 -04:00
parent 921c3d2535
commit 5d220e8222
2 changed files with 201 additions and 230 deletions

View File

@ -985,19 +985,12 @@ class UserSerializer(BaseSerializer):
return self._validate_ldap_managed_field(value, 'is_superuser')
class UserAuthorizedTokenSerializer(BaseSerializer):
class BaseOAuth2TokenSerializer(BaseSerializer):
refresh_token = serializers.SerializerMethodField()
token = serializers.SerializerMethodField()
ALLOWED_SCOPES = ['read', 'write']
class Meta:
model = OAuth2AccessToken
fields = (
'*', '-name', 'description', 'user', 'token', 'refresh_token',
'expires', 'scope', 'application'
)
read_only_fields = ('user', 'token', 'expires')
def get_token(self, obj):
request = self.context.get('request', None)
try:
@ -1006,7 +999,36 @@ class UserAuthorizedTokenSerializer(BaseSerializer):
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
return ''
def _is_valid_scope(self, value):
if not value or (not isinstance(value, six.string_types)):
return False
words = value.split()
for word in words:
if words.count(word) > 1:
return False # do not allow duplicates
if word not in self.ALLOWED_SCOPES:
return False
return True
def validate_scope(self, value):
if not self._is_valid_scope(value):
raise serializers.ValidationError(_(
'Must be a simple space-separated string with allowed scopes {}.'
).format(self.ALLOWED_SCOPES))
return value
class UserAuthorizedTokenSerializer(BaseOAuth2TokenSerializer):
class Meta:
model = OAuth2AccessToken
fields = (
'*', '-name', 'description', 'user', 'token', 'refresh_token',
'expires', 'scope', 'application'
)
read_only_fields = ('user', 'token', 'expires')
def get_refresh_token(self, obj):
request = self.context.get('request', None)
@ -1035,7 +1057,162 @@ class UserAuthorizedTokenSerializer(BaseSerializer):
access_token=obj
)
return obj
class OAuth2TokenSerializer(BaseOAuth2TokenSerializer):
class Meta:
model = OAuth2AccessToken
fields = (
'*', '-name', 'description', 'user', 'token', 'refresh_token',
'application', 'expires', 'scope',
)
read_only_fields = ('user', 'token', 'expires')
extra_kwargs = {
'scope': {'allow_null': False, 'required': True},
'user': {'allow_null': False, 'required': True}
}
def get_modified(self, obj):
if obj is None:
return None
return obj.updated
def get_related(self, obj):
ret = super(OAuth2TokenSerializer, self).get_related(obj)
if obj.user:
ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk})
if obj.application:
ret['application'] = self.reverse(
'api:o_auth2_application_detail', kwargs={'pk': obj.application.pk}
)
ret['activity_stream'] = self.reverse(
'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk}
)
return ret
def get_refresh_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return getattr(obj.refresh_token, 'token', '')
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
def create(self, validated_data):
current_user = self.context['request'].user
validated_data['user'] = current_user
validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta(
seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS']
)
obj = super(OAuth2TokenSerializer, self).create(validated_data)
if obj.application and obj.application.user:
obj.user = obj.application.user
obj.save()
if obj.application is not None:
RefreshToken.objects.create(
user=current_user,
token=generate_token(),
application=obj.application,
access_token=obj
)
return obj
class OAuth2TokenDetailSerializer(OAuth2TokenSerializer):
class Meta:
read_only_fields = ('*', 'user', 'application')
class OAuth2AuthorizedTokenSerializer(BaseOAuth2TokenSerializer):
class Meta:
model = OAuth2AccessToken
fields = (
'*', '-name', 'description', '-user', 'token', 'refresh_token',
'expires', 'scope', 'application',
)
read_only_fields = ('user', 'token', 'expires')
extra_kwargs = {
'scope': {'allow_null': False, 'required': True}
}
def get_refresh_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return getattr(obj.refresh_token, 'token', '')
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
def create(self, validated_data):
current_user = self.context['request'].user
validated_data['user'] = current_user
validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta(
seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS']
)
obj = super(OAuth2AuthorizedTokenSerializer, self).create(validated_data)
if obj.application and obj.application.user:
obj.user = obj.application.user
obj.save()
if obj.application is not None:
RefreshToken.objects.create(
user=current_user,
token=generate_token(),
application=obj.application,
access_token=obj
)
return obj
class OAuth2PersonalTokenSerializer(BaseOAuth2TokenSerializer):
class Meta:
model = OAuth2AccessToken
fields = (
'*', '-name', 'description', 'user', 'token', 'refresh_token',
'application', 'expires', 'scope',
)
read_only_fields = ('user', 'token', 'expires', 'application')
extra_kwargs = {
'scope': {'allow_null': False, 'required': True}
}
def get_modified(self, obj):
if obj is None:
return None
return obj.updated
def get_related(self, obj):
ret = super(OAuth2PersonalTokenSerializer, self).get_related(obj)
if obj.user:
ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk})
ret['activity_stream'] = self.reverse(
'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk}
)
return ret
def get_refresh_token(self, obj):
return None
def create(self, validated_data):
validated_data['user'] = self.context['request'].user
validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta(
seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS']
)
validated_data['application'] = None
obj = super(OAuth2PersonalTokenSerializer, self).create(validated_data)
obj.save()
return obj
class OAuth2ApplicationSerializer(BaseSerializer):
@ -1096,223 +1273,6 @@ class OAuth2ApplicationSerializer(BaseSerializer):
return ret
class OAuth2TokenSerializer(BaseSerializer):
refresh_token = serializers.SerializerMethodField()
token = serializers.SerializerMethodField()
ALLOWED_SCOPES = ['read', 'write']
class Meta:
model = OAuth2AccessToken
fields = (
'*', '-name', 'description', 'user', 'token', 'refresh_token',
'application', 'expires', 'scope',
)
read_only_fields = ('user', 'token', 'expires')
extra_kwargs = {
'scope': {'allow_null': False, 'required': True},
'user': {'allow_null': False, 'required': True}
}
def get_modified(self, obj):
if obj is None:
return None
return obj.updated
def get_related(self, obj):
ret = super(OAuth2TokenSerializer, self).get_related(obj)
if obj.user:
ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk})
if obj.application:
ret['application'] = self.reverse(
'api:o_auth2_application_detail', kwargs={'pk': obj.application.pk}
)
ret['activity_stream'] = self.reverse(
'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk}
)
return ret
def get_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return obj.token
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
def get_refresh_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return getattr(obj.refresh_token, 'token', '')
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
def _is_valid_scope(self, value):
if not value or (not isinstance(value, six.string_types)):
return False
words = value.split()
for word in words:
if words.count(word) > 1:
return False # do not allow duplicates
if word not in self.ALLOWED_SCOPES:
return False
return True
def validate_scope(self, value):
if not self._is_valid_scope(value):
raise serializers.ValidationError(_(
'Must be a simple space-separated string with allowed scopes {}.'
).format(self.ALLOWED_SCOPES))
return value
def create(self, validated_data):
current_user = self.context['request'].user
validated_data['user'] = current_user
validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta(
seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS']
)
obj = super(OAuth2TokenSerializer, self).create(validated_data)
if obj.application and obj.application.user:
obj.user = obj.application.user
obj.save()
if obj.application is not None:
RefreshToken.objects.create(
user=current_user,
token=generate_token(),
application=obj.application,
access_token=obj
)
return obj
class OAuth2TokenDetailSerializer(OAuth2TokenSerializer):
class Meta:
read_only_fields = ('*', 'user', 'application')
class OAuth2AuthorizedTokenSerializer(BaseSerializer):
refresh_token = serializers.SerializerMethodField()
token = serializers.SerializerMethodField()
class Meta:
model = OAuth2AccessToken
fields = (
'*', '-name', 'description', '-user', 'token', 'refresh_token',
'expires', 'scope', 'application',
)
read_only_fields = ('user', 'token', 'expires')
extra_kwargs = {
'scope': {'allow_null': False, 'required': True}
}
def get_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return obj.token
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
def get_refresh_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return getattr(obj.refresh_token, 'token', '')
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
def create(self, validated_data):
current_user = self.context['request'].user
validated_data['user'] = current_user
validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta(
seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS']
)
obj = super(OAuth2AuthorizedTokenSerializer, self).create(validated_data)
if obj.application and obj.application.user:
obj.user = obj.application.user
obj.save()
if obj.application is not None:
RefreshToken.objects.create(
user=current_user,
token=generate_token(),
application=obj.application,
access_token=obj
)
return obj
class OAuth2PersonalTokenSerializer(BaseSerializer):
refresh_token = serializers.SerializerMethodField()
token = serializers.SerializerMethodField()
class Meta:
model = OAuth2AccessToken
fields = (
'*', '-name', 'description', 'user', 'token', 'refresh_token',
'application', 'expires', 'scope',
)
read_only_fields = ('user', 'token', 'expires', 'application')
extra_kwargs = {
'scope': {'allow_null': False, 'required': True}
}
def get_modified(self, obj):
if obj is None:
return None
return obj.updated
def get_related(self, obj):
ret = super(OAuth2PersonalTokenSerializer, self).get_related(obj)
if obj.user:
ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk})
if obj.application:
ret['application'] = self.reverse(
'api:o_auth2_application_detail', kwargs={'pk': obj.application.pk}
)
ret['activity_stream'] = self.reverse(
'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk}
)
return ret
def get_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return obj.token
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
def get_refresh_token(self, obj):
return None
def create(self, validated_data):
validated_data['user'] = self.context['request'].user
validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta(
seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS']
)
validated_data['application'] = None
obj = super(OAuth2PersonalTokenSerializer, self).create(validated_data)
obj.save()
return obj
class OrganizationSerializer(BaseSerializer):
show_capabilities = ['edit', 'delete']

View File

@ -29,7 +29,7 @@ def test_personal_access_token_creation(oauth_application, post, alice):
@pytest.mark.django_db
def test_oauth_application_create(admin, organization, post):
def test_oauth2_application_create(admin, organization, post):
response = post(
reverse('api:o_auth2_application_list'), {
'name': 'test app',
@ -47,7 +47,18 @@ def test_oauth_application_create(admin, organization, post):
assert created_app.client_type == 'confidential'
assert created_app.authorization_grant_type == 'password'
assert created_app.organization == organization
@pytest.mark.django_db
def test_oauth2_validator(admin, oauth_application, post):
post(
reverse('api:o_auth2_application_list'), {
'name': 'Write App Token',
'application': oauth_application.pk,
'scope': 'Write',
}, admin, expect=400
)
@pytest.mark.django_db
def test_oauth_application_update(oauth_application, organization, patch, admin, alice):