diff --git a/awx/api/pagination.py b/awx/api/pagination.py index dfb57699c3..9a416e9995 100644 --- a/awx/api/pagination.py +++ b/awx/api/pagination.py @@ -1,78 +1,16 @@ # Copyright (c) 2015 Ansible, Inc. # All Rights Reserved. -# Django -from django.core.paginator import Paginator as DjangoPaginator -from django.core.paginator import PageNotAnInteger, EmptyPage -from django.db import connections - # Django REST Framework from django.conf import settings from rest_framework import pagination from rest_framework.utils.urls import replace_query_param -class Paginator(DjangoPaginator): - - def __init__(self, object_list, per_page, orphans=0, allow_empty_first_page=True): - self.count_field = None - # Based on http://stackoverflow.com/questions/156114/best-way-to-get-result-count-before-limit-was-applied - # With PostgreSQL, we can use a window function to include the total - # count of results (before limit and offset are applied) as an extra - # column and avoid having to issue a separate COUNT(*) query. - if hasattr(object_list, 'extra'): - if connections[getattr(object_list, 'db', None) or 'default'].vendor == 'postgresql': - object_list = object_list.extra(select=dict(__count='COUNT(*) OVER()')) - self.count_field = '__count' - super(Paginator, self).__init__(object_list, per_page, orphans, allow_empty_first_page) - assert self.orphans == 0 - - def validate_number(self, number, check_num_pages=True): - """ - Validates the given 1-based page number. - """ - try: - number = int(number) - except (TypeError, ValueError): - raise PageNotAnInteger('That page number is not an integer') - if number < 1: - raise EmptyPage('That page number is less than 1') - # Optionally skip checking num_pages, since that will result in a - # COUNT(*) query. - if check_num_pages and number > self.num_pages: - if number == 1 and self.allow_empty_first_page: - pass - else: - raise EmptyPage('That page contains no results') - return number - - def page(self, number): - """ - Returns a Page object for the given 1-based page number. - """ - number = self.validate_number(number, check_num_pages=bool(self.count_field is None)) - bottom = (number - 1) * self.per_page - top = bottom + self.per_page - sub_list = self.object_list[bottom:top] - if self.count_field and self._count is None: - # Execute one query to fetch all results. - sub_list = list(sub_list) - try: - # Get the total count from the first result. - self._count = getattr(sub_list[0], self.count_field) - except IndexError: - # If no results were returned, we still don't know the total - # count, but do know that we've reached an empty page. - if number > 1 or not self.allow_empty_first_page: - raise EmptyPage('That page contains no results') - return self._get_page(sub_list, number, self) - - class Pagination(pagination.PageNumberPagination): page_size_query_param = 'page_size' max_page_size = settings.MAX_PAGE_SIZE - django_paginator_class = Paginator def get_next_link(self): if not self.page.has_next(): diff --git a/awx/main/tests/functional/api/test_pagination.py b/awx/main/tests/functional/api/test_pagination.py new file mode 100644 index 0000000000..aed0f2e034 --- /dev/null +++ b/awx/main/tests/functional/api/test_pagination.py @@ -0,0 +1,40 @@ +import pytest + +from awx.main.models.inventory import Group, Host +from awx.api.pagination import Pagination + + +@pytest.fixture +def host(inventory): + def handler(name, groups): + h = Host(name=name, inventory=inventory) + h.save() + h = Host.objects.get(name=name, inventory=inventory) + for g in groups: + h.groups.add(g) + h.save() + h = Host.objects.get(name=name, inventory=inventory) + return h + return handler + + +@pytest.fixture +def group(inventory): + def handler(name): + g = Group(name=name, inventory=inventory) + g.save() + g = Group.objects.get(name=name, inventory=inventory) + return g + return handler + + +@pytest.mark.django_db +def test_pagination_backend_output_correct_total_count(group, host): + # NOTE: this test might not be db-backend-agnostic. Manual tests might be needed also + g1 = group('pg_group1') + g2 = group('pg_group2') + host('pg_host1', [g1, g2]) + queryset = Host.objects.filter(groups__name__in=('pg_group1', 'pg_group2')).distinct() + p = Pagination().django_paginator_class(queryset, 10) + p.page(1) + assert p.count == 1