Merge pull request #1932 from rooftopcellist/validate_scope

Validate scope
This commit is contained in:
Christian Adams
2018-05-25 10:49:02 -04:00
committed by GitHub
8 changed files with 132 additions and 274 deletions

View File

@@ -948,7 +948,7 @@ class UserSerializer(BaseSerializer):
access_list = self.reverse('api:user_access_list', kwargs={'pk': obj.pk}), access_list = self.reverse('api:user_access_list', kwargs={'pk': obj.pk}),
tokens = self.reverse('api:o_auth2_token_list', kwargs={'pk': obj.pk}), tokens = self.reverse('api:o_auth2_token_list', kwargs={'pk': obj.pk}),
authorized_tokens = self.reverse('api:user_authorized_token_list', kwargs={'pk': obj.pk}), authorized_tokens = self.reverse('api:user_authorized_token_list', kwargs={'pk': obj.pk}),
personal_tokens = self.reverse('api:o_auth2_personal_token_list', kwargs={'pk': obj.pk}), personal_tokens = self.reverse('api:user_personal_token_list', kwargs={'pk': obj.pk}),
)) ))
return res return res
@@ -985,19 +985,24 @@ class UserSerializer(BaseSerializer):
return self._validate_ldap_managed_field(value, 'is_superuser') return self._validate_ldap_managed_field(value, 'is_superuser')
class UserAuthorizedTokenSerializer(BaseSerializer): class BaseOAuth2TokenSerializer(BaseSerializer):
refresh_token = serializers.SerializerMethodField() refresh_token = serializers.SerializerMethodField()
token = serializers.SerializerMethodField() token = serializers.SerializerMethodField()
ALLOWED_SCOPES = ['read', 'write']
class Meta: class Meta:
model = OAuth2AccessToken model = OAuth2AccessToken
fields = ( fields = (
'*', '-name', 'description', 'user', 'token', 'refresh_token', '*', '-name', 'description', 'user', 'token', 'refresh_token',
'expires', 'scope', 'application' 'application', 'expires', 'scope',
) )
read_only_fields = ('user', 'token', 'expires') read_only_fields = ('user', 'token', 'expires', 'refresh_token')
extra_kwargs = {
'scope': {'allow_null': False, 'required': True},
'user': {'allow_null': False, 'required': True}
}
def get_token(self, obj): def get_token(self, obj):
request = self.context.get('request', None) request = self.context.get('request', None)
try: try:
@@ -1006,17 +1011,60 @@ class UserAuthorizedTokenSerializer(BaseSerializer):
else: else:
return TOKEN_CENSOR return TOKEN_CENSOR
except ObjectDoesNotExist: except ObjectDoesNotExist:
return '' return ''
def get_refresh_token(self, obj): def get_refresh_token(self, obj):
request = self.context.get('request', None) request = self.context.get('request', None)
try: try:
if request.method == 'POST': if not obj.refresh_token:
return None
elif request.method == 'POST':
return getattr(obj.refresh_token, 'token', '') return getattr(obj.refresh_token, 'token', '')
else: else:
return TOKEN_CENSOR return TOKEN_CENSOR
except ObjectDoesNotExist: except ObjectDoesNotExist:
return '' return None
def get_related(self, obj):
ret = super(BaseOAuth2TokenSerializer, self).get_related(obj)
if obj.user:
ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk})
if obj.application:
ret['application'] = self.reverse(
'api:o_auth2_application_detail', kwargs={'pk': obj.application.pk}
)
ret['activity_stream'] = self.reverse(
'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk}
)
return ret
def _is_valid_scope(self, value):
if not value or (not isinstance(value, six.string_types)):
return False
words = value.split()
for word in words:
if words.count(word) > 1:
return False # do not allow duplicates
if word not in self.ALLOWED_SCOPES:
return False
return True
def validate_scope(self, value):
if not self._is_valid_scope(value):
raise serializers.ValidationError(_(
'Must be a simple space-separated string with allowed scopes {}.'
).format(self.ALLOWED_SCOPES))
return value
class UserAuthorizedTokenSerializer(BaseOAuth2TokenSerializer):
class Meta:
extra_kwargs = {
'scope': {'allow_null': False, 'required': True},
'user': {'allow_null': False, 'required': True},
'application': {'allow_null': False, 'required': True}
}
def create(self, validated_data): def create(self, validated_data):
current_user = self.context['request'].user current_user = self.context['request'].user
@@ -1035,140 +1083,9 @@ class UserAuthorizedTokenSerializer(BaseSerializer):
access_token=obj access_token=obj
) )
return obj return obj
class OAuth2ApplicationSerializer(BaseSerializer):
show_capabilities = ['edit', 'delete']
class Meta:
model = OAuth2Application
fields = (
'*', 'description', '-user', 'client_id', 'client_secret', 'client_type',
'redirect_uris', 'authorization_grant_type', 'skip_authorization', 'organization'
)
read_only_fields = ('client_id', 'client_secret')
read_only_on_update_fields = ('user', 'authorization_grant_type')
extra_kwargs = {
'user': {'allow_null': True, 'required': False},
'organization': {'allow_null': False},
'authorization_grant_type': {'allow_null': False}
}
def to_representation(self, obj):
ret = super(OAuth2ApplicationSerializer, self).to_representation(obj)
if obj.client_type == 'public':
ret.pop('client_secret', None)
return ret
def get_modified(self, obj):
if obj is None:
return None
return obj.updated
def get_related(self, obj):
ret = super(OAuth2ApplicationSerializer, self).get_related(obj)
if obj.user:
ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk})
ret['tokens'] = self.reverse(
'api:o_auth2_application_token_list', kwargs={'pk': obj.pk}
)
ret['activity_stream'] = self.reverse(
'api:o_auth2_application_activity_stream_list', kwargs={'pk': obj.pk}
)
return ret
def _summary_field_tokens(self, obj):
token_list = [{'id': x.pk, 'token': TOKEN_CENSOR, 'scope': x.scope} for x in obj.oauth2accesstoken_set.all()[:10]]
if has_model_field_prefetched(obj, 'oauth2accesstoken_set'):
token_count = len(obj.oauth2accesstoken_set.all())
else:
if len(token_list) < 10:
token_count = len(token_list)
else:
token_count = obj.oauth2accesstoken_set.count()
return {'count': token_count, 'results': token_list}
def get_summary_fields(self, obj):
ret = super(OAuth2ApplicationSerializer, self).get_summary_fields(obj)
ret['tokens'] = self._summary_field_tokens(obj)
return ret
class OAuth2TokenSerializer(BaseSerializer): class OAuth2TokenSerializer(BaseOAuth2TokenSerializer):
refresh_token = serializers.SerializerMethodField()
token = serializers.SerializerMethodField()
ALLOWED_SCOPES = ['read', 'write']
class Meta:
model = OAuth2AccessToken
fields = (
'*', '-name', 'description', 'user', 'token', 'refresh_token',
'application', 'expires', 'scope',
)
read_only_fields = ('user', 'token', 'expires')
extra_kwargs = {
'scope': {'allow_null': False, 'required': True},
'user': {'allow_null': False, 'required': True}
}
def get_modified(self, obj):
if obj is None:
return None
return obj.updated
def get_related(self, obj):
ret = super(OAuth2TokenSerializer, self).get_related(obj)
if obj.user:
ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk})
if obj.application:
ret['application'] = self.reverse(
'api:o_auth2_application_detail', kwargs={'pk': obj.application.pk}
)
ret['activity_stream'] = self.reverse(
'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk}
)
return ret
def get_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return obj.token
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
def get_refresh_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return getattr(obj.refresh_token, 'token', '')
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
def _is_valid_scope(self, value):
if not value or (not isinstance(value, six.string_types)):
return False
words = value.split()
for word in words:
if words.count(word) > 1:
return False # do not allow duplicates
if word not in self.ALLOWED_SCOPES:
return False
return True
def validate_scope(self, value):
if not self._is_valid_scope(value):
raise serializers.ValidationError(_(
'Must be a simple space-separated string with allowed scopes {}.'
).format(self.ALLOWED_SCOPES))
return value
def create(self, validated_data): def create(self, validated_data):
current_user = self.context['request'].user current_user = self.context['request'].user
@@ -1197,109 +1114,10 @@ class OAuth2TokenDetailSerializer(OAuth2TokenSerializer):
read_only_fields = ('*', 'user', 'application') read_only_fields = ('*', 'user', 'application')
class OAuth2AuthorizedTokenSerializer(BaseSerializer): class UserPersonalTokenSerializer(BaseOAuth2TokenSerializer):
refresh_token = serializers.SerializerMethodField()
token = serializers.SerializerMethodField()
class Meta: class Meta:
model = OAuth2AccessToken
fields = (
'*', '-name', 'description', '-user', 'token', 'refresh_token',
'expires', 'scope', 'application',
)
read_only_fields = ('user', 'token', 'expires')
extra_kwargs = {
'scope': {'allow_null': False, 'required': True}
}
def get_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return obj.token
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
def get_refresh_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return getattr(obj.refresh_token, 'token', '')
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
def create(self, validated_data):
current_user = self.context['request'].user
validated_data['user'] = current_user
validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta(
seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS']
)
obj = super(OAuth2AuthorizedTokenSerializer, self).create(validated_data)
if obj.application and obj.application.user:
obj.user = obj.application.user
obj.save()
if obj.application is not None:
RefreshToken.objects.create(
user=current_user,
token=generate_token(),
application=obj.application,
access_token=obj
)
return obj
class OAuth2PersonalTokenSerializer(BaseSerializer):
refresh_token = serializers.SerializerMethodField()
token = serializers.SerializerMethodField()
class Meta:
model = OAuth2AccessToken
fields = (
'*', '-name', 'description', 'user', 'token', 'refresh_token',
'application', 'expires', 'scope',
)
read_only_fields = ('user', 'token', 'expires', 'application') read_only_fields = ('user', 'token', 'expires', 'application')
extra_kwargs = {
'scope': {'allow_null': False, 'required': True}
}
def get_modified(self, obj):
if obj is None:
return None
return obj.updated
def get_related(self, obj):
ret = super(OAuth2PersonalTokenSerializer, self).get_related(obj)
if obj.user:
ret['user'] = self.reverse('api:user_detail', kwargs={'pk': obj.user.pk})
if obj.application:
ret['application'] = self.reverse(
'api:o_auth2_application_detail', kwargs={'pk': obj.application.pk}
)
ret['activity_stream'] = self.reverse(
'api:o_auth2_token_activity_stream_list', kwargs={'pk': obj.pk}
)
return ret
def get_token(self, obj):
request = self.context.get('request', None)
try:
if request.method == 'POST':
return obj.token
else:
return TOKEN_CENSOR
except ObjectDoesNotExist:
return ''
def get_refresh_token(self, obj):
return None
def create(self, validated_data): def create(self, validated_data):
validated_data['user'] = self.context['request'].user validated_data['user'] = self.context['request'].user
@@ -1308,11 +1126,57 @@ class OAuth2PersonalTokenSerializer(BaseSerializer):
seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS'] seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS']
) )
validated_data['application'] = None validated_data['application'] = None
obj = super(OAuth2PersonalTokenSerializer, self).create(validated_data) obj = super(UserPersonalTokenSerializer, self).create(validated_data)
obj.save() obj.save()
return obj return obj
class OAuth2ApplicationSerializer(BaseSerializer):
show_capabilities = ['edit', 'delete']
class Meta:
model = OAuth2Application
fields = (
'*', 'description', '-user', 'client_id', 'client_secret', 'client_type',
'redirect_uris', 'authorization_grant_type', 'skip_authorization', 'organization'
)
read_only_fields = ('client_id', 'client_secret')
read_only_on_update_fields = ('user', 'authorization_grant_type')
extra_kwargs = {
'user': {'allow_null': True, 'required': False},
'organization': {'allow_null': False},
'authorization_grant_type': {'allow_null': False}
}
def to_representation(self, obj):
ret = super(OAuth2ApplicationSerializer, self).to_representation(obj)
if obj.client_type == 'public':
ret.pop('client_secret', None)
return ret
def get_modified(self, obj):
if obj is None:
return None
return obj.updated
def _summary_field_tokens(self, obj):
token_list = [{'id': x.pk, 'token': TOKEN_CENSOR, 'scope': x.scope} for x in obj.oauth2accesstoken_set.all()[:10]]
if has_model_field_prefetched(obj, 'oauth2accesstoken_set'):
token_count = len(obj.oauth2accesstoken_set.all())
else:
if len(token_list) < 10:
token_count = len(token_list)
else:
token_count = obj.oauth2accesstoken_set.count()
return {'count': token_count, 'results': token_list}
def get_summary_fields(self, obj):
ret = super(OAuth2ApplicationSerializer, self).get_summary_fields(obj)
ret['tokens'] = self._summary_field_tokens(obj)
return ret
class OrganizationSerializer(BaseSerializer): class OrganizationSerializer(BaseSerializer):
show_capabilities = ['edit', 'delete'] show_capabilities = ['edit', 'delete']

