From 7722e414e25eea70b3e034efff9501065227da64 Mon Sep 17 00:00:00 2001 From: Michael DeHaan Date: Wed, 20 Mar 2013 22:47:51 -0400 Subject: [PATCH] Add functions for checking size of paginated results --- lib/main/tests.py | 29 +++++++++++++++++------------ lib/main/views.py | 17 +++++++++++++---- lib/settings/defaults.py | 6 +++++- 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/lib/main/tests.py b/lib/main/tests.py index f0d63bda30..789ff306bf 100644 --- a/lib/main/tests.py +++ b/lib/main/tests.py @@ -34,6 +34,11 @@ class BaseTest(django.test.TestCase): results.append(Organization.objects.create(name="org%s" % x, description="org%s" % x)) return results + def check_pagination_and_size(self, data, desired_count, previous=None, next=None): + self.assertEquals(data['count'], desired_count) + self.assertEquals(data['previous'], previous) + self.assertEquals(data['next'], next) + def setup_users(self): # Create a user. self.super_username = 'admin' @@ -48,13 +53,13 @@ class BaseTest(django.test.TestCase): (self.other_django_user, self.other_acom_user) = self.make_user(self.other_username, self.other_password, super_user=False) def get_super_credentials(self): - return self.create_basic(self.super_username, self.super_password) + return (self.super_username, self.super_password) def get_normal_credentials(self): - return self.create_basic(self.normal_username, self.normal_password) + return (self.normal_username, self.normal_password) def get_other_credentials(self): - return self.create_basic(self.other_username, self.other_password) + return (self.other_username, self.other_password) def get_invalid_credentials(self): return ('random', 'combination') @@ -63,11 +68,9 @@ class BaseTest(django.test.TestCase): assert method is not None if method != 'get': assert data is not None - client = None + client = Client() if auth: - client = Client(username=auth[0], password=auth[1]) - else: - client = Client() + client.login(username=auth[0], password=auth[1]) method = getattr(client,method) response = None if data is not None: @@ -75,20 +78,20 @@ class BaseTest(django.test.TestCase): else: response = method(url) if expect is not None: - assert response.status_code == expect, "expected %s got %s" % (expect, response.status_code) - data = json.loads(response.text) + assert response.status_code == expect, "expected status %s, got %s (%s) for url=%s as auth=%s" % (expect, response.status_code, response.status_text, url, auth) + data = json.loads(response.content) return data def get(self, url, expect=200, auth=None): return self._generic_rest(url, data=None, expect=expect, auth=auth, method='get') - def post(self, url, expect=200, auth=None): + def post(self, url, expect=204, auth=None): return self._generic_rest(url, data=None, expect=expect, auth=auth, method='post') def put(self, url, expect=200, auth=None): return self._generic_rest(url, data=None, expect=expect, auth=auth, method='put') - def delete(self, url, expect=200, auth=None): + def delete(self, url, expect=201, auth=None): return self._generic_rest(url, data=None, expect=expect, auth=auth, method='delete') class OrganizationsTest(BaseTest): @@ -127,7 +130,9 @@ class OrganizationsTest(BaseTest): self.get(self.collection(), expect=401, auth=self.get_invalid_credentials()) # superuser credentials == 200, full list - #resp = self.api_client.get(self.collection(), format='json', authentication=self.get_super_credentials()) + data = self.get(self.collection(), expect=200, auth=self.get_super_credentials()) + self.check_pagination_and_size(data, 10, previous=None, next=None) + #self.assertValidJSONResponse(resp) #self.assertEqual(len(self.deserialize(resp)['objects']), 10) # check member data diff --git a/lib/main/views.py b/lib/main/views.py index 4008a31557..feb414d875 100644 --- a/lib/main/views.py +++ b/lib/main/views.py @@ -4,6 +4,7 @@ from django.views.decorators.csrf import csrf_exempt #from rest_framework.parsers import JSONParser from lib.main.models import * from lib.main.serializers import * +from django.contrib.auth.models import AnonymousUser from rest_framework import mixins from rest_framework import generics @@ -16,13 +17,19 @@ from rest_framework import permissions class CustomRbac(permissions.BasePermission): - def has_object_permission(self, request, view, obj): + def has_permission(self, request, view, obj=None): - if request.method in permissions.SAFE_METHODS: # GET, HEAD, OPTIONS + if type(request.user) == AnonymousUser: + return False + + #if getattr(request, 'user') is None: + # return False + + if obj is None: return True - # Write permissions are only allowed to the owner of the snippet - return obj.owner == request.user + return True # obj.owner == request.user + class OrganizationsList(generics.ListCreateAPIView): @@ -31,6 +38,8 @@ class OrganizationsList(generics.ListCreateAPIView): model = Organization serializer_class = OrganizationSerializer + #authentication_classes = (SessionAuthentication, BasicAuthentication) + #permission_classes = (IsAuthenticated,) permission_classes = (CustomRbac,) diff --git a/lib/settings/defaults.py b/lib/settings/defaults.py index ead8366758..f8af949a81 100644 --- a/lib/settings/defaults.py +++ b/lib/settings/defaults.py @@ -24,7 +24,11 @@ MANAGERS = ADMINS REST_FRAMEWORK = { 'PAGINATE_BY': 10, - 'PAGINATE_BY_PARAM': 'page_size' + 'PAGINATE_BY_PARAM': 'page_size', + 'DEFAULT_AUTHENTICATION_CLASSES': ( + 'rest_framework.authentication.BasicAuthentication', + 'rest_framework.authentication.SessionAuthentication', + ) } DATABASES = {