From 381b47201b3189067a49797d059912e53cdd3bea Mon Sep 17 00:00:00 2001 From: Aaron Tan Date: Mon, 24 Apr 2017 16:47:17 -0400 Subject: [PATCH] Allow exception view to accept all valid HTTP methods. --- awx/main/tests/unit/test_views.py | 37 +++++++++++++++++++++++++++++++ awx/main/views.py | 18 ++++++++++++--- 2 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 awx/main/tests/unit/test_views.py diff --git a/awx/main/tests/unit/test_views.py b/awx/main/tests/unit/test_views.py new file mode 100644 index 0000000000..2204635eb6 --- /dev/null +++ b/awx/main/tests/unit/test_views.py @@ -0,0 +1,37 @@ +import pytest +import mock + +# Django REST Framework +from rest_framework import exceptions + +# AWX +from awx.main.views import ApiErrorView + + +HTTP_METHOD_NAMES = [ + 'get', + 'post', + 'put', + 'patch', + 'delete', + 'head', + 'options', + 'trace', +] + + +@pytest.fixture +def api_view_obj_fixture(): + return ApiErrorView() + + +@pytest.mark.parametrize('method_name', HTTP_METHOD_NAMES) +def test_exception_view_allow_http_methods(method_name): + assert hasattr(ApiErrorView, method_name) + + +@pytest.mark.parametrize('method_name', HTTP_METHOD_NAMES) +def test_exception_view_raises_exception(api_view_obj_fixture, method_name): + request_mock = mock.MagicMock() + with pytest.raises(exceptions.APIException): + getattr(api_view_obj_fixture, method_name)(request_mock) diff --git a/awx/main/views.py b/awx/main/views.py index f476f81cfd..c8bcfa304f 100644 --- a/awx/main/views.py +++ b/awx/main/views.py @@ -10,20 +10,32 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import exceptions, permissions, views +def _force_raising_exception(view_obj, request, format=None): + raise view_obj.exception_class() + + class ApiErrorView(views.APIView): authentication_classes = [] permission_classes = (permissions.AllowAny,) metadata_class = None - allowed_methods = ('GET', 'HEAD') exception_class = exceptions.APIException view_name = _('API Error') def get_view_name(self): return self.view_name - def get(self, request, format=None): - raise self.exception_class() + def finalize_response(self, request, response, *args, **kwargs): + response = super(ApiErrorView, self).finalize_response(request, response, *args, **kwargs) + try: + del response['Allow'] + except Exception: + pass + return response + + +for method_name in ApiErrorView.http_method_names: + setattr(ApiErrorView, method_name, _force_raising_exception) def handle_error(request, status=404, **kwargs):