Fixes to get tests to pass after updating vendored packages.

This commit is contained in:
Chris Church 2013-09-29 21:37:09 -04:00
parent e0a94cbf32
commit 5768f544ec
4 changed files with 53 additions and 52 deletions

View File

@ -10,6 +10,7 @@ from django.http import HttpResponse, Http404
from django.contrib.auth.models import User
from django.shortcuts import get_object_or_404
from django.template.loader import render_to_string
from django.utils.safestring import mark_safe
from django.utils.timezone import now
# Django REST Framework
@ -30,6 +31,42 @@ __all__ = ['APIView', 'GenericAPIView', 'ListAPIView', 'ListCreateAPIView',
'SubListAPIView', 'SubListCreateAPIView', 'RetrieveAPIView',
'RetrieveUpdateAPIView', 'RetrieveUpdateDestroyAPIView']
def get_view_name(cls, suffix=None):
'''
Wrapper around REST framework get_view_name() to support get_name() method
and view_name property on a view class.
'''
name = ''
if hasattr(cls, 'get_name') and callable(cls.get_name):
name = cls().get_name()
elif hasattr(cls, 'view_name'):
if callable(cls.view_name):
name = cls.view_name()
else:
name = cls.view_name
if name:
return ('%s %s' % (name, suffix)) if suffix else name
return views.get_view_name(cls, suffix=None)
def get_view_description(cls, html=False):
'''
Wrapper around REST framework get_view_description() to support
get_description() method and view_description property on a view class.
'''
if hasattr(cls, 'get_description') and callable(cls.get_description):
desc = cls().get_description(html=html)
cls = type(cls.__name__, (object,), {'__doc__': desc})
elif hasattr(cls, 'view_description'):
if callable(cls.view_description):
view_desc = cls.view_description()
else:
view_desc = cls.view_description
cls = type(cls.__name__, (object,), {'__doc__': view_desc})
desc = views.get_view_description(cls, html=html)
if html:
desc = '<div class="description">%s</div>' % desc
return mark_safe(desc)
class APIView(views.APIView):
def get_authenticate_header(self, request):

View File

@ -850,11 +850,15 @@ class ProjectUpdatesTest(BaseTransactionTest):
def check_project_update(self, project, should_fail=False, **kwargs):
pu = kwargs.pop('project_update', None)
should_error = kwargs.pop('should_error', False)
if not pu:
pu = project.update(**kwargs)
self.assertTrue(pu)
pu = ProjectUpdate.objects.get(pk=pu.pk)
if should_fail:
if should_error:
self.assertEqual(pu.status, 'error',
pu.result_stdout + pu.result_traceback)
elif should_fail:
self.assertEqual(pu.status, 'failed',
pu.result_stdout + pu.result_traceback)
elif should_fail is False:
@ -864,11 +868,12 @@ class ProjectUpdatesTest(BaseTransactionTest):
pass # If should_fail is None, we don't care.
# Get the SCM URL from the job args, if it starts with a '/' we aren't
# handling the URL correctly.
scm_url_in_args_re = re.compile(r'\\(?:\\\\)??"scm_url\\(?:\\\\)??": \\(?:\\\\)??"(.*?)\\(?:\\\\)??"')
match = scm_url_in_args_re.search(pu.job_args)
self.assertTrue(match, pu.job_args)
scm_url_in_args = match.groups()[0]
self.assertFalse(scm_url_in_args.startswith('/'), scm_url_in_args)
if not should_error:
scm_url_in_args_re = re.compile(r'\\(?:\\\\)??"scm_url\\(?:\\\\)??": \\(?:\\\\)??"(.*?)\\(?:\\\\)??"')
match = scm_url_in_args_re.search(pu.job_args)
self.assertTrue(match, pu.job_args)
scm_url_in_args = match.groups()[0]
self.assertFalse(scm_url_in_args.startswith('/'), scm_url_in_args)
#return pu
# Make sure scm_password doesn't show up anywhere in args or output
# from project update.
@ -1088,7 +1093,9 @@ class ProjectUpdatesTest(BaseTransactionTest):
scm_username=scm_username,
scm_password=scm_password,
)
self.check_project_update(project2, should_fail=True)
should_error = bool('github.com' in scm_url and scm_username != 'git')
self.check_project_update(project2, should_fail=True,
should_error=should_error)
def create_local_git_repo(self):
repo_dir = tempfile.mkdtemp()

View File

@ -162,48 +162,3 @@ urlpatterns = patterns('awx.main.views',
url(r'^$', 'api_root_view'),
url(r'^v1/', include(v1_urls)),
)
# Monkeypatch get_view_name and get_view_description in Django REST Framework
# 2.3.x to allow a custom view name or description to be defined on the view
# class, instead of always using __name__ and __doc__. Used to be possible in
# 2.2.x by defining get_name() and get_description() methods on a view.
try:
import rest_framework.utils.formatting
from django.utils.safestring import mark_safe
original_get_view_name = rest_framework.utils.formatting.get_view_name
def get_view_name(cls, suffix=None):
name = ''
# Support for get_name method on views compatible with 2.2.x.
if hasattr(cls, 'get_name') and callable(cls.get_name):
name = cls().get_name()
elif hasattr(cls, 'view_name'):
if callable(cls.view_name):
name = cls.view_name()
else:
name = cls.view_name
if name:
return ('%s %s' % (name, suffix)) if suffix else name
return original_get_view_name(cls, suffix=None)
rest_framework.utils.formatting.get_view_name = get_view_name
original_get_view_description = rest_framework.utils.formatting.get_view_description
def get_view_description(cls, html=False):
# Support for get_description method on views compatible with 2.2.x.
if hasattr(cls, 'get_description') and callable(cls.get_description):
desc = cls().get_description(html=html)
cls = type(cls.__name__, (object,), {'__doc__': desc})
elif hasattr(cls, 'view_description'):
if callable(cls.view_description):
view_desc = cls.view_description()
else:
view_desc = cls.view_description
cls = type(cls.__name__, (object,), {'__doc__': view_desc})
desc = original_get_view_description(cls, html=html)
if html:
desc = '<div class="description">%s</div>' % desc
return mark_safe(desc)
rest_framework.utils.formatting.get_view_description = get_view_description
except ImportError:
pass

View File

@ -161,6 +161,8 @@ REST_FRAMEWORK = {
'rest_framework.renderers.JSONRenderer',
'awx.main.renderers.BrowsableAPIRenderer',
),
'VIEW_NAME_FUNCTION': 'awx.main.base_views.get_view_name',
'VIEW_DESCRIPTION_FUNCTION': 'awx.main.base_views.get_view_description',
}
AUTHENTICATION_BACKENDS = (