View File

@@ -11,7 +11,6 @@ from awx.api.views import (
OAuth2TokenList, OAuth2TokenList,
OAuth2TokenDetail, OAuth2TokenDetail,
OAuth2TokenActivityStreamList, OAuth2TokenActivityStreamList,
OAuth2PersonalTokenList
) )
@@ -42,8 +41,7 @@ urls = [
r'^tokens/(?P<pk>[0-9]+)/activity_stream/$', r'^tokens/(?P<pk>[0-9]+)/activity_stream/$',
OAuth2TokenActivityStreamList.as_view(), OAuth2TokenActivityStreamList.as_view(),
name='o_auth2_token_activity_stream_list' name='o_auth2_token_activity_stream_list'
), ),
url(r'^personal_tokens/$', OAuth2PersonalTokenList.as_view(), name='o_auth2_personal_token_list'),
] ]
__all__ = ['urls'] __all__ = ['urls']

View File

@@ -67,8 +67,8 @@ from .schedule import urls as schedule_urls
from .activity_stream import urls as activity_stream_urls from .activity_stream import urls as activity_stream_urls
from .instance import urls as instance_urls from .instance import urls as instance_urls
from .instance_group import urls as instance_group_urls from .instance_group import urls as instance_group_urls
from .user_oauth import urls as user_oauth_urls from .oauth2 import urls as oauth2_urls
from .oauth import urls as oauth_urls from .oauth2_root import urls as oauth2_root_urls
v1_urls = [ v1_urls = [
@@ -130,7 +130,7 @@ v2_urls = [
url(r'^applications/(?P<pk>[0-9]+)/$', OAuth2ApplicationDetail.as_view(), name='o_auth2_application_detail'), url(r'^applications/(?P<pk>[0-9]+)/$', OAuth2ApplicationDetail.as_view(), name='o_auth2_application_detail'),
url(r'^applications/(?P<pk>[0-9]+)/tokens/$', ApplicationOAuth2TokenList.as_view(), name='application_o_auth2_token_list'), url(r'^applications/(?P<pk>[0-9]+)/tokens/$', ApplicationOAuth2TokenList.as_view(), name='application_o_auth2_token_list'),
url(r'^tokens/$', OAuth2TokenList.as_view(), name='o_auth2_token_list'), url(r'^tokens/$', OAuth2TokenList.as_view(), name='o_auth2_token_list'),
url(r'^', include(user_oauth_urls)), url(r'^', include(oauth2_urls)),
] ]
app_name = 'api' app_name = 'api'
@@ -145,7 +145,7 @@ urlpatterns = [
url(r'^logout/$', LoggedLogoutView.as_view( url(r'^logout/$', LoggedLogoutView.as_view(
next_page='/api/', redirect_field_name='next' next_page='/api/', redirect_field_name='next'
), name='logout'), ), name='logout'),
url(r'^o/', include(oauth_urls)), url(r'^o/', include(oauth2_root_urls)),
] ]
if settings.SETTINGS_MODULE == 'awx.settings.development': if settings.SETTINGS_MODULE == 'awx.settings.development':
from awx.api.swagger import SwaggerSchemaView from awx.api.swagger import SwaggerSchemaView

