From 5768f544ec45d1620fa23b79ef89c6e6a4788d4d Mon Sep 17 00:00:00 2001 From: Chris Church Date: Sun, 29 Sep 2013 21:37:09 -0400 Subject: [PATCH] Fixes to get tests to pass after updating vendored packages. --- awx/main/base_views.py | 37 +++++++++++++++++++++++++++++++ awx/main/tests/projects.py | 21 ++++++++++++------ awx/main/urls.py | 45 -------------------------------------- awx/settings/defaults.py | 2 ++ 4 files changed, 53 insertions(+), 52 deletions(-) diff --git a/awx/main/base_views.py b/awx/main/base_views.py index 0ac7b3e9c1..2a45fa4328 100644 --- a/awx/main/base_views.py +++ b/awx/main/base_views.py @@ -10,6 +10,7 @@ from django.http import HttpResponse, Http404 from django.contrib.auth.models import User from django.shortcuts import get_object_or_404 from django.template.loader import render_to_string +from django.utils.safestring import mark_safe from django.utils.timezone import now # Django REST Framework @@ -30,6 +31,42 @@ __all__ = ['APIView', 'GenericAPIView', 'ListAPIView', 'ListCreateAPIView', 'SubListAPIView', 'SubListCreateAPIView', 'RetrieveAPIView', 'RetrieveUpdateAPIView', 'RetrieveUpdateDestroyAPIView'] +def get_view_name(cls, suffix=None): + ''' + Wrapper around REST framework get_view_name() to support get_name() method + and view_name property on a view class. + ''' + name = '' + if hasattr(cls, 'get_name') and callable(cls.get_name): + name = cls().get_name() + elif hasattr(cls, 'view_name'): + if callable(cls.view_name): + name = cls.view_name() + else: + name = cls.view_name + if name: + return ('%s %s' % (name, suffix)) if suffix else name + return views.get_view_name(cls, suffix=None) + +def get_view_description(cls, html=False): + ''' + Wrapper around REST framework get_view_description() to support + get_description() method and view_description property on a view class. + ''' + if hasattr(cls, 'get_description') and callable(cls.get_description): + desc = cls().get_description(html=html) + cls = type(cls.__name__, (object,), {'__doc__': desc}) + elif hasattr(cls, 'view_description'): + if callable(cls.view_description): + view_desc = cls.view_description() + else: + view_desc = cls.view_description + cls = type(cls.__name__, (object,), {'__doc__': view_desc}) + desc = views.get_view_description(cls, html=html) + if html: + desc = '
%s
' % desc + return mark_safe(desc) + class APIView(views.APIView): def get_authenticate_header(self, request): diff --git a/awx/main/tests/projects.py b/awx/main/tests/projects.py index d2ea7f42b6..0d65bd9c05 100644 --- a/awx/main/tests/projects.py +++ b/awx/main/tests/projects.py @@ -850,11 +850,15 @@ class ProjectUpdatesTest(BaseTransactionTest): def check_project_update(self, project, should_fail=False, **kwargs): pu = kwargs.pop('project_update', None) + should_error = kwargs.pop('should_error', False) if not pu: pu = project.update(**kwargs) self.assertTrue(pu) pu = ProjectUpdate.objects.get(pk=pu.pk) - if should_fail: + if should_error: + self.assertEqual(pu.status, 'error', + pu.result_stdout + pu.result_traceback) + elif should_fail: self.assertEqual(pu.status, 'failed', pu.result_stdout + pu.result_traceback) elif should_fail is False: @@ -864,11 +868,12 @@ class ProjectUpdatesTest(BaseTransactionTest): pass # If should_fail is None, we don't care. # Get the SCM URL from the job args, if it starts with a '/' we aren't # handling the URL correctly. - scm_url_in_args_re = re.compile(r'\\(?:\\\\)??"scm_url\\(?:\\\\)??": \\(?:\\\\)??"(.*?)\\(?:\\\\)??"') - match = scm_url_in_args_re.search(pu.job_args) - self.assertTrue(match, pu.job_args) - scm_url_in_args = match.groups()[0] - self.assertFalse(scm_url_in_args.startswith('/'), scm_url_in_args) + if not should_error: + scm_url_in_args_re = re.compile(r'\\(?:\\\\)??"scm_url\\(?:\\\\)??": \\(?:\\\\)??"(.*?)\\(?:\\\\)??"') + match = scm_url_in_args_re.search(pu.job_args) + self.assertTrue(match, pu.job_args) + scm_url_in_args = match.groups()[0] + self.assertFalse(scm_url_in_args.startswith('/'), scm_url_in_args) #return pu # Make sure scm_password doesn't show up anywhere in args or output # from project update. @@ -1088,7 +1093,9 @@ class ProjectUpdatesTest(BaseTransactionTest): scm_username=scm_username, scm_password=scm_password, ) - self.check_project_update(project2, should_fail=True) + should_error = bool('github.com' in scm_url and scm_username != 'git') + self.check_project_update(project2, should_fail=True, + should_error=should_error) def create_local_git_repo(self): repo_dir = tempfile.mkdtemp() diff --git a/awx/main/urls.py b/awx/main/urls.py index f58a6830f8..99109b5235 100644 --- a/awx/main/urls.py +++ b/awx/main/urls.py @@ -162,48 +162,3 @@ urlpatterns = patterns('awx.main.views', url(r'^$', 'api_root_view'), url(r'^v1/', include(v1_urls)), ) - -# Monkeypatch get_view_name and get_view_description in Django REST Framework -# 2.3.x to allow a custom view name or description to be defined on the view -# class, instead of always using __name__ and __doc__. Used to be possible in -# 2.2.x by defining get_name() and get_description() methods on a view. - -try: - import rest_framework.utils.formatting - from django.utils.safestring import mark_safe - - original_get_view_name = rest_framework.utils.formatting.get_view_name - def get_view_name(cls, suffix=None): - name = '' - # Support for get_name method on views compatible with 2.2.x. - if hasattr(cls, 'get_name') and callable(cls.get_name): - name = cls().get_name() - elif hasattr(cls, 'view_name'): - if callable(cls.view_name): - name = cls.view_name() - else: - name = cls.view_name - if name: - return ('%s %s' % (name, suffix)) if suffix else name - return original_get_view_name(cls, suffix=None) - rest_framework.utils.formatting.get_view_name = get_view_name - - original_get_view_description = rest_framework.utils.formatting.get_view_description - def get_view_description(cls, html=False): - # Support for get_description method on views compatible with 2.2.x. - if hasattr(cls, 'get_description') and callable(cls.get_description): - desc = cls().get_description(html=html) - cls = type(cls.__name__, (object,), {'__doc__': desc}) - elif hasattr(cls, 'view_description'): - if callable(cls.view_description): - view_desc = cls.view_description() - else: - view_desc = cls.view_description - cls = type(cls.__name__, (object,), {'__doc__': view_desc}) - desc = original_get_view_description(cls, html=html) - if html: - desc = '
%s
' % desc - return mark_safe(desc) - rest_framework.utils.formatting.get_view_description = get_view_description -except ImportError: - pass diff --git a/awx/settings/defaults.py b/awx/settings/defaults.py index 95f680466a..061dd21776 100644 --- a/awx/settings/defaults.py +++ b/awx/settings/defaults.py @@ -161,6 +161,8 @@ REST_FRAMEWORK = { 'rest_framework.renderers.JSONRenderer', 'awx.main.renderers.BrowsableAPIRenderer', ), + 'VIEW_NAME_FUNCTION': 'awx.main.base_views.get_view_name', + 'VIEW_DESCRIPTION_FUNCTION': 'awx.main.base_views.get_view_description', } AUTHENTICATION_BACKENDS = (