mirror of
https://github.com/ansible/awx.git
synced 2026-02-25 15:06:02 -03:30
Add functions for checking size of paginated results
This commit is contained in:
@@ -34,6 +34,11 @@ class BaseTest(django.test.TestCase):
|
|||||||
results.append(Organization.objects.create(name="org%s" % x, description="org%s" % x))
|
results.append(Organization.objects.create(name="org%s" % x, description="org%s" % x))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def check_pagination_and_size(self, data, desired_count, previous=None, next=None):
|
||||||
|
self.assertEquals(data['count'], desired_count)
|
||||||
|
self.assertEquals(data['previous'], previous)
|
||||||
|
self.assertEquals(data['next'], next)
|
||||||
|
|
||||||
def setup_users(self):
|
def setup_users(self):
|
||||||
# Create a user.
|
# Create a user.
|
||||||
self.super_username = 'admin'
|
self.super_username = 'admin'
|
||||||
@@ -48,13 +53,13 @@ class BaseTest(django.test.TestCase):
|
|||||||
(self.other_django_user, self.other_acom_user) = self.make_user(self.other_username, self.other_password, super_user=False)
|
(self.other_django_user, self.other_acom_user) = self.make_user(self.other_username, self.other_password, super_user=False)
|
||||||
|
|
||||||
def get_super_credentials(self):
|
def get_super_credentials(self):
|
||||||
return self.create_basic(self.super_username, self.super_password)
|
return (self.super_username, self.super_password)
|
||||||
|
|
||||||
def get_normal_credentials(self):
|
def get_normal_credentials(self):
|
||||||
return self.create_basic(self.normal_username, self.normal_password)
|
return (self.normal_username, self.normal_password)
|
||||||
|
|
||||||
def get_other_credentials(self):
|
def get_other_credentials(self):
|
||||||
return self.create_basic(self.other_username, self.other_password)
|
return (self.other_username, self.other_password)
|
||||||
|
|
||||||
def get_invalid_credentials(self):
|
def get_invalid_credentials(self):
|
||||||
return ('random', 'combination')
|
return ('random', 'combination')
|
||||||
@@ -63,11 +68,9 @@ class BaseTest(django.test.TestCase):
|
|||||||
assert method is not None
|
assert method is not None
|
||||||
if method != 'get':
|
if method != 'get':
|
||||||
assert data is not None
|
assert data is not None
|
||||||
client = None
|
client = Client()
|
||||||
if auth:
|
if auth:
|
||||||
client = Client(username=auth[0], password=auth[1])
|
client.login(username=auth[0], password=auth[1])
|
||||||
else:
|
|
||||||
client = Client()
|
|
||||||
method = getattr(client,method)
|
method = getattr(client,method)
|
||||||
response = None
|
response = None
|
||||||
if data is not None:
|
if data is not None:
|
||||||
@@ -75,20 +78,20 @@ class BaseTest(django.test.TestCase):
|
|||||||
else:
|
else:
|
||||||
response = method(url)
|
response = method(url)
|
||||||
if expect is not None:
|
if expect is not None:
|
||||||
assert response.status_code == expect, "expected %s got %s" % (expect, response.status_code)
|
assert response.status_code == expect, "expected status %s, got %s (%s) for url=%s as auth=%s" % (expect, response.status_code, response.status_text, url, auth)
|
||||||
data = json.loads(response.text)
|
data = json.loads(response.content)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def get(self, url, expect=200, auth=None):
|
def get(self, url, expect=200, auth=None):
|
||||||
return self._generic_rest(url, data=None, expect=expect, auth=auth, method='get')
|
return self._generic_rest(url, data=None, expect=expect, auth=auth, method='get')
|
||||||
|
|
||||||
def post(self, url, expect=200, auth=None):
|
def post(self, url, expect=204, auth=None):
|
||||||
return self._generic_rest(url, data=None, expect=expect, auth=auth, method='post')
|
return self._generic_rest(url, data=None, expect=expect, auth=auth, method='post')
|
||||||
|
|
||||||
def put(self, url, expect=200, auth=None):
|
def put(self, url, expect=200, auth=None):
|
||||||
return self._generic_rest(url, data=None, expect=expect, auth=auth, method='put')
|
return self._generic_rest(url, data=None, expect=expect, auth=auth, method='put')
|
||||||
|
|
||||||
def delete(self, url, expect=200, auth=None):
|
def delete(self, url, expect=201, auth=None):
|
||||||
return self._generic_rest(url, data=None, expect=expect, auth=auth, method='delete')
|
return self._generic_rest(url, data=None, expect=expect, auth=auth, method='delete')
|
||||||
|
|
||||||
class OrganizationsTest(BaseTest):
|
class OrganizationsTest(BaseTest):
|
||||||
@@ -127,7 +130,9 @@ class OrganizationsTest(BaseTest):
|
|||||||
self.get(self.collection(), expect=401, auth=self.get_invalid_credentials())
|
self.get(self.collection(), expect=401, auth=self.get_invalid_credentials())
|
||||||
|
|
||||||
# superuser credentials == 200, full list
|
# superuser credentials == 200, full list
|
||||||
#resp = self.api_client.get(self.collection(), format='json', authentication=self.get_super_credentials())
|
data = self.get(self.collection(), expect=200, auth=self.get_super_credentials())
|
||||||
|
self.check_pagination_and_size(data, 10, previous=None, next=None)
|
||||||
|
|
||||||
#self.assertValidJSONResponse(resp)
|
#self.assertValidJSONResponse(resp)
|
||||||
#self.assertEqual(len(self.deserialize(resp)['objects']), 10)
|
#self.assertEqual(len(self.deserialize(resp)['objects']), 10)
|
||||||
# check member data
|
# check member data
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from django.views.decorators.csrf import csrf_exempt
|
|||||||
#from rest_framework.parsers import JSONParser
|
#from rest_framework.parsers import JSONParser
|
||||||
from lib.main.models import *
|
from lib.main.models import *
|
||||||
from lib.main.serializers import *
|
from lib.main.serializers import *
|
||||||
|
from django.contrib.auth.models import AnonymousUser
|
||||||
|
|
||||||
from rest_framework import mixins
|
from rest_framework import mixins
|
||||||
from rest_framework import generics
|
from rest_framework import generics
|
||||||
@@ -16,13 +17,19 @@ from rest_framework import permissions
|
|||||||
|
|
||||||
class CustomRbac(permissions.BasePermission):
|
class CustomRbac(permissions.BasePermission):
|
||||||
|
|
||||||
def has_object_permission(self, request, view, obj):
|
def has_permission(self, request, view, obj=None):
|
||||||
|
|
||||||
if request.method in permissions.SAFE_METHODS: # GET, HEAD, OPTIONS
|
if type(request.user) == AnonymousUser:
|
||||||
|
return False
|
||||||
|
|
||||||
|
#if getattr(request, 'user') is None:
|
||||||
|
# return False
|
||||||
|
|
||||||
|
if obj is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Write permissions are only allowed to the owner of the snippet
|
return True # obj.owner == request.user
|
||||||
return obj.owner == request.user
|
|
||||||
|
|
||||||
|
|
||||||
class OrganizationsList(generics.ListCreateAPIView):
|
class OrganizationsList(generics.ListCreateAPIView):
|
||||||
@@ -31,6 +38,8 @@ class OrganizationsList(generics.ListCreateAPIView):
|
|||||||
|
|
||||||
model = Organization
|
model = Organization
|
||||||
serializer_class = OrganizationSerializer
|
serializer_class = OrganizationSerializer
|
||||||
|
#authentication_classes = (SessionAuthentication, BasicAuthentication)
|
||||||
|
#permission_classes = (IsAuthenticated,)
|
||||||
|
|
||||||
permission_classes = (CustomRbac,)
|
permission_classes = (CustomRbac,)
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,11 @@ MANAGERS = ADMINS
|
|||||||
|
|
||||||
REST_FRAMEWORK = {
|
REST_FRAMEWORK = {
|
||||||
'PAGINATE_BY': 10,
|
'PAGINATE_BY': 10,
|
||||||
'PAGINATE_BY_PARAM': 'page_size'
|
'PAGINATE_BY_PARAM': 'page_size',
|
||||||
|
'DEFAULT_AUTHENTICATION_CLASSES': (
|
||||||
|
'rest_framework.authentication.BasicAuthentication',
|
||||||
|
'rest_framework.authentication.SessionAuthentication',
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
DATABASES = {
|
DATABASES = {
|
||||||
|
|||||||
Reference in New Issue
Block a user