View File

@@ -16,7 +16,7 @@ from awx.api.views import (
UserAccessList, UserAccessList,
OAuth2ApplicationList, OAuth2ApplicationList,
OAuth2UserTokenList, OAuth2UserTokenList,
OAuth2PersonalTokenList, UserPersonalTokenList,
UserAuthorizedTokenList, UserAuthorizedTokenList,
) )
@@ -34,7 +34,7 @@ urls = [
url(r'^(?P<pk>[0-9]+)/applications/$', OAuth2ApplicationList.as_view(), name='o_auth2_application_list'), url(r'^(?P<pk>[0-9]+)/applications/$', OAuth2ApplicationList.as_view(), name='o_auth2_application_list'),
url(r'^(?P<pk>[0-9]+)/tokens/$', OAuth2UserTokenList.as_view(), name='o_auth2_token_list'), url(r'^(?P<pk>[0-9]+)/tokens/$', OAuth2UserTokenList.as_view(), name='o_auth2_token_list'),
url(r'^(?P<pk>[0-9]+)/authorized_tokens/$', UserAuthorizedTokenList.as_view(), name='user_authorized_token_list'), url(r'^(?P<pk>[0-9]+)/authorized_tokens/$', UserAuthorizedTokenList.as_view(), name='user_authorized_token_list'),
url(r'^(?P<pk>[0-9]+)/personal_tokens/$', OAuth2PersonalTokenList.as_view(), name='o_auth2_personal_token_list'), url(r'^(?P<pk>[0-9]+)/personal_tokens/$', UserPersonalTokenList.as_view(), name='user_personal_token_list'),
] ]

View File

@@ -1611,21 +1611,6 @@ class OAuth2UserTokenList(SubListCreateAPIView):
relationship = 'main_oauth2accesstoken' relationship = 'main_oauth2accesstoken'
parent_key = 'user' parent_key = 'user'
swagger_topic = 'Authentication' swagger_topic = 'Authentication'
class OAuth2AuthorizedTokenList(SubListCreateAPIView):
view_name = _("OAuth2 Authorized Access Tokens")
model = OAuth2AccessToken
serializer_class = OAuth2AuthorizedTokenSerializer
parent_model = OAuth2Application
relationship = 'oauth2accesstoken_set'
parent_key = 'application'
swagger_topic = 'Authentication'
def get_queryset(self):
return get_access_token_model().objects.filter(application__isnull=False, user=self.request.user)
class UserAuthorizedTokenList(SubListCreateAPIView): class UserAuthorizedTokenList(SubListCreateAPIView):
@@ -1633,7 +1618,7 @@ class UserAuthorizedTokenList(SubListCreateAPIView):
view_name = _("OAuth2 User Authorized Access Tokens") view_name = _("OAuth2 User Authorized Access Tokens")
model = OAuth2AccessToken model = OAuth2AccessToken
serializer_class = OAuth2AuthorizedTokenSerializer serializer_class = UserAuthorizedTokenSerializer
parent_model = User parent_model = User
relationship = 'oauth2accesstoken_set' relationship = 'oauth2accesstoken_set'
parent_key = 'user' parent_key = 'user'
@@ -1655,12 +1640,12 @@ class OrganizationApplicationList(SubListCreateAPIView):
swagger_topic = 'Authentication' swagger_topic = 'Authentication'
class OAuth2PersonalTokenList(SubListCreateAPIView): class UserPersonalTokenList(SubListCreateAPIView):
view_name = _("OAuth2 Personal Access Tokens") view_name = _("OAuth2 Personal Access Tokens")
model = OAuth2AccessToken model = OAuth2AccessToken
serializer_class = OAuth2PersonalTokenSerializer serializer_class = UserPersonalTokenSerializer
parent_model = User parent_model = User
relationship = 'main_oauth2accesstoken' relationship = 'main_oauth2accesstoken'
parent_key = 'user' parent_key = 'user'

View File

@@ -29,7 +29,7 @@ def test_personal_access_token_creation(oauth_application, post, alice):
@pytest.mark.django_db @pytest.mark.django_db
def test_oauth_application_create(admin, organization, post): def test_oauth2_application_create(admin, organization, post):
response = post( response = post(
reverse('api:o_auth2_application_list'), { reverse('api:o_auth2_application_list'), {
'name': 'test app', 'name': 'test app',
@@ -47,7 +47,18 @@ def test_oauth_application_create(admin, organization, post):
assert created_app.client_type == 'confidential' assert created_app.client_type == 'confidential'
assert created_app.authorization_grant_type == 'password' assert created_app.authorization_grant_type == 'password'
assert created_app.organization == organization assert created_app.organization == organization
@pytest.mark.django_db
def test_oauth2_validator(admin, oauth_application, post):
post(
reverse('api:o_auth2_application_list'), {
'name': 'Write App Token',
'application': oauth_application.pk,
'scope': 'Write',
}, admin, expect=400
)
@pytest.mark.django_db @pytest.mark.django_db
def test_oauth_application_update(oauth_application, organization, patch, admin, alice): def test_oauth_application_update(oauth_application, organization, patch, admin, alice):

View File

@@ -200,7 +200,7 @@ class TestOAuth2Token:
user_list = [admin, org_admin, org_member, alice] user_list = [admin, org_admin, org_member, alice]
can_access_list = [True, False, True, False] can_access_list = [True, False, True, False]
response = post( response = post(
reverse('api:o_auth2_personal_token_list', kwargs={'pk': org_member.pk}), reverse('api:user_personal_token_list', kwargs={'pk': org_member.pk}),
{'scope': 'read'}, org_member, expect=201 {'scope': 'read'}, org_member, expect=201
) )
token = AccessToken.objects.get(token=response.data['token']) token = AccessToken.objects.get(token=response.data['token'])
@@ -220,7 +220,7 @@ class TestOAuth2Token:
for user, can_access in zip(user_list, can_access_list): for user, can_access in zip(user_list, can_access_list):
response = post( response = post(
reverse('api:o_auth2_personal_token_list', kwargs={'pk': user.pk}), reverse('api:user_personal_token_list', kwargs={'pk': user.pk}),
{'scope': 'read', 'application':None}, user, expect=201 {'scope': 'read', 'application':None}, user, expect=201
) )
token = AccessToken.objects.get(token=response.data['token']) token = AccessToken.objects.get(token=response.data['token'])