Merge branch 'release_3.3.0' into devel

This commit is contained in:
Matthew Jones
2018-05-17 16:07:47 -04:00
321 changed files with 8920 additions and 6224 deletions

View File

@@ -296,7 +296,7 @@ uwsgi: collectstatic
@if [ "$(VENV_BASE)" ]; then \ @if [ "$(VENV_BASE)" ]; then \
. $(VENV_BASE)/awx/bin/activate; \ . $(VENV_BASE)/awx/bin/activate; \
fi; \ fi; \
uwsgi -b 32768 --socket 127.0.0.1:8050 --module=awx.wsgi:application --home=/venv/awx --chdir=/awx_devel/ --vacuum --processes=5 --harakiri=120 --master --no-orphans --py-autoreload 1 --max-requests=1000 --stats /tmp/stats.socket --master-fifo=/awxfifo --lazy-apps --logformat "%(addr) %(method) %(uri) - %(proto) %(status)" --hook-accepting1-once="exec:/bin/sh -c '[ -f /tmp/celery_pid ] && kill -1 `cat /tmp/celery_pid` || true'" uwsgi -b 32768 --socket 127.0.0.1:8050 --module=awx.wsgi:application --home=/venv/awx --chdir=/awx_devel/ --vacuum --processes=5 --harakiri=120 --master --no-orphans --py-autoreload 1 --max-requests=1000 --stats /tmp/stats.socket --lazy-apps --logformat "%(addr) %(method) %(uri) - %(proto) %(status)" --hook-accepting1-once="exec:/bin/sh -c '[ -f /tmp/celery_pid ] && kill -1 `cat /tmp/celery_pid` || true'"
daphne: daphne:
@if [ "$(VENV_BASE)" ]; then \ @if [ "$(VENV_BASE)" ]; then \
@@ -372,13 +372,14 @@ awx-link:
sed -i "s/placeholder/$(shell git describe --long | sed 's/\./\\./g')/" /awx_devel/awx.egg-info/PKG-INFO sed -i "s/placeholder/$(shell git describe --long | sed 's/\./\\./g')/" /awx_devel/awx.egg-info/PKG-INFO
cp /tmp/awx.egg-link /venv/awx/lib/python2.7/site-packages/awx.egg-link cp /tmp/awx.egg-link /venv/awx/lib/python2.7/site-packages/awx.egg-link
TEST_DIRS ?= awx/main/tests/unit awx/main/tests/functional awx/conf/tests awx/sso/tests TEST_DIRS ?= awx/main/tests/unit awx/main/tests/functional awx/conf/tests awx/sso/tests awx/network_ui/tests/unit
# Run all API unit tests. # Run all API unit tests.
test: test:
@if [ "$(VENV_BASE)" ]; then \ @if [ "$(VENV_BASE)" ]; then \
. $(VENV_BASE)/awx/bin/activate; \ . $(VENV_BASE)/awx/bin/activate; \
fi; \ fi; \
py.test $(TEST_DIRS) py.test -n auto $(TEST_DIRS)
test_combined: test_ansible test test_combined: test_ansible test
@@ -386,7 +387,7 @@ test_unit:
@if [ "$(VENV_BASE)" ]; then \ @if [ "$(VENV_BASE)" ]; then \
. $(VENV_BASE)/awx/bin/activate; \ . $(VENV_BASE)/awx/bin/activate; \
fi; \ fi; \
py.test awx/main/tests/unit awx/conf/tests/unit awx/sso/tests/unit py.test awx/main/tests/unit awx/conf/tests/unit awx/sso/tests/unit awx/network_ui/tests/unit
test_ansible: test_ansible:
@if [ "$(VENV_BASE)" ]; then \ @if [ "$(VENV_BASE)" ]; then \

View File

@@ -4,6 +4,7 @@ from django.utils.translation import ugettext_lazy as _
# AWX # AWX
from awx.conf import fields, register from awx.conf import fields, register
from awx.api.fields import OAuth2ProviderField from awx.api.fields import OAuth2ProviderField
from oauth2_provider.settings import oauth2_settings
register( register(
@@ -36,7 +37,7 @@ register(
register( register(
'OAUTH2_PROVIDER', 'OAUTH2_PROVIDER',
field_class=OAuth2ProviderField, field_class=OAuth2ProviderField,
default={'ACCESS_TOKEN_EXPIRE_SECONDS': 315360000000, default={'ACCESS_TOKEN_EXPIRE_SECONDS': oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS,
'AUTHORIZATION_CODE_EXPIRE_SECONDS': 600}, 'AUTHORIZATION_CODE_EXPIRE_SECONDS': 600},
label=_('OAuth 2 Timeout Settings'), label=_('OAuth 2 Timeout Settings'),
help_text=_('Dictionary for customizing OAuth 2 timeouts, available items are ' help_text=_('Dictionary for customizing OAuth 2 timeouts, available items are '

View File

@@ -77,6 +77,63 @@ class TypeFilterBackend(BaseFilterBackend):
raise ParseError(*e.args) raise ParseError(*e.args)
def get_field_from_path(model, path):
'''
Given a Django ORM lookup path (possibly over multiple models)
Returns the last field in the line, and also the revised lookup path
ex., given
model=Organization
path='project__timeout'
returns tuple of field at the end of the line as well as a corrected
path, for special cases we do substitutions
(<IntegerField for timeout>, 'project__timeout')
'''
# Store of all the fields used to detect repeats
field_set = set([])
new_parts = []
for name in path.split('__'):
if model is None:
raise ParseError(_('No related model for field {}.').format(name))
# HACK: Make project and inventory source filtering by old field names work for backwards compatibility.
if model._meta.object_name in ('Project', 'InventorySource'):
name = {
'current_update': 'current_job',
'last_update': 'last_job',
'last_update_failed': 'last_job_failed',
'last_updated': 'last_job_run',
}.get(name, name)
if name == 'type' and 'polymorphic_ctype' in get_all_field_names(model):
name = 'polymorphic_ctype'
new_parts.append('polymorphic_ctype__model')
else:
new_parts.append(name)
if name in getattr(model, 'PASSWORD_FIELDS', ()):
raise PermissionDenied(_('Filtering on password fields is not allowed.'))
elif name == 'pk':
field = model._meta.pk
else:
name_alt = name.replace("_", "")
if name_alt in model._meta.fields_map.keys():
field = model._meta.fields_map[name_alt]
new_parts.pop()
new_parts.append(name_alt)
else:
field = model._meta.get_field(name)
if isinstance(field, ForeignObjectRel) and getattr(field.field, '__prevent_search__', False):
raise PermissionDenied(_('Filtering on %s is not allowed.' % name))
elif getattr(field, '__prevent_search__', False):
raise PermissionDenied(_('Filtering on %s is not allowed.' % name))
if field in field_set:
# Field traversed twice, could create infinite JOINs, DoSing Tower
raise ParseError(_('Loops not allowed in filters, detected on field {}.').format(field.name))
field_set.add(field)
model = getattr(field, 'related_model', None)
return field, '__'.join(new_parts)
class FieldLookupBackend(BaseFilterBackend): class FieldLookupBackend(BaseFilterBackend):
''' '''
Filter using field lookups provided via query string parameters. Filter using field lookups provided via query string parameters.
@@ -91,61 +148,23 @@ class FieldLookupBackend(BaseFilterBackend):
'isnull', 'search') 'isnull', 'search')
def get_field_from_lookup(self, model, lookup): def get_field_from_lookup(self, model, lookup):
field = None
parts = lookup.split('__') if '__' in lookup and lookup.rsplit('__', 1)[-1] in self.SUPPORTED_LOOKUPS:
if parts and parts[-1] not in self.SUPPORTED_LOOKUPS: path, suffix = lookup.rsplit('__', 1)
parts.append('exact') else:
path = lookup
suffix = 'exact'
if not path:
raise ParseError(_('Query string field name not provided.'))
# FIXME: Could build up a list of models used across relationships, use # FIXME: Could build up a list of models used across relationships, use
# those lookups combined with request.user.get_queryset(Model) to make # those lookups combined with request.user.get_queryset(Model) to make
# sure user cannot query using objects he could not view. # sure user cannot query using objects he could not view.
new_parts = [] field, new_path = get_field_from_path(model, path)
# Store of all the fields used to detect repeats new_lookup = new_path
field_set = set([]) new_lookup = '__'.join([new_path, suffix])
for name in parts[:-1]:
# HACK: Make project and inventory source filtering by old field names work for backwards compatibility.
if model._meta.object_name in ('Project', 'InventorySource'):
name = {
'current_update': 'current_job',
'last_update': 'last_job',
'last_update_failed': 'last_job_failed',
'last_updated': 'last_job_run',
}.get(name, name)
if name == 'type' and 'polymorphic_ctype' in get_all_field_names(model):
name = 'polymorphic_ctype'
new_parts.append('polymorphic_ctype__model')
else:
new_parts.append(name)
if name in getattr(model, 'PASSWORD_FIELDS', ()):
raise PermissionDenied(_('Filtering on password fields is not allowed.'))
elif name == 'pk':
field = model._meta.pk
else:
name_alt = name.replace("_", "")
if name_alt in model._meta.fields_map.keys():
field = model._meta.fields_map[name_alt]
new_parts.pop()
new_parts.append(name_alt)
else:
field = model._meta.get_field(name)
if 'auth' in name or 'token' in name:
raise PermissionDenied(_('Filtering on %s is not allowed.' % name))
if isinstance(field, ForeignObjectRel) and getattr(field.field, '__prevent_search__', False):
raise PermissionDenied(_('Filtering on %s is not allowed.' % name))
elif getattr(field, '__prevent_search__', False):
raise PermissionDenied(_('Filtering on %s is not allowed.' % name))
if field in field_set:
# Field traversed twice, could create infinite JOINs, DoSing Tower
raise ParseError(_('Loops not allowed in filters, detected on field {}.').format(field.name))
field_set.add(field)
model = getattr(field, 'related_model', None) or field.model
if parts:
new_parts.append(parts[-1])
new_lookup = '__'.join(new_parts)
return field, new_lookup return field, new_lookup
def to_python_related(self, value): def to_python_related(self, value):
@@ -371,7 +390,7 @@ class OrderByBackend(BaseFilterBackend):
else: else:
order_by = (value,) order_by = (value,)
if order_by: if order_by:
order_by = self._strip_sensitive_model_fields(queryset.model, order_by) order_by = self._validate_ordering_fields(queryset.model, order_by)
# Special handling of the type field for ordering. In this # Special handling of the type field for ordering. In this
# case, we're not sorting exactly on the type field, but # case, we're not sorting exactly on the type field, but
@@ -396,15 +415,17 @@ class OrderByBackend(BaseFilterBackend):
# Return a 400 for invalid field names. # Return a 400 for invalid field names.
raise ParseError(*e.args) raise ParseError(*e.args)
def _strip_sensitive_model_fields(self, model, order_by): def _validate_ordering_fields(self, model, order_by):
for field_name in order_by: for field_name in order_by:
# strip off the negation prefix `-` if it exists # strip off the negation prefix `-` if it exists
_field_name = field_name.split('-')[-1] prefix = ''
path = field_name
if field_name[0] == '-':
prefix = field_name[0]
path = field_name[1:]
try: try:
# if the field name is encrypted/sensitive, don't sort on it field, new_path = get_field_from_path(model, path)
if _field_name in getattr(model, 'PASSWORD_FIELDS', ()) or \ new_path = '{}{}'.format(prefix, new_path)
getattr(model._meta.get_field(_field_name), '__prevent_search__', False): except (FieldError, FieldDoesNotExist) as e:
raise ParseError(_('cannot order by field %s') % _field_name) raise ParseError(e.args[0])
except FieldDoesNotExist: yield new_path
pass
yield field_name

View File

@@ -6,6 +6,7 @@ import inspect
import logging import logging
import time import time
import six import six
import urllib
# Django # Django
from django.conf import settings from django.conf import settings
@@ -29,6 +30,7 @@ from rest_framework.response import Response
from rest_framework import status from rest_framework import status
from rest_framework import views from rest_framework import views
from rest_framework.permissions import AllowAny from rest_framework.permissions import AllowAny
from rest_framework.renderers import JSONRenderer
# cryptography # cryptography
from cryptography.fernet import InvalidToken from cryptography.fernet import InvalidToken
@@ -39,7 +41,7 @@ from awx.main.models import * # noqa
from awx.main.access import access_registry from awx.main.access import access_registry
from awx.main.utils import * # noqa from awx.main.utils import * # noqa
from awx.main.utils.db import get_all_field_names from awx.main.utils.db import get_all_field_names
from awx.api.serializers import ResourceAccessListElementSerializer, CopySerializer from awx.api.serializers import ResourceAccessListElementSerializer, CopySerializer, UserSerializer
from awx.api.versioning import URLPathVersioning, get_request_version from awx.api.versioning import URLPathVersioning, get_request_version
from awx.api.metadata import SublistAttachDetatchMetadata, Metadata from awx.api.metadata import SublistAttachDetatchMetadata, Metadata
@@ -70,6 +72,13 @@ class LoggedLoginView(auth_views.LoginView):
if current_user and getattr(current_user, 'pk', None) and current_user != original_user: if current_user and getattr(current_user, 'pk', None) and current_user != original_user:
logger.info("User {} logged in.".format(current_user.username)) logger.info("User {} logged in.".format(current_user.username))
if request.user.is_authenticated: if request.user.is_authenticated:
logger.info(smart_text(u"User {} logged in".format(self.request.user.username)))
ret.set_cookie('userLoggedIn', 'true')
current_user = UserSerializer(self.request.user)
current_user = JSONRenderer().render(current_user.data)
current_user = urllib.quote('%s' % current_user, '')
ret.set_cookie('current_user', current_user)
return ret return ret
else: else:
ret.status_code = 401 ret.status_code = 401
@@ -82,6 +91,7 @@ class LoggedLogoutView(auth_views.LogoutView):
original_user = getattr(request, 'user', None) original_user = getattr(request, 'user', None)
ret = super(LoggedLogoutView, self).dispatch(request, *args, **kwargs) ret = super(LoggedLogoutView, self).dispatch(request, *args, **kwargs)
current_user = getattr(request, 'user', None) current_user = getattr(request, 'user', None)
ret.set_cookie('userLoggedIn', 'false')
if (not current_user or not getattr(current_user, 'pk', True)) \ if (not current_user or not getattr(current_user, 'pk', True)) \
and current_user != original_user: and current_user != original_user:
logger.info("User {} logged out.".format(original_user.username)) logger.info("User {} logged out.".format(original_user.username))
@@ -868,6 +878,9 @@ class CopyAPIView(GenericAPIView):
obj, field.name, field_val obj, field.name, field_val
) )
new_obj = model.objects.create(**create_kwargs) new_obj = model.objects.create(**create_kwargs)
logger.debug(six.text_type('Deep copy: Created new object {}({})').format(
new_obj, model
))
# Need to save separatedly because Djang-crum get_current_user would # Need to save separatedly because Djang-crum get_current_user would
# not work properly in non-request-response-cycle context. # not work properly in non-request-response-cycle context.
new_obj.created_by = creater new_obj.created_by = creater

View File

@@ -62,15 +62,11 @@ class Metadata(metadata.SimpleMetadata):
opts = serializer.Meta.model._meta.concrete_model._meta opts = serializer.Meta.model._meta.concrete_model._meta
verbose_name = smart_text(opts.verbose_name) verbose_name = smart_text(opts.verbose_name)
field_info['help_text'] = field_help_text[field.field_name].format(verbose_name) field_info['help_text'] = field_help_text[field.field_name].format(verbose_name)
# If field is not part of the model, then show it as non-filterable
else: for model_field in serializer.Meta.model._meta.fields:
is_model_field = False if field.field_name == model_field.name:
for model_field in serializer.Meta.model._meta.fields: field_info['filterable'] = True
if field.field_name == model_field.name: break
is_model_field = True
break
if not is_model_field:
field_info['filterable'] = False
# Indicate if a field has a default value. # Indicate if a field has a default value.
# FIXME: Still isn't showing all default values? # FIXME: Still isn't showing all default values?

View File

@@ -14,7 +14,6 @@ from datetime import timedelta
# OAuth2 # OAuth2
from oauthlib.common import generate_token from oauthlib.common import generate_token
from oauth2_provider.settings import oauth2_settings
# Django # Django
from django.conf import settings from django.conf import settings
@@ -1024,7 +1023,7 @@ class UserAuthorizedTokenSerializer(BaseSerializer):
validated_data['user'] = current_user validated_data['user'] = current_user
validated_data['token'] = generate_token() validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta( validated_data['expires'] = now() + timedelta(
seconds=oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS']
) )
obj = super(OAuth2TokenSerializer, self).create(validated_data) obj = super(OAuth2TokenSerializer, self).create(validated_data)
obj.save() obj.save()
@@ -1176,7 +1175,7 @@ class OAuth2TokenSerializer(BaseSerializer):
validated_data['user'] = current_user validated_data['user'] = current_user
validated_data['token'] = generate_token() validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta( validated_data['expires'] = now() + timedelta(
seconds=oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS']
) )
obj = super(OAuth2TokenSerializer, self).create(validated_data) obj = super(OAuth2TokenSerializer, self).create(validated_data)
if obj.application and obj.application.user: if obj.application and obj.application.user:
@@ -1239,7 +1238,7 @@ class OAuth2AuthorizedTokenSerializer(BaseSerializer):
validated_data['user'] = current_user validated_data['user'] = current_user
validated_data['token'] = generate_token() validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta( validated_data['expires'] = now() + timedelta(
seconds=oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS']
) )
obj = super(OAuth2AuthorizedTokenSerializer, self).create(validated_data) obj = super(OAuth2AuthorizedTokenSerializer, self).create(validated_data)
if obj.application and obj.application.user: if obj.application and obj.application.user:
@@ -1306,7 +1305,7 @@ class OAuth2PersonalTokenSerializer(BaseSerializer):
validated_data['user'] = self.context['request'].user validated_data['user'] = self.context['request'].user
validated_data['token'] = generate_token() validated_data['token'] = generate_token()
validated_data['expires'] = now() + timedelta( validated_data['expires'] = now() + timedelta(
seconds=oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS seconds=settings.OAUTH2_PROVIDER['ACCESS_TOKEN_EXPIRE_SECONDS']
) )
validated_data['application'] = None validated_data['application'] = None
obj = super(OAuth2PersonalTokenSerializer, self).create(validated_data) obj = super(OAuth2PersonalTokenSerializer, self).create(validated_data)
@@ -4512,9 +4511,19 @@ class SchedulePreviewSerializer(BaseSerializer):
class ScheduleSerializer(LaunchConfigurationBaseSerializer, SchedulePreviewSerializer): class ScheduleSerializer(LaunchConfigurationBaseSerializer, SchedulePreviewSerializer):
show_capabilities = ['edit', 'delete'] show_capabilities = ['edit', 'delete']
timezone = serializers.SerializerMethodField()
until = serializers.SerializerMethodField()
class Meta: class Meta:
model = Schedule model = Schedule
fields = ('*', 'unified_job_template', 'enabled', 'dtstart', 'dtend', 'rrule', 'next_run',) fields = ('*', 'unified_job_template', 'enabled', 'dtstart', 'dtend', 'rrule', 'next_run', 'timezone',
'until')
def get_timezone(self, obj):
return obj.timezone
def get_until(self, obj):
return obj.until
def get_related(self, obj): def get_related(self, obj):
res = super(ScheduleSerializer, self).get_related(obj) res = super(ScheduleSerializer, self).get_related(obj)
@@ -4600,7 +4609,7 @@ class InstanceGroupSerializer(BaseSerializer):
"this group when new instances come online.") "this group when new instances come online.")
) )
policy_instance_list = serializers.ListField( policy_instance_list = serializers.ListField(
child=serializers.CharField(), child=serializers.CharField(), required=False,
help_text=_("List of exact-match Instances that will be assigned to this group") help_text=_("List of exact-match Instances that will be assigned to this group")
) )
@@ -4627,6 +4636,11 @@ class InstanceGroupSerializer(BaseSerializer):
raise serializers.ValidationError(_('{} is not a valid hostname of an existing instance.').format(instance_name)) raise serializers.ValidationError(_('{} is not a valid hostname of an existing instance.').format(instance_name))
return value return value
def validate_name(self, value):
if self.instance and self.instance.name == 'tower' and value != 'tower':
raise serializers.ValidationError(_('tower instance group name may not be changed.'))
return value
def get_jobs_qs(self): def get_jobs_qs(self):
# Store running jobs queryset in context, so it will be shared in ListView # Store running jobs queryset in context, so it will be shared in ListView
if 'running_jobs' not in self.context: if 'running_jobs' not in self.context:

View File

@@ -15,7 +15,7 @@ from awx.api.views import (
UserActivityStreamList, UserActivityStreamList,
UserAccessList, UserAccessList,
OAuth2ApplicationList, OAuth2ApplicationList,
OAuth2TokenList, OAuth2UserTokenList,
OAuth2PersonalTokenList, OAuth2PersonalTokenList,
UserAuthorizedTokenList, UserAuthorizedTokenList,
) )
@@ -32,7 +32,7 @@ urls = [
url(r'^(?P<pk>[0-9]+)/activity_stream/$', UserActivityStreamList.as_view(), name='user_activity_stream_list'), url(r'^(?P<pk>[0-9]+)/activity_stream/$', UserActivityStreamList.as_view(), name='user_activity_stream_list'),
url(r'^(?P<pk>[0-9]+)/access_list/$', UserAccessList.as_view(), name='user_access_list'), url(r'^(?P<pk>[0-9]+)/access_list/$', UserAccessList.as_view(), name='user_access_list'),
url(r'^(?P<pk>[0-9]+)/applications/$', OAuth2ApplicationList.as_view(), name='o_auth2_application_list'), url(r'^(?P<pk>[0-9]+)/applications/$', OAuth2ApplicationList.as_view(), name='o_auth2_application_list'),
url(r'^(?P<pk>[0-9]+)/tokens/$', OAuth2TokenList.as_view(), name='o_auth2_token_list'), url(r'^(?P<pk>[0-9]+)/tokens/$', OAuth2UserTokenList.as_view(), name='o_auth2_token_list'),
url(r'^(?P<pk>[0-9]+)/authorized_tokens/$', UserAuthorizedTokenList.as_view(), name='user_authorized_token_list'), url(r'^(?P<pk>[0-9]+)/authorized_tokens/$', UserAuthorizedTokenList.as_view(), name='user_authorized_token_list'),
url(r'^(?P<pk>[0-9]+)/personal_tokens/$', OAuth2PersonalTokenList.as_view(), name='o_auth2_personal_token_list'), url(r'^(?P<pk>[0-9]+)/personal_tokens/$', OAuth2PersonalTokenList.as_view(), name='o_auth2_personal_token_list'),

View File

@@ -404,9 +404,11 @@ class ApiV1ConfigView(APIView):
data.update(dict( data.update(dict(
project_base_dir = settings.PROJECTS_ROOT, project_base_dir = settings.PROJECTS_ROOT,
project_local_paths = Project.get_local_path_choices(), project_local_paths = Project.get_local_path_choices(),
custom_virtualenvs = get_custom_venv_choices(),
)) ))
if JobTemplate.accessible_objects(request.user, 'admin_role').exists():
data['custom_virtualenvs'] = get_custom_venv_choices()
return Response(data) return Response(data)
def post(self, request): def post(self, request):
@@ -610,6 +612,7 @@ class InstanceList(ListAPIView):
view_name = _("Instances") view_name = _("Instances")
model = Instance model = Instance
serializer_class = InstanceSerializer serializer_class = InstanceSerializer
search_fields = ('hostname',)
class InstanceDetail(RetrieveUpdateAPIView): class InstanceDetail(RetrieveUpdateAPIView):
@@ -696,6 +699,7 @@ class InstanceGroupInstanceList(InstanceGroupMembershipMixin, SubListAttachDetac
serializer_class = InstanceSerializer serializer_class = InstanceSerializer
parent_model = InstanceGroup parent_model = InstanceGroup
relationship = "instances" relationship = "instances"
search_fields = ('hostname',)
class ScheduleList(ListAPIView): class ScheduleList(ListAPIView):
@@ -745,11 +749,11 @@ class ScheduleZoneInfo(APIView):
swagger_topic = 'System Configuration' swagger_topic = 'System Configuration'
def get(self, request): def get(self, request):
from dateutil.zoneinfo import get_zonefile_instance zones = [
return Response([
{'name': zone} {'name': zone}
for zone in sorted(get_zonefile_instance().zones) for zone in Schedule.get_zoneinfo()
]) ]
return Response(zones)
class LaunchConfigCredentialsBase(SubListAttachDetachAPIView): class LaunchConfigCredentialsBase(SubListAttachDetachAPIView):
@@ -1072,6 +1076,7 @@ class OrganizationActivityStreamList(ActivityStreamEnforcementMixin, SubListAPIV
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = Organization parent_model = Organization
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
class OrganizationNotificationTemplatesList(SubListCreateAttachDetachAPIView): class OrganizationNotificationTemplatesList(SubListCreateAttachDetachAPIView):
@@ -1126,6 +1131,7 @@ class OrganizationObjectRolesList(SubListAPIView):
model = Role model = Role
serializer_class = RoleSerializer serializer_class = RoleSerializer
parent_model = Organization parent_model = Organization
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
po = self.get_parent_object() po = self.get_parent_object()
@@ -1165,6 +1171,7 @@ class TeamRolesList(SubListAttachDetachAPIView):
metadata_class = RoleMetadata metadata_class = RoleMetadata
parent_model = Team parent_model = Team
relationship='member_role.children' relationship='member_role.children'
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
team = get_object_or_404(Team, pk=self.kwargs['pk']) team = get_object_or_404(Team, pk=self.kwargs['pk'])
@@ -1202,6 +1209,7 @@ class TeamObjectRolesList(SubListAPIView):
model = Role model = Role
serializer_class = RoleSerializer serializer_class = RoleSerializer
parent_model = Team parent_model = Team
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
po = self.get_parent_object() po = self.get_parent_object()
@@ -1233,6 +1241,7 @@ class TeamActivityStreamList(ActivityStreamEnforcementMixin, SubListAPIView):
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = Team parent_model = Team
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
def get_queryset(self): def get_queryset(self):
parent = self.get_parent_object() parent = self.get_parent_object()
@@ -1328,6 +1337,7 @@ class ProjectActivityStreamList(ActivityStreamEnforcementMixin, SubListAPIView):
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = Project parent_model = Project
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
def get_queryset(self): def get_queryset(self):
parent = self.get_parent_object() parent = self.get_parent_object()
@@ -1415,6 +1425,7 @@ class ProjectUpdateEventsList(SubListAPIView):
parent_model = ProjectUpdate parent_model = ProjectUpdate
relationship = 'project_update_events' relationship = 'project_update_events'
view_name = _('Project Update Events List') view_name = _('Project Update Events List')
search_fields = ('stdout',)
def finalize_response(self, request, response, *args, **kwargs): def finalize_response(self, request, response, *args, **kwargs):
response['X-UI-Max-Events'] = settings.MAX_UI_JOB_EVENTS response['X-UI-Max-Events'] = settings.MAX_UI_JOB_EVENTS
@@ -1428,6 +1439,7 @@ class SystemJobEventsList(SubListAPIView):
parent_model = SystemJob parent_model = SystemJob
relationship = 'system_job_events' relationship = 'system_job_events'
view_name = _('System Job Events List') view_name = _('System Job Events List')
search_fields = ('stdout',)
def finalize_response(self, request, response, *args, **kwargs): def finalize_response(self, request, response, *args, **kwargs):
response['X-UI-Max-Events'] = settings.MAX_UI_JOB_EVENTS response['X-UI-Max-Events'] = settings.MAX_UI_JOB_EVENTS
@@ -1441,6 +1453,7 @@ class InventoryUpdateEventsList(SubListAPIView):
parent_model = InventoryUpdate parent_model = InventoryUpdate
relationship = 'inventory_update_events' relationship = 'inventory_update_events'
view_name = _('Inventory Update Events List') view_name = _('Inventory Update Events List')
search_fields = ('stdout',)
def finalize_response(self, request, response, *args, **kwargs): def finalize_response(self, request, response, *args, **kwargs):
response['X-UI-Max-Events'] = settings.MAX_UI_JOB_EVENTS response['X-UI-Max-Events'] = settings.MAX_UI_JOB_EVENTS
@@ -1468,6 +1481,7 @@ class ProjectUpdateNotificationsList(SubListAPIView):
serializer_class = NotificationSerializer serializer_class = NotificationSerializer
parent_model = ProjectUpdate parent_model = ProjectUpdate
relationship = 'notifications' relationship = 'notifications'
search_fields = ('subject', 'notification_type', 'body',)
class ProjectUpdateScmInventoryUpdates(SubListCreateAPIView): class ProjectUpdateScmInventoryUpdates(SubListCreateAPIView):
@@ -1491,6 +1505,7 @@ class ProjectObjectRolesList(SubListAPIView):
model = Role model = Role
serializer_class = RoleSerializer serializer_class = RoleSerializer
parent_model = Project parent_model = Project
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
po = self.get_parent_object() po = self.get_parent_object()
@@ -1573,6 +1588,7 @@ class OAuth2ApplicationActivityStreamList(ActivityStreamEnforcementMixin, SubLis
parent_model = OAuth2Application parent_model = OAuth2Application
relationship = 'activitystream_set' relationship = 'activitystream_set'
swagger_topic = 'Authentication' swagger_topic = 'Authentication'
search_fields = ('changes',)
class OAuth2TokenList(ListCreateAPIView): class OAuth2TokenList(ListCreateAPIView):
@@ -1582,6 +1598,18 @@ class OAuth2TokenList(ListCreateAPIView):
model = OAuth2AccessToken model = OAuth2AccessToken
serializer_class = OAuth2TokenSerializer serializer_class = OAuth2TokenSerializer
swagger_topic = 'Authentication' swagger_topic = 'Authentication'
class OAuth2UserTokenList(SubListCreateAPIView):
view_name = _("OAuth2 User Tokens")
model = OAuth2AccessToken
serializer_class = OAuth2TokenSerializer
parent_model = User
relationship = 'main_oauth2accesstoken'
parent_key = 'user'
swagger_topic = 'Authentication'
class OAuth2AuthorizedTokenList(SubListCreateAPIView): class OAuth2AuthorizedTokenList(SubListCreateAPIView):
@@ -1657,6 +1685,7 @@ class OAuth2TokenActivityStreamList(ActivityStreamEnforcementMixin, SubListAPIVi
parent_model = OAuth2AccessToken parent_model = OAuth2AccessToken
relationship = 'activitystream_set' relationship = 'activitystream_set'
swagger_topic = 'Authentication' swagger_topic = 'Authentication'
search_fields = ('changes',)
class UserTeamsList(ListAPIView): class UserTeamsList(ListAPIView):
@@ -1680,6 +1709,7 @@ class UserRolesList(SubListAttachDetachAPIView):
parent_model = User parent_model = User
relationship='roles' relationship='roles'
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
u = get_object_or_404(User, pk=self.kwargs['pk']) u = get_object_or_404(User, pk=self.kwargs['pk'])
@@ -1766,6 +1796,7 @@ class UserActivityStreamList(ActivityStreamEnforcementMixin, SubListAPIView):
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = User parent_model = User
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
def get_queryset(self): def get_queryset(self):
parent = self.get_parent_object() parent = self.get_parent_object()
@@ -1852,6 +1883,7 @@ class CredentialTypeActivityStreamList(ActivityStreamEnforcementMixin, SubListAP
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = CredentialType parent_model = CredentialType
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
# remove in 3.3 # remove in 3.3
@@ -1965,6 +1997,7 @@ class CredentialActivityStreamList(ActivityStreamEnforcementMixin, SubListAPIVie
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = Credential parent_model = Credential
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
class CredentialAccessList(ResourceAccessList): class CredentialAccessList(ResourceAccessList):
@@ -1978,6 +2011,7 @@ class CredentialObjectRolesList(SubListAPIView):
model = Role model = Role
serializer_class = RoleSerializer serializer_class = RoleSerializer
parent_model = Credential parent_model = Credential
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
po = self.get_parent_object() po = self.get_parent_object()
@@ -2018,6 +2052,7 @@ class InventoryScriptObjectRolesList(SubListAPIView):
model = Role model = Role
serializer_class = RoleSerializer serializer_class = RoleSerializer
parent_model = CustomInventoryScript parent_model = CustomInventoryScript
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
po = self.get_parent_object() po = self.get_parent_object()
@@ -2105,6 +2140,7 @@ class InventoryActivityStreamList(ActivityStreamEnforcementMixin, SubListAPIView
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = Inventory parent_model = Inventory
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
def get_queryset(self): def get_queryset(self):
parent = self.get_parent_object() parent = self.get_parent_object()
@@ -2132,6 +2168,7 @@ class InventoryObjectRolesList(SubListAPIView):
model = Role model = Role
serializer_class = RoleSerializer serializer_class = RoleSerializer
parent_model = Inventory parent_model = Inventory
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
po = self.get_parent_object() po = self.get_parent_object()
@@ -2275,6 +2312,7 @@ class HostActivityStreamList(ActivityStreamEnforcementMixin, SubListAPIView):
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = Host parent_model = Host
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
def get_queryset(self): def get_queryset(self):
parent = self.get_parent_object() parent = self.get_parent_object()
@@ -2288,6 +2326,7 @@ class HostFactVersionsList(SystemTrackingEnforcementMixin, ParentMixin, ListAPIV
model = Fact model = Fact
serializer_class = FactVersionSerializer serializer_class = FactVersionSerializer
parent_model = Host parent_model = Host
search_fields = ('facts',)
def get_queryset(self): def get_queryset(self):
from_spec = self.request.query_params.get('from', None) from_spec = self.request.query_params.get('from', None)
@@ -2521,6 +2560,7 @@ class GroupActivityStreamList(ActivityStreamEnforcementMixin, SubListAPIView):
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = Group parent_model = Group
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
def get_queryset(self): def get_queryset(self):
parent = self.get_parent_object() parent = self.get_parent_object()
@@ -2752,6 +2792,7 @@ class InventorySourceActivityStreamList(ActivityStreamEnforcementMixin, SubListA
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = InventorySource parent_model = InventorySource
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
class InventorySourceNotificationTemplatesAnyList(SubListCreateAttachDetachAPIView): class InventorySourceNotificationTemplatesAnyList(SubListCreateAttachDetachAPIView):
@@ -2891,6 +2932,7 @@ class InventoryUpdateNotificationsList(SubListAPIView):
serializer_class = NotificationSerializer serializer_class = NotificationSerializer
parent_model = InventoryUpdate parent_model = InventoryUpdate
relationship = 'notifications' relationship = 'notifications'
search_fields = ('subject', 'notification_type', 'body',)
class JobTemplateList(ListCreateAPIView): class JobTemplateList(ListCreateAPIView):
@@ -3229,6 +3271,7 @@ class JobTemplateActivityStreamList(ActivityStreamEnforcementMixin, SubListAPIVi
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = JobTemplate parent_model = JobTemplate
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
class JobTemplateNotificationTemplatesAnyList(SubListCreateAttachDetachAPIView): class JobTemplateNotificationTemplatesAnyList(SubListCreateAttachDetachAPIView):
@@ -3512,6 +3555,7 @@ class JobTemplateObjectRolesList(SubListAPIView):
model = Role model = Role
serializer_class = RoleSerializer serializer_class = RoleSerializer
parent_model = JobTemplate parent_model = JobTemplate
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
po = self.get_parent_object() po = self.get_parent_object()
@@ -3529,6 +3573,7 @@ class WorkflowJobNodeList(WorkflowsEnforcementMixin, ListAPIView):
model = WorkflowJobNode model = WorkflowJobNode
serializer_class = WorkflowJobNodeListSerializer serializer_class = WorkflowJobNodeListSerializer
search_fields = ('unified_job_template__name', 'unified_job_template__description',)
class WorkflowJobNodeDetail(WorkflowsEnforcementMixin, RetrieveAPIView): class WorkflowJobNodeDetail(WorkflowsEnforcementMixin, RetrieveAPIView):
@@ -3549,6 +3594,7 @@ class WorkflowJobTemplateNodeList(WorkflowsEnforcementMixin, ListCreateAPIView):
model = WorkflowJobTemplateNode model = WorkflowJobTemplateNode
serializer_class = WorkflowJobTemplateNodeSerializer serializer_class = WorkflowJobTemplateNodeSerializer
search_fields = ('unified_job_template__name', 'unified_job_template__description',)
class WorkflowJobTemplateNodeDetail(WorkflowsEnforcementMixin, RetrieveUpdateDestroyAPIView): class WorkflowJobTemplateNodeDetail(WorkflowsEnforcementMixin, RetrieveUpdateDestroyAPIView):
@@ -3570,6 +3616,7 @@ class WorkflowJobTemplateNodeChildrenBaseList(WorkflowsEnforcementMixin, Enforce
parent_model = WorkflowJobTemplateNode parent_model = WorkflowJobTemplateNode
relationship = '' relationship = ''
enforce_parent_relationship = 'workflow_job_template' enforce_parent_relationship = 'workflow_job_template'
search_fields = ('unified_job_template__name', 'unified_job_template__description',)
''' '''
Limit the set of WorkflowJobTemplateNodes to the related nodes of specified by Limit the set of WorkflowJobTemplateNodes to the related nodes of specified by
@@ -3639,6 +3686,7 @@ class WorkflowJobNodeChildrenBaseList(WorkflowsEnforcementMixin, SubListAPIView)
serializer_class = WorkflowJobNodeListSerializer serializer_class = WorkflowJobNodeListSerializer
parent_model = WorkflowJobNode parent_model = WorkflowJobNode
relationship = '' relationship = ''
search_fields = ('unified_job_template__name', 'unified_job_template__description',)
# #
#Limit the set of WorkflowJobeNodes to the related nodes of specified by #Limit the set of WorkflowJobeNodes to the related nodes of specified by
@@ -3702,12 +3750,18 @@ class WorkflowJobTemplateCopy(WorkflowsEnforcementMixin, CopyAPIView):
item = getattr(obj, field_name, None) item = getattr(obj, field_name, None)
if item is None: if item is None:
continue continue
if field_name in ['inventory']: elif field_name in ['inventory']:
if not user.can_access(item.__class__, 'use', item): if not user.can_access(item.__class__, 'use', item):
setattr(obj, field_name, None) setattr(obj, field_name, None)
if field_name in ['unified_job_template']: elif field_name in ['unified_job_template']:
if not user.can_access(item.__class__, 'start', item, validate_license=False): if not user.can_access(item.__class__, 'start', item, validate_license=False):
setattr(obj, field_name, None) setattr(obj, field_name, None)
elif field_name in ['credentials']:
for cred in item.all():
if not user.can_access(cred.__class__, 'use', cred):
logger.debug(six.text_type(
'Deep copy: removing {} from relationship due to permissions').format(cred))
item.remove(cred.pk)
obj.save() obj.save()
@@ -3788,6 +3842,7 @@ class WorkflowJobTemplateWorkflowNodesList(WorkflowsEnforcementMixin, SubListCre
parent_model = WorkflowJobTemplate parent_model = WorkflowJobTemplate
relationship = 'workflow_job_template_nodes' relationship = 'workflow_job_template_nodes'
parent_key = 'workflow_job_template' parent_key = 'workflow_job_template'
search_fields = ('unified_job_template__name', 'unified_job_template__description',)
def get_queryset(self): def get_queryset(self):
return super(WorkflowJobTemplateWorkflowNodesList, self).get_queryset().order_by('id') return super(WorkflowJobTemplateWorkflowNodesList, self).get_queryset().order_by('id')
@@ -3848,6 +3903,7 @@ class WorkflowJobTemplateObjectRolesList(WorkflowsEnforcementMixin, SubListAPIVi
model = Role model = Role
serializer_class = RoleSerializer serializer_class = RoleSerializer
parent_model = WorkflowJobTemplate parent_model = WorkflowJobTemplate
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
po = self.get_parent_object() po = self.get_parent_object()
@@ -3861,6 +3917,7 @@ class WorkflowJobTemplateActivityStreamList(WorkflowsEnforcementMixin, ActivityS
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = WorkflowJobTemplate parent_model = WorkflowJobTemplate
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
def get_queryset(self): def get_queryset(self):
parent = self.get_parent_object() parent = self.get_parent_object()
@@ -3890,6 +3947,7 @@ class WorkflowJobWorkflowNodesList(WorkflowsEnforcementMixin, SubListAPIView):
parent_model = WorkflowJob parent_model = WorkflowJob
relationship = 'workflow_job_nodes' relationship = 'workflow_job_nodes'
parent_key = 'workflow_job' parent_key = 'workflow_job'
search_fields = ('unified_job_template__name', 'unified_job_template__description',)
def get_queryset(self): def get_queryset(self):
return super(WorkflowJobWorkflowNodesList, self).get_queryset().order_by('id') return super(WorkflowJobWorkflowNodesList, self).get_queryset().order_by('id')
@@ -3918,6 +3976,7 @@ class WorkflowJobNotificationsList(WorkflowsEnforcementMixin, SubListAPIView):
serializer_class = NotificationSerializer serializer_class = NotificationSerializer
parent_model = WorkflowJob parent_model = WorkflowJob
relationship = 'notifications' relationship = 'notifications'
search_fields = ('subject', 'notification_type', 'body',)
class WorkflowJobActivityStreamList(WorkflowsEnforcementMixin, ActivityStreamEnforcementMixin, SubListAPIView): class WorkflowJobActivityStreamList(WorkflowsEnforcementMixin, ActivityStreamEnforcementMixin, SubListAPIView):
@@ -3926,6 +3985,7 @@ class WorkflowJobActivityStreamList(WorkflowsEnforcementMixin, ActivityStreamEnf
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = WorkflowJob parent_model = WorkflowJob
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
class SystemJobTemplateList(ListAPIView): class SystemJobTemplateList(ListAPIView):
@@ -4081,6 +4141,7 @@ class JobActivityStreamList(ActivityStreamEnforcementMixin, SubListAPIView):
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = Job parent_model = Job
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
# TODO: remove endpoint in 3.3 # TODO: remove endpoint in 3.3
@@ -4284,6 +4345,7 @@ class JobNotificationsList(SubListAPIView):
serializer_class = NotificationSerializer serializer_class = NotificationSerializer
parent_model = Job parent_model = Job
relationship = 'notifications' relationship = 'notifications'
search_fields = ('subject', 'notification_type', 'body',)
class BaseJobHostSummariesList(SubListAPIView): class BaseJobHostSummariesList(SubListAPIView):
@@ -4293,6 +4355,7 @@ class BaseJobHostSummariesList(SubListAPIView):
parent_model = None # Subclasses must define this attribute. parent_model = None # Subclasses must define this attribute.
relationship = 'job_host_summaries' relationship = 'job_host_summaries'
view_name = _('Job Host Summaries List') view_name = _('Job Host Summaries List')
search_fields = ('host_name',)
def get_queryset(self): def get_queryset(self):
parent = self.get_parent_object() parent = self.get_parent_object()
@@ -4325,6 +4388,7 @@ class JobEventList(ListAPIView):
model = JobEvent model = JobEvent
serializer_class = JobEventSerializer serializer_class = JobEventSerializer
search_fields = ('stdout',)
class JobEventDetail(RetrieveAPIView): class JobEventDetail(RetrieveAPIView):
@@ -4340,6 +4404,7 @@ class JobEventChildrenList(SubListAPIView):
parent_model = JobEvent parent_model = JobEvent
relationship = 'children' relationship = 'children'
view_name = _('Job Event Children List') view_name = _('Job Event Children List')
search_fields = ('stdout',)
class JobEventHostsList(HostRelatedSearchMixin, SubListAPIView): class JobEventHostsList(HostRelatedSearchMixin, SubListAPIView):
@@ -4553,6 +4618,7 @@ class AdHocCommandEventList(ListAPIView):
model = AdHocCommandEvent model = AdHocCommandEvent
serializer_class = AdHocCommandEventSerializer serializer_class = AdHocCommandEventSerializer
search_fields = ('stdout',)
class AdHocCommandEventDetail(RetrieveAPIView): class AdHocCommandEventDetail(RetrieveAPIView):
@@ -4568,6 +4634,7 @@ class BaseAdHocCommandEventsList(SubListAPIView):
parent_model = None # Subclasses must define this attribute. parent_model = None # Subclasses must define this attribute.
relationship = 'ad_hoc_command_events' relationship = 'ad_hoc_command_events'
view_name = _('Ad Hoc Command Events List') view_name = _('Ad Hoc Command Events List')
search_fields = ('stdout',)
class HostAdHocCommandEventsList(BaseAdHocCommandEventsList): class HostAdHocCommandEventsList(BaseAdHocCommandEventsList):
@@ -4590,6 +4657,7 @@ class AdHocCommandActivityStreamList(ActivityStreamEnforcementMixin, SubListAPIV
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
parent_model = AdHocCommand parent_model = AdHocCommand
relationship = 'activitystream_set' relationship = 'activitystream_set'
search_fields = ('changes',)
class AdHocCommandNotificationsList(SubListAPIView): class AdHocCommandNotificationsList(SubListAPIView):
@@ -4598,6 +4666,7 @@ class AdHocCommandNotificationsList(SubListAPIView):
serializer_class = NotificationSerializer serializer_class = NotificationSerializer
parent_model = AdHocCommand parent_model = AdHocCommand
relationship = 'notifications' relationship = 'notifications'
search_fields = ('subject', 'notification_type', 'body',)
class SystemJobList(ListCreateAPIView): class SystemJobList(ListCreateAPIView):
@@ -4638,6 +4707,7 @@ class SystemJobNotificationsList(SubListAPIView):
serializer_class = NotificationSerializer serializer_class = NotificationSerializer
parent_model = SystemJob parent_model = SystemJob
relationship = 'notifications' relationship = 'notifications'
search_fields = ('subject', 'notification_type', 'body',)
class UnifiedJobTemplateList(ListAPIView): class UnifiedJobTemplateList(ListAPIView):
@@ -4706,7 +4776,6 @@ class UnifiedJobStdout(RetrieveAPIView):
try: try:
target_format = request.accepted_renderer.format target_format = request.accepted_renderer.format
if target_format in ('html', 'api', 'json'): if target_format in ('html', 'api', 'json'):
content_format = request.query_params.get('content_format', 'html')
content_encoding = request.query_params.get('content_encoding', None) content_encoding = request.query_params.get('content_encoding', None)
start_line = request.query_params.get('start_line', 0) start_line = request.query_params.get('start_line', 0)
end_line = request.query_params.get('end_line', None) end_line = request.query_params.get('end_line', None)
@@ -4732,10 +4801,10 @@ class UnifiedJobStdout(RetrieveAPIView):
if target_format == 'api': if target_format == 'api':
return Response(mark_safe(data)) return Response(mark_safe(data))
if target_format == 'json': if target_format == 'json':
if content_encoding == 'base64' and content_format == 'ansi': content = content.encode('utf-8')
return Response({'range': {'start': start, 'end': end, 'absolute_end': absolute_end}, 'content': b64encode(content.encode('utf-8'))}) if content_encoding == 'base64':
elif content_format == 'html': content = b64encode(content)
return Response({'range': {'start': start, 'end': end, 'absolute_end': absolute_end}, 'content': body}) return Response({'range': {'start': start, 'end': end, 'absolute_end': absolute_end}, 'content': content})
return Response(data) return Response(data)
elif target_format == 'txt': elif target_format == 'txt':
return Response(unified_job.result_stdout) return Response(unified_job.result_stdout)
@@ -4843,6 +4912,7 @@ class NotificationTemplateNotificationList(SubListAPIView):
parent_model = NotificationTemplate parent_model = NotificationTemplate
relationship = 'notifications' relationship = 'notifications'
parent_key = 'notification_template' parent_key = 'notification_template'
search_fields = ('subject', 'notification_type', 'body',)
class NotificationTemplateCopy(CopyAPIView): class NotificationTemplateCopy(CopyAPIView):
@@ -4855,6 +4925,7 @@ class NotificationList(ListAPIView):
model = Notification model = Notification
serializer_class = NotificationSerializer serializer_class = NotificationSerializer
search_fields = ('subject', 'notification_type', 'body',)
class NotificationDetail(RetrieveAPIView): class NotificationDetail(RetrieveAPIView):
@@ -4879,6 +4950,7 @@ class ActivityStreamList(ActivityStreamEnforcementMixin, SimpleListAPIView):
model = ActivityStream model = ActivityStream
serializer_class = ActivityStreamSerializer serializer_class = ActivityStreamSerializer
search_fields = ('changes',)
class ActivityStreamDetail(ActivityStreamEnforcementMixin, RetrieveAPIView): class ActivityStreamDetail(ActivityStreamEnforcementMixin, RetrieveAPIView):
@@ -4892,6 +4964,7 @@ class RoleList(ListAPIView):
model = Role model = Role
serializer_class = RoleSerializer serializer_class = RoleSerializer
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
result = Role.visible_roles(self.request.user) result = Role.visible_roles(self.request.user)
@@ -5004,6 +5077,7 @@ class RoleParentsList(SubListAPIView):
parent_model = Role parent_model = Role
relationship = 'parents' relationship = 'parents'
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
role = Role.objects.get(pk=self.kwargs['pk']) role = Role.objects.get(pk=self.kwargs['pk'])
@@ -5017,6 +5091,7 @@ class RoleChildrenList(SubListAPIView):
parent_model = Role parent_model = Role
relationship = 'children' relationship = 'children'
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
search_fields = ('role_field', 'content_type__model',)
def get_queryset(self): def get_queryset(self):
role = Role.objects.get(pk=self.kwargs['pk']) role = Role.objects.get(pk=self.kwargs['pk'])

View File

@@ -2,8 +2,6 @@
from django.apps import AppConfig from django.apps import AppConfig
# from django.core import checks # from django.core import checks
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from awx.main.utils.handlers import configure_external_logger
from django.conf import settings
class ConfConfig(AppConfig): class ConfConfig(AppConfig):
@@ -11,16 +9,7 @@ class ConfConfig(AppConfig):
name = 'awx.conf' name = 'awx.conf'
verbose_name = _('Configuration') verbose_name = _('Configuration')
def configure_oauth2_provider(self, settings):
from oauth2_provider import settings as o_settings
o_settings.oauth2_settings = o_settings.OAuth2ProviderSettings(
settings.OAUTH2_PROVIDER, o_settings.DEFAULTS,
o_settings.IMPORT_STRINGS, o_settings.MANDATORY
)
def ready(self): def ready(self):
self.module.autodiscover() self.module.autodiscover()
from .settings import SettingsWrapper from .settings import SettingsWrapper
SettingsWrapper.initialize() SettingsWrapper.initialize()
configure_external_logger(settings)
self.configure_oauth2_provider(settings)

View File

@@ -5,6 +5,8 @@ import logging
import sys import sys
import threading import threading
import time import time
import StringIO
import traceback
import six import six
@@ -62,11 +64,19 @@ __all__ = ['SettingsWrapper', 'get_settings_to_cache', 'SETTING_CACHE_NOTSET']
def _log_database_error(): def _log_database_error():
try: try:
yield yield
except (ProgrammingError, OperationalError) as e: except (ProgrammingError, OperationalError):
if get_tower_migration_version() < '310': if 'migrate' in sys.argv and get_tower_migration_version() < '310':
logger.info('Using default settings until version 3.1 migration.') logger.info('Using default settings until version 3.1 migration.')
else: else:
logger.warning('Database settings are not available, using defaults (%s)', e, exc_info=True) # Somewhat ugly - craming the full stack trace into the log message
# the available exc_info does not give information about the real caller
# TODO: replace in favor of stack_info kwarg in python 3
sio = StringIO.StringIO()
traceback.print_stack(file=sio)
sinfo = sio.getvalue()
sio.close()
sinfo = sinfo.strip('\n')
logger.warning('Database settings are not available, using defaults, logged from:\n{}'.format(sinfo))
finally: finally:
pass pass

View File

View File

@@ -338,13 +338,14 @@ def test_setting_singleton_delete_no_read_only_fields(api_request, dummy_setting
@pytest.mark.django_db @pytest.mark.django_db
def test_setting_logging_test(api_request): def test_setting_logging_test(api_request):
with mock.patch('awx.conf.views.BaseHTTPSHandler.perform_test') as mock_func: with mock.patch('awx.conf.views.AWXProxyHandler.perform_test') as mock_func:
api_request( api_request(
'post', 'post',
reverse('api:setting_logging_test'), reverse('api:setting_logging_test'),
data={'LOG_AGGREGATOR_HOST': 'http://foobar', 'LOG_AGGREGATOR_TYPE': 'logstash'} data={'LOG_AGGREGATOR_HOST': 'http://foobar', 'LOG_AGGREGATOR_TYPE': 'logstash'}
) )
test_arguments = mock_func.call_args[0][0] call = mock_func.call_args_list[0]
assert test_arguments.LOG_AGGREGATOR_HOST == 'http://foobar' args, kwargs = call
assert test_arguments.LOG_AGGREGATOR_TYPE == 'logstash' given_settings = kwargs['custom_settings']
assert test_arguments.LOG_AGGREGATOR_LEVEL == 'DEBUG' assert given_settings.LOG_AGGREGATOR_HOST == 'http://foobar'
assert given_settings.LOG_AGGREGATOR_TYPE == 'logstash'

View File

@@ -0,0 +1,6 @@
# Ensure that our autouse overwrites are working
def test_cache(settings):
assert settings.CACHES['default']['BACKEND'] == 'django.core.cache.backends.locmem.LocMemCache'
assert settings.CACHES['default']['LOCATION'].startswith('unique-')

View File

@@ -21,7 +21,7 @@ from awx.api.generics import * # noqa
from awx.api.permissions import IsSuperUser from awx.api.permissions import IsSuperUser
from awx.api.versioning import reverse, get_request_version from awx.api.versioning import reverse, get_request_version
from awx.main.utils import * # noqa from awx.main.utils import * # noqa
from awx.main.utils.handlers import BaseHTTPSHandler, UDPHandler, LoggingConnectivityException from awx.main.utils.handlers import AWXProxyHandler, LoggingConnectivityException
from awx.main.tasks import handle_setting_changes from awx.main.tasks import handle_setting_changes
from awx.conf.license import get_licensed_features from awx.conf.license import get_licensed_features
from awx.conf.models import Setting from awx.conf.models import Setting
@@ -198,12 +198,9 @@ class SettingLoggingTest(GenericAPIView):
mock_settings = MockSettings() mock_settings = MockSettings()
for k, v in serializer.validated_data.items(): for k, v in serializer.validated_data.items():
setattr(mock_settings, k, v) setattr(mock_settings, k, v)
mock_settings.LOG_AGGREGATOR_LEVEL = 'DEBUG' AWXProxyHandler().perform_test(custom_settings=mock_settings)
if mock_settings.LOG_AGGREGATOR_PROTOCOL.upper() == 'UDP': if mock_settings.LOG_AGGREGATOR_PROTOCOL.upper() == 'UDP':
UDPHandler.perform_test(mock_settings)
return Response(status=status.HTTP_201_CREATED) return Response(status=status.HTTP_201_CREATED)
else:
BaseHTTPSHandler.perform_test(mock_settings)
except LoggingConnectivityException as e: except LoggingConnectivityException as e:
return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return Response(status=status.HTTP_200_OK) return Response(status=status.HTTP_200_OK)

File diff suppressed because it is too large Load Diff

View File

@@ -1366,6 +1366,7 @@ class JobTemplateAccess(BaseAccess):
'job_tags', 'force_handlers', 'skip_tags', 'ask_variables_on_launch', 'job_tags', 'force_handlers', 'skip_tags', 'ask_variables_on_launch',
'ask_tags_on_launch', 'ask_job_type_on_launch', 'ask_skip_tags_on_launch', 'ask_tags_on_launch', 'ask_job_type_on_launch', 'ask_skip_tags_on_launch',
'ask_inventory_on_launch', 'ask_credential_on_launch', 'survey_enabled', 'ask_inventory_on_launch', 'ask_credential_on_launch', 'survey_enabled',
'custom_virtualenv', 'diff_mode',
# These fields are ignored, but it is convenient for QA to allow clients to post them # These fields are ignored, but it is convenient for QA to allow clients to post them
'last_job_run', 'created', 'modified', 'last_job_run', 'created', 'modified',
@@ -1814,13 +1815,14 @@ class WorkflowJobTemplateAccess(BaseAccess):
missing_credentials = [] missing_credentials = []
missing_inventories = [] missing_inventories = []
qs = obj.workflow_job_template_nodes qs = obj.workflow_job_template_nodes
qs = qs.prefetch_related('unified_job_template', 'inventory__use_role', 'credential__use_role') qs = qs.prefetch_related('unified_job_template', 'inventory__use_role', 'credentials__use_role')
for node in qs.all(): for node in qs.all():
node_errors = {} node_errors = {}
if node.inventory and self.user not in node.inventory.use_role: if node.inventory and self.user not in node.inventory.use_role:
missing_inventories.append(node.inventory.name) missing_inventories.append(node.inventory.name)
if node.credential and self.user not in node.credential.use_role: for cred in node.credentials.all():
missing_credentials.append(node.credential.name) if self.user not in cred.use_role:
missing_credentials.append(cred.name)
ujt = node.unified_job_template ujt = node.unified_job_template
if ujt and not self.user.can_access(UnifiedJobTemplate, 'start', ujt, validate_license=False): if ujt and not self.user.can_access(UnifiedJobTemplate, 'start', ujt, validate_license=False):
missing_ujt.append(ujt.name) missing_ujt.append(ujt.name)
@@ -1924,7 +1926,7 @@ class WorkflowJobAccess(BaseAccess):
return self.can_recreate(obj) return self.can_recreate(obj)
def can_recreate(self, obj): def can_recreate(self, obj):
node_qs = obj.workflow_job_nodes.all().prefetch_related('inventory', 'credential', 'unified_job_template') node_qs = obj.workflow_job_nodes.all().prefetch_related('inventory', 'credentials', 'unified_job_template')
node_access = WorkflowJobNodeAccess(user=self.user) node_access = WorkflowJobNodeAccess(user=self.user)
wj_add_perm = True wj_add_perm = True
for node in node_qs: for node in node_qs:

View File

@@ -193,8 +193,10 @@ def update_role_parentage_for_instance(instance):
''' '''
for implicit_role_field in getattr(instance.__class__, '__implicit_role_fields'): for implicit_role_field in getattr(instance.__class__, '__implicit_role_fields'):
cur_role = getattr(instance, implicit_role_field.name) cur_role = getattr(instance, implicit_role_field.name)
original_parents = set(json.loads(cur_role.implicit_parents))
new_parents = implicit_role_field._resolve_parent_roles(instance) new_parents = implicit_role_field._resolve_parent_roles(instance)
cur_role.parents.set(new_parents) cur_role.parents.remove(*list(original_parents - new_parents))
cur_role.parents.add(*list(new_parents - original_parents))
new_parents_list = list(new_parents) new_parents_list = list(new_parents)
new_parents_list.sort() new_parents_list.sort()
new_parents_json = json.dumps(new_parents_list) new_parents_json = json.dumps(new_parents_list)
@@ -802,23 +804,33 @@ class CredentialTypeInjectorField(JSONSchemaField):
for field in model_instance.defined_fields for field in model_instance.defined_fields
) )
class ExplodingNamespace:
def __unicode__(self):
raise UndefinedError(_('Must define unnamed file injector in order to reference `tower.filename`.'))
class TowerNamespace: class TowerNamespace:
filename = None def __init__(self):
self.filename = ExplodingNamespace()
def __unicode__(self):
raise UndefinedError(_('Cannot directly reference reserved `tower` namespace container.'))
valid_namespace['tower'] = TowerNamespace() valid_namespace['tower'] = TowerNamespace()
# ensure either single file or multi-file syntax is used (but not both) # ensure either single file or multi-file syntax is used (but not both)
template_names = [x for x in value.get('file', {}).keys() if x.startswith('template')] template_names = [x for x in value.get('file', {}).keys() if x.startswith('template')]
if 'template' in template_names and len(template_names) > 1: if 'template' in template_names:
raise django_exceptions.ValidationError( valid_namespace['tower'].filename = 'EXAMPLE_FILENAME'
_('Must use multi-file syntax when injecting multiple files'), if len(template_names) > 1:
code='invalid', raise django_exceptions.ValidationError(
params={'value': value}, _('Must use multi-file syntax when injecting multiple files'),
) code='invalid',
if 'template' not in template_names: params={'value': value},
valid_namespace['tower'].filename = TowerNamespace() )
elif template_names:
for template_name in template_names: for template_name in template_names:
template_name = template_name.split('.')[1] template_name = template_name.split('.')[1]
setattr(valid_namespace['tower'].filename, template_name, 'EXAMPLE') setattr(valid_namespace['tower'].filename, template_name, 'EXAMPLE_FILENAME')
for type_, injector in value.items(): for type_, injector in value.items():
for key, tmpl in injector.items(): for key, tmpl in injector.items():

View File

@@ -135,8 +135,7 @@ class AnsibleInventoryLoader(object):
self.tmp_private_dir = build_proot_temp_dir() self.tmp_private_dir = build_proot_temp_dir()
logger.debug("Using fresh temporary directory '{}' for isolation.".format(self.tmp_private_dir)) logger.debug("Using fresh temporary directory '{}' for isolation.".format(self.tmp_private_dir))
kwargs['proot_temp_dir'] = self.tmp_private_dir kwargs['proot_temp_dir'] = self.tmp_private_dir
# Run from source's location so that custom script contents are in `show_paths` kwargs['proot_show_paths'] = [functioning_dir(self.source)]
cwd = functioning_dir(self.source)
logger.debug("Running from `{}` working directory.".format(cwd)) logger.debug("Running from `{}` working directory.".format(cwd))
return wrap_args_with_proot(cmd, cwd, **kwargs) return wrap_args_with_proot(cmd, cwd, **kwargs)

View File

@@ -3,6 +3,7 @@
from awx.main.models import Instance, InstanceGroup from awx.main.models import Instance, InstanceGroup
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
import six
class Command(BaseCommand): class Command(BaseCommand):
@@ -13,10 +14,11 @@ class Command(BaseCommand):
super(Command, self).__init__() super(Command, self).__init__()
for instance in Instance.objects.all(): for instance in Instance.objects.all():
print("hostname: {}; created: {}; heartbeat: {}; capacity: {}".format(instance.hostname, instance.created, print(six.text_type(
instance.modified, instance.capacity)) "hostname: {0.hostname}; created: {0.created}; "
"heartbeat: {0.modified}; capacity: {0.capacity}").format(instance))
for instance_group in InstanceGroup.objects.all(): for instance_group in InstanceGroup.objects.all():
print("Instance Group: {}; created: {}; capacity: {}; members: {}".format(instance_group.name, print(six.text_type(
instance_group.created, "Instance Group: {0.name}; created: {0.created}; "
instance_group.capacity, "capacity: {0.capacity}; members: {1}").format(instance_group,
[x.hostname for x in instance_group.instances.all()])) [x.hostname for x in instance_group.instances.all()]))

View File

@@ -19,11 +19,11 @@ class InstanceNotFound(Exception):
class Command(BaseCommand): class Command(BaseCommand):
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument('--queuename', dest='queuename', type=str, parser.add_argument('--queuename', dest='queuename', type=lambda s: six.text_type(s, 'utf8'),
help='Queue to create/update') help='Queue to create/update')
parser.add_argument('--hostnames', dest='hostnames', type=str, parser.add_argument('--hostnames', dest='hostnames', type=lambda s: six.text_type(s, 'utf8'),
help='Comma-Delimited Hosts to add to the Queue') help='Comma-Delimited Hosts to add to the Queue')
parser.add_argument('--controller', dest='controller', type=str, parser.add_argument('--controller', dest='controller', type=lambda s: six.text_type(s, 'utf8'),
default='', help='The controlling group (makes this an isolated group)') default='', help='The controlling group (makes this an isolated group)')
parser.add_argument('--instance_percent', dest='instance_percent', type=int, default=0, parser.add_argument('--instance_percent', dest='instance_percent', type=int, default=0,
help='The percentage of active instances that will be assigned to this group'), help='The percentage of active instances that will be assigned to this group'),
@@ -96,7 +96,7 @@ class Command(BaseCommand):
if options.get('hostnames'): if options.get('hostnames'):
hostname_list = options.get('hostnames').split(",") hostname_list = options.get('hostnames').split(",")
with advisory_lock('instance_group_registration_%s' % queuename): with advisory_lock(six.text_type('instance_group_registration_{}').format(queuename)):
(ig, created, changed) = self.get_create_update_instance_group(queuename, inst_per, inst_min) (ig, created, changed) = self.get_create_update_instance_group(queuename, inst_per, inst_min)
if created: if created:
print(six.text_type("Creating instance group {}".format(ig.name))) print(six.text_type("Creating instance group {}".format(ig.name)))

View File

@@ -95,7 +95,7 @@ class ReplayJobEvents():
raise RuntimeError("Job is of type {} and replay is not yet supported.".format(type(job))) raise RuntimeError("Job is of type {} and replay is not yet supported.".format(type(job)))
sys.exit(1) sys.exit(1)
def run(self, job_id, speed=1.0, verbosity=0): def run(self, job_id, speed=1.0, verbosity=0, skip=0):
stats = { stats = {
'events_ontime': { 'events_ontime': {
'total': 0, 'total': 0,
@@ -126,7 +126,10 @@ class ReplayJobEvents():
sys.exit(1) sys.exit(1)
je_previous = None je_previous = None
for je_current in job_events: for n, je_current in enumerate(job_events):
if n < skip:
continue
if not je_previous: if not je_previous:
stats['recording_start'] = je_current.created stats['recording_start'] = je_current.created
self.start(je_current.created) self.start(je_current.created)
@@ -163,21 +166,25 @@ class ReplayJobEvents():
stats['events_total'] += 1 stats['events_total'] += 1
je_previous = je_current je_previous = je_current
stats['replay_end'] = self.now()
stats['replay_duration'] = (stats['replay_end'] - stats['replay_start']).total_seconds()
stats['replay_start'] = stats['replay_start'].isoformat()
stats['replay_end'] = stats['replay_end'].isoformat()
stats['recording_end'] = je_current.created if stats['events_total'] > 2:
stats['recording_duration'] = (stats['recording_end'] - stats['recording_start']).total_seconds() stats['replay_end'] = self.now()
stats['recording_start'] = stats['recording_start'].isoformat() stats['replay_duration'] = (stats['replay_end'] - stats['replay_start']).total_seconds()
stats['recording_end'] = stats['recording_end'].isoformat() stats['replay_start'] = stats['replay_start'].isoformat()
stats['replay_end'] = stats['replay_end'].isoformat()
stats['recording_end'] = je_current.created
stats['recording_duration'] = (stats['recording_end'] - stats['recording_start']).total_seconds()
stats['recording_start'] = stats['recording_start'].isoformat()
stats['recording_end'] = stats['recording_end'].isoformat()
stats['events_ontime']['percentage'] = (stats['events_ontime']['total'] / float(stats['events_total'])) * 100.00
stats['events_late']['percentage'] = (stats['events_late']['total'] / float(stats['events_total'])) * 100.00
stats['events_distance_average'] = stats['events_distance_total'] / stats['events_total']
stats['events_late']['lateness_average'] = stats['events_late']['lateness_total'] / stats['events_late']['total']
else:
stats = {'events_total': stats['events_total']}
stats['events_ontime']['percentage'] = (stats['events_ontime']['total'] / float(stats['events_total'])) * 100.00
stats['events_late']['percentage'] = (stats['events_late']['total'] / float(stats['events_total'])) * 100.00
stats['events_distance_average'] = stats['events_distance_total'] / stats['events_total']
stats['events_late']['lateness_average'] = stats['events_late']['lateness_total'] / stats['events_late']['total']
if verbosity >= 2: if verbosity >= 2:
print(json.dumps(stats, indent=4, sort_keys=True)) print(json.dumps(stats, indent=4, sort_keys=True))
@@ -191,11 +198,14 @@ class Command(BaseCommand):
help='Id of the job to replay (job or adhoc)') help='Id of the job to replay (job or adhoc)')
parser.add_argument('--speed', dest='speed', type=int, metavar='s', parser.add_argument('--speed', dest='speed', type=int, metavar='s',
help='Speedup factor.') help='Speedup factor.')
parser.add_argument('--skip', dest='skip', type=int, metavar='k',
help='Number of events to skip.')
def handle(self, *args, **options): def handle(self, *args, **options):
job_id = options.get('job_id') job_id = options.get('job_id')
speed = options.get('speed') or 1 speed = options.get('speed') or 1
verbosity = options.get('verbosity') or 0 verbosity = options.get('verbosity') or 0
skip = options.get('skip') or 0
replayer = ReplayJobEvents() replayer = ReplayJobEvents()
replayer.run(job_id, speed, verbosity) replayer.run(job_id, speed, verbosity, skip)

View File

@@ -77,7 +77,7 @@ class InstanceManager(models.Manager):
def me(self): def me(self):
"""Return the currently active instance.""" """Return the currently active instance."""
# If we are running unit tests, return a stub record. # If we are running unit tests, return a stub record.
if settings.IS_TESTING(sys.argv): if settings.IS_TESTING(sys.argv) or hasattr(sys, '_called_from_test'):
return self.model(id=1, return self.model(id=1,
hostname='localhost', hostname='localhost',
uuid='00000000-0000-0000-0000-000000000000') uuid='00000000-0000-0000-0000-000000000000')

View File

@@ -3,6 +3,7 @@
# Django # Django
from django.conf import settings # noqa from django.conf import settings # noqa
from django.db.models.signals import pre_delete # noqa
# AWX # AWX
from awx.main.models.base import * # noqa from awx.main.models.base import * # noqa
@@ -58,6 +59,18 @@ User.add_to_class('can_access_with_errors', check_user_access_with_errors)
User.add_to_class('accessible_objects', user_accessible_objects) User.add_to_class('accessible_objects', user_accessible_objects)
def cleanup_created_modified_by(sender, **kwargs):
# work around a bug in django-polymorphic that doesn't properly
# handle cascades for reverse foreign keys on the polymorphic base model
# https://github.com/django-polymorphic/django-polymorphic/issues/229
for cls in (UnifiedJobTemplate, UnifiedJob):
cls.objects.filter(created_by=kwargs['instance']).update(created_by=None)
cls.objects.filter(modified_by=kwargs['instance']).update(modified_by=None)
pre_delete.connect(cleanup_created_modified_by, sender=User)
@property @property
def user_get_organizations(user): def user_get_organizations(user):
return Organization.objects.filter(member_role__members=user) return Organization.objects.filter(member_role__members=user)
@@ -169,3 +182,9 @@ activity_stream_registrar.connect(OAuth2AccessToken)
# prevent API filtering on certain Django-supplied sensitive fields # prevent API filtering on certain Django-supplied sensitive fields
prevent_search(User._meta.get_field('password')) prevent_search(User._meta.get_field('password'))
prevent_search(OAuth2AccessToken._meta.get_field('token'))
prevent_search(RefreshToken._meta.get_field('token'))
prevent_search(OAuth2Application._meta.get_field('client_secret'))
prevent_search(OAuth2Application._meta.get_field('client_id'))
prevent_search(Grant._meta.get_field('code'))

View File

@@ -153,7 +153,7 @@ class AdHocCommand(UnifiedJob, JobNotificationMixin):
return reverse('api:ad_hoc_command_detail', kwargs={'pk': self.pk}, request=request) return reverse('api:ad_hoc_command_detail', kwargs={'pk': self.pk}, request=request)
def get_ui_url(self): def get_ui_url(self):
return urljoin(settings.TOWER_URL_BASE, "/#/ad_hoc_commands/{}".format(self.pk)) return urljoin(settings.TOWER_URL_BASE, "/#/jobs/command/{}".format(self.pk))
@property @property
def notification_templates(self): def notification_templates(self):

View File

@@ -4,9 +4,11 @@ import logging
from django.conf import settings from django.conf import settings
from django.db import models, DatabaseError from django.db import models, DatabaseError
from django.utils.dateparse import parse_datetime from django.utils.dateparse import parse_datetime
from django.utils.text import Truncator
from django.utils.timezone import utc from django.utils.timezone import utc
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.utils.encoding import force_text from django.utils.encoding import force_text
import six
from awx.api.versioning import reverse from awx.api.versioning import reverse
from awx.main.fields import JSONField from awx.main.fields import JSONField
@@ -22,6 +24,22 @@ __all__ = ['JobEvent', 'ProjectUpdateEvent', 'AdHocCommandEvent',
'InventoryUpdateEvent', 'SystemJobEvent'] 'InventoryUpdateEvent', 'SystemJobEvent']
def sanitize_event_keys(kwargs, valid_keys):
# Sanity check: Don't honor keys that we don't recognize.
for key in kwargs.keys():
if key not in valid_keys:
kwargs.pop(key)
# Truncate certain values over 1k
for key in [
'play', 'role', 'task', 'playbook'
]:
if isinstance(kwargs.get(key), six.string_types):
if len(kwargs[key]) > 1024:
kwargs[key] = Truncator(kwargs[key]).chars(1024)
class BasePlaybookEvent(CreatedModifiedModel): class BasePlaybookEvent(CreatedModifiedModel):
''' '''
An event/message logged from a playbook callback for each host. An event/message logged from a playbook callback for each host.
@@ -257,7 +275,7 @@ class BasePlaybookEvent(CreatedModifiedModel):
return updated_fields return updated_fields
@classmethod @classmethod
def create_from_data(self, **kwargs): def create_from_data(cls, **kwargs):
pk = None pk = None
for key in ('job_id', 'project_update_id'): for key in ('job_id', 'project_update_id'):
if key in kwargs: if key in kwargs:
@@ -279,12 +297,8 @@ class BasePlaybookEvent(CreatedModifiedModel):
except (KeyError, ValueError): except (KeyError, ValueError):
kwargs.pop('created', None) kwargs.pop('created', None)
# Sanity check: Don't honor keys that we don't recognize. sanitize_event_keys(kwargs, cls.VALID_KEYS)
for key in kwargs.keys(): job_event = cls.objects.create(**kwargs)
if key not in self.VALID_KEYS:
kwargs.pop(key)
job_event = self.objects.create(**kwargs)
analytics_logger.info('Event data saved.', extra=dict(python_objects=dict(job_event=job_event))) analytics_logger.info('Event data saved.', extra=dict(python_objects=dict(job_event=job_event)))
return job_event return job_event
@@ -551,7 +565,7 @@ class BaseCommandEvent(CreatedModifiedModel):
return u'%s @ %s' % (self.get_event_display(), self.created.isoformat()) return u'%s @ %s' % (self.get_event_display(), self.created.isoformat())
@classmethod @classmethod
def create_from_data(self, **kwargs): def create_from_data(cls, **kwargs):
# Convert the datetime for the event's creation # Convert the datetime for the event's creation
# appropriately, and include a time zone for it. # appropriately, and include a time zone for it.
# #
@@ -565,12 +579,8 @@ class BaseCommandEvent(CreatedModifiedModel):
except (KeyError, ValueError): except (KeyError, ValueError):
kwargs.pop('created', None) kwargs.pop('created', None)
# Sanity check: Don't honor keys that we don't recognize. sanitize_event_keys(kwargs, cls.VALID_KEYS)
for key in kwargs.keys(): return cls.objects.create(**kwargs)
if key not in self.VALID_KEYS:
kwargs.pop(key)
return self.objects.create(**kwargs)
def get_event_display(self): def get_event_display(self):
''' '''

View File

@@ -1646,7 +1646,7 @@ class InventoryUpdate(UnifiedJob, InventorySourceOptions, JobNotificationMixin,
return reverse('api:inventory_update_detail', kwargs={'pk': self.pk}, request=request) return reverse('api:inventory_update_detail', kwargs={'pk': self.pk}, request=request)
def get_ui_url(self): def get_ui_url(self):
return urljoin(settings.TOWER_URL_BASE, "/#/inventory_sync/{}".format(self.pk)) return urljoin(settings.TOWER_URL_BASE, "/#/jobs/inventory/{}".format(self.pk))
def get_actual_source_path(self): def get_actual_source_path(self):
'''Alias to source_path that combines with project path for for SCM file based sources''' '''Alias to source_path that combines with project path for for SCM file based sources'''

View File

@@ -530,7 +530,7 @@ class Job(UnifiedJob, JobOptions, SurveyJobMixin, JobNotificationMixin, TaskMana
return reverse('api:job_detail', kwargs={'pk': self.pk}, request=request) return reverse('api:job_detail', kwargs={'pk': self.pk}, request=request)
def get_ui_url(self): def get_ui_url(self):
return urljoin(settings.TOWER_URL_BASE, "/#/jobs/{}".format(self.pk)) return urljoin(settings.TOWER_URL_BASE, "/#/jobs/playbook/{}".format(self.pk))
@property @property
def ansible_virtualenv_path(self): def ansible_virtualenv_path(self):
@@ -1192,7 +1192,7 @@ class SystemJob(UnifiedJob, SystemJobOptions, JobNotificationMixin):
return reverse('api:system_job_detail', kwargs={'pk': self.pk}, request=request) return reverse('api:system_job_detail', kwargs={'pk': self.pk}, request=request)
def get_ui_url(self): def get_ui_url(self):
return urljoin(settings.TOWER_URL_BASE, "/#/management_jobs/{}".format(self.pk)) return urljoin(settings.TOWER_URL_BASE, "/#/jobs/system/{}".format(self.pk))
@property @property
def event_class(self): def event_class(self):

View File

@@ -241,6 +241,7 @@ class Project(UnifiedJobTemplate, ProjectOptions, ResourceMixin, CustomVirtualEn
SOFT_UNIQUE_TOGETHER = [('polymorphic_ctype', 'name', 'organization')] SOFT_UNIQUE_TOGETHER = [('polymorphic_ctype', 'name', 'organization')]
FIELDS_TO_PRESERVE_AT_COPY = ['labels', 'instance_groups', 'credentials'] FIELDS_TO_PRESERVE_AT_COPY = ['labels', 'instance_groups', 'credentials']
FIELDS_TO_DISCARD_AT_COPY = ['local_path'] FIELDS_TO_DISCARD_AT_COPY = ['local_path']
FIELDS_TRIGGER_UPDATE = frozenset(['scm_url', 'scm_branch', 'scm_type'])
class Meta: class Meta:
app_label = 'main' app_label = 'main'
@@ -323,6 +324,11 @@ class Project(UnifiedJobTemplate, ProjectOptions, ResourceMixin, CustomVirtualEn
['name', 'description', 'schedule'] ['name', 'description', 'schedule']
) )
def __init__(self, *args, **kwargs):
r = super(Project, self).__init__(*args, **kwargs)
self._prior_values_store = self._current_sensitive_fields()
return r
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
new_instance = not bool(self.pk) new_instance = not bool(self.pk)
# If update_fields has been specified, add our field names to it, # If update_fields has been specified, add our field names to it,
@@ -354,9 +360,22 @@ class Project(UnifiedJobTemplate, ProjectOptions, ResourceMixin, CustomVirtualEn
with disable_activity_stream(): with disable_activity_stream():
self.save(update_fields=update_fields) self.save(update_fields=update_fields)
# If we just created a new project with SCM, start the initial update. # If we just created a new project with SCM, start the initial update.
if new_instance and self.scm_type and not skip_update: # also update if certain fields have changed
relevant_change = False
new_values = self._current_sensitive_fields()
if hasattr(self, '_prior_values_store') and self._prior_values_store != new_values:
relevant_change = True
self._prior_values_store = new_values
if (relevant_change or new_instance) and (not skip_update) and self.scm_type:
self.update() self.update()
def _current_sensitive_fields(self):
new_values = {}
for attr, val in self.__dict__.items():
if attr in Project.FIELDS_TRIGGER_UPDATE:
new_values[attr] = val
return new_values
def _get_current_status(self): def _get_current_status(self):
if self.scm_type: if self.scm_type:
if self.current_job and self.current_job.status: if self.current_job and self.current_job.status:
@@ -533,7 +552,7 @@ class ProjectUpdate(UnifiedJob, ProjectOptions, JobNotificationMixin, TaskManage
return reverse('api:project_update_detail', kwargs={'pk': self.pk}, request=request) return reverse('api:project_update_detail', kwargs={'pk': self.pk}, request=request)
def get_ui_url(self): def get_ui_url(self):
return urlparse.urljoin(settings.TOWER_URL_BASE, "/#/scm_update/{}".format(self.pk)) return urlparse.urljoin(settings.TOWER_URL_BASE, "/#/jobs/project/{}".format(self.pk))
def _update_parent_instance(self): def _update_parent_instance(self):
parent_instance = self._get_parent_instance() parent_instance = self._get_parent_instance()

View File

@@ -172,7 +172,7 @@ class Role(models.Model):
elif accessor.__class__.__name__ == 'Team': elif accessor.__class__.__name__ == 'Team':
return self.ancestors.filter(pk=accessor.member_role.id).exists() return self.ancestors.filter(pk=accessor.member_role.id).exists()
elif type(accessor) == Role: elif type(accessor) == Role:
return self.ancestors.filter(pk=accessor).exists() return self.ancestors.filter(pk=accessor.pk).exists()
else: else:
accessor_type = ContentType.objects.get_for_model(accessor) accessor_type = ContentType.objects.get_for_model(accessor)
roles = Role.objects.filter(content_type__pk=accessor_type.id, roles = Role.objects.filter(content_type__pk=accessor_type.id,

View File

@@ -1,15 +1,19 @@
# Copyright (c) 2015 Ansible, Inc. # Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved. # All Rights Reserved.
import logging
import datetime import datetime
import logging
import re
import dateutil.rrule import dateutil.rrule
from dateutil.tz import datetime_exists import dateutil.parser
from dateutil.tz import datetime_exists, tzutc
from dateutil.zoneinfo import get_zonefile_instance
# Django # Django
from django.db import models from django.db import models
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.utils.timezone import now from django.utils.timezone import now, make_aware
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
# AWX # AWX
@@ -27,6 +31,9 @@ logger = logging.getLogger('awx.main.models.schedule')
__all__ = ['Schedule'] __all__ = ['Schedule']
UTC_TIMEZONES = {x: tzutc() for x in dateutil.parser.parserinfo().UTCZONE}
class ScheduleFilterMethods(object): class ScheduleFilterMethods(object):
def enabled(self, enabled=True): def enabled(self, enabled=True):
@@ -94,13 +101,98 @@ class Schedule(CommonModel, LaunchTimeConfig):
help_text=_("The next time that the scheduled action will run.") help_text=_("The next time that the scheduled action will run.")
) )
@classmethod
def get_zoneinfo(self):
return sorted(get_zonefile_instance().zones)
@property
def timezone(self):
utc = tzutc()
all_zones = Schedule.get_zoneinfo()
all_zones.sort(key = lambda x: -len(x))
for r in Schedule.rrulestr(self.rrule)._rrule:
if r._dtstart:
tzinfo = r._dtstart.tzinfo
if tzinfo is utc:
return 'UTC'
fname = tzinfo._filename
for zone in all_zones:
if fname.endswith(zone):
return zone
logger.warn('Could not detect valid zoneinfo for {}'.format(self.rrule))
return ''
@property
def until(self):
# The UNTIL= datestamp (if any) coerced from UTC to the local naive time
# of the DTSTART
for r in Schedule.rrulestr(self.rrule)._rrule:
if r._until:
local_until = r._until.astimezone(r._dtstart.tzinfo)
naive_until = local_until.replace(tzinfo=None)
return naive_until.isoformat()
return ''
@classmethod
def coerce_naive_until(cls, rrule):
#
# RFC5545 specifies that the UNTIL rule part MUST ALWAYS be a date
# with UTC time. This is extra work for API implementers because
# it requires them to perform DTSTART local -> UTC datetime coercion on
# POST and UTC -> DTSTART local coercion on GET.
#
# This block of code is a departure from the RFC. If you send an
# rrule like this to the API (without a Z on the UNTIL):
#
# DTSTART;TZID=America/New_York:20180502T150000 RRULE:FREQ=HOURLY;INTERVAL=1;UNTIL=20180502T180000
#
# ...we'll assume that the naive UNTIL is intended to match the DTSTART
# timezone (America/New_York), and so we'll coerce to UTC _for you_
# automatically.
#
if 'until=' in rrule.lower():
# if DTSTART;TZID= is used, coerce "naive" UNTIL values
# to the proper UTC date
match_until = re.match(".*?(?P<until>UNTIL\=[0-9]+T[0-9]+)(?P<utcflag>Z?)", rrule)
if not len(match_until.group('utcflag')):
# rrule = DTSTART;TZID=America/New_York:20200601T120000 RRULE:...;UNTIL=20200601T170000
# Find the UNTIL=N part of the string
# naive_until = UNTIL=20200601T170000
naive_until = match_until.group('until')
# What is the DTSTART timezone for:
# DTSTART;TZID=America/New_York:20200601T120000 RRULE:...;UNTIL=20200601T170000Z
# local_tz = tzfile('/usr/share/zoneinfo/America/New_York')
local_tz = dateutil.rrule.rrulestr(
rrule.replace(naive_until, naive_until + 'Z'),
tzinfos=UTC_TIMEZONES
)._dtstart.tzinfo
# Make a datetime object with tzinfo=<the DTSTART timezone>
# localized_until = datetime.datetime(2020, 6, 1, 17, 0, tzinfo=tzfile('/usr/share/zoneinfo/America/New_York'))
localized_until = make_aware(
datetime.datetime.strptime(re.sub('^UNTIL=', '', naive_until), "%Y%m%dT%H%M%S"),
local_tz
)
# Coerce the datetime to UTC and format it as a string w/ Zulu format
# utc_until = UNTIL=20200601T220000Z
utc_until = 'UNTIL=' + localized_until.astimezone(pytz.utc).strftime('%Y%m%dT%H%M%SZ')
# rrule was: DTSTART;TZID=America/New_York:20200601T120000 RRULE:...;UNTIL=20200601T170000
# rrule is now: DTSTART;TZID=America/New_York:20200601T120000 RRULE:...;UNTIL=20200601T220000Z
rrule = rrule.replace(naive_until, utc_until)
return rrule
@classmethod @classmethod
def rrulestr(cls, rrule, **kwargs): def rrulestr(cls, rrule, **kwargs):
""" """
Apply our own custom rrule parsing requirements Apply our own custom rrule parsing requirements
""" """
rrule = Schedule.coerce_naive_until(rrule)
kwargs['forceset'] = True kwargs['forceset'] = True
x = dateutil.rrule.rrulestr(rrule, **kwargs) x = dateutil.rrule.rrulestr(rrule, tzinfos=UTC_TIMEZONES, **kwargs)
for r in x._rrule: for r in x._rrule:
if r._dtstart and r._dtstart.tzinfo is None: if r._dtstart and r._dtstart.tzinfo is None:
@@ -158,4 +250,5 @@ class Schedule(CommonModel, LaunchTimeConfig):
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
self.update_computed_fields() self.update_computed_fields()
self.rrule = Schedule.coerce_naive_until(self.rrule)
super(Schedule, self).save(*args, **kwargs) super(Schedule, self).save(*args, **kwargs)

View File

@@ -838,8 +838,11 @@ class UnifiedJob(PolymorphicModel, PasswordFieldsModel, CommonModelNameNotUnique
setattr(unified_job, fd, val) setattr(unified_job, fd, val)
unified_job.save() unified_job.save()
# Labels coppied here # Labels copied here
copy_m2m_relationships(self, unified_job, fields) from awx.main.signals import disable_activity_stream
with disable_activity_stream():
copy_m2m_relationships(self, unified_job, fields)
return unified_job return unified_job
def launch_prompts(self): def launch_prompts(self):

View File

@@ -205,7 +205,7 @@ def set_original_organization(sender, instance, **kwargs):
pre-save organization, so we can later determine if the organization pre-save organization, so we can later determine if the organization
field is dirty. field is dirty.
''' '''
instance.__original_org = instance.organization instance.__original_org_id = instance.organization_id
def save_related_job_templates(sender, instance, **kwargs): def save_related_job_templates(sender, instance, **kwargs):
@@ -217,7 +217,7 @@ def save_related_job_templates(sender, instance, **kwargs):
if sender not in (Project, Inventory): if sender not in (Project, Inventory):
raise ValueError('This signal callback is only intended for use with Project or Inventory') raise ValueError('This signal callback is only intended for use with Project or Inventory')
if instance.__original_org != instance.organization: if instance.__original_org_id != instance.organization_id:
jtq = JobTemplate.objects.filter(**{sender.__name__.lower(): instance}) jtq = JobTemplate.objects.filter(**{sender.__name__.lower(): instance})
for jt in jtq: for jt in jtq:
update_role_parentage_for_instance(jt) update_role_parentage_for_instance(jt)
@@ -494,6 +494,8 @@ def activity_stream_delete(sender, instance, **kwargs):
return return
changes = model_to_dict(instance) changes = model_to_dict(instance)
object1 = camelcase_to_underscore(instance.__class__.__name__) object1 = camelcase_to_underscore(instance.__class__.__name__)
if type(instance) == OAuth2AccessToken:
changes['token'] = TOKEN_CENSOR
activity_entry = ActivityStream( activity_entry = ActivityStream(
operation='delete', operation='delete',
changes=json.dumps(changes), changes=json.dumps(changes),

View File

@@ -29,7 +29,7 @@ except Exception:
# Celery # Celery
from celery import Task, shared_task, Celery from celery import Task, shared_task, Celery
from celery.signals import celeryd_init, worker_process_init, worker_shutdown, worker_ready, celeryd_after_setup from celery.signals import celeryd_init, worker_shutdown, worker_ready, celeryd_after_setup
# Django # Django
from django.conf import settings from django.conf import settings
@@ -49,6 +49,7 @@ from crum import impersonate
# AWX # AWX
from awx import __version__ as awx_application_version from awx import __version__ as awx_application_version
from awx.main.constants import CLOUD_PROVIDERS, PRIVILEGE_ESCALATION_METHODS from awx.main.constants import CLOUD_PROVIDERS, PRIVILEGE_ESCALATION_METHODS
from awx.main.access import access_registry
from awx.main.models import * # noqa from awx.main.models import * # noqa
from awx.main.constants import ACTIVE_STATES from awx.main.constants import ACTIVE_STATES
from awx.main.exceptions import AwxTaskError from awx.main.exceptions import AwxTaskError
@@ -59,13 +60,15 @@ from awx.main.utils import (get_ansible_version, get_ssh_version, decrypt_field,
wrap_args_with_proot, OutputEventFilter, OutputVerboseFilter, ignore_inventory_computed_fields, wrap_args_with_proot, OutputEventFilter, OutputVerboseFilter, ignore_inventory_computed_fields,
ignore_inventory_group_removal, get_type_for_model, extract_ansible_vars) ignore_inventory_group_removal, get_type_for_model, extract_ansible_vars)
from awx.main.utils.safe_yaml import safe_dump, sanitize_jinja from awx.main.utils.safe_yaml import safe_dump, sanitize_jinja
from awx.main.utils.reload import restart_local_services, stop_local_services from awx.main.utils.reload import stop_local_services
from awx.main.utils.pglock import advisory_lock from awx.main.utils.pglock import advisory_lock
from awx.main.utils.ha import update_celery_worker_routes, register_celery_worker_queues from awx.main.utils.ha import register_celery_worker_queues
from awx.main.utils.handlers import configure_external_logger
from awx.main.consumers import emit_channel_notification from awx.main.consumers import emit_channel_notification
from awx.conf import settings_registry from awx.conf import settings_registry
from rest_framework.exceptions import PermissionDenied
__all__ = ['RunJob', 'RunSystemJob', 'RunProjectUpdate', 'RunInventoryUpdate', __all__ = ['RunJob', 'RunSystemJob', 'RunProjectUpdate', 'RunInventoryUpdate',
'RunAdHocCommand', 'handle_work_error', 'handle_work_success', 'apply_cluster_membership_policies', 'RunAdHocCommand', 'handle_work_error', 'handle_work_success', 'apply_cluster_membership_policies',
'update_inventory_computed_fields', 'update_host_smart_inventory_memberships', 'update_inventory_computed_fields', 'update_host_smart_inventory_memberships',
@@ -117,15 +120,6 @@ def celery_startup(conf=None, **kwargs):
logger.exception(six.text_type("Failed to rebuild schedule {}.").format(sch)) logger.exception(six.text_type("Failed to rebuild schedule {}.").format(sch))
@worker_process_init.connect
def task_set_logger_pre_run(*args, **kwargs):
try:
cache.close()
configure_external_logger(settings, is_startup=False)
except Exception:
logger.exception('Encountered error on initial log configuration.')
@worker_shutdown.connect @worker_shutdown.connect
def inform_cluster_of_shutdown(*args, **kwargs): def inform_cluster_of_shutdown(*args, **kwargs):
try: try:
@@ -152,7 +146,7 @@ def apply_cluster_membership_policies(self):
# Process policy instance list first, these will represent manually managed instances # Process policy instance list first, these will represent manually managed instances
# that will not go through automatic policy determination # that will not go through automatic policy determination
for ig in InstanceGroup.objects.all(): for ig in InstanceGroup.objects.all():
logger.info(six.text_type("Considering group {}").format(ig.name)) logger.info(six.text_type("Applying cluster membership policies to Group {}").format(ig.name))
ig.instances.clear() ig.instances.clear()
group_actual = Group(obj=ig, instances=[]) group_actual = Group(obj=ig, instances=[])
for i in ig.policy_instance_list: for i in ig.policy_instance_list:
@@ -160,7 +154,7 @@ def apply_cluster_membership_policies(self):
if not inst.exists(): if not inst.exists():
continue continue
inst = inst[0] inst = inst[0]
logger.info(six.text_type("Policy List, adding {} to {}").format(inst.hostname, ig.name)) logger.info(six.text_type("Policy List, adding Instance {} to Group {}").format(inst.hostname, ig.name))
group_actual.instances.append(inst.id) group_actual.instances.append(inst.id)
ig.instances.add(inst) ig.instances.add(inst)
filtered_instances.append(inst) filtered_instances.append(inst)
@@ -173,7 +167,7 @@ def apply_cluster_membership_policies(self):
for i in sorted(actual_instances, cmp=lambda x,y: len(x.groups) - len(y.groups)): for i in sorted(actual_instances, cmp=lambda x,y: len(x.groups) - len(y.groups)):
if len(g.instances) >= g.obj.policy_instance_minimum: if len(g.instances) >= g.obj.policy_instance_minimum:
break break
logger.info(six.text_type("Policy minimum, adding {} to {}").format(i.obj.hostname, g.obj.name)) logger.info(six.text_type("Policy minimum, adding Instance {} to Group {}").format(i.obj.hostname, g.obj.name))
g.obj.instances.add(i.obj) g.obj.instances.add(i.obj)
g.instances.append(i.obj.id) g.instances.append(i.obj.id)
i.groups.append(g.obj.id) i.groups.append(g.obj.id)
@@ -182,14 +176,14 @@ def apply_cluster_membership_policies(self):
for i in sorted(actual_instances, cmp=lambda x,y: len(x.groups) - len(y.groups)): for i in sorted(actual_instances, cmp=lambda x,y: len(x.groups) - len(y.groups)):
if 100 * float(len(g.instances)) / len(actual_instances) >= g.obj.policy_instance_percentage: if 100 * float(len(g.instances)) / len(actual_instances) >= g.obj.policy_instance_percentage:
break break
logger.info(six.text_type("Policy percentage, adding {} to {}").format(i.obj.hostname, g.obj.name)) logger.info(six.text_type("Policy percentage, adding Instance {} to Group {}").format(i.obj.hostname, g.obj.name))
g.instances.append(i.obj.id) g.instances.append(i.obj.id)
g.obj.instances.add(i.obj) g.obj.instances.add(i.obj)
i.groups.append(g.obj.id) i.groups.append(g.obj.id)
handle_ha_toplogy_changes.apply([]) handle_ha_toplogy_changes.apply([])
@shared_task(queue='tower_broadcast_all', bind=True) @shared_task(exchange='tower_broadcast_all', bind=True)
def handle_setting_changes(self, setting_keys): def handle_setting_changes(self, setting_keys):
orig_len = len(setting_keys) orig_len = len(setting_keys)
for i in range(orig_len): for i in range(orig_len):
@@ -200,15 +194,9 @@ def handle_setting_changes(self, setting_keys):
cache_keys = set(setting_keys) cache_keys = set(setting_keys)
logger.debug('cache delete_many(%r)', cache_keys) logger.debug('cache delete_many(%r)', cache_keys)
cache.delete_many(cache_keys) cache.delete_many(cache_keys)
for key in cache_keys:
if key.startswith('LOG_AGGREGATOR_'):
restart_local_services(['uwsgi', 'celery', 'beat', 'callback'])
break
elif key == 'OAUTH2_PROVIDER':
restart_local_services(['uwsgi'])
@shared_task(bind=True, queue='tower_broadcast_all') @shared_task(bind=True, exchange='tower_broadcast_all')
def handle_ha_toplogy_changes(self): def handle_ha_toplogy_changes(self):
(changed, instance) = Instance.objects.get_or_register() (changed, instance) = Instance.objects.get_or_register()
if changed: if changed:
@@ -217,39 +205,24 @@ def handle_ha_toplogy_changes(self):
awx_app = Celery('awx') awx_app = Celery('awx')
awx_app.config_from_object('django.conf:settings') awx_app.config_from_object('django.conf:settings')
instances, removed_queues, added_queues = register_celery_worker_queues(awx_app, self.request.hostname) instances, removed_queues, added_queues = register_celery_worker_queues(awx_app, self.request.hostname)
for instance in instances: if len(removed_queues) + len(added_queues) > 0:
logger.info(six.text_type("Workers on tower node '{}' removed from queues {} and added to queues {}") logger.info(six.text_type("Workers on tower node(s) '{}' removed from queues {} and added to queues {}")
.format(instance.hostname, removed_queues, added_queues)) .format([i.hostname for i in instances], removed_queues, added_queues))
updated_routes = update_celery_worker_routes(instance, settings)
logger.info(six.text_type("Worker on tower node '{}' updated celery routes {} all routes are now {}")
.format(instance.hostname, updated_routes, self.app.conf.CELERY_ROUTES))
@worker_ready.connect @worker_ready.connect
def handle_ha_toplogy_worker_ready(sender, **kwargs): def handle_ha_toplogy_worker_ready(sender, **kwargs):
logger.debug(six.text_type("Configure celeryd queues task on host {}").format(sender.hostname)) logger.debug(six.text_type("Configure celeryd queues task on host {}").format(sender.hostname))
instances, removed_queues, added_queues = register_celery_worker_queues(sender.app, sender.hostname) instances, removed_queues, added_queues = register_celery_worker_queues(sender.app, sender.hostname)
for instance in instances: if len(removed_queues) + len(added_queues) > 0:
logger.info(six.text_type("Workers on tower node '{}' unsubscribed from queues {} and subscribed to queues {}") logger.info(six.text_type("Workers on tower node(s) '{}' removed from queues {} and added to queues {}")
.format(instance.hostname, removed_queues, added_queues)) .format([i.hostname for i in instances], removed_queues, added_queues))
# Expedite the first hearbeat run so a node comes online quickly. # Expedite the first hearbeat run so a node comes online quickly.
cluster_node_heartbeat.apply([]) cluster_node_heartbeat.apply([])
apply_cluster_membership_policies.apply([]) apply_cluster_membership_policies.apply([])
@celeryd_init.connect
def handle_update_celery_routes(sender=None, conf=None, **kwargs):
conf = conf if conf else sender.app.conf
logger.debug(six.text_type("Registering celery routes for {}").format(sender))
(changed, instance) = Instance.objects.get_or_register()
if changed:
logger.info(six.text_type("Registered tower node '{}'").format(instance.hostname))
added_routes = update_celery_worker_routes(instance, conf)
logger.info(six.text_type("Workers on tower node '{}' added routes {} all routes are now {}")
.format(instance.hostname, added_routes, conf.CELERY_ROUTES))
@celeryd_after_setup.connect @celeryd_after_setup.connect
def handle_update_celery_hostname(sender, instance, **kwargs): def handle_update_celery_hostname(sender, instance, **kwargs):
(changed, tower_instance) = Instance.objects.get_or_register() (changed, tower_instance) = Instance.objects.get_or_register()
@@ -282,7 +255,10 @@ def send_notifications(notification_list, job_id=None):
notification.error = smart_str(e) notification.error = smart_str(e)
update_fields.append('error') update_fields.append('error')
finally: finally:
notification.save(update_fields=update_fields) try:
notification.save(update_fields=update_fields)
except Exception as e:
logger.exception(six.text_type('Error saving notification {} result.').format(notification.id))
@shared_task(bind=True, queue=settings.CELERY_DEFAULT_QUEUE) @shared_task(bind=True, queue=settings.CELERY_DEFAULT_QUEUE)
@@ -426,6 +402,13 @@ def awx_periodic_scheduler(self):
for schedule in old_schedules: for schedule in old_schedules:
schedule.save() schedule.save()
schedules = Schedule.objects.enabled().between(last_run, run_now) schedules = Schedule.objects.enabled().between(last_run, run_now)
invalid_license = False
try:
access_registry[Job](None).check_license()
except PermissionDenied as e:
invalid_license = e
for schedule in schedules: for schedule in schedules:
template = schedule.unified_job_template template = schedule.unified_job_template
schedule.save() # To update next_run timestamp. schedule.save() # To update next_run timestamp.
@@ -435,6 +418,13 @@ def awx_periodic_scheduler(self):
try: try:
job_kwargs = schedule.get_job_kwargs() job_kwargs = schedule.get_job_kwargs()
new_unified_job = schedule.unified_job_template.create_unified_job(**job_kwargs) new_unified_job = schedule.unified_job_template.create_unified_job(**job_kwargs)
if invalid_license:
new_unified_job.status = 'failed'
new_unified_job.job_explanation = str(invalid_license)
new_unified_job.save(update_fields=['status', 'job_explanation'])
new_unified_job.websocket_emit_status("failed")
raise invalid_license
can_start = new_unified_job.signal_start() can_start = new_unified_job.signal_start()
except Exception: except Exception:
logger.exception('Error spawning scheduled job.') logger.exception('Error spawning scheduled job.')
@@ -561,6 +551,8 @@ def delete_inventory(self, inventory_id, user_id):
with ignore_inventory_computed_fields(), ignore_inventory_group_removal(), impersonate(user): with ignore_inventory_computed_fields(), ignore_inventory_group_removal(), impersonate(user):
try: try:
i = Inventory.objects.get(id=inventory_id) i = Inventory.objects.get(id=inventory_id)
for host in i.hosts.iterator():
host.job_events_as_primary_host.update(host=None)
i.delete() i.delete()
emit_channel_notification( emit_channel_notification(
'inventories-status_changed', 'inventories-status_changed',
@@ -1677,7 +1669,13 @@ class RunProjectUpdate(BaseTask):
raise raise
try: try:
start_time = time.time()
fcntl.flock(self.lock_fd, fcntl.LOCK_EX) fcntl.flock(self.lock_fd, fcntl.LOCK_EX)
waiting_time = time.time() - start_time
if waiting_time > 1.0:
logger.info(six.text_type(
'{} spent {} waiting to acquire lock for local source tree '
'for path {}.').format(instance.log_format, waiting_time, lock_path))
except IOError as e: except IOError as e:
os.close(self.lock_fd) os.close(self.lock_fd)
logger.error(six.text_type("I/O error({0}) while trying to aquire lock on file [{1}]: {2}").format(e.errno, lock_path, e.strerror)) logger.error(six.text_type("I/O error({0}) while trying to aquire lock on file [{1}]: {2}").format(e.errno, lock_path, e.strerror))
@@ -1725,6 +1723,10 @@ class RunInventoryUpdate(BaseTask):
event_model = InventoryUpdateEvent event_model = InventoryUpdateEvent
event_data_key = 'inventory_update_id' event_data_key = 'inventory_update_id'
@property
def proot_show_paths(self):
return [self.get_path_to('..', 'plugins', 'inventory')]
def build_private_data(self, inventory_update, **kwargs): def build_private_data(self, inventory_update, **kwargs):
""" """
Return private data needed for inventory update. Return private data needed for inventory update.
@@ -2080,6 +2082,8 @@ class RunInventoryUpdate(BaseTask):
return args return args
def build_cwd(self, inventory_update, **kwargs): def build_cwd(self, inventory_update, **kwargs):
if inventory_update.source == 'scm' and inventory_update.source_project_update:
return inventory_update.source_project_update.get_project_path(check_if_exists=False)
return self.get_path_to('..', 'plugins', 'inventory') return self.get_path_to('..', 'plugins', 'inventory')
def get_idle_timeout(self): def get_idle_timeout(self):
@@ -2331,6 +2335,9 @@ def _reconstruct_relationships(copy_mapping):
setattr(new_obj, field_name, related_obj) setattr(new_obj, field_name, related_obj)
elif field.many_to_many: elif field.many_to_many:
for related_obj in getattr(old_obj, field_name).all(): for related_obj in getattr(old_obj, field_name).all():
logger.debug(six.text_type('Deep copy: Adding {} to {}({}).{} relationship').format(
related_obj, new_obj, model, field_name
))
getattr(new_obj, field_name).add(copy_mapping.get(related_obj, related_obj)) getattr(new_obj, field_name).add(copy_mapping.get(related_obj, related_obj))
new_obj.save() new_obj.save()
@@ -2352,7 +2359,7 @@ def deep_copy_model_obj(
except ObjectDoesNotExist: except ObjectDoesNotExist:
logger.warning("Object or user no longer exists.") logger.warning("Object or user no longer exists.")
return return
with transaction.atomic(): with transaction.atomic(), ignore_inventory_computed_fields():
copy_mapping = {} copy_mapping = {}
for sub_obj_setup in sub_obj_list: for sub_obj_setup in sub_obj_list:
sub_model = getattr(importlib.import_module(sub_obj_setup[0]), sub_model = getattr(importlib.import_module(sub_obj_setup[0]),
@@ -2372,3 +2379,5 @@ def deep_copy_model_obj(
importlib.import_module(permission_check_func[0]), permission_check_func[1] importlib.import_module(permission_check_func[0]), permission_check_func[1]
), permission_check_func[2]) ), permission_check_func[2])
permission_check_func(creater, copy_mapping.values()) permission_check_func(creater, copy_mapping.values())
if isinstance(new_obj, Inventory):
update_inventory_computed_fields.delay(new_obj.id, True)

View File

@@ -15,6 +15,16 @@ from awx.main.tests.factories import (
) )
def pytest_configure(config):
import sys
sys._called_from_test = True
def pytest_unconfigure(config):
import sys
del sys._called_from_test
@pytest.fixture @pytest.fixture
def mock_access(): def mock_access():
@contextmanager @contextmanager
@@ -96,3 +106,21 @@ def get_ssh_version(mocker):
@pytest.fixture @pytest.fixture
def job_template_with_survey_passwords_unit(job_template_with_survey_passwords_factory): def job_template_with_survey_passwords_unit(job_template_with_survey_passwords_factory):
return job_template_with_survey_passwords_factory(persisted=False) return job_template_with_survey_passwords_factory(persisted=False)
@pytest.fixture
def mock_cache():
class MockCache(object):
cache = {}
def get(self, key, default=None):
return self.cache.get(key, default)
def set(self, key, value, timeout=60):
self.cache[key] = value
def delete(self, key):
del self.cache[key]
return MockCache()

View File

@@ -1,6 +1,7 @@
from django.db import connection from django.db import connection
from django.db.models.signals import post_migrate from django.db.models.signals import post_migrate
from django.apps import apps from django.apps import apps
from django.conf import settings
def app_post_migration(sender, app_config, **kwargs): def app_post_migration(sender, app_config, **kwargs):
@@ -17,7 +18,8 @@ def app_post_migration(sender, app_config, **kwargs):
) )
post_migrate.connect(app_post_migration, sender=apps.get_app_config('main')) if settings.DATABASES['default']['ENGINE'] == 'django.db.backends.sqlite3':
post_migrate.connect(app_post_migration, sender=apps.get_app_config('main'))

View File

@@ -94,10 +94,16 @@ class TestDeleteViews:
@pytest.mark.django_db @pytest.mark.django_db
def test_non_filterable_field(options, instance, admin_user): def test_filterable_fields(options, instance, admin_user):
r = options( r = options(
url=instance.get_absolute_url(), url=instance.get_absolute_url(),
user=admin_user user=admin_user
) )
field_info = r.data['actions']['GET']['percent_capacity_remaining']
assert 'filterable' in field_info filterable_info = r.data['actions']['GET']['created']
non_filterable_info = r.data['actions']['GET']['percent_capacity_remaining']
assert 'filterable' in filterable_info
assert filterable_info['filterable'] is True
assert 'filterable' not in non_filterable_info

View File

@@ -87,7 +87,7 @@ def test_delete_instance_group_jobs_running(delete, instance_group_jobs_running,
@pytest.mark.django_db @pytest.mark.django_db
def test_modify_delete_tower_instance_group_prevented(delete, options, tower_instance_group, user, patch, put): def test_delete_rename_tower_instance_group_prevented(delete, options, tower_instance_group, instance_group, user, patch):
url = reverse("api:instance_group_detail", kwargs={'pk': tower_instance_group.pk}) url = reverse("api:instance_group_detail", kwargs={'pk': tower_instance_group.pk})
super_user = user('bob', True) super_user = user('bob', True)
@@ -99,6 +99,13 @@ def test_modify_delete_tower_instance_group_prevented(delete, options, tower_ins
assert 'GET' in resp.data['actions'] assert 'GET' in resp.data['actions']
assert 'PUT' in resp.data['actions'] assert 'PUT' in resp.data['actions']
# Rename 'tower' instance group denied
patch(url, {'name': 'tower_prime'}, super_user, expect=400)
# Rename, other instance group OK
url = reverse("api:instance_group_detail", kwargs={'pk': instance_group.pk})
patch(url, {'name': 'foobar'}, super_user, expect=200)
@pytest.mark.django_db @pytest.mark.django_db
def test_prevent_delete_iso_and_control_groups(delete, isolated_instance_group, admin): def test_prevent_delete_iso_and_control_groups(delete, isolated_instance_group, admin):

View File

@@ -126,9 +126,8 @@ def test_list_cannot_order_by_unsearchable_field(get, organization, alice, order
) )
custom_script.admin_role.members.add(alice) custom_script.admin_role.members.add(alice)
response = get(reverse('api:inventory_script_list'), alice, get(reverse('api:inventory_script_list'), alice,
QUERY_STRING='order_by=%s' % order_by, status=400) QUERY_STRING='order_by=%s' % order_by, expect=403)
assert response.status_code == 400
@pytest.mark.parametrize("role_field,expected_status_code", [ @pytest.mark.parametrize("role_field,expected_status_code", [

View File

@@ -625,17 +625,31 @@ def test_save_survey_passwords_on_migration(job_template_with_survey_passwords):
@pytest.mark.django_db @pytest.mark.django_db
def test_job_template_custom_virtualenv(get, patch, organization_factory, job_template_factory): @pytest.mark.parametrize('access', ["superuser", "admin", "peon"])
def test_job_template_custom_virtualenv(get, patch, organization_factory, job_template_factory, alice, access):
objs = organization_factory("org", superusers=['admin']) objs = organization_factory("org", superusers=['admin'])
jt = job_template_factory("jt", organization=objs.organization, jt = job_template_factory("jt", organization=objs.organization,
inventory='test_inv', project='test_proj').job_template inventory='test_inv', project='test_proj').job_template
user = alice
if access == "superuser":
user = objs.superusers.admin
elif access == "admin":
jt.admin_role.members.add(alice)
else:
jt.read_role.members.add(alice)
with TemporaryDirectory(dir=settings.BASE_VENV_PATH) as temp_dir: with TemporaryDirectory(dir=settings.BASE_VENV_PATH) as temp_dir:
admin = objs.superusers.admin
os.makedirs(os.path.join(temp_dir, 'bin', 'activate')) os.makedirs(os.path.join(temp_dir, 'bin', 'activate'))
url = reverse('api:job_template_detail', kwargs={'pk': jt.id}) url = reverse('api:job_template_detail', kwargs={'pk': jt.id})
patch(url, {'custom_virtualenv': temp_dir}, user=admin, expect=200)
assert get(url, user=admin).data['custom_virtualenv'] == os.path.join(temp_dir, '') if access == "peon":
patch(url, {'custom_virtualenv': temp_dir}, user=user, expect=403)
assert 'custom_virtualenv' not in get(url, user=user)
assert JobTemplate.objects.get(pk=jt.id).custom_virtualenv is None
else:
patch(url, {'custom_virtualenv': temp_dir}, user=user, expect=200)
assert get(url, user=user).data['custom_virtualenv'] == os.path.join(temp_dir, '')
@pytest.mark.django_db @pytest.mark.django_db

View File

@@ -172,3 +172,12 @@ def test_oauth_application_delete(oauth_application, post, delete, admin):
assert Application.objects.filter(client_id=oauth_application.client_id).count() == 0 assert Application.objects.filter(client_id=oauth_application.client_id).count() == 0
assert RefreshToken.objects.filter(application=oauth_application).count() == 0 assert RefreshToken.objects.filter(application=oauth_application).count() == 0
assert AccessToken.objects.filter(application=oauth_application).count() == 0 assert AccessToken.objects.filter(application=oauth_application).count() == 0
@pytest.mark.django_db
def test_oauth_list_user_tokens(oauth_application, post, get, admin, alice):
for user in (admin, alice):
url = reverse('api:o_auth2_token_list', kwargs={'pk': user.pk})
post(url, {'scope': 'read'}, user, expect=201)
response = get(url, admin, expect=200)
assert response.data['count'] == 1

View File

@@ -14,7 +14,7 @@ import mock
# AWX # AWX
from awx.api.versioning import reverse from awx.api.versioning import reverse
from awx.conf.models import Setting from awx.conf.models import Setting
from awx.main.utils.handlers import BaseHTTPSHandler, LoggingConnectivityException from awx.main.utils.handlers import AWXProxyHandler, LoggingConnectivityException
import six import six
@@ -217,7 +217,7 @@ def test_logging_aggregrator_connection_test_bad_request(get, post, admin, key):
@pytest.mark.django_db @pytest.mark.django_db
def test_logging_aggregrator_connection_test_valid(mocker, get, post, admin): def test_logging_aggregrator_connection_test_valid(mocker, get, post, admin):
with mock.patch.object(BaseHTTPSHandler, 'perform_test') as perform_test: with mock.patch.object(AWXProxyHandler, 'perform_test') as perform_test:
url = reverse('api:setting_logging_test') url = reverse('api:setting_logging_test')
user_data = { user_data = {
'LOG_AGGREGATOR_TYPE': 'logstash', 'LOG_AGGREGATOR_TYPE': 'logstash',
@@ -227,7 +227,8 @@ def test_logging_aggregrator_connection_test_valid(mocker, get, post, admin):
'LOG_AGGREGATOR_PASSWORD': 'mcstash' 'LOG_AGGREGATOR_PASSWORD': 'mcstash'
} }
post(url, user_data, user=admin, expect=200) post(url, user_data, user=admin, expect=200)
create_settings = perform_test.call_args[0][0] args, kwargs = perform_test.call_args_list[0]
create_settings = kwargs['custom_settings']
for k, v in user_data.items(): for k, v in user_data.items():
assert hasattr(create_settings, k) assert hasattr(create_settings, k)
assert getattr(create_settings, k) == v assert getattr(create_settings, k) == v
@@ -238,7 +239,7 @@ def test_logging_aggregrator_connection_test_with_masked_password(mocker, patch,
url = reverse('api:setting_singleton_detail', kwargs={'category_slug': 'logging'}) url = reverse('api:setting_singleton_detail', kwargs={'category_slug': 'logging'})
patch(url, user=admin, data={'LOG_AGGREGATOR_PASSWORD': 'password123'}, expect=200) patch(url, user=admin, data={'LOG_AGGREGATOR_PASSWORD': 'password123'}, expect=200)
with mock.patch.object(BaseHTTPSHandler, 'perform_test') as perform_test: with mock.patch.object(AWXProxyHandler, 'perform_test') as perform_test:
url = reverse('api:setting_logging_test') url = reverse('api:setting_logging_test')
user_data = { user_data = {
'LOG_AGGREGATOR_TYPE': 'logstash', 'LOG_AGGREGATOR_TYPE': 'logstash',
@@ -248,13 +249,14 @@ def test_logging_aggregrator_connection_test_with_masked_password(mocker, patch,
'LOG_AGGREGATOR_PASSWORD': '$encrypted$' 'LOG_AGGREGATOR_PASSWORD': '$encrypted$'
} }
post(url, user_data, user=admin, expect=200) post(url, user_data, user=admin, expect=200)
create_settings = perform_test.call_args[0][0] args, kwargs = perform_test.call_args_list[0]
create_settings = kwargs['custom_settings']
assert getattr(create_settings, 'LOG_AGGREGATOR_PASSWORD') == 'password123' assert getattr(create_settings, 'LOG_AGGREGATOR_PASSWORD') == 'password123'
@pytest.mark.django_db @pytest.mark.django_db
def test_logging_aggregrator_connection_test_invalid(mocker, get, post, admin): def test_logging_aggregrator_connection_test_invalid(mocker, get, post, admin):
with mock.patch.object(BaseHTTPSHandler, 'perform_test') as perform_test: with mock.patch.object(AWXProxyHandler, 'perform_test') as perform_test:
perform_test.side_effect = LoggingConnectivityException('404: Not Found') perform_test.side_effect = LoggingConnectivityException('404: Not Found')
url = reverse('api:setting_logging_test') url = reverse('api:setting_logging_test')
resp = post(url, { resp = post(url, {

View File

@@ -8,6 +8,7 @@ import tempfile
from django.conf import settings from django.conf import settings
from django.db.backends.sqlite3.base import SQLiteCursorWrapper from django.db.backends.sqlite3.base import SQLiteCursorWrapper
import mock
import pytest import pytest
from awx.api.versioning import reverse from awx.api.versioning import reverse
@@ -184,6 +185,7 @@ def test_text_stdout_with_max_stdout(sqlite_copy_expert, get, admin):
[_mk_inventory_update, InventoryUpdateEvent, 'inventory_update', 'api:inventory_update_stdout'], [_mk_inventory_update, InventoryUpdateEvent, 'inventory_update', 'api:inventory_update_stdout'],
]) ])
@pytest.mark.parametrize('fmt', ['txt', 'ansi']) @pytest.mark.parametrize('fmt', ['txt', 'ansi'])
@mock.patch('awx.main.redact.UriCleaner.SENSITIVE_URI_PATTERN', mock.Mock(**{'search.return_value': None})) # really slow for large strings
def test_max_bytes_display(sqlite_copy_expert, Parent, Child, relation, view, fmt, get, admin): def test_max_bytes_display(sqlite_copy_expert, Parent, Child, relation, view, fmt, get, admin):
job = Parent() job = Parent()
job.save() job.save()
@@ -231,6 +233,7 @@ def test_legacy_result_stdout_text_fallback(Cls, view, fmt, get, admin):
[_mk_inventory_update, 'api:inventory_update_stdout'] [_mk_inventory_update, 'api:inventory_update_stdout']
]) ])
@pytest.mark.parametrize('fmt', ['txt', 'ansi']) @pytest.mark.parametrize('fmt', ['txt', 'ansi'])
@mock.patch('awx.main.redact.UriCleaner.SENSITIVE_URI_PATTERN', mock.Mock(**{'search.return_value': None})) # really slow for large strings
def test_legacy_result_stdout_with_max_bytes(Cls, view, fmt, get, admin): def test_legacy_result_stdout_with_max_bytes(Cls, view, fmt, get, admin):
job = Cls() job = Cls()
job.save() job.save()
@@ -282,7 +285,7 @@ def test_unicode_with_base64_ansi(sqlite_copy_expert, get, admin):
url = reverse( url = reverse(
'api:job_stdout', 'api:job_stdout',
kwargs={'pk': job.pk} kwargs={'pk': job.pk}
) + '?format=json&content_encoding=base64&content_format=ansi' ) + '?format=json&content_encoding=base64'
response = get(url, user=admin, expect=200) response = get(url, user=admin, expect=200)
content = base64.b64decode(json.loads(response.content)['content']) content = base64.b64decode(json.loads(response.content)['content'])

View File

@@ -9,7 +9,6 @@ from six.moves import xrange
# Django # Django
from django.core.urlresolvers import resolve from django.core.urlresolvers import resolve
from django.core.cache import cache
from django.utils.six.moves.urllib.parse import urlparse from django.utils.six.moves.urllib.parse import urlparse
from django.utils import timezone from django.utils import timezone
from django.contrib.auth.models import User from django.contrib.auth.models import User
@@ -57,14 +56,6 @@ def swagger_autogen(requests=__SWAGGER_REQUESTS__):
return requests return requests
@pytest.fixture(autouse=True)
def clear_cache():
'''
Clear cache (local memory) for each test to prevent using cached settings.
'''
cache.clear()
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def celery_memory_broker(): def celery_memory_broker():
''' '''

View File

@@ -0,0 +1,33 @@
import pytest
import mock
from awx.main.models import Project
@pytest.mark.django_db
def test_project_initial_update():
with mock.patch.object(Project, "update") as mock_update:
Project.objects.create(name='foo', scm_type='git')
mock_update.assert_called_once_with()
@pytest.mark.django_db
def test_does_not_update_nonsensitive_change(project):
with mock.patch.object(Project, "update") as mock_update:
project.scm_update_on_launch = not project.scm_update_on_launch
project.save()
mock_update.assert_not_called()
@pytest.mark.django_db
def test_sensitive_change_triggers_update(project):
with mock.patch.object(Project, "update") as mock_update:
project.scm_url = 'https://foo.invalid'
project.save()
mock_update.assert_called_once_with()
# test other means of initialization
project = Project.objects.get(pk=project.pk)
with mock.patch.object(Project, "update") as mock_update:
project.scm_url = 'https://foo2.invalid'
project.save()
mock_update.assert_called_once_with()

View File

@@ -1,5 +1,6 @@
from datetime import datetime from datetime import datetime
from django.utils.timezone import now
import mock import mock
import pytest import pytest
import pytz import pytz
@@ -131,31 +132,19 @@ def test_utc_until(job_template, until, dtend):
@pytest.mark.django_db @pytest.mark.django_db
@pytest.mark.parametrize('dtstart, until', [ @pytest.mark.parametrize('dtstart, until', [
['20180601T120000Z', '20180602T170000'], ['DTSTART:20380601T120000Z', '20380601T170000'], # noon UTC to 5PM UTC
['TZID=America/New_York:20180601T120000', '20180602T170000'], ['DTSTART;TZID=America/New_York:20380601T120000', '20380601T170000'], # noon EST to 5PM EST
]) ])
def test_tzinfo_naive_until(job_template, dtstart, until): def test_tzinfo_naive_until(job_template, dtstart, until):
rrule = 'DTSTART;{} RRULE:FREQ=DAILY;INTERVAL=1;UNTIL={}'.format(dtstart, until) # noqa rrule = '{} RRULE:FREQ=HOURLY;INTERVAL=1;UNTIL={}'.format(dtstart, until) # noqa
s = Schedule( s = Schedule(
name='Some Schedule', name='Some Schedule',
rrule=rrule, rrule=rrule,
unified_job_template=job_template unified_job_template=job_template
) )
with pytest.raises(ValueError): s.save()
s.save() gen = Schedule.rrulestr(s.rrule).xafter(now(), count=20)
assert len(list(gen)) == 6 # noon, 1PM, 2, 3, 4, 5PM
@pytest.mark.django_db
def test_until_must_be_utc(job_template):
rrule = 'DTSTART;TZID=America/New_York:20180601T120000 RRULE:FREQ=DAILY;INTERVAL=1;UNTIL=20180602T000000' # noqa the Z is required
s = Schedule(
name='Some Schedule',
rrule=rrule,
unified_job_template=job_template
)
with pytest.raises(ValueError) as e:
s.save()
assert 'RRULE UNTIL values must be specified in UTC' in str(e)
@pytest.mark.django_db @pytest.mark.django_db
@@ -203,3 +192,85 @@ def test_beginning_of_time(job_template):
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
s.save() s.save()
@pytest.mark.django_db
@pytest.mark.parametrize('rrule, tz', [
['DTSTART:20300112T210000Z RRULE:FREQ=DAILY;INTERVAL=1', 'UTC'],
['DTSTART;TZID=America/New_York:20300112T210000 RRULE:FREQ=DAILY;INTERVAL=1', 'America/New_York']
])
def test_timezone_property(job_template, rrule, tz):
s = Schedule(
name='Some Schedule',
rrule=rrule,
unified_job_template=job_template
)
assert s.timezone == tz
@pytest.mark.django_db
def test_utc_until_property(job_template):
rrule = 'DTSTART:20380601T120000Z RRULE:FREQ=HOURLY;INTERVAL=1;UNTIL=20380601T170000Z'
s = Schedule(
name='Some Schedule',
rrule=rrule,
unified_job_template=job_template
)
s.save()
assert s.rrule.endswith('20380601T170000Z')
assert s.until == '2038-06-01T17:00:00'
@pytest.mark.django_db
def test_localized_until_property(job_template):
rrule = 'DTSTART;TZID=America/New_York:20380601T120000 RRULE:FREQ=HOURLY;INTERVAL=1;UNTIL=20380601T220000Z'
s = Schedule(
name='Some Schedule',
rrule=rrule,
unified_job_template=job_template
)
s.save()
assert s.rrule.endswith('20380601T220000Z')
assert s.until == '2038-06-01T17:00:00'
@pytest.mark.django_db
def test_utc_naive_coercion(job_template):
rrule = 'DTSTART:20380601T120000Z RRULE:FREQ=HOURLY;INTERVAL=1;UNTIL=20380601T170000'
s = Schedule(
name='Some Schedule',
rrule=rrule,
unified_job_template=job_template
)
s.save()
assert s.rrule.endswith('20380601T170000Z')
assert s.until == '2038-06-01T17:00:00'
@pytest.mark.django_db
def test_est_naive_coercion(job_template):
rrule = 'DTSTART;TZID=America/New_York:20380601T120000 RRULE:FREQ=HOURLY;INTERVAL=1;UNTIL=20380601T170000'
s = Schedule(
name='Some Schedule',
rrule=rrule,
unified_job_template=job_template
)
s.save()
assert s.rrule.endswith('20380601T220000Z') # 5PM EDT = 10PM UTC
assert s.until == '2038-06-01T17:00:00'
@pytest.mark.django_db
def test_empty_until_property(job_template):
rrule = 'DTSTART;TZID=America/New_York:20380601T120000 RRULE:FREQ=HOURLY;INTERVAL=1'
s = Schedule(
name='Some Schedule',
rrule=rrule,
unified_job_template=job_template
)
s.save()
assert s.until == ''

View File

@@ -113,7 +113,7 @@ def test_single_job_dependencies_project_launch(default_instance_group, job_temp
p.scm_update_cache_timeout = 0 p.scm_update_cache_timeout = 0
p.scm_type = "git" p.scm_type = "git"
p.scm_url = "http://github.com/ansible/ansible.git" p.scm_url = "http://github.com/ansible/ansible.git"
p.save() p.save(skip_update=True)
with mock.patch("awx.main.scheduler.TaskManager.start_task"): with mock.patch("awx.main.scheduler.TaskManager.start_task"):
tm = TaskManager() tm = TaskManager()
with mock.patch.object(TaskManager, "create_project_update", wraps=tm.create_project_update) as mock_pu: with mock.patch.object(TaskManager, "create_project_update", wraps=tm.create_project_update) as mock_pu:
@@ -241,15 +241,16 @@ def test_shared_dependencies_launch(default_instance_group, job_template_factory
@pytest.mark.django_db @pytest.mark.django_db
def test_cleanup_interval(): def test_cleanup_interval(mock_cache):
assert cache.get('last_celery_task_cleanup') is None with mock.patch.multiple('awx.main.scheduler.task_manager.cache', get=mock_cache.get, set=mock_cache.set):
assert mock_cache.get('last_celery_task_cleanup') is None
TaskManager().cleanup_inconsistent_celery_tasks() TaskManager().cleanup_inconsistent_celery_tasks()
last_cleanup = cache.get('last_celery_task_cleanup') last_cleanup = mock_cache.get('last_celery_task_cleanup')
assert isinstance(last_cleanup, datetime) assert isinstance(last_cleanup, datetime)
TaskManager().cleanup_inconsistent_celery_tasks() TaskManager().cleanup_inconsistent_celery_tasks()
assert cache.get('last_celery_task_cleanup') == last_cleanup assert cache.get('last_celery_task_cleanup') == last_cleanup
class TestReaper(): class TestReaper():
@@ -326,7 +327,8 @@ class TestReaper():
@pytest.mark.django_db @pytest.mark.django_db
@mock.patch.object(JobNotificationMixin, 'send_notification_templates') @mock.patch.object(JobNotificationMixin, 'send_notification_templates')
@mock.patch.object(TaskManager, 'get_active_tasks', lambda self: ([], [])) @mock.patch.object(TaskManager, 'get_active_tasks', lambda self: ([], []))
def test_cleanup_inconsistent_task(self, notify, active_tasks, considered_jobs, reapable_jobs, running_tasks, waiting_tasks, mocker): def test_cleanup_inconsistent_task(self, notify, active_tasks, considered_jobs, reapable_jobs, running_tasks, waiting_tasks, mocker, settings):
settings.AWX_INCONSISTENT_TASK_INTERVAL = 0
tm = TaskManager() tm = TaskManager()
tm.get_running_tasks = mocker.Mock(return_value=(running_tasks, waiting_tasks)) tm.get_running_tasks = mocker.Mock(return_value=(running_tasks, waiting_tasks))

View File

@@ -5,6 +5,7 @@ from awx.main.models import (
Organization, Organization,
Project, Project,
) )
from awx.main.fields import update_role_parentage_for_instance
@pytest.mark.django_db @pytest.mark.django_db
@@ -202,3 +203,11 @@ def test_auto_parenting():
assert org1.admin_role.is_ancestor_of(prj2.admin_role) is False assert org1.admin_role.is_ancestor_of(prj2.admin_role) is False
assert org2.admin_role.is_ancestor_of(prj1.admin_role) assert org2.admin_role.is_ancestor_of(prj1.admin_role)
assert org2.admin_role.is_ancestor_of(prj2.admin_role) assert org2.admin_role.is_ancestor_of(prj2.admin_role)
@pytest.mark.django_db
def test_update_parents_keeps_teams(team, project):
project.update_role.parents.add(team.member_role)
assert team.member_role in project.update_role # test prep sanity check
update_role_parentage_for_instance(project)
assert team.member_role in project.update_role # actual assertion

View File

@@ -102,21 +102,21 @@ class TestOAuth2Application:
assert access.can_delete(app) is can_access assert access.can_delete(app) is can_access
def test_superuser_can_always_create(self, admin, org_admin, org_member, alice): def test_superuser_can_always_create(self, admin, org_admin, org_member, alice, organization):
access = OAuth2ApplicationAccess(admin) access = OAuth2ApplicationAccess(admin)
for user in [admin, org_admin, org_member, alice]: for user in [admin, org_admin, org_member, alice]:
assert access.can_add({ assert access.can_add({
'name': 'test app', 'user': user.pk, 'client_type': 'confidential', 'name': 'test app', 'user': user.pk, 'client_type': 'confidential',
'authorization_grant_type': 'password', 'organization': 1 'authorization_grant_type': 'password', 'organization': organization.id
}) })
def test_normal_user_cannot_create(self, admin, org_admin, org_member, alice): def test_normal_user_cannot_create(self, admin, org_admin, org_member, alice, organization):
for access_user in [org_member, alice]: for access_user in [org_member, alice]:
access = OAuth2ApplicationAccess(access_user) access = OAuth2ApplicationAccess(access_user)
for user in [admin, org_admin, org_member, alice]: for user in [admin, org_admin, org_member, alice]:
assert not access.can_add({ assert not access.can_add({
'name': 'test app', 'user': user.pk, 'client_type': 'confidential', 'name': 'test app', 'user': user.pk, 'client_type': 'confidential',
'authorization_grant_type': 'password', 'organization': 1 'authorization_grant_type': 'password', 'organization': organization.id
}) })

View File

@@ -0,0 +1,6 @@
# Ensure that our autouse overwrites are working
def test_cache(settings):
assert settings.CACHES['default']['BACKEND'] == 'django.core.cache.backends.locmem.LocMemCache'
assert settings.CACHES['default']['LOCATION'].startswith('unique-')

View File

@@ -3,14 +3,19 @@
import pytest import pytest
from rest_framework.exceptions import PermissionDenied, ParseError from rest_framework.exceptions import PermissionDenied, ParseError
from awx.api.filters import FieldLookupBackend from awx.api.filters import FieldLookupBackend, OrderByBackend, get_field_from_path
from awx.main.models import (AdHocCommand, ActivityStream, from awx.main.models import (AdHocCommand, ActivityStream,
CustomInventoryScript, Credential, Job, CustomInventoryScript, Credential, Job,
JobTemplate, SystemJob, UnifiedJob, User, JobTemplate, SystemJob, UnifiedJob, User,
WorkflowJob, WorkflowJobTemplate, WorkflowJob, WorkflowJobTemplate,
WorkflowJobOptions, InventorySource) WorkflowJobOptions, InventorySource,
JobEvent)
from awx.main.models.oauth import OAuth2Application
from awx.main.models.jobs import JobOptions from awx.main.models.jobs import JobOptions
# Django
from django.db.models.fields import FieldDoesNotExist
def test_related(): def test_related():
field_lookup = FieldLookupBackend() field_lookup = FieldLookupBackend()
@@ -20,6 +25,27 @@ def test_related():
print(new_lookup) print(new_lookup)
def test_invalid_filter_key():
field_lookup = FieldLookupBackend()
# FieldDoesNotExist is caught and converted to ParseError by filter_queryset
with pytest.raises(FieldDoesNotExist) as excinfo:
field_lookup.value_to_python(JobEvent, 'event_data.task_action', 'foo')
assert 'has no field named' in str(excinfo)
def test_invalid_field_hop():
with pytest.raises(ParseError) as excinfo:
get_field_from_path(Credential, 'organization__description__user')
assert 'No related model for' in str(excinfo)
def test_invalid_order_by_key():
field_order_by = OrderByBackend()
with pytest.raises(ParseError) as excinfo:
[f for f in field_order_by._validate_ordering_fields(JobEvent, ('event_data.task_action',))]
assert 'has no field named' in str(excinfo)
@pytest.mark.parametrize(u"empty_value", [u'', '']) @pytest.mark.parametrize(u"empty_value", [u'', ''])
def test_empty_in(empty_value): def test_empty_in(empty_value):
field_lookup = FieldLookupBackend() field_lookup = FieldLookupBackend()
@@ -57,7 +83,6 @@ def test_filter_on_password_field(password_field, lookup_suffix):
(User, 'password__icontains'), (User, 'password__icontains'),
(User, 'settings__value__icontains'), (User, 'settings__value__icontains'),
(User, 'main_oauth2accesstoken__token__gt'), (User, 'main_oauth2accesstoken__token__gt'),
(User, 'main_oauth2application__name__gt'),
(UnifiedJob, 'job_args__icontains'), (UnifiedJob, 'job_args__icontains'),
(UnifiedJob, 'job_env__icontains'), (UnifiedJob, 'job_env__icontains'),
(UnifiedJob, 'start_args__icontains'), (UnifiedJob, 'start_args__icontains'),
@@ -70,8 +95,8 @@ def test_filter_on_password_field(password_field, lookup_suffix):
(JobTemplate, 'survey_spec__icontains'), (JobTemplate, 'survey_spec__icontains'),
(WorkflowJobTemplate, 'survey_spec__icontains'), (WorkflowJobTemplate, 'survey_spec__icontains'),
(CustomInventoryScript, 'script__icontains'), (CustomInventoryScript, 'script__icontains'),
(ActivityStream, 'o_auth2_access_token__gt'), (ActivityStream, 'o_auth2_application__client_secret__gt'),
(ActivityStream, 'o_auth2_application__gt') (OAuth2Application, 'grant__code__gt')
]) ])
def test_filter_sensitive_fields_and_relations(model, query): def test_filter_sensitive_fields_and_relations(model, query):
field_lookup = FieldLookupBackend() field_lookup = FieldLookupBackend()

View File

@@ -16,6 +16,9 @@ from awx.api.views import (
from awx.main.models import ( from awx.main.models import (
Host, Host,
) )
from awx.main.views import handle_error
from rest_framework.test import APIRequestFactory
@pytest.fixture @pytest.fixture
@@ -25,6 +28,12 @@ def mock_response_new(mocker):
return m return m
def test_handle_error():
# Assure that templating of error does not raise errors
request = APIRequestFactory().get('/fooooo/')
handle_error(request)
class TestApiRootView: class TestApiRootView:
def test_get_endpoints(self, mocker, mock_response_new): def test_get_endpoints(self, mocker, mock_response_new):
endpoints = [ endpoints = [

View File

@@ -1,4 +1,5 @@
import pytest import pytest
import logging
from mock import PropertyMock from mock import PropertyMock
@@ -7,3 +8,16 @@ from mock import PropertyMock
def _disable_database_settings(mocker): def _disable_database_settings(mocker):
m = mocker.patch('awx.conf.settings.SettingsWrapper.all_supported_settings', new_callable=PropertyMock) m = mocker.patch('awx.conf.settings.SettingsWrapper.all_supported_settings', new_callable=PropertyMock)
m.return_value = [] m.return_value = []
@pytest.fixture()
def dummy_log_record():
return logging.LogRecord(
'awx', # logger name
20, # loglevel INFO
'./awx/some/module.py', # pathname
100, # lineno
'User joe logged in', # msg
tuple(), # args,
None # exc_info
)

View File

@@ -90,7 +90,7 @@ def test_cancel_callback_error():
extra_fields = {} extra_fields = {}
status, rc = run.run_pexpect( status, rc = run.run_pexpect(
['ls', '-la'], ['sleep', '2'],
HERE, HERE,
{}, {},
stdout, stdout,

View File

@@ -44,3 +44,18 @@ def test_playbook_event_strip_invalid_keys(job_identifier, cls):
'extra_key': 'extra_value' 'extra_key': 'extra_value'
}) })
manager.create.assert_called_with(**{job_identifier: 123}) manager.create.assert_called_with(**{job_identifier: 123})
@pytest.mark.parametrize('field', [
'play', 'role', 'task', 'playbook'
])
def test_really_long_event_fields(field):
with mock.patch.object(JobEvent, 'objects') as manager:
JobEvent.create_from_data(**{
'job_id': 123,
field: 'X' * 4096
})
manager.create.assert_called_with(**{
'job_id': 123,
field: 'X' * 1021 + '...'
})

View File

@@ -1,8 +1,9 @@
import tempfile import tempfile
import json import json
import yaml import yaml
import pytest import pytest
from itertools import count
from awx.main.utils.encryption import encrypt_value from awx.main.utils.encryption import encrypt_value
from awx.main.tasks import RunJob from awx.main.tasks import RunJob
from awx.main.models import ( from awx.main.models import (
@@ -16,6 +17,15 @@ from awx.main.utils.safe_yaml import SafeLoader
ENCRYPTED_SECRET = encrypt_value('secret') ENCRYPTED_SECRET = encrypt_value('secret')
class DistinctParametrize(object):
def __init__(self):
self._gen = count(0)
def __call__(self, value):
return str(next(self._gen))
@pytest.mark.survey @pytest.mark.survey
class SurveyVariableValidation: class SurveyVariableValidation:
@@ -243,7 +253,7 @@ def test_optional_survey_question_defaults(
('password', 'foo', 5, {'extra_vars': {'x': ''}}, {'x': ''}), ('password', 'foo', 5, {'extra_vars': {'x': ''}}, {'x': ''}),
('password', ENCRYPTED_SECRET, 5, {'extra_vars': {'x': '$encrypted$'}}, {}), ('password', ENCRYPTED_SECRET, 5, {'extra_vars': {'x': '$encrypted$'}}, {}),
('password', ENCRYPTED_SECRET, 10, {'extra_vars': {'x': '$encrypted$'}}, {'x': ENCRYPTED_SECRET}), ('password', ENCRYPTED_SECRET, 10, {'extra_vars': {'x': '$encrypted$'}}, {'x': ENCRYPTED_SECRET}),
]) ], ids=DistinctParametrize())
def test_survey_encryption_defaults(survey_spec_factory, question_type, default, maxlen, kwargs, expected): def test_survey_encryption_defaults(survey_spec_factory, question_type, default, maxlen, kwargs, expected):
spec = survey_spec_factory([ spec = survey_spec_factory([
{ {

View File

@@ -158,7 +158,7 @@ def test_jt_existing_values_are_nonsensitive(job_template_with_ids, user_unit):
"""Assure that permission checks are not required if submitted data is """Assure that permission checks are not required if submitted data is
identical to what the job template already has.""" identical to what the job template already has."""
data = model_to_dict(job_template_with_ids) data = model_to_dict(job_template_with_ids, exclude=['unifiedjobtemplate_ptr'])
access = JobTemplateAccess(user_unit) access = JobTemplateAccess(user_unit)
assert access.changes_are_non_sensitive(job_template_with_ids, data) assert access.changes_are_non_sensitive(job_template_with_ids, data)

View File

@@ -96,10 +96,26 @@ def test_cred_type_input_schema_validity(input_, valid):
({'invalid-injector': {}}, False), ({'invalid-injector': {}}, False),
({'file': 123}, False), ({'file': 123}, False),
({'file': {}}, True), ({'file': {}}, True),
# Uses credential inputs inside of unnamed file contents
({'file': {'template': '{{username}}'}}, True), ({'file': {'template': '{{username}}'}}, True),
# Uses named file
({'file': {'template.username': '{{username}}'}}, True), ({'file': {'template.username': '{{username}}'}}, True),
# Uses multiple named files
({'file': {'template.username': '{{username}}', 'template.password': '{{pass}}'}}, True), ({'file': {'template.username': '{{username}}', 'template.password': '{{pass}}'}}, True),
# Use of unnamed file mutually exclusive with use of named files
({'file': {'template': '{{username}}', 'template.password': '{{pass}}'}}, False), ({'file': {'template': '{{username}}', 'template.password': '{{pass}}'}}, False),
# References non-existant named file
({'env': {'FROM_FILE': "{{tower.filename.cert}}"}}, False),
# References unnamed file, but a file was never defined
({'env': {'FROM_FILE': "{{tower.filename}}"}}, False),
# Cannot reference tower namespace itself (what would this return??)
({'env': {'FROM_FILE': "{{tower}}"}}, False),
# References filename of a named file
({'file': {'template.cert': '{{awx_secret}}'}, 'env': {'FROM_FILE': "{{tower.filename.cert}}"}}, True),
# With named files, `tower.filename` is another namespace, so it cannot be referenced
({'file': {'template.cert': '{{awx_secret}}'}, 'env': {'FROM_FILE': "{{tower.filename}}"}}, False),
# With an unnamed file, `tower.filename` is just the filename
({'file': {'template': '{{awx_secret}}'}, 'env': {'THE_FILENAME': "{{tower.filename}}"}}, True),
({'file': {'foo': 'bar'}}, False), ({'file': {'foo': 'bar'}}, False),
({'env': 123}, False), ({'env': 123}, False),
({'env': {}}, True), ({'env': {}}, True),

View File

@@ -2155,7 +2155,7 @@ def test_aquire_lock_open_fail_logged(logging_getLogger, os_open):
ProjectUpdate = tasks.RunProjectUpdate() ProjectUpdate = tasks.RunProjectUpdate()
with pytest.raises(OSError, errno=3, strerror='dummy message'): with pytest.raises(OSError, message='dummy message'):
ProjectUpdate.acquire_lock(instance) ProjectUpdate.acquire_lock(instance)
assert logger.err.called_with("I/O error({0}) while trying to open lock file [{1}]: {2}".format(3, 'this_file_does_not_exist', 'dummy message')) assert logger.err.called_with("I/O error({0}) while trying to open lock file [{1}]: {2}".format(3, 'this_file_does_not_exist', 'dummy message'))
@@ -2181,7 +2181,7 @@ def test_aquire_lock_acquisition_fail_logged(fcntl_flock, logging_getLogger, os_
ProjectUpdate = tasks.RunProjectUpdate() ProjectUpdate = tasks.RunProjectUpdate()
with pytest.raises(IOError, errno=3, strerror='dummy message'): with pytest.raises(IOError, message='dummy message'):
ProjectUpdate.acquire_lock(instance) ProjectUpdate.acquire_lock(instance)
os_close.assert_called_with(3) os_close.assert_called_with(3)
assert logger.err.called_with("I/O error({0}) while trying to aquire lock on file [{1}]: {2}".format(3, 'this_file_does_not_exist', 'dummy message')) assert logger.err.called_with("I/O error({0}) while trying to aquire lock on file [{1}]: {2}".format(3, 'this_file_does_not_exist', 'dummy message'))

View File

@@ -3,6 +3,10 @@ import mock
# Django REST Framework # Django REST Framework
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.generics import ListAPIView
# Django
from django.core.urlresolvers import RegexURLResolver, RegexURLPattern
# AWX # AWX
from awx.main.views import ApiErrorView from awx.main.views import ApiErrorView
@@ -52,3 +56,44 @@ def test_disable_post_on_v1_inventory_source_list(version, supports_post):
inv_source_list.request = mock.MagicMock() inv_source_list.request = mock.MagicMock()
with mock.patch('awx.api.views.get_request_version', return_value=version): with mock.patch('awx.api.views.get_request_version', return_value=version):
assert ('POST' in inv_source_list.allowed_methods) == supports_post assert ('POST' in inv_source_list.allowed_methods) == supports_post
def test_views_have_search_fields():
from awx.api.urls import urlpatterns as api_patterns
patterns = set([])
url_views = set([])
# Add recursive URL patterns
unprocessed = set(api_patterns)
while unprocessed:
to_process = unprocessed.copy()
unprocessed = set([])
for pattern in to_process:
if hasattr(pattern, 'lookup_str') and not pattern.lookup_str.startswith('awx.api'):
continue
patterns.add(pattern)
if isinstance(pattern, RegexURLResolver):
for sub_pattern in pattern.url_patterns:
if sub_pattern not in patterns:
unprocessed.add(sub_pattern)
# Get view classes
for pattern in patterns:
if isinstance(pattern, RegexURLPattern) and hasattr(pattern.callback, 'view_class'):
cls = pattern.callback.view_class
if issubclass(cls, ListAPIView):
url_views.add(pattern.callback.view_class)
# Gather any views that don't have search fields defined
views_missing_search = []
for View in url_views:
view = View()
if not hasattr(view, 'search_fields') or len(view.search_fields) == 0:
views_missing_search.append(view)
if views_missing_search:
raise Exception('{} views do not have search fields defined:\n{}'.format(
len(views_missing_search),
'\n'.join([
v.__class__.__name__ + ' (model: {})'.format(getattr(v, 'model', type(None)).__name__)
for v in views_missing_search
]))
)

View File

@@ -7,10 +7,10 @@ import pytest
from uuid import uuid4 from uuid import uuid4
import json import json
import yaml import yaml
import mock
from backports.tempfile import TemporaryDirectory from backports.tempfile import TemporaryDirectory
from django.conf import settings from django.conf import settings
from django.core.cache import cache
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError
@@ -26,14 +26,6 @@ from awx.main.models import (
) )
@pytest.fixture(autouse=True)
def clear_cache():
'''
Clear cache (local memory) for each test to prevent using cached settings.
'''
cache.clear()
@pytest.mark.parametrize('input_, output', [ @pytest.mark.parametrize('input_, output', [
({"foo": "bar"}, {"foo": "bar"}), ({"foo": "bar"}, {"foo": "bar"}),
('{"foo": "bar"}', {"foo": "bar"}), ('{"foo": "bar"}', {"foo": "bar"}),
@@ -114,46 +106,48 @@ def test_get_type_for_model(model, name):
@pytest.fixture @pytest.fixture
def memoized_function(mocker): def memoized_function(mocker, mock_cache):
@common.memoize(track_function=True) with mock.patch('awx.main.utils.common.get_memoize_cache', return_value=mock_cache):
def myfunction(key, value): @common.memoize(track_function=True)
if key not in myfunction.calls: def myfunction(key, value):
myfunction.calls[key] = 0 if key not in myfunction.calls:
myfunction.calls[key] = 0
myfunction.calls[key] += 1 myfunction.calls[key] += 1
if myfunction.calls[key] == 1: if myfunction.calls[key] == 1:
return value return value
else: else:
return '%s called %s times' % (value, myfunction.calls[key]) return '%s called %s times' % (value, myfunction.calls[key])
myfunction.calls = dict() myfunction.calls = dict()
return myfunction return myfunction
def test_memoize_track_function(memoized_function): def test_memoize_track_function(memoized_function, mock_cache):
assert memoized_function('scott', 'scotterson') == 'scotterson' assert memoized_function('scott', 'scotterson') == 'scotterson'
assert cache.get('myfunction') == {u'scott-scotterson': 'scotterson'} assert mock_cache.get('myfunction') == {u'scott-scotterson': 'scotterson'}
assert memoized_function('scott', 'scotterson') == 'scotterson' assert memoized_function('scott', 'scotterson') == 'scotterson'
assert memoized_function.calls['scott'] == 1 assert memoized_function.calls['scott'] == 1
assert memoized_function('john', 'smith') == 'smith' assert memoized_function('john', 'smith') == 'smith'
assert cache.get('myfunction') == {u'scott-scotterson': 'scotterson', u'john-smith': 'smith'} assert mock_cache.get('myfunction') == {u'scott-scotterson': 'scotterson', u'john-smith': 'smith'}
assert memoized_function('john', 'smith') == 'smith' assert memoized_function('john', 'smith') == 'smith'
assert memoized_function.calls['john'] == 1 assert memoized_function.calls['john'] == 1
def test_memoize_delete(memoized_function): def test_memoize_delete(memoized_function, mock_cache):
assert memoized_function('john', 'smith') == 'smith' assert memoized_function('john', 'smith') == 'smith'
assert memoized_function('john', 'smith') == 'smith' assert memoized_function('john', 'smith') == 'smith'
assert memoized_function.calls['john'] == 1 assert memoized_function.calls['john'] == 1
assert cache.get('myfunction') == {u'john-smith': 'smith'} assert mock_cache.get('myfunction') == {u'john-smith': 'smith'}
common.memoize_delete('myfunction') with mock.patch('awx.main.utils.common.memoize_delete', side_effect=mock_cache.delete):
common.memoize_delete('myfunction')
assert cache.get('myfunction') is None assert mock_cache.get('myfunction') is None
assert memoized_function('john', 'smith') == 'smith called 2 times' assert memoized_function('john', 'smith') == 'smith called 2 times'
assert memoized_function.calls['john'] == 2 assert memoized_function.calls['john'] == 2

View File

@@ -5,7 +5,7 @@ import mock
from collections import namedtuple from collections import namedtuple
# AWX # AWX
from awx.main.utils.filters import SmartFilter from awx.main.utils.filters import SmartFilter, ExternalLoggerEnabled
# Django # Django
from django.db.models import Q from django.db.models import Q
@@ -13,6 +13,37 @@ from django.db.models import Q
import six import six
@pytest.mark.parametrize('params, logger_name, expected', [
# skip all records if enabled_flag = False
({'enabled_flag': False}, 'awx.main', False),
# skip all records if the host is undefined
({'enabled_flag': True}, 'awx.main', False),
# skip all records if underlying logger is used by handlers themselves
({'enabled_flag': True}, 'awx.main.utils.handlers', False),
({'enabled_flag': True, 'enabled_loggers': ['awx']}, 'awx.main', True),
({'enabled_flag': True, 'enabled_loggers': ['abc']}, 'awx.analytics.xyz', False),
({'enabled_flag': True, 'enabled_loggers': ['xyz']}, 'awx.analytics.xyz', True),
])
def test_base_logging_handler_skip_log(params, logger_name, expected, dummy_log_record):
filter = ExternalLoggerEnabled(**params)
dummy_log_record.name = logger_name
assert filter.filter(dummy_log_record) is expected, (params, logger_name)
@pytest.mark.parametrize('level, expect', [
(30, True), # warning
(20, False) # info
])
def test_log_configurable_severity(level, expect, dummy_log_record):
dummy_log_record.levelno = level
filter = ExternalLoggerEnabled(
enabled_flag=True,
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'],
lvl='WARNING'
)
assert filter.filter(dummy_log_record) is expect
Field = namedtuple('Field', 'name') Field = namedtuple('Field', 'name')
Meta = namedtuple('Meta', 'fields') Meta = namedtuple('Meta', 'fields')

View File

@@ -6,22 +6,15 @@
# python # python
import pytest import pytest
import mock import mock
from contextlib import nested
# AWX # AWX
from awx.main.utils.ha import ( from awx.main.utils.ha import (
_add_remove_celery_worker_queues, _add_remove_celery_worker_queues,
update_celery_worker_routes, AWXCeleryRouter,
) )
@pytest.fixture
def conf():
class Conf():
CELERY_ROUTES = dict()
CELERYBEAT_SCHEDULE = dict()
return Conf()
class TestAddRemoveCeleryWorkerQueues(): class TestAddRemoveCeleryWorkerQueues():
@pytest.fixture @pytest.fixture
def instance_generator(self, mocker): def instance_generator(self, mocker):
@@ -47,54 +40,54 @@ class TestAddRemoveCeleryWorkerQueues():
app.control.cancel_consumer = mocker.MagicMock() app.control.cancel_consumer = mocker.MagicMock()
return app return app
@pytest.mark.parametrize("static_queues,_worker_queues,groups,hostname,added_expected,removed_expected", [ @pytest.mark.parametrize("broadcast_queues,static_queues,_worker_queues,groups,hostname,added_expected,removed_expected", [
(['east', 'west'], ['east', 'west', 'east-1'], [], 'east-1', [], []), (['tower_broadcast_all'], ['east', 'west'], ['east', 'west', 'east-1'], [], 'east-1', ['tower_broadcast_all_east-1'], []),
([], ['east', 'west', 'east-1'], ['east', 'west'], 'east-1', [], []), ([], [], ['east', 'west', 'east-1'], ['east', 'west'], 'east-1', [], []),
([], ['east', 'west'], ['east', 'west'], 'east-1', ['east-1'], []), ([], [], ['east', 'west'], ['east', 'west'], 'east-1', ['east-1'], []),
([], [], ['east', 'west'], 'east-1', ['east', 'west', 'east-1'], []), ([], [], [], ['east', 'west'], 'east-1', ['east', 'west', 'east-1'], []),
([], ['china', 'russia'], ['east', 'west'], 'east-1', ['east', 'west', 'east-1'], ['china', 'russia']), ([], [], ['china', 'russia'], ['east', 'west'], 'east-1', ['east', 'west', 'east-1'], ['china', 'russia']),
]) ])
def test__add_remove_celery_worker_queues_noop(self, mock_app, def test__add_remove_celery_worker_queues_noop(self, mock_app,
instance_generator, instance_generator,
worker_queues_generator, worker_queues_generator,
static_queues, _worker_queues, broadcast_queues,
static_queues, _worker_queues,
groups, hostname, groups, hostname,
added_expected, removed_expected): added_expected, removed_expected):
instance = instance_generator(groups=groups, hostname=hostname) instance = instance_generator(groups=groups, hostname=hostname)
worker_queues = worker_queues_generator(_worker_queues) worker_queues = worker_queues_generator(_worker_queues)
with mock.patch('awx.main.utils.ha.settings.AWX_CELERY_QUEUES_STATIC', static_queues): with nested(
mock.patch('awx.main.utils.ha.settings.AWX_CELERY_QUEUES_STATIC', static_queues),
mock.patch('awx.main.utils.ha.settings.AWX_CELERY_BCAST_QUEUES_STATIC', broadcast_queues),
mock.patch('awx.main.utils.ha.settings.CLUSTER_HOST_ID', hostname)):
(added_queues, removed_queues) = _add_remove_celery_worker_queues(mock_app, [instance], worker_queues, hostname) (added_queues, removed_queues) = _add_remove_celery_worker_queues(mock_app, [instance], worker_queues, hostname)
assert set(added_queues) == set(added_expected) assert set(added_queues) == set(added_expected)
assert set(removed_queues) == set(removed_expected) assert set(removed_queues) == set(removed_expected)
class TestUpdateCeleryWorkerRoutes(): class TestUpdateCeleryWorkerRouter():
@pytest.mark.parametrize("is_controller,expected_routes", [ @pytest.mark.parametrize("is_controller,expected_routes", [
(False, { (False, {
'awx.main.tasks.cluster_node_heartbeat': {'queue': 'east-1', 'routing_key': 'east-1'}, 'awx.main.tasks.cluster_node_heartbeat': {'queue': 'east-1', 'routing_key': 'east-1'},
'awx.main.tasks.purge_old_stdout_files': {'queue': 'east-1', 'routing_key': 'east-1'} 'awx.main.tasks.purge_old_stdout_files': {'queue': 'east-1', 'routing_key': 'east-1'}
}), }),
(True, { (True, {
'awx.main.tasks.cluster_node_heartbeat': {'queue': 'east-1', 'routing_key': 'east-1'}, 'awx.main.tasks.cluster_node_heartbeat': {'queue': 'east-1', 'routing_key': 'east-1'},
'awx.main.tasks.purge_old_stdout_files': {'queue': 'east-1', 'routing_key': 'east-1'}, 'awx.main.tasks.purge_old_stdout_files': {'queue': 'east-1', 'routing_key': 'east-1'},
'awx.main.tasks.awx_isolated_heartbeat': {'queue': 'east-1', 'routing_key': 'east-1'}, 'awx.main.tasks.awx_isolated_heartbeat': {'queue': 'east-1', 'routing_key': 'east-1'},
}), }),
]) ])
def test_update_celery_worker_routes(self, mocker, conf, is_controller, expected_routes): def test_update_celery_worker_routes(self, mocker, is_controller, expected_routes):
instance = mocker.MagicMock() def get_or_register():
instance.hostname = 'east-1' instance = mock.MagicMock()
instance.is_controller = mocker.MagicMock(return_value=is_controller) instance.hostname = 'east-1'
instance.is_controller = mock.MagicMock(return_value=is_controller)
return (False, instance)
assert update_celery_worker_routes(instance, conf) == expected_routes with mock.patch('awx.main.models.Instance.objects.get_or_register', get_or_register):
assert conf.CELERY_ROUTES == expected_routes router = AWXCeleryRouter()
def test_update_celery_worker_routes_deleted(self, mocker, conf): for k,v in expected_routes.iteritems():
instance = mocker.MagicMock() assert router.route_for_task(k) == v
instance.hostname = 'east-1'
instance.is_controller = mocker.MagicMock(return_value=False)
conf.CELERY_ROUTES = {'awx.main.tasks.awx_isolated_heartbeat': 'foobar'}
update_celery_worker_routes(instance, conf)
assert 'awx.main.tasks.awx_isolated_heartbeat' not in conf.CELERY_ROUTES

View File

@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import base64 import base64
import cStringIO import cStringIO
import json
import logging import logging
import socket import socket
import datetime import datetime
@@ -10,7 +9,6 @@ from uuid import uuid4
import mock import mock
from django.conf import settings
from django.conf import LazySettings from django.conf import LazySettings
import pytest import pytest
import requests import requests
@@ -18,23 +16,11 @@ from requests_futures.sessions import FuturesSession
from awx.main.utils.handlers import (BaseHandler, BaseHTTPSHandler as HTTPSHandler, from awx.main.utils.handlers import (BaseHandler, BaseHTTPSHandler as HTTPSHandler,
TCPHandler, UDPHandler, _encode_payload_for_socket, TCPHandler, UDPHandler, _encode_payload_for_socket,
PARAM_NAMES, LoggingConnectivityException) PARAM_NAMES, LoggingConnectivityException,
AWXProxyHandler)
from awx.main.utils.formatters import LogstashFormatter from awx.main.utils.formatters import LogstashFormatter
@pytest.fixture()
def dummy_log_record():
return logging.LogRecord(
'awx', # logger name
20, # loglevel INFO
'./awx/some/module.py', # pathname
100, # lineno
'User joe logged in', # msg
tuple(), # args,
None # exc_info
)
@pytest.fixture() @pytest.fixture()
def http_adapter(): def http_adapter():
class FakeHTTPAdapter(requests.adapters.HTTPAdapter): class FakeHTTPAdapter(requests.adapters.HTTPAdapter):
@@ -80,105 +66,91 @@ def test_https_logging_handler_requests_async_implementation():
def test_https_logging_handler_has_default_http_timeout(): def test_https_logging_handler_has_default_http_timeout():
handler = HTTPSHandler.from_django_settings(settings) handler = TCPHandler()
assert handler.tcp_timeout == 5 assert handler.tcp_timeout == 5
@pytest.mark.parametrize('param', PARAM_NAMES.keys()) @pytest.mark.parametrize('param', ['host', 'port', 'indv_facts'])
def test_base_logging_handler_defaults(param): def test_base_logging_handler_defaults(param):
handler = BaseHandler() handler = BaseHandler()
assert hasattr(handler, param) and getattr(handler, param) is None assert hasattr(handler, param) and getattr(handler, param) is None
@pytest.mark.parametrize('param', PARAM_NAMES.keys()) @pytest.mark.parametrize('param', ['host', 'port', 'indv_facts'])
def test_base_logging_handler_kwargs(param): def test_base_logging_handler_kwargs(param):
handler = BaseHandler(**{param: 'EXAMPLE'}) handler = BaseHandler(**{param: 'EXAMPLE'})
assert hasattr(handler, param) and getattr(handler, param) == 'EXAMPLE' assert hasattr(handler, param) and getattr(handler, param) == 'EXAMPLE'
@pytest.mark.parametrize('param, django_settings_name', PARAM_NAMES.items()) @pytest.mark.parametrize('params', [
def test_base_logging_handler_from_django_settings(param, django_settings_name): {
'LOG_AGGREGATOR_HOST': 'https://server.invalid',
'LOG_AGGREGATOR_PORT': 22222,
'LOG_AGGREGATOR_TYPE': 'loggly',
'LOG_AGGREGATOR_USERNAME': 'foo',
'LOG_AGGREGATOR_PASSWORD': 'bar',
'LOG_AGGREGATOR_INDIVIDUAL_FACTS': True,
'LOG_AGGREGATOR_TCP_TIMEOUT': 96,
'LOG_AGGREGATOR_VERIFY_CERT': False,
'LOG_AGGREGATOR_PROTOCOL': 'https'
},
{
'LOG_AGGREGATOR_HOST': 'https://server.invalid',
'LOG_AGGREGATOR_PORT': 22222,
'LOG_AGGREGATOR_PROTOCOL': 'udp'
}
])
def test_real_handler_from_django_settings(params):
settings = LazySettings()
settings.configure(**params)
handler = AWXProxyHandler().get_handler(custom_settings=settings)
# need the _reverse_ dictionary from PARAM_NAMES
attr_lookup = {}
for attr_name, setting_name in PARAM_NAMES.items():
attr_lookup[setting_name] = attr_name
for setting_name, val in params.items():
attr_name = attr_lookup[setting_name]
if attr_name == 'protocol':
continue
assert hasattr(handler, attr_name)
def test_invalid_kwarg_to_real_handler():
settings = LazySettings() settings = LazySettings()
settings.configure(**{ settings.configure(**{
django_settings_name: 'EXAMPLE' 'LOG_AGGREGATOR_HOST': 'https://server.invalid',
'LOG_AGGREGATOR_PORT': 22222,
'LOG_AGGREGATOR_PROTOCOL': 'udp',
'LOG_AGGREGATOR_VERIFY_CERT': False # setting not valid for UDP handler
}) })
handler = BaseHandler.from_django_settings(settings) handler = AWXProxyHandler().get_handler(custom_settings=settings)
assert hasattr(handler, param) and getattr(handler, param) == 'EXAMPLE' assert not hasattr(handler, 'verify_cert')
@pytest.mark.parametrize('params, logger_name, expected', [ def test_base_logging_handler_emit_system_tracking(dummy_log_record):
# skip all records if enabled_flag = False handler = BaseHandler(host='127.0.0.1', indv_facts=True)
({'enabled_flag': False}, 'awx.main', True),
# skip all records if the host is undefined
({'host': '', 'enabled_flag': True}, 'awx.main', True),
# skip all records if underlying logger is used by handlers themselves
({'host': '127.0.0.1', 'enabled_flag': True}, 'awx.main.utils.handlers', True),
({'host': '127.0.0.1', 'enabled_flag': True}, 'awx.main', False),
({'host': '127.0.0.1', 'enabled_flag': True, 'enabled_loggers': ['abc']}, 'awx.analytics.xyz', True),
({'host': '127.0.0.1', 'enabled_flag': True, 'enabled_loggers': ['xyz']}, 'awx.analytics.xyz', False),
])
def test_base_logging_handler_skip_log(params, logger_name, expected):
handler = BaseHandler(**params)
assert handler._skip_log(logger_name) is expected
def test_base_logging_handler_emit(dummy_log_record):
handler = BaseHandler(host='127.0.0.1', enabled_flag=True,
message_type='logstash', lvl='INFO',
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'])
handler.setFormatter(LogstashFormatter()) handler.setFormatter(LogstashFormatter())
sent_payloads = handler.emit(dummy_log_record) dummy_log_record.name = 'awx.analytics.system_tracking'
dummy_log_record.msg = None
assert len(sent_payloads) == 1 dummy_log_record.inventory_id = 11
body = json.loads(sent_payloads[0]) dummy_log_record.host_name = 'my_lucky_host'
dummy_log_record.job_id = 777
assert body['level'] == 'INFO' dummy_log_record.ansible_facts = {
assert body['logger_name'] == 'awx'
assert body['message'] == 'User joe logged in'
def test_base_logging_handler_ignore_low_severity_msg(dummy_log_record):
handler = BaseHandler(host='127.0.0.1', enabled_flag=True,
message_type='logstash', lvl='WARNING',
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'])
handler.setFormatter(LogstashFormatter())
sent_payloads = handler.emit(dummy_log_record)
assert len(sent_payloads) == 0
def test_base_logging_handler_emit_system_tracking():
handler = BaseHandler(host='127.0.0.1', enabled_flag=True,
message_type='logstash', indv_facts=True, lvl='INFO',
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'])
handler.setFormatter(LogstashFormatter())
record = logging.LogRecord(
'awx.analytics.system_tracking', # logger name
20, # loglevel INFO
'./awx/some/module.py', # pathname
100, # lineno
None, # msg
tuple(), # args,
None # exc_info
)
record.inventory_id = 11
record.host_name = 'my_lucky_host'
record.job_id = 777
record.ansible_facts = {
"ansible_kernel": "4.4.66-boot2docker", "ansible_kernel": "4.4.66-boot2docker",
"ansible_machine": "x86_64", "ansible_machine": "x86_64",
"ansible_swapfree_mb": 4663, "ansible_swapfree_mb": 4663,
} }
record.ansible_facts_modified = datetime.datetime.now(tzutc()).isoformat() dummy_log_record.ansible_facts_modified = datetime.datetime.now(tzutc()).isoformat()
sent_payloads = handler.emit(record) sent_payloads = handler.emit(dummy_log_record)
assert len(sent_payloads) == 1 assert len(sent_payloads) == 1
assert sent_payloads[0]['ansible_facts'] == record.ansible_facts assert sent_payloads[0]['ansible_facts'] == dummy_log_record.ansible_facts
assert sent_payloads[0]['ansible_facts_modified'] == record.ansible_facts_modified assert sent_payloads[0]['ansible_facts_modified'] == dummy_log_record.ansible_facts_modified
assert sent_payloads[0]['level'] == 'INFO' assert sent_payloads[0]['level'] == 'INFO'
assert sent_payloads[0]['logger_name'] == 'awx.analytics.system_tracking' assert sent_payloads[0]['logger_name'] == 'awx.analytics.system_tracking'
assert sent_payloads[0]['job_id'] == record.job_id assert sent_payloads[0]['job_id'] == dummy_log_record.job_id
assert sent_payloads[0]['inventory_id'] == record.inventory_id assert sent_payloads[0]['inventory_id'] == dummy_log_record.inventory_id
assert sent_payloads[0]['host_name'] == record.host_name assert sent_payloads[0]['host_name'] == dummy_log_record.host_name
@pytest.mark.parametrize('host, port, normalized, hostname_only', [ @pytest.mark.parametrize('host, port, normalized, hostname_only', [
@@ -236,16 +208,18 @@ def test_https_logging_handler_connectivity_test(http_adapter, status, reason, e
def emit(self, record): def emit(self, record):
return super(FakeHTTPSHandler, self).emit(record) return super(FakeHTTPSHandler, self).emit(record)
if exc: with mock.patch.object(AWXProxyHandler, 'get_handler_class') as mock_get_class:
with pytest.raises(exc) as e: mock_get_class.return_value = FakeHTTPSHandler
FakeHTTPSHandler.perform_test(settings) if exc:
assert str(e).endswith('%s: %s' % (status, reason)) with pytest.raises(exc) as e:
else: AWXProxyHandler().perform_test(settings)
assert FakeHTTPSHandler.perform_test(settings) is None assert str(e).endswith('%s: %s' % (status, reason))
else:
assert AWXProxyHandler().perform_test(settings) is None
def test_https_logging_handler_logstash_auth_info(): def test_https_logging_handler_logstash_auth_info():
handler = HTTPSHandler(message_type='logstash', username='bob', password='ansible', lvl='INFO') handler = HTTPSHandler(message_type='logstash', username='bob', password='ansible')
handler._add_auth_information() handler._add_auth_information()
assert isinstance(handler.session.auth, requests.auth.HTTPBasicAuth) assert isinstance(handler.session.auth, requests.auth.HTTPBasicAuth)
assert handler.session.auth.username == 'bob' assert handler.session.auth.username == 'bob'
@@ -261,9 +235,7 @@ def test_https_logging_handler_splunk_auth_info():
def test_https_logging_handler_connection_error(connection_error_adapter, def test_https_logging_handler_connection_error(connection_error_adapter,
dummy_log_record): dummy_log_record):
handler = HTTPSHandler(host='127.0.0.1', enabled_flag=True, handler = HTTPSHandler(host='127.0.0.1', message_type='logstash')
message_type='logstash', lvl='INFO',
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'])
handler.setFormatter(LogstashFormatter()) handler.setFormatter(LogstashFormatter())
handler.session.mount('http://', connection_error_adapter) handler.session.mount('http://', connection_error_adapter)
@@ -289,9 +261,7 @@ def test_https_logging_handler_connection_error(connection_error_adapter,
@pytest.mark.parametrize('message_type', ['logstash', 'splunk']) @pytest.mark.parametrize('message_type', ['logstash', 'splunk'])
def test_https_logging_handler_emit_without_cred(http_adapter, dummy_log_record, def test_https_logging_handler_emit_without_cred(http_adapter, dummy_log_record,
message_type): message_type):
handler = HTTPSHandler(host='127.0.0.1', enabled_flag=True, handler = HTTPSHandler(host='127.0.0.1', message_type=message_type)
message_type=message_type, lvl='INFO',
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'])
handler.setFormatter(LogstashFormatter()) handler.setFormatter(LogstashFormatter())
handler.session.mount('http://', http_adapter) handler.session.mount('http://', http_adapter)
async_futures = handler.emit(dummy_log_record) async_futures = handler.emit(dummy_log_record)
@@ -312,10 +282,9 @@ def test_https_logging_handler_emit_without_cred(http_adapter, dummy_log_record,
def test_https_logging_handler_emit_logstash_with_creds(http_adapter, def test_https_logging_handler_emit_logstash_with_creds(http_adapter,
dummy_log_record): dummy_log_record):
handler = HTTPSHandler(host='127.0.0.1', enabled_flag=True, handler = HTTPSHandler(host='127.0.0.1',
username='user', password='pass', username='user', password='pass',
message_type='logstash', lvl='INFO', message_type='logstash')
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'])
handler.setFormatter(LogstashFormatter()) handler.setFormatter(LogstashFormatter())
handler.session.mount('http://', http_adapter) handler.session.mount('http://', http_adapter)
async_futures = handler.emit(dummy_log_record) async_futures = handler.emit(dummy_log_record)
@@ -328,9 +297,8 @@ def test_https_logging_handler_emit_logstash_with_creds(http_adapter,
def test_https_logging_handler_emit_splunk_with_creds(http_adapter, def test_https_logging_handler_emit_splunk_with_creds(http_adapter,
dummy_log_record): dummy_log_record):
handler = HTTPSHandler(host='127.0.0.1', enabled_flag=True, handler = HTTPSHandler(host='127.0.0.1',
password='pass', message_type='splunk', lvl='INFO', password='pass', message_type='splunk')
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'])
handler.setFormatter(LogstashFormatter()) handler.setFormatter(LogstashFormatter())
handler.session.mount('http://', http_adapter) handler.session.mount('http://', http_adapter)
async_futures = handler.emit(dummy_log_record) async_futures = handler.emit(dummy_log_record)
@@ -351,9 +319,7 @@ def test_encode_payload_for_socket(payload, encoded_payload):
def test_udp_handler_create_socket_at_init(): def test_udp_handler_create_socket_at_init():
handler = UDPHandler(host='127.0.0.1', port=4399, handler = UDPHandler(host='127.0.0.1', port=4399)
enabled_flag=True, message_type='splunk', lvl='INFO',
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'])
assert hasattr(handler, 'socket') assert hasattr(handler, 'socket')
assert isinstance(handler.socket, socket.socket) assert isinstance(handler.socket, socket.socket)
assert handler.socket.family == socket.AF_INET assert handler.socket.family == socket.AF_INET
@@ -361,9 +327,7 @@ def test_udp_handler_create_socket_at_init():
def test_udp_handler_send(dummy_log_record): def test_udp_handler_send(dummy_log_record):
handler = UDPHandler(host='127.0.0.1', port=4399, handler = UDPHandler(host='127.0.0.1', port=4399)
enabled_flag=True, message_type='splunk', lvl='INFO',
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'])
handler.setFormatter(LogstashFormatter()) handler.setFormatter(LogstashFormatter())
with mock.patch('awx.main.utils.handlers._encode_payload_for_socket', return_value="des") as encode_mock,\ with mock.patch('awx.main.utils.handlers._encode_payload_for_socket', return_value="des") as encode_mock,\
mock.patch.object(handler, 'socket') as socket_mock: mock.patch.object(handler, 'socket') as socket_mock:
@@ -373,9 +337,7 @@ def test_udp_handler_send(dummy_log_record):
def test_tcp_handler_send(fake_socket, dummy_log_record): def test_tcp_handler_send(fake_socket, dummy_log_record):
handler = TCPHandler(host='127.0.0.1', port=4399, tcp_timeout=5, handler = TCPHandler(host='127.0.0.1', port=4399, tcp_timeout=5)
enabled_flag=True, message_type='splunk', lvl='INFO',
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'])
handler.setFormatter(LogstashFormatter()) handler.setFormatter(LogstashFormatter())
with mock.patch('socket.socket', return_value=fake_socket) as sok_init_mock,\ with mock.patch('socket.socket', return_value=fake_socket) as sok_init_mock,\
mock.patch('select.select', return_value=([], [fake_socket], [])): mock.patch('select.select', return_value=([], [fake_socket], [])):
@@ -388,9 +350,7 @@ def test_tcp_handler_send(fake_socket, dummy_log_record):
def test_tcp_handler_return_if_socket_unavailable(fake_socket, dummy_log_record): def test_tcp_handler_return_if_socket_unavailable(fake_socket, dummy_log_record):
handler = TCPHandler(host='127.0.0.1', port=4399, tcp_timeout=5, handler = TCPHandler(host='127.0.0.1', port=4399, tcp_timeout=5)
enabled_flag=True, message_type='splunk', lvl='INFO',
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'])
handler.setFormatter(LogstashFormatter()) handler.setFormatter(LogstashFormatter())
with mock.patch('socket.socket', return_value=fake_socket) as sok_init_mock,\ with mock.patch('socket.socket', return_value=fake_socket) as sok_init_mock,\
mock.patch('select.select', return_value=([], [], [])): mock.patch('select.select', return_value=([], [], [])):
@@ -403,9 +363,7 @@ def test_tcp_handler_return_if_socket_unavailable(fake_socket, dummy_log_record)
def test_tcp_handler_log_exception(fake_socket, dummy_log_record): def test_tcp_handler_log_exception(fake_socket, dummy_log_record):
handler = TCPHandler(host='127.0.0.1', port=4399, tcp_timeout=5, handler = TCPHandler(host='127.0.0.1', port=4399, tcp_timeout=5)
enabled_flag=True, message_type='splunk', lvl='INFO',
enabled_loggers=['awx', 'activity_stream', 'job_events', 'system_tracking'])
handler.setFormatter(LogstashFormatter()) handler.setFormatter(LogstashFormatter())
with mock.patch('socket.socket', return_value=fake_socket) as sok_init_mock,\ with mock.patch('socket.socket', return_value=fake_socket) as sok_init_mock,\
mock.patch('select.select', return_value=([], [], [])),\ mock.patch('select.select', return_value=([], [], [])),\

View File

@@ -13,31 +13,3 @@ def test_produce_supervisor_command(mocker):
['supervisorctl', 'restart', 'tower-processes:receiver',], ['supervisorctl', 'restart', 'tower-processes:receiver',],
stderr=-1, stdin=-1, stdout=-1) stderr=-1, stdin=-1, stdout=-1)
def test_routing_of_service_restarts_works(mocker):
'''
This tests that the parent restart method will call the appropriate
service restart methods, depending on which services are given in args
'''
with mocker.patch.object(reload, '_uwsgi_fifo_command'),\
mocker.patch.object(reload, '_reset_celery_thread_pool'),\
mocker.patch.object(reload, '_supervisor_service_command'):
reload.restart_local_services(['uwsgi', 'celery', 'flower', 'daphne'])
reload._uwsgi_fifo_command.assert_called_once_with(uwsgi_command="c")
reload._reset_celery_thread_pool.assert_called_once_with()
reload._supervisor_service_command.assert_called_once_with(['flower', 'daphne'], command="restart")
def test_routing_of_service_restarts_diables(mocker):
'''
Test that methods are not called if not in the args
'''
with mocker.patch.object(reload, '_uwsgi_fifo_command'),\
mocker.patch.object(reload, '_reset_celery_thread_pool'),\
mocker.patch.object(reload, '_supervisor_service_command'):
reload.restart_local_services(['flower'])
reload._uwsgi_fifo_command.assert_not_called()
reload._reset_celery_thread_pool.assert_not_called()
reload._supervisor_service_command.assert_called_once_with(['flower'], command="restart")

View File

@@ -127,12 +127,16 @@ class IllegalArgumentError(ValueError):
pass pass
def get_memoize_cache():
from django.core.cache import cache
return cache
def memoize(ttl=60, cache_key=None, track_function=False): def memoize(ttl=60, cache_key=None, track_function=False):
''' '''
Decorator to wrap a function and cache its result. Decorator to wrap a function and cache its result.
''' '''
from django.core.cache import cache cache = get_memoize_cache()
def _memoizer(f, *args, **kwargs): def _memoizer(f, *args, **kwargs):
if cache_key and track_function: if cache_key and track_function:
@@ -160,8 +164,7 @@ def memoize(ttl=60, cache_key=None, track_function=False):
def memoize_delete(function_name): def memoize_delete(function_name):
from django.core.cache import cache cache = get_memoize_cache()
return cache.delete(function_name) return cache.delete(function_name)

View File

@@ -8,14 +8,106 @@ from pyparsing import (
CharsNotIn, CharsNotIn,
ParseException, ParseException,
) )
from logging import Filter, _levelNames
import six import six
import django from django.apps import apps
from django.db import models
from django.conf import settings
from awx.main.utils.common import get_search_fields from awx.main.utils.common import get_search_fields
__all__ = ['SmartFilter'] __all__ = ['SmartFilter', 'ExternalLoggerEnabled']
class FieldFromSettings(object):
"""
Field interface - defaults to getting value from setting
if otherwise set, provided value will take precedence
over value in settings
"""
def __init__(self, setting_name):
self.setting_name = setting_name
def __get__(self, instance, type=None):
if self.setting_name in getattr(instance, 'settings_override', {}):
return instance.settings_override[self.setting_name]
return getattr(settings, self.setting_name, None)
def __set__(self, instance, value):
if value is None:
if hasattr(instance, 'settings_override'):
instance.settings_override.pop('instance', None)
else:
if not hasattr(instance, 'settings_override'):
instance.settings_override = {}
instance.settings_override[self.setting_name] = value
class ExternalLoggerEnabled(Filter):
# Prevents recursive logging loops from swamping the server
LOGGER_BLACKLIST = (
# loggers that may be called in process of emitting a log
'awx.main.utils.handlers',
'awx.main.utils.formatters',
'awx.main.utils.filters',
'awx.main.utils.encryption',
'awx.main.utils.log',
# loggers that may be called getting logging settings
'awx.conf'
)
lvl = FieldFromSettings('LOG_AGGREGATOR_LEVEL')
enabled_loggers = FieldFromSettings('LOG_AGGREGATOR_LOGGERS')
enabled_flag = FieldFromSettings('LOG_AGGREGATOR_ENABLED')
def __init__(self, **kwargs):
super(ExternalLoggerEnabled, self).__init__()
for field_name, field_value in kwargs.items():
if not isinstance(ExternalLoggerEnabled.__dict__.get(field_name, None), FieldFromSettings):
raise Exception('%s is not a valid kwarg' % field_name)
if field_value is None:
continue
setattr(self, field_name, field_value)
def filter(self, record):
"""
Uses the database settings to determine if the current
external log configuration says that this particular record
should be sent to the external log aggregator
False - should not be logged
True - should be logged
"""
# Logger exceptions
for logger_name in self.LOGGER_BLACKLIST:
if record.name.startswith(logger_name):
return False
# General enablement
if not self.enabled_flag:
return False
# Level enablement
if record.levelno < _levelNames[self.lvl]:
# logging._levelNames -> logging._nameToLevel in python 3
return False
# Logger type enablement
loggers = self.enabled_loggers
if not loggers:
return False
if record.name.startswith('awx.analytics'):
base_path, headline_name = record.name.rsplit('.', 1)
return bool(headline_name in loggers)
else:
if '.' in record.name:
base_name, trailing_path = record.name.split('.', 1)
else:
base_name = record.name
return bool(base_name in loggers)
def string_to_type(t): def string_to_type(t):
@@ -36,7 +128,7 @@ def string_to_type(t):
def get_model(name): def get_model(name):
return django.apps.apps.get_model('main', name) return apps.get_model('main', name)
class SmartFilter(object): class SmartFilter(object):
@@ -52,7 +144,7 @@ class SmartFilter(object):
search_kwargs = self._expand_search(k, v) search_kwargs = self._expand_search(k, v)
if search_kwargs: if search_kwargs:
kwargs.update(search_kwargs) kwargs.update(search_kwargs)
q = reduce(lambda x, y: x | y, [django.db.models.Q(**{u'%s__contains' % _k:_v}) for _k, _v in kwargs.items()]) q = reduce(lambda x, y: x | y, [models.Q(**{u'%s__contains' % _k:_v}) for _k, _v in kwargs.items()])
self.result = Host.objects.filter(q) self.result = Host.objects.filter(q)
else: else:
kwargs[k] = v kwargs[k] = v

View File

@@ -9,6 +9,8 @@ import logging
import six import six
from django.conf import settings
class TimeFormatter(logging.Formatter): class TimeFormatter(logging.Formatter):
''' '''
@@ -20,15 +22,6 @@ class TimeFormatter(logging.Formatter):
class LogstashFormatter(LogstashFormatterVersion1): class LogstashFormatter(LogstashFormatterVersion1):
def __init__(self, **kwargs):
settings_module = kwargs.pop('settings_module', None)
ret = super(LogstashFormatter, self).__init__(**kwargs)
if settings_module:
self.host_id = getattr(settings_module, 'CLUSTER_HOST_ID', None)
if hasattr(settings_module, 'LOG_AGGREGATOR_TOWER_UUID'):
self.tower_uuid = settings_module.LOG_AGGREGATOR_TOWER_UUID
self.message_type = getattr(settings_module, 'LOG_AGGREGATOR_TYPE', 'other')
return ret
def reformat_data_for_log(self, raw_data, kind=None): def reformat_data_for_log(self, raw_data, kind=None):
''' '''
@@ -147,6 +140,15 @@ class LogstashFormatter(LogstashFormatterVersion1):
if record.name.startswith('awx.analytics'): if record.name.startswith('awx.analytics'):
log_kind = record.name[len('awx.analytics.'):] log_kind = record.name[len('awx.analytics.'):]
fields = self.reformat_data_for_log(fields, kind=log_kind) fields = self.reformat_data_for_log(fields, kind=log_kind)
# General AWX metadata
for log_name, setting_name in [
('type', 'LOG_AGGREGATOR_TYPE'),
('cluster_host_id', 'CLUSTER_HOST_ID'),
('tower_uuid', 'LOG_AGGREGATOR_TOWER_UUID')]:
if hasattr(settings, setting_name):
fields[log_name] = getattr(settings, setting_name, None)
elif log_name == 'type':
fields[log_name] = 'other'
return fields return fields
def format(self, record): def format(self, record):
@@ -158,18 +160,12 @@ class LogstashFormatter(LogstashFormatterVersion1):
'@timestamp': self.format_timestamp(record.created), '@timestamp': self.format_timestamp(record.created),
'message': record.getMessage(), 'message': record.getMessage(),
'host': self.host, 'host': self.host,
'type': self.message_type,
# Extra Fields # Extra Fields
'level': record.levelname, 'level': record.levelname,
'logger_name': record.name, 'logger_name': record.name,
} }
if getattr(self, 'tower_uuid', None):
message['tower_uuid'] = self.tower_uuid
if getattr(self, 'host_id', None):
message['cluster_host_id'] = self.host_id
# Add extra fields # Add extra fields
message.update(self.get_extra_fields(record)) message.update(self.get_extra_fields(record))

View File

@@ -10,6 +10,10 @@ from django.conf import settings
from awx.main.models import Instance from awx.main.models import Instance
def construct_bcast_queue_name(common_name):
return common_name.encode('utf8') + '_' + settings.CLUSTER_HOST_ID
def _add_remove_celery_worker_queues(app, controlled_instances, worker_queues, worker_name): def _add_remove_celery_worker_queues(app, controlled_instances, worker_queues, worker_name):
removed_queues = [] removed_queues = []
added_queues = [] added_queues = []
@@ -19,17 +23,15 @@ def _add_remove_celery_worker_queues(app, controlled_instances, worker_queues, w
ig_names.update(instance.rampart_groups.values_list('name', flat=True)) ig_names.update(instance.rampart_groups.values_list('name', flat=True))
worker_queue_names = set([q['name'] for q in worker_queues]) worker_queue_names = set([q['name'] for q in worker_queues])
bcast_queue_names = set([construct_bcast_queue_name(n) for n in settings.AWX_CELERY_BCAST_QUEUES_STATIC])
all_queue_names = ig_names | hostnames | set(settings.AWX_CELERY_QUEUES_STATIC) all_queue_names = ig_names | hostnames | set(settings.AWX_CELERY_QUEUES_STATIC)
desired_queues = bcast_queue_names | (all_queue_names if instance.enabled else set())
# Remove queues that aren't in the instance group # Remove queues
for queue in worker_queues: for queue_name in worker_queue_names:
if queue['name'] in settings.AWX_CELERY_QUEUES_STATIC or \ if queue_name not in desired_queues:
queue['alias'] in settings.AWX_CELERY_BCAST_QUEUES_STATIC: app.control.cancel_consumer(queue_name.encode("utf8"), reply=True, destination=[worker_name])
continue removed_queues.append(queue_name.encode("utf8"))
if queue['name'] not in all_queue_names or not instance.enabled:
app.control.cancel_consumer(queue['name'].encode("utf8"), reply=True, destination=[worker_name])
removed_queues.append(queue['name'].encode("utf8"))
# Add queues for instance and instance groups # Add queues for instance and instance groups
for queue_name in all_queue_names: for queue_name in all_queue_names:
@@ -37,27 +39,35 @@ def _add_remove_celery_worker_queues(app, controlled_instances, worker_queues, w
app.control.add_consumer(queue_name.encode("utf8"), reply=True, destination=[worker_name]) app.control.add_consumer(queue_name.encode("utf8"), reply=True, destination=[worker_name])
added_queues.append(queue_name.encode("utf8")) added_queues.append(queue_name.encode("utf8"))
# Add stable-named broadcast queues
for queue_name in settings.AWX_CELERY_BCAST_QUEUES_STATIC:
bcast_queue_name = construct_bcast_queue_name(queue_name)
if bcast_queue_name not in worker_queue_names:
app.control.add_consumer(bcast_queue_name,
exchange=queue_name.encode("utf8"),
exchange_type='fanout',
routing_key=queue_name.encode("utf8"),
reply=True)
added_queues.append(bcast_queue_name)
return (added_queues, removed_queues) return (added_queues, removed_queues)
def update_celery_worker_routes(instance, conf): class AWXCeleryRouter(object):
tasks = [ def route_for_task(self, task, args=None, kwargs=None):
'awx.main.tasks.cluster_node_heartbeat', (changed, instance) = Instance.objects.get_or_register()
'awx.main.tasks.purge_old_stdout_files', tasks = [
] 'awx.main.tasks.cluster_node_heartbeat',
routes_updated = {} 'awx.main.tasks.purge_old_stdout_files',
# Instance is, effectively, a controller node ]
if instance.is_controller(): isolated_tasks = [
tasks.append('awx.main.tasks.awx_isolated_heartbeat') 'awx.main.tasks.awx_isolated_heartbeat',
else: ]
if 'awx.main.tasks.awx_isolated_heartbeat' in conf.CELERY_ROUTES: if task in tasks:
del conf.CELERY_ROUTES['awx.main.tasks.awx_isolated_heartbeat'] return {'queue': instance.hostname.encode("utf8"), 'routing_key': instance.hostname.encode("utf8")}
for t in tasks: if instance.is_controller() and task in isolated_tasks:
conf.CELERY_ROUTES[t] = {'queue': instance.hostname.encode("utf8"), 'routing_key': instance.hostname.encode("utf8")} return {'queue': instance.hostname.encode("utf8"), 'routing_key': instance.hostname.encode("utf8")}
routes_updated[t] = conf.CELERY_ROUTES[t]
return routes_updated
def register_celery_worker_queues(app, celery_worker_name): def register_celery_worker_queues(app, celery_worker_name):

View File

@@ -13,40 +13,35 @@ import six
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from requests.exceptions import RequestException from requests.exceptions import RequestException
# loggly # Django
import traceback
from django.conf import settings from django.conf import settings
# requests futures, a dependency used by these handlers
from requests_futures.sessions import FuturesSession from requests_futures.sessions import FuturesSession
# AWX # AWX
from awx.main.utils.formatters import LogstashFormatter from awx.main.utils.formatters import LogstashFormatter
__all__ = ['HTTPSNullHandler', 'BaseHTTPSHandler', 'TCPHandler', 'UDPHandler', __all__ = ['BaseHTTPSHandler', 'TCPHandler', 'UDPHandler',
'configure_external_logger'] 'AWXProxyHandler']
logger = logging.getLogger('awx.main.utils.handlers') logger = logging.getLogger('awx.main.utils.handlers')
# AWX external logging handler, generally designed to be used
# with the accompanying LogstashHandler, derives from python-logstash library
# Non-blocking request accomplished by FuturesSession, similar
# to the loggly-python-handler library (not used)
# Translation of parameter names to names in Django settings # Translation of parameter names to names in Django settings
# logging settings category, only those related to handler / log emission
PARAM_NAMES = { PARAM_NAMES = {
'host': 'LOG_AGGREGATOR_HOST', 'host': 'LOG_AGGREGATOR_HOST',
'port': 'LOG_AGGREGATOR_PORT', 'port': 'LOG_AGGREGATOR_PORT',
'message_type': 'LOG_AGGREGATOR_TYPE', 'message_type': 'LOG_AGGREGATOR_TYPE',
'username': 'LOG_AGGREGATOR_USERNAME', 'username': 'LOG_AGGREGATOR_USERNAME',
'password': 'LOG_AGGREGATOR_PASSWORD', 'password': 'LOG_AGGREGATOR_PASSWORD',
'enabled_loggers': 'LOG_AGGREGATOR_LOGGERS',
'indv_facts': 'LOG_AGGREGATOR_INDIVIDUAL_FACTS', 'indv_facts': 'LOG_AGGREGATOR_INDIVIDUAL_FACTS',
'enabled_flag': 'LOG_AGGREGATOR_ENABLED',
'tcp_timeout': 'LOG_AGGREGATOR_TCP_TIMEOUT', 'tcp_timeout': 'LOG_AGGREGATOR_TCP_TIMEOUT',
'verify_cert': 'LOG_AGGREGATOR_VERIFY_CERT', 'verify_cert': 'LOG_AGGREGATOR_VERIFY_CERT',
'lvl': 'LOG_AGGREGATOR_LEVEL', 'protocol': 'LOG_AGGREGATOR_PROTOCOL'
} }
@@ -58,13 +53,6 @@ class LoggingConnectivityException(Exception):
pass pass
class HTTPSNullHandler(logging.NullHandler):
"Placeholder null handler to allow loading without database access"
def __init__(self, *args, **kwargs):
return super(HTTPSNullHandler, self).__init__()
class VerboseThreadPoolExecutor(ThreadPoolExecutor): class VerboseThreadPoolExecutor(ThreadPoolExecutor):
last_log_emit = 0 last_log_emit = 0
@@ -91,32 +79,25 @@ class VerboseThreadPoolExecutor(ThreadPoolExecutor):
**kwargs) **kwargs)
LEVEL_MAPPING = { class SocketResult:
'DEBUG': logging.DEBUG, '''
'INFO': logging.INFO, A class to be the return type of methods that send data over a socket
'WARNING': logging.WARNING, allows object to be used in the same way as a request futures object
'ERROR': logging.ERROR, '''
'CRITICAL': logging.CRITICAL, def __init__(self, ok, reason=None):
} self.ok = ok
self.reason = reason
def result(self):
return self
class BaseHandler(logging.Handler): class BaseHandler(logging.Handler):
def __init__(self, **kwargs): def __init__(self, host=None, port=None, indv_facts=None, **kwargs):
super(BaseHandler, self).__init__() super(BaseHandler, self).__init__()
for fd in PARAM_NAMES: self.host = host
setattr(self, fd, kwargs.get(fd, None)) self.port = port
self.indv_facts = indv_facts
@classmethod
def from_django_settings(cls, settings, *args, **kwargs):
for param, django_setting_name in PARAM_NAMES.items():
kwargs[param] = getattr(settings, django_setting_name, None)
return cls(*args, **kwargs)
def get_full_message(self, record):
if record.exc_info:
return '\n'.join(traceback.format_exception(*record.exc_info))
else:
return record.getMessage()
def _send(self, payload): def _send(self, payload):
"""Actually send message to log aggregator. """Actually send message to log aggregator.
@@ -128,26 +109,11 @@ class BaseHandler(logging.Handler):
return [self._send(json.loads(self.format(record)))] return [self._send(json.loads(self.format(record)))]
return [self._send(self.format(record))] return [self._send(self.format(record))]
def _skip_log(self, logger_name):
if self.host == '' or (not self.enabled_flag):
return True
# Don't send handler-related records.
if logger_name == logger.name:
return True
# AWX log emission is only turned off by enablement setting
if not logger_name.startswith('awx.analytics'):
return False
return self.enabled_loggers is None or logger_name[len('awx.analytics.'):] not in self.enabled_loggers
def emit(self, record): def emit(self, record):
""" """
Emit a log record. Returns a list of zero or more Emit a log record. Returns a list of zero or more
implementation-specific objects for tests. implementation-specific objects for tests.
""" """
if not record.name.startswith('awx.analytics') and record.levelno < LEVEL_MAPPING[self.lvl]:
return []
if self._skip_log(record.name):
return []
try: try:
return self._format_and_send_record(record) return self._format_and_send_record(record)
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
@@ -181,6 +147,11 @@ class BaseHandler(logging.Handler):
class BaseHTTPSHandler(BaseHandler): class BaseHTTPSHandler(BaseHandler):
'''
Originally derived from python-logstash library
Non-blocking request accomplished by FuturesSession, similar
to the loggly-python-handler library
'''
def _add_auth_information(self): def _add_auth_information(self):
if self.message_type == 'logstash': if self.message_type == 'logstash':
if not self.username: if not self.username:
@@ -196,39 +167,20 @@ class BaseHTTPSHandler(BaseHandler):
} }
self.session.headers.update(headers) self.session.headers.update(headers)
def __init__(self, fqdn=False, **kwargs): def __init__(self, fqdn=False, message_type=None, username=None, password=None,
tcp_timeout=5, verify_cert=True, **kwargs):
self.fqdn = fqdn self.fqdn = fqdn
self.message_type = message_type
self.username = username
self.password = password
self.tcp_timeout = tcp_timeout
self.verify_cert = verify_cert
super(BaseHTTPSHandler, self).__init__(**kwargs) super(BaseHTTPSHandler, self).__init__(**kwargs)
self.session = FuturesSession(executor=VerboseThreadPoolExecutor( self.session = FuturesSession(executor=VerboseThreadPoolExecutor(
max_workers=2 # this is the default used by requests_futures max_workers=2 # this is the default used by requests_futures
)) ))
self._add_auth_information() self._add_auth_information()
@classmethod
def perform_test(cls, settings):
"""
Tests logging connectivity for the current logging settings.
@raises LoggingConnectivityException
"""
handler = cls.from_django_settings(settings)
handler.enabled_flag = True
handler.setFormatter(LogstashFormatter(settings_module=settings))
logger = logging.getLogger(__file__)
fn, lno, func = logger.findCaller()
record = logger.makeRecord('awx', 10, fn, lno,
'AWX Connection Test', tuple(),
None, func)
futures = handler.emit(record)
for future in futures:
try:
resp = future.result()
if not resp.ok:
raise LoggingConnectivityException(
': '.join([str(resp.status_code), resp.reason or ''])
)
except RequestException as e:
raise LoggingConnectivityException(str(e))
def _get_post_kwargs(self, payload_input): def _get_post_kwargs(self, payload_input):
if self.message_type == 'splunk': if self.message_type == 'splunk':
# Splunk needs data nested under key "event" # Splunk needs data nested under key "event"
@@ -265,6 +217,10 @@ def _encode_payload_for_socket(payload):
class TCPHandler(BaseHandler): class TCPHandler(BaseHandler):
def __init__(self, tcp_timeout=5, **kwargs):
self.tcp_timeout = tcp_timeout
super(TCPHandler, self).__init__(**kwargs)
def _send(self, payload): def _send(self, payload):
payload = _encode_payload_for_socket(payload) payload = _encode_payload_for_socket(payload)
sok = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sok = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -273,39 +229,32 @@ class TCPHandler(BaseHandler):
sok.setblocking(0) sok.setblocking(0)
_, ready_to_send, _ = select.select([], [sok], [], float(self.tcp_timeout)) _, ready_to_send, _ = select.select([], [sok], [], float(self.tcp_timeout))
if len(ready_to_send) == 0: if len(ready_to_send) == 0:
logger.warning("Socket currently busy, failed to send message") ret = SocketResult(False, "Socket currently busy, failed to send message")
sok.close() logger.warning(ret.reason)
return else:
sok.send(payload) sok.send(payload)
ret = SocketResult(True) # success!
except Exception as e: except Exception as e:
logger.exception("Error sending message from %s: %s" % ret = SocketResult(False, "Error sending message from %s: %s" %
(TCPHandler.__name__, e.message)) (TCPHandler.__name__,
sok.close() ' '.join(six.text_type(arg) for arg in e.args)))
logger.exception(ret.reason)
finally:
sok.close()
return ret
class UDPHandler(BaseHandler): class UDPHandler(BaseHandler):
message = "Cannot determine if UDP messages are received."
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(UDPHandler, self).__init__(**kwargs) super(UDPHandler, self).__init__(**kwargs)
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
def _send(self, payload): def _send(self, payload):
payload = _encode_payload_for_socket(payload) payload = _encode_payload_for_socket(payload)
return self.socket.sendto(payload, (self._get_host(hostname_only=True), self.port or 0)) self.socket.sendto(payload, (self._get_host(hostname_only=True), self.port or 0))
return SocketResult(True, reason=self.message)
@classmethod
def perform_test(cls, settings):
"""
Tests logging connectivity for the current logging settings.
"""
handler = cls.from_django_settings(settings)
handler.enabled_flag = True
handler.setFormatter(LogstashFormatter(settings_module=settings))
logger = logging.getLogger(__file__)
fn, lno, func = logger.findCaller()
record = logger.makeRecord('awx', 10, fn, lno,
'AWX Connection Test', tuple(),
None, func)
handler.emit(_encode_payload_for_socket(record))
HANDLER_MAPPING = { HANDLER_MAPPING = {
@@ -315,6 +264,88 @@ HANDLER_MAPPING = {
} }
class AWXProxyHandler(logging.Handler):
'''
Handler specific to the AWX external logging feature
Will dynamically create a handler specific to the configured
protocol, and will create a new one automatically on setting change
Managing parameters:
All parameters will get their value from settings as a default
if the parameter was either provided on init, or set manually,
this value will take precedence.
Parameters match same parameters in the actualized handler classes.
'''
def __init__(self, **kwargs):
# TODO: process 'level' kwarg
super(AWXProxyHandler, self).__init__(**kwargs)
self._handler = None
self._old_kwargs = {}
def get_handler_class(self, protocol):
return HANDLER_MAPPING[protocol]
def get_handler(self, custom_settings=None, force_create=False):
new_kwargs = {}
use_settings = custom_settings or settings
for field_name, setting_name in PARAM_NAMES.items():
val = getattr(use_settings, setting_name, None)
if val is None:
continue
new_kwargs[field_name] = val
if new_kwargs == self._old_kwargs and self._handler and (not force_create):
# avoids re-creating session objects, and other such things
return self._handler
self._old_kwargs = new_kwargs.copy()
# TODO: remove any kwargs no applicable to that particular handler
protocol = new_kwargs.pop('protocol', None)
HandlerClass = self.get_handler_class(protocol)
# cleanup old handler and make new one
if self._handler:
self._handler.close()
logger.debug('Creating external log handler due to startup or settings change.')
self._handler = HandlerClass(**new_kwargs)
if self.formatter:
# self.format(record) is called inside of emit method
# so not safe to assume this can be handled within self
self._handler.setFormatter(self.formatter)
return self._handler
def emit(self, record):
actual_handler = self.get_handler()
return actual_handler.emit(record)
def perform_test(self, custom_settings):
"""
Tests logging connectivity for given settings module.
@raises LoggingConnectivityException
"""
handler = self.get_handler(custom_settings=custom_settings, force_create=True)
handler.setFormatter(LogstashFormatter())
logger = logging.getLogger(__file__)
fn, lno, func = logger.findCaller()
record = logger.makeRecord('awx', 10, fn, lno,
'AWX Connection Test', tuple(),
None, func)
futures = handler.emit(record)
for future in futures:
try:
resp = future.result()
if not resp.ok:
if isinstance(resp, SocketResult):
raise LoggingConnectivityException(
'Socket error: {}'.format(resp.reason or '')
)
else:
raise LoggingConnectivityException(
': '.join([str(resp.status_code), resp.reason or ''])
)
except RequestException as e:
raise LoggingConnectivityException(str(e))
ColorHandler = logging.StreamHandler ColorHandler = logging.StreamHandler
if settings.COLOR_LOGS is True: if settings.COLOR_LOGS is True:
@@ -340,41 +371,3 @@ if settings.COLOR_LOGS is True:
except ImportError: except ImportError:
# logutils is only used for colored logs in the dev environment # logutils is only used for colored logs in the dev environment
pass pass
def _add_or_remove_logger(address, instance):
specific_logger = logging.getLogger(address)
for i, handler in enumerate(specific_logger.handlers):
if isinstance(handler, (HTTPSNullHandler, BaseHTTPSHandler)):
specific_logger.handlers[i] = instance or HTTPSNullHandler()
break
else:
if instance is not None:
specific_logger.handlers.append(instance)
def configure_external_logger(settings_module, is_startup=True):
is_enabled = settings_module.LOG_AGGREGATOR_ENABLED
if is_startup and (not is_enabled):
# Pass-through if external logging not being used
return
instance = None
if is_enabled:
handler_class = HANDLER_MAPPING[settings_module.LOG_AGGREGATOR_PROTOCOL]
instance = handler_class.from_django_settings(settings_module)
# Obtain the Formatter class from settings to maintain customizations
configurator = logging.config.DictConfigurator(settings_module.LOGGING)
formatter_config = settings_module.LOGGING['formatters']['json'].copy()
formatter_config['settings_module'] = settings_module
formatter = configurator.configure_custom(formatter_config)
instance.setFormatter(formatter)
awx_logger_instance = instance
if is_enabled and 'awx' not in settings_module.LOG_AGGREGATOR_LOGGERS:
awx_logger_instance = None
_add_or_remove_logger('awx.analytics', instance)
_add_or_remove_logger('awx', awx_logger_instance)

View File

@@ -8,29 +8,9 @@ import logging
# Django # Django
from django.conf import settings from django.conf import settings
# Celery
from celery import Celery
logger = logging.getLogger('awx.main.utils.reload') logger = logging.getLogger('awx.main.utils.reload')
def _uwsgi_fifo_command(uwsgi_command):
# http://uwsgi-docs.readthedocs.io/en/latest/MasterFIFO.html#available-commands
logger.warn('Initiating uWSGI chain reload of server')
TRIGGER_COMMAND = uwsgi_command
with open(settings.UWSGI_FIFO_LOCATION, 'w') as awxfifo:
awxfifo.write(TRIGGER_COMMAND)
def _reset_celery_thread_pool():
# Do not use current_app because of this outstanding issue:
# https://github.com/celery/celery/issues/4410
app = Celery('awx')
app.config_from_object('django.conf:settings')
app.control.broadcast('pool_restart', arguments={'reload': True},
destination=['celery@{}'.format(settings.CLUSTER_HOST_ID)], reply=False)
def _supervisor_service_command(service_internal_names, command, communicate=True): def _supervisor_service_command(service_internal_names, command, communicate=True):
''' '''
Service internal name options: Service internal name options:
@@ -68,21 +48,6 @@ def _supervisor_service_command(service_internal_names, command, communicate=Tru
logger.info('Submitted supervisorctl {} command, not waiting for result'.format(command)) logger.info('Submitted supervisorctl {} command, not waiting for result'.format(command))
def restart_local_services(service_internal_names):
logger.warn('Restarting services {} on this node in response to user action'.format(service_internal_names))
if 'uwsgi' in service_internal_names:
_uwsgi_fifo_command(uwsgi_command='c')
service_internal_names.remove('uwsgi')
restart_celery = False
if 'celery' in service_internal_names:
restart_celery = True
service_internal_names.remove('celery')
_supervisor_service_command(service_internal_names, command='restart')
if restart_celery:
# Celery restarted last because this probably includes current process
_reset_celery_thread_pool()
def stop_local_services(service_internal_names, communicate=True): def stop_local_services(service_internal_names, communicate=True):
logger.warn('Stopping services {} on this node in response to user action'.format(service_internal_names)) logger.warn('Stopping services {} on this node in response to user action'.format(service_internal_names))
_supervisor_service_command(service_internal_names, command='stop', communicate=communicate) _supervisor_service_command(service_internal_names, command='stop', communicate=communicate)

View File

@@ -1,6 +1,6 @@
# Copyright (c) 2017 Red Hat, Inc # Copyright (c) 2017 Red Hat, Inc
from channels import Group import channels
from channels.sessions import channel_session from channels.auth import channel_session_user, channel_session_user_from_http
from awx.network_ui.models import Topology, Device, Link, Client, Interface from awx.network_ui.models import Topology, Device, Link, Client, Interface
from awx.network_ui.models import TopologyInventory from awx.network_ui.models import TopologyInventory
import urlparse import urlparse
@@ -22,6 +22,10 @@ def parse_inventory_id(data):
inventory_id = int(inventory_id[0]) inventory_id = int(inventory_id[0])
except ValueError: except ValueError:
inventory_id = None inventory_id = None
except IndexError:
inventory_id = None
except TypeError:
inventory_id = None
if not inventory_id: if not inventory_id:
inventory_id = None inventory_id = None
return inventory_id return inventory_id
@@ -42,10 +46,10 @@ class NetworkingEvents(object):
message_type = data.pop(0) message_type = data.pop(0)
message_value = data.pop(0) message_value = data.pop(0)
if isinstance(message_value, list): if isinstance(message_value, list):
logger.error("Message has no sender") logger.warning("Message has no sender")
return None, None return None, None
if isinstance(message_value, dict) and client_id != message_value.get('sender'): if isinstance(message_value, dict) and client_id != message_value.get('sender'):
logger.error("client_id mismatch expected: %s actual %s", client_id, message_value.get('sender')) logger.warning("client_id mismatch expected: %s actual %s", client_id, message_value.get('sender'))
return None, None return None, None
return message_type, message_value return message_type, message_value
else: else:
@@ -58,11 +62,19 @@ class NetworkingEvents(object):
of name onX where X is the message type. of name onX where X is the message type.
''' '''
topology_id = message.get('topology') topology_id = message.get('topology')
assert topology_id is not None, "No topology_id" if topology_id is None:
logger.warning("Unsupported message %s: no topology", message)
return
client_id = message.get('client') client_id = message.get('client')
assert client_id is not None, "No client_id" if client_id is None:
logger.warning("Unsupported message %s: no client", message)
return
if 'text' not in message:
logger.warning("Unsupported message %s: no data", message)
return
message_type, message_value = self.parse_message_text(message['text'], client_id) message_type, message_value = self.parse_message_text(message['text'], client_id)
if message_type is None: if message_type is None:
logger.warning("Unsupported message %s: no message type", message)
return return
handler = self.get_handler(message_type) handler = self.get_handler(message_type)
if handler is not None: if handler is not None:
@@ -98,9 +110,6 @@ class NetworkingEvents(object):
def onDeviceMove(self, device, topology_id, client_id): def onDeviceMove(self, device, topology_id, client_id):
Device.objects.filter(topology_id=topology_id, cid=device['id']).update(x=device['x'], y=device['y']) Device.objects.filter(topology_id=topology_id, cid=device['id']).update(x=device['x'], y=device['y'])
def onDeviceInventoryUpdate(self, device, topology_id, client_id):
Device.objects.filter(topology_id=topology_id, cid=device['id']).update(host_id=device['host_id'])
def onDeviceLabelEdit(self, device, topology_id, client_id): def onDeviceLabelEdit(self, device, topology_id, client_id):
logger.debug("Device label edited %s", device) logger.debug("Device label edited %s", device)
Device.objects.filter(topology_id=topology_id, cid=device['id']).update(name=device['name']) Device.objects.filter(topology_id=topology_id, cid=device['id']).update(name=device['name'])
@@ -132,6 +141,12 @@ class NetworkingEvents(object):
device_map = dict(Device.objects device_map = dict(Device.objects
.filter(topology_id=topology_id, cid__in=[link['from_device_id'], link['to_device_id']]) .filter(topology_id=topology_id, cid__in=[link['from_device_id'], link['to_device_id']])
.values_list('cid', 'pk')) .values_list('cid', 'pk'))
if link['from_device_id'] not in device_map:
logger.warning('Device not found')
return
if link['to_device_id'] not in device_map:
logger.warning('Device not found')
return
Link.objects.get_or_create(cid=link['id'], Link.objects.get_or_create(cid=link['id'],
name=link['name'], name=link['name'],
from_device_id=device_map[link['from_device_id']], from_device_id=device_map[link['from_device_id']],
@@ -150,8 +165,10 @@ class NetworkingEvents(object):
.filter(topology_id=topology_id, cid__in=[link['from_device_id'], link['to_device_id']]) .filter(topology_id=topology_id, cid__in=[link['from_device_id'], link['to_device_id']])
.values_list('cid', 'pk')) .values_list('cid', 'pk'))
if link['from_device_id'] not in device_map: if link['from_device_id'] not in device_map:
logger.warning('Device not found')
return return
if link['to_device_id'] not in device_map: if link['to_device_id'] not in device_map:
logger.warning('Device not found')
return return
Link.objects.filter(cid=link['id'], Link.objects.filter(cid=link['id'],
from_device_id=device_map[link['from_device_id']], from_device_id=device_map[link['from_device_id']],
@@ -189,8 +206,15 @@ class NetworkingEvents(object):
networking_events_dispatcher = NetworkingEvents() networking_events_dispatcher = NetworkingEvents()
@channel_session @channel_session_user_from_http
def ws_connect(message): def ws_connect(message):
if not message.user.is_authenticated():
logger.error("Request user is not authenticated to use websocket.")
message.reply_channel.send({"close": True})
return
else:
message.reply_channel.send({"accept": True})
data = urlparse.parse_qs(message.content['query_string']) data = urlparse.parse_qs(message.content['query_string'])
inventory_id = parse_inventory_id(data) inventory_id = parse_inventory_id(data)
topology_ids = list(TopologyInventory.objects.filter(inventory_id=inventory_id).values_list('pk', flat=True)) topology_ids = list(TopologyInventory.objects.filter(inventory_id=inventory_id).values_list('pk', flat=True))
@@ -205,11 +229,11 @@ def ws_connect(message):
TopologyInventory(inventory_id=inventory_id, topology_id=topology.pk).save() TopologyInventory(inventory_id=inventory_id, topology_id=topology.pk).save()
topology_id = topology.pk topology_id = topology.pk
message.channel_session['topology_id'] = topology_id message.channel_session['topology_id'] = topology_id
Group("topology-%s" % topology_id).add(message.reply_channel) channels.Group("topology-%s" % topology_id).add(message.reply_channel)
client = Client() client = Client()
client.save() client.save()
message.channel_session['client_id'] = client.pk message.channel_session['client_id'] = client.pk
Group("client-%s" % client.pk).add(message.reply_channel) channels.Group("client-%s" % client.pk).add(message.reply_channel)
message.reply_channel.send({"text": json.dumps(["id", client.pk])}) message.reply_channel.send({"text": json.dumps(["id", client.pk])})
message.reply_channel.send({"text": json.dumps(["topology_id", topology_id])}) message.reply_channel.send({"text": json.dumps(["topology_id", topology_id])})
topology_data = transform_dict(dict(id='topology_id', topology_data = transform_dict(dict(id='topology_id',
@@ -268,18 +292,18 @@ def send_snapshot(channel, topology_id):
channel.send({"text": json.dumps(["Snapshot", snapshot])}) channel.send({"text": json.dumps(["Snapshot", snapshot])})
@channel_session @channel_session_user
def ws_message(message): def ws_message(message):
# Send to all clients editing the topology # Send to all clients editing the topology
Group("topology-%s" % message.channel_session['topology_id']).send({"text": message['text']}) channels.Group("topology-%s" % message.channel_session['topology_id']).send({"text": message['text']})
# Send to networking_events handler # Send to networking_events handler
networking_events_dispatcher.handle({"text": message['text'], networking_events_dispatcher.handle({"text": message['text'],
"topology": message.channel_session['topology_id'], "topology": message.channel_session['topology_id'],
"client": message.channel_session['client_id']}) "client": message.channel_session['client_id']})
@channel_session @channel_session_user
def ws_disconnect(message): def ws_disconnect(message):
if 'topology_id' in message.channel_session: if 'topology_id' in message.channel_session:
Group("topology-%s" % message.channel_session['topology_id']).discard(message.reply_channel) channels.Group("topology-%s" % message.channel_session['topology_id']).discard(message.reply_channel)

View File

@@ -3,7 +3,7 @@ from channels.routing import route
from awx.network_ui.consumers import ws_connect, ws_message, ws_disconnect from awx.network_ui.consumers import ws_connect, ws_message, ws_disconnect
channel_routing = [ channel_routing = [
route("websocket.connect", ws_connect, path=r"^/network_ui/topology"), route("websocket.connect", ws_connect, path=r"^/network_ui/topology/"),
route("websocket.receive", ws_message, path=r"^/network_ui/topology"), route("websocket.receive", ws_message, path=r"^/network_ui/topology/"),
route("websocket.disconnect", ws_disconnect, path=r"^/network_ui/topology"), route("websocket.disconnect", ws_disconnect, path=r"^/network_ui/topology/"),
] ]

View File

View File

@@ -0,0 +1,9 @@
import pytest
from mock import PropertyMock
@pytest.fixture(autouse=True)
def _disable_database_settings(mocker):
m = mocker.patch('awx.conf.settings.SettingsWrapper.all_supported_settings', new_callable=PropertyMock)
m.return_value = []

View File

View File

@@ -0,0 +1,240 @@
import mock
import logging
import json
import imp
from mock import patch
patch('channels.auth.channel_session_user', lambda x: x).start()
patch('channels.auth.channel_session_user_from_http', lambda x: x).start()
from awx.network_ui.consumers import parse_inventory_id, networking_events_dispatcher, send_snapshot # noqa
from awx.network_ui.models import Topology, Device, Link, Interface, TopologyInventory, Client # noqa
import awx # noqa
import awx.network_ui # noqa
import awx.network_ui.consumers # noqa
imp.reload(awx.network_ui.consumers)
def test_parse_inventory_id():
assert parse_inventory_id({}) is None
assert parse_inventory_id({'inventory_id': ['1']}) == 1
assert parse_inventory_id({'inventory_id': ['0']}) is None
assert parse_inventory_id({'inventory_id': ['X']}) is None
assert parse_inventory_id({'inventory_id': []}) is None
assert parse_inventory_id({'inventory_id': 'x'}) is None
assert parse_inventory_id({'inventory_id': '12345'}) == 1
assert parse_inventory_id({'inventory_id': 1}) is None
def test_network_events_handle_message_incomplete_message1():
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch.object(logger, 'warning') as log_mock:
networking_events_dispatcher.handle({})
log_mock.assert_called_once_with(
'Unsupported message %s: no topology', {})
def test_network_events_handle_message_incomplete_message2():
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch.object(logger, 'warning') as log_mock:
networking_events_dispatcher.handle({'topology': [0]})
log_mock.assert_called_once_with(
'Unsupported message %s: no client', {'topology': [0]})
def test_network_events_handle_message_incomplete_message3():
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch.object(logger, 'warning') as log_mock:
networking_events_dispatcher.handle({'topology': [1]})
log_mock.assert_called_once_with(
'Unsupported message %s: no client', {'topology': [1]})
def test_network_events_handle_message_incomplete_message4():
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch.object(logger, 'warning') as log_mock:
networking_events_dispatcher.handle({'topology': 1, 'client': 1})
log_mock.assert_called_once_with('Unsupported message %s: no data', {
'client': 1, 'topology': 1})
def test_network_events_handle_message_incomplete_message5():
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch.object(logger, 'warning') as log_mock:
message = ['DeviceCreate']
networking_events_dispatcher.handle(
{'topology': 1, 'client': 1, 'text': json.dumps(message)})
log_mock.assert_called_once_with('Unsupported message %s: no message type', {
'text': '["DeviceCreate"]', 'client': 1, 'topology': 1})
def test_network_events_handle_message_incomplete_message6():
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch.object(logger, 'warning') as log_mock:
message = ['DeviceCreate', []]
networking_events_dispatcher.handle(
{'topology': 1, 'client': 1, 'text': json.dumps(message)})
log_mock.assert_has_calls([
mock.call('Message has no sender'),
mock.call('Unsupported message %s: no message type', {'text': '["DeviceCreate", []]', 'client': 1, 'topology': 1})])
def test_network_events_handle_message_incomplete_message7():
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch.object(logger, 'warning') as log_mock:
message = ['DeviceCreate', {}]
networking_events_dispatcher.handle(
{'topology': 1, 'client': 1, 'text': json.dumps(message)})
log_mock.assert_has_calls([
mock.call('client_id mismatch expected: %s actual %s', 1, None),
mock.call('Unsupported message %s: no message type', {'text': '["DeviceCreate", {}]', 'client': 1, 'topology': 1})])
def test_network_events_handle_message_incomplete_message8():
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch.object(logger, 'warning') as log_mock:
message = ['Unsupported', {'sender': 1}]
networking_events_dispatcher.handle(
{'topology': 1, 'client': 1, 'text': json.dumps(message)})
log_mock.assert_called_once_with(
'Unsupported message %s: no handler', u'Unsupported')
def test_send_snapshot_empty():
channel = mock.MagicMock()
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Device, 'objects'),\
mock.patch.object(Link, 'objects'),\
mock.patch.object(Interface, 'objects'),\
mock.patch.object(Topology, 'objects'):
send_snapshot(channel, 1)
log_mock.assert_not_called()
channel.send.assert_called_once_with(
{'text': '["Snapshot", {"links": [], "devices": [], "sender": 0}]'})
def test_send_snapshot_single():
channel = mock.MagicMock()
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Device, 'objects') as device_objects_mock,\
mock.patch.object(Link, 'objects'),\
mock.patch.object(Interface, 'objects') as interface_objects_mock:
interface_objects_mock.filter.return_value.values.return_value = [
dict(cid=1, device_id=1, id=1, name="eth0")]
device_objects_mock.filter.return_value.values.return_value = [
dict(cid=1, id=1, device_type="host", name="host1", x=0, y=0,
interface_id_seq=1, host_id=1)]
send_snapshot(channel, 1)
device_objects_mock.filter.assert_called_once_with(topology_id=1)
device_objects_mock.filter.return_value.values.assert_called_once_with()
interface_objects_mock.filter.assert_called_once_with(
device__topology_id=1)
interface_objects_mock.filter.return_value.values.assert_called_once_with()
log_mock.assert_not_called()
channel.send.assert_called_once_with(
{'text': '''["Snapshot", {"links": [], "devices": [{"interface_id_seq": 1, \
"name": "host1", "interfaces": [{"id": 1, "device_id": 1, "name": "eth0", "interface_id": 1}], \
"device_type": "host", "host_id": 1, "y": 0, "x": 0, "id": 1, "device_id": 1}], "sender": 0}]'''})
def test_ws_disconnect():
message = mock.MagicMock()
message.channel_session = dict(topology_id=1)
message.reply_channel = 'foo'
with mock.patch('channels.Group') as group_mock:
awx.network_ui.consumers.ws_disconnect(message)
group_mock.assert_called_once_with('topology-1')
group_mock.return_value.discard.assert_called_once_with('foo')
def test_ws_disconnect_no_topology():
message = mock.MagicMock()
with mock.patch('channels.Group') as group_mock:
awx.network_ui.consumers.ws_disconnect(message)
group_mock.assert_not_called()
def test_ws_message():
message = mock.MagicMock()
message.channel_session = dict(topology_id=1, client_id=1)
message.__getitem__.return_value = json.dumps([])
print (message['text'])
with mock.patch('channels.Group') as group_mock:
awx.network_ui.consumers.ws_message(message)
group_mock.assert_called_once_with('topology-1')
group_mock.return_value.send.assert_called_once_with({'text': '[]'})
def test_ws_connect_unauthenticated():
message = mock.MagicMock()
message.user.is_authenticated.return_value = False
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch.object(logger, 'error') as log_mock:
awx.network_ui.consumers.ws_connect(message)
log_mock.assert_called_once_with('Request user is not authenticated to use websocket.')
def test_ws_connect_new_topology():
message = mock.MagicMock()
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch('awx.network_ui.consumers.Client') as client_mock,\
mock.patch('awx.network_ui.consumers.Topology') as topology_mock,\
mock.patch('channels.Group'),\
mock.patch('awx.network_ui.consumers.send_snapshot') as send_snapshot_mock,\
mock.patch.object(logger, 'warning'),\
mock.patch.object(TopologyInventory, 'objects'),\
mock.patch.object(TopologyInventory, 'save'),\
mock.patch.object(Topology, 'save'),\
mock.patch.object(Topology, 'objects'),\
mock.patch.object(Device, 'objects'),\
mock.patch.object(Link, 'objects'),\
mock.patch.object(Interface, 'objects'):
client_mock.return_value.pk = 777
topology_mock.return_value = Topology(
name="topology", scale=1.0, panX=0, panY=0, pk=999)
awx.network_ui.consumers.ws_connect(message)
message.reply_channel.send.assert_has_calls([
mock.call({'text': '["id", 777]'}),
mock.call({'text': '["topology_id", 999]'}),
mock.call(
{'text': '["Topology", {"scale": 1.0, "name": "topology", "device_id_seq": 0, "panY": 0, "panX": 0, "topology_id": 999, "link_id_seq": 0}]'}),
])
send_snapshot_mock.assert_called_once_with(message.reply_channel, 999)
def test_ws_connect_existing_topology():
message = mock.MagicMock()
logger = logging.getLogger('awx.network_ui.consumers')
with mock.patch('awx.network_ui.consumers.Client') as client_mock,\
mock.patch('awx.network_ui.consumers.send_snapshot') as send_snapshot_mock,\
mock.patch('channels.Group'),\
mock.patch.object(logger, 'warning'),\
mock.patch.object(TopologyInventory, 'objects') as topology_inventory_objects_mock,\
mock.patch.object(TopologyInventory, 'save'),\
mock.patch.object(Topology, 'save'),\
mock.patch.object(Topology, 'objects') as topology_objects_mock,\
mock.patch.object(Device, 'objects'),\
mock.patch.object(Link, 'objects'),\
mock.patch.object(Interface, 'objects'):
topology_inventory_objects_mock.filter.return_value.values_list.return_value = [
1]
client_mock.return_value.pk = 888
topology_objects_mock.get.return_value = Topology(pk=1001,
id=1,
name="topo",
panX=0,
panY=0,
scale=1.0,
link_id_seq=1,
device_id_seq=1)
awx.network_ui.consumers.ws_connect(message)
message.reply_channel.send.assert_has_calls([
mock.call({'text': '["id", 888]'}),
mock.call({'text': '["topology_id", 1001]'}),
mock.call(
{'text': '["Topology", {"scale": 1.0, "name": "topo", "device_id_seq": 1, "panY": 0, "panX": 0, "topology_id": 1001, "link_id_seq": 1}]'}),
])
send_snapshot_mock.assert_called_once_with(message.reply_channel, 1001)

View File

@@ -0,0 +1,15 @@
from awx.network_ui.models import Device, Topology, Interface
def test_device():
assert str(Device(name="foo")) == "foo"
def test_topology():
assert str(Topology(name="foo")) == "foo"
def test_interface():
assert str(Interface(name="foo")) == "foo"

View File

@@ -0,0 +1,451 @@
import mock
import json
import logging
from awx.network_ui.consumers import networking_events_dispatcher
from awx.network_ui.models import Topology, Device, Link, Interface
def message(message):
def wrapper(fn):
fn.tests_message = message
return fn
return wrapper
@message('DeviceMove')
def test_network_events_handle_message_DeviceMove():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['DeviceMove', dict(
msg_type='DeviceMove',
sender=1,
id=1,
x=100,
y=100,
previous_x=0,
previous_y=0
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Device, 'objects') as device_objects_mock:
networking_events_dispatcher.handle(message)
device_objects_mock.filter.assert_called_once_with(
cid=1, topology_id=1)
device_objects_mock.filter.return_value.update.assert_called_once_with(
x=100, y=100)
log_mock.assert_not_called()
@message('DeviceCreate')
def test_network_events_handle_message_DeviceCreate():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['DeviceCreate', dict(msg_type='DeviceCreate',
sender=1,
id=1,
x=0,
y=0,
name="test_created",
type='host',
host_id=None)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Topology.objects, 'filter') as topology_objects_mock,\
mock.patch.object(Device.objects, 'get_or_create') as device_objects_mock:
device_mock = mock.MagicMock()
filter_mock = mock.MagicMock()
device_objects_mock.return_value = [device_mock, True]
topology_objects_mock.return_value = filter_mock
networking_events_dispatcher.handle(message)
device_objects_mock.assert_called_once_with(
cid=1,
defaults={'name': u'test_created', 'cid': 1, 'device_type': u'host',
'x': 0, 'y': 0, 'host_id': None},
topology_id=1)
device_mock.save.assert_called_once_with()
topology_objects_mock.assert_called_once_with(
device_id_seq__lt=1, pk=1)
filter_mock.update.assert_called_once_with(device_id_seq=1)
log_mock.assert_not_called()
@message('DeviceLabelEdit')
def test_network_events_handle_message_DeviceLabelEdit():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['DeviceLabelEdit', dict(
msg_type='DeviceLabelEdit',
sender=1,
id=1,
name='test_changed',
previous_name='test_created'
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Device.objects, 'filter') as device_objects_filter_mock:
networking_events_dispatcher.handle(message)
device_objects_filter_mock.assert_called_once_with(
cid=1, topology_id=1)
log_mock.assert_not_called()
@message('DeviceSelected')
def test_network_events_handle_message_DeviceSelected():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['DeviceSelected', dict(
msg_type='DeviceSelected',
sender=1,
id=1
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock:
networking_events_dispatcher.handle(message)
log_mock.assert_not_called()
@message('DeviceUnSelected')
def test_network_events_handle_message_DeviceUnSelected():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['DeviceUnSelected', dict(
msg_type='DeviceUnSelected',
sender=1,
id=1
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock:
networking_events_dispatcher.handle(message)
log_mock.assert_not_called()
@message('DeviceDestroy')
def test_network_events_handle_message_DeviceDestory():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['DeviceDestroy', dict(
msg_type='DeviceDestroy',
sender=1,
id=1,
previous_x=0,
previous_y=0,
previous_name="",
previous_type="host",
previous_host_id="1")]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Device, 'objects') as device_objects_mock:
networking_events_dispatcher.handle(message)
device_objects_mock.filter.assert_called_once_with(
cid=1, topology_id=1)
device_objects_mock.filter.return_value.delete.assert_called_once_with()
log_mock.assert_not_called()
@message('InterfaceCreate')
def test_network_events_handle_message_InterfaceCreate():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['InterfaceCreate', dict(
msg_type='InterfaceCreate',
sender=1,
device_id=1,
id=1,
name='eth0'
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Device, 'objects') as device_objects_mock,\
mock.patch.object(Interface, 'objects') as interface_objects_mock:
device_objects_mock.get.return_value.pk = 99
networking_events_dispatcher.handle(message)
device_objects_mock.get.assert_called_once_with(cid=1, topology_id=1)
device_objects_mock.filter.assert_called_once_with(
cid=1, interface_id_seq__lt=1, topology_id=1)
interface_objects_mock.get_or_create.assert_called_once_with(
cid=1, defaults={'name': u'eth0'}, device_id=99)
log_mock.assert_not_called()
@message('InterfaceLabelEdit')
def test_network_events_handle_message_InterfaceLabelEdit():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['InterfaceLabelEdit', dict(
msg_type='InterfaceLabelEdit',
sender=1,
id=1,
device_id=1,
name='new name',
previous_name='old name'
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Interface, 'objects') as interface_objects_mock:
networking_events_dispatcher.handle(message)
interface_objects_mock.filter.assert_called_once_with(
cid=1, device__cid=1, device__topology_id=1)
interface_objects_mock.filter.return_value.update.assert_called_once_with(
name=u'new name')
log_mock.assert_not_called()
@message('LinkLabelEdit')
def test_network_events_handle_message_LinkLabelEdit():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['LinkLabelEdit', dict(
msg_type='LinkLabelEdit',
sender=1,
id=1,
name='new name',
previous_name='old name'
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Link, 'objects') as link_objects_mock:
networking_events_dispatcher.handle(message)
link_objects_mock.filter.assert_called_once_with(
cid=1, from_device__topology_id=1)
link_objects_mock.filter.return_value.update.assert_called_once_with(
name=u'new name')
log_mock.assert_not_called()
@message('LinkCreate')
def test_network_events_handle_message_LinkCreate():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['LinkCreate', dict(
msg_type='LinkCreate',
id=1,
sender=1,
name="",
from_device_id=1,
to_device_id=2,
from_interface_id=1,
to_interface_id=1
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Device, 'objects') as device_objects_mock,\
mock.patch.object(Link, 'objects') as link_objects_mock,\
mock.patch.object(Interface, 'objects') as interface_objects_mock,\
mock.patch.object(Topology, 'objects') as topology_objects_mock:
values_list_mock = mock.MagicMock()
values_list_mock.values_list.return_value = [(1,1), (2,2)]
interface_objects_mock.get.return_value = mock.MagicMock()
interface_objects_mock.get.return_value.pk = 7
device_objects_mock.filter.return_value = values_list_mock
topology_objects_mock.filter.return_value = mock.MagicMock()
networking_events_dispatcher.handle(message)
device_objects_mock.filter.assert_called_once_with(
cid__in=[1, 2], topology_id=1)
values_list_mock.values_list.assert_called_once_with('cid', 'pk')
link_objects_mock.get_or_create.assert_called_once_with(
cid=1, from_device_id=1, from_interface_id=7, name=u'',
to_device_id=2, to_interface_id=7)
topology_objects_mock.filter.assert_called_once_with(
link_id_seq__lt=1, pk=1)
topology_objects_mock.filter.return_value.update.assert_called_once_with(
link_id_seq=1)
log_mock.assert_not_called()
@message('LinkCreate')
def test_network_events_handle_message_LinkCreate_bad_device1():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['LinkCreate', dict(
msg_type='LinkCreate',
id=1,
sender=1,
name="",
from_device_id=1,
to_device_id=2,
from_interface_id=1,
to_interface_id=1
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Device, 'objects') as device_objects_mock,\
mock.patch.object(Link, 'objects'),\
mock.patch.object(Interface, 'objects') as interface_objects_mock,\
mock.patch.object(Topology, 'objects') as topology_objects_mock:
values_list_mock = mock.MagicMock()
values_list_mock.values_list.return_value = [(9,1), (2,2)]
interface_objects_mock.get.return_value = mock.MagicMock()
interface_objects_mock.get.return_value.pk = 7
device_objects_mock.filter.return_value = values_list_mock
topology_objects_mock.filter.return_value = mock.MagicMock()
networking_events_dispatcher.handle(message)
device_objects_mock.filter.assert_called_once_with(
cid__in=[1, 2], topology_id=1)
values_list_mock.values_list.assert_called_once_with('cid', 'pk')
log_mock.assert_called_once_with('Device not found')
@message('LinkCreate')
def test_network_events_handle_message_LinkCreate_bad_device2():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['LinkCreate', dict(
msg_type='LinkCreate',
id=1,
sender=1,
name="",
from_device_id=1,
to_device_id=2,
from_interface_id=1,
to_interface_id=1
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Device, 'objects') as device_objects_mock,\
mock.patch.object(Link, 'objects'),\
mock.patch.object(Interface, 'objects') as interface_objects_mock,\
mock.patch.object(Topology, 'objects') as topology_objects_mock:
values_list_mock = mock.MagicMock()
values_list_mock.values_list.return_value = [(1,1), (9,2)]
interface_objects_mock.get.return_value = mock.MagicMock()
interface_objects_mock.get.return_value.pk = 7
device_objects_mock.filter.return_value = values_list_mock
topology_objects_mock.filter.return_value = mock.MagicMock()
networking_events_dispatcher.handle(message)
device_objects_mock.filter.assert_called_once_with(
cid__in=[1, 2], topology_id=1)
values_list_mock.values_list.assert_called_once_with('cid', 'pk')
log_mock.assert_called_once_with('Device not found')
@message('LinkDestroy')
def test_network_events_handle_message_LinkDestroy():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['LinkDestroy', dict(
msg_type='LinkDestroy',
id=1,
sender=1,
name="",
from_device_id=1,
to_device_id=2,
from_interface_id=1,
to_interface_id=1
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Device.objects, 'filter') as device_filter_mock,\
mock.patch.object(Link.objects, 'filter') as link_filter_mock,\
mock.patch.object(Interface.objects, 'get') as interface_get_mock:
values_mock = mock.MagicMock()
interface_get_mock.return_value = mock.MagicMock()
interface_get_mock.return_value.pk = 7
device_filter_mock.return_value = values_mock
values_mock.values_list.return_value = [(1,1), (2,2)]
networking_events_dispatcher.handle(message)
device_filter_mock.assert_called_once_with(
cid__in=[1, 2], topology_id=1)
values_mock.values_list.assert_called_once_with('cid', 'pk')
link_filter_mock.assert_called_once_with(
cid=1, from_device_id=1, from_interface_id=7, to_device_id=2, to_interface_id=7)
log_mock.assert_not_called()
@message('LinkDestroy')
def test_network_events_handle_message_LinkDestroy_bad_device_map1():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['LinkDestroy', dict(
msg_type='LinkDestroy',
id=1,
sender=1,
name="",
from_device_id=1,
to_device_id=2,
from_interface_id=1,
to_interface_id=1
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Device.objects, 'filter') as device_filter_mock,\
mock.patch.object(Link.objects, 'filter'),\
mock.patch.object(Interface.objects, 'get') as interface_get_mock:
values_mock = mock.MagicMock()
interface_get_mock.return_value = mock.MagicMock()
interface_get_mock.return_value.pk = 7
device_filter_mock.return_value = values_mock
values_mock.values_list.return_value = [(9,1), (2,2)]
networking_events_dispatcher.handle(message)
log_mock.assert_called_once_with('Device not found')
@message('LinkDestroy')
def test_network_events_handle_message_LinkDestroy_bad_device_map2():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['LinkDestroy', dict(
msg_type='LinkDestroy',
id=1,
sender=1,
name="",
from_device_id=1,
to_device_id=2,
from_interface_id=1,
to_interface_id=1
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock,\
mock.patch.object(Device.objects, 'filter') as device_filter_mock,\
mock.patch.object(Link.objects, 'filter'),\
mock.patch.object(Interface.objects, 'get') as interface_get_mock:
values_mock = mock.MagicMock()
interface_get_mock.return_value = mock.MagicMock()
interface_get_mock.return_value.pk = 7
device_filter_mock.return_value = values_mock
values_mock.values_list.return_value = [(1,1), (9,2)]
networking_events_dispatcher.handle(message)
log_mock.assert_called_once_with('Device not found')
@message('LinkSelected')
def test_network_events_handle_message_LinkSelected():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['LinkSelected', dict(
msg_type='LinkSelected',
sender=1,
id=1
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock:
networking_events_dispatcher.handle(message)
log_mock.assert_not_called()
@message('LinkUnSelected')
def test_network_events_handle_message_LinkUnSelected():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['LinkUnSelected', dict(
msg_type='LinkUnSelected',
sender=1,
id=1
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock:
networking_events_dispatcher.handle(message)
log_mock.assert_not_called()
@message('MultipleMessage')
def test_network_events_handle_message_MultipleMessage_unsupported_message():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['MultipleMessage', dict(
msg_type='MultipleMessage',
sender=1,
messages=[dict(msg_type="Unsupported")]
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock:
networking_events_dispatcher.handle(message)
log_mock.assert_called_once_with(
'Unsupported message %s', u'Unsupported')
@message('MultipleMessage')
def test_network_events_handle_message_MultipleMessage():
logger = logging.getLogger('awx.network_ui.consumers')
message_data = ['MultipleMessage', dict(
msg_type='MultipleMessage',
sender=1,
messages=[dict(msg_type="DeviceSelected")]
)]
message = {'topology': 1, 'client': 1, 'text': json.dumps(message_data)}
with mock.patch.object(logger, 'warning') as log_mock:
networking_events_dispatcher.handle(message)
log_mock.assert_not_called()

View File

@@ -0,0 +1,9 @@
import awx.network_ui.routing
def test_routing():
'''
Tests that the number of routes in awx.network_ui.routing is 3.
'''
assert len(awx.network_ui.routing.channel_routing) == 3

View File

@@ -0,0 +1,65 @@
import mock
from awx.network_ui.views import topology_data, NetworkAnnotatedInterface, json_topology_data, yaml_topology_data
from awx.network_ui.models import Topology, Device, Link, Interface
def test_topology_data():
with mock.patch.object(Topology, 'objects'),\
mock.patch.object(Device, 'objects') as device_objects_mock,\
mock.patch.object(Link, 'objects') as link_objects_mock,\
mock.patch.object(Interface, 'objects'),\
mock.patch.object(NetworkAnnotatedInterface, 'filter'):
device_objects_mock.filter.return_value.order_by.return_value = [
Device(pk=1), Device(pk=2)]
link_objects_mock.filter.return_value = [Link(from_device=Device(name='from', cid=1),
to_device=Device(
name='to', cid=2),
from_interface=Interface(
name="eth0", cid=1),
to_interface=Interface(
name="eth0", cid=1),
name="",
pk=1
)]
data = topology_data(1)
assert len(data['devices']) == 2
assert len(data['links']) == 1
def test_json_topology_data():
request = mock.MagicMock()
request.GET = dict(topology_id=1)
with mock.patch('awx.network_ui.views.topology_data') as topology_data_mock:
topology_data_mock.return_value = dict()
json_topology_data(request)
topology_data_mock.assert_called_once_with(1)
def test_yaml_topology_data():
request = mock.MagicMock()
request.GET = dict(topology_id=1)
with mock.patch('awx.network_ui.views.topology_data') as topology_data_mock:
topology_data_mock.return_value = dict()
yaml_topology_data(request)
topology_data_mock.assert_called_once_with(1)
def test_json_topology_data_no_topology_id():
request = mock.MagicMock()
request.GET = dict()
with mock.patch('awx.network_ui.views.topology_data') as topology_data_mock:
topology_data_mock.return_value = dict()
json_topology_data(request)
topology_data_mock.assert_not_called()
def test_yaml_topology_data_no_topology_id():
request = mock.MagicMock()
request.GET = dict()
with mock.patch('awx.network_ui.views.topology_data') as topology_data_mock:
topology_data_mock.return_value = dict()
yaml_topology_data(request)
topology_data_mock.assert_not_called()

View File

@@ -5,6 +5,6 @@ from awx.network_ui import views
app_name = 'network_ui' app_name = 'network_ui'
urlpatterns = [ urlpatterns = [
url(r'^topology.json$', views.json_topology_data, name='json_topology_data'), url(r'^topology.json/?$', views.json_topology_data, name='json_topology_data'),
url(r'^topology.yaml$', views.yaml_topology_data, name='yaml_topology_data'), url(r'^topology.yaml/?$', views.yaml_topology_data, name='yaml_topology_data'),
] ]

View File

@@ -1,11 +1,9 @@
# Copyright (c) 2017 Red Hat, Inc # Copyright (c) 2017 Red Hat, Inc
from django.shortcuts import render
from django import forms from django import forms
from django.http import JsonResponse, HttpResponseBadRequest, HttpResponse from django.http import JsonResponse, HttpResponseBadRequest, HttpResponse
from awx.network_ui.models import Topology, Device, Link, Interface from awx.network_ui.models import Topology, Device, Link, Interface
from django.db.models import Q from django.db.models import Q
import yaml import yaml
import json
NetworkAnnotatedInterface = Interface.objects.values('name', NetworkAnnotatedInterface = Interface.objects.values('name',
'cid', 'cid',
@@ -63,18 +61,6 @@ def topology_data(topology_id):
return data return data
def yaml_serialize_topology(topology_id):
return yaml.safe_dump(topology_data(topology_id), default_flow_style=False)
def json_serialize_topology(topology_id):
return json.dumps(topology_data(topology_id))
def index(request):
return render(request, "network_ui/index.html", dict(topologies=Topology.objects.all().order_by('-pk')))
class TopologyForm(forms.Form): class TopologyForm(forms.Form):
topology_id = forms.IntegerField() topology_id = forms.IntegerField()
@@ -82,7 +68,10 @@ class TopologyForm(forms.Form):
def json_topology_data(request): def json_topology_data(request):
form = TopologyForm(request.GET) form = TopologyForm(request.GET)
if form.is_valid(): if form.is_valid():
return JsonResponse(topology_data(form.cleaned_data['topology_id'])) response = JsonResponse(topology_data(form.cleaned_data['topology_id']),
content_type='application/force-download')
response['Content-Disposition'] = 'attachment; filename="{}"'.format('topology.json')
return response
else: else:
return HttpResponseBadRequest(form.errors) return HttpResponseBadRequest(form.errors)
@@ -90,9 +79,11 @@ def json_topology_data(request):
def yaml_topology_data(request): def yaml_topology_data(request):
form = TopologyForm(request.GET) form = TopologyForm(request.GET)
if form.is_valid(): if form.is_valid():
return HttpResponse(yaml.safe_dump(topology_data(form.cleaned_data['topology_id']), response = HttpResponse(yaml.safe_dump(topology_data(form.cleaned_data['topology_id']),
default_flow_style=False), default_flow_style=False),
content_type='application/yaml') content_type='application/force-download')
response['Content-Disposition'] = 'attachment; filename="{}"'.format('topology.yaml')
return response
else: else:
return HttpResponseBadRequest(form.errors) return HttpResponseBadRequest(form.errors)

View File

@@ -90,6 +90,7 @@ def read_tower_inventory(tower_host, tower_user, tower_pass, inventory, license_
tower_host = "https://{}".format(tower_host) tower_host = "https://{}".format(tower_host)
inventory_url = urljoin(tower_host, "/api/v2/inventories/{}/script/?hostvars=1&towervars=1&all=1".format(inventory.replace('/', ''))) inventory_url = urljoin(tower_host, "/api/v2/inventories/{}/script/?hostvars=1&towervars=1&all=1".format(inventory.replace('/', '')))
config_url = urljoin(tower_host, "/api/v2/config/") config_url = urljoin(tower_host, "/api/v2/config/")
reason = None
try: try:
if license_type != "open": if license_type != "open":
config_response = requests.get(config_url, config_response = requests.get(config_url,
@@ -106,14 +107,16 @@ def read_tower_inventory(tower_host, tower_user, tower_pass, inventory, license_
response = requests.get(inventory_url, response = requests.get(inventory_url,
auth=HTTPBasicAuth(tower_user, tower_pass), auth=HTTPBasicAuth(tower_user, tower_pass),
verify=not ignore_ssl) verify=not ignore_ssl)
try:
json_response = response.json()
except (ValueError, TypeError) as e:
reason = "Failed to parse json from host: {}".format(e)
if response.ok: if response.ok:
return response.json() return json_response
json_reason = response.json() if not reason:
reason = json_reason.get('detail', 'Retrieving Tower Inventory Failed') reason = json_response.get('detail', 'Retrieving Tower Inventory Failed')
except requests.ConnectionError as e: except requests.ConnectionError as e:
reason = "Connection to remote host failed: {}".format(e) reason = "Connection to remote host failed: {}".format(e)
except json.JSONDecodeError as e:
reason = "Failed to parse json from host: {}".format(e)
raise RuntimeError(reason) raise RuntimeError(reason)

View File

@@ -4,13 +4,10 @@
import os import os
import re # noqa import re # noqa
import sys import sys
import ldap
import djcelery import djcelery
import six import six
from datetime import timedelta from datetime import timedelta
from kombu.common import Broadcast
# global settings # global settings
from django.conf import global_settings from django.conf import global_settings
# ugettext lazy # ugettext lazy
@@ -41,6 +38,13 @@ def IS_TESTING(argv=None):
return is_testing(argv) return is_testing(argv)
if "pytest" in sys.modules:
import mock
with mock.patch('__main__.__builtins__.dir', return_value=[]):
import ldap
else:
import ldap
DEBUG = True DEBUG = True
SQL_DEBUG = DEBUG SQL_DEBUG = DEBUG
@@ -456,6 +460,9 @@ BROKER_POOL_LIMIT = None
BROKER_URL = 'amqp://guest:guest@localhost:5672//' BROKER_URL = 'amqp://guest:guest@localhost:5672//'
CELERY_EVENT_QUEUE_TTL = 5 CELERY_EVENT_QUEUE_TTL = 5
CELERY_DEFAULT_QUEUE = 'awx_private_queue' CELERY_DEFAULT_QUEUE = 'awx_private_queue'
CELERY_DEFAULT_EXCHANGE = 'awx_private_queue'
CELERY_DEFAULT_ROUTING_KEY = 'awx_private_queue'
CELERY_DEFAULT_EXCHANGE_TYPE = 'direct'
CELERY_TASK_SERIALIZER = 'json' CELERY_TASK_SERIALIZER = 'json'
CELERY_RESULT_SERIALIZER = 'json' CELERY_RESULT_SERIALIZER = 'json'
CELERY_ACCEPT_CONTENT = ['json'] CELERY_ACCEPT_CONTENT = ['json']
@@ -466,10 +473,8 @@ CELERYD_POOL_RESTARTS = True
CELERYD_AUTOSCALER = 'awx.main.utils.autoscale:DynamicAutoScaler' CELERYD_AUTOSCALER = 'awx.main.utils.autoscale:DynamicAutoScaler'
CELERY_RESULT_BACKEND = 'djcelery.backends.database:DatabaseBackend' CELERY_RESULT_BACKEND = 'djcelery.backends.database:DatabaseBackend'
CELERY_IMPORTS = ('awx.main.scheduler.tasks',) CELERY_IMPORTS = ('awx.main.scheduler.tasks',)
CELERY_QUEUES = ( CELERY_QUEUES = ()
Broadcast('tower_broadcast_all'), CELERY_ROUTES = ('awx.main.utils.ha.AWXCeleryRouter',)
)
CELERY_ROUTES = {}
def log_celery_failure(*args): def log_celery_failure(*args):
@@ -532,19 +537,12 @@ ASGI_AMQP = {
} }
# Django Caching Configuration # Django Caching Configuration
if is_testing(): CACHES = {
CACHES = { 'default': {
'default': { 'BACKEND': 'django.core.cache.backends.memcached.MemcachedCache',
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', 'LOCATION': 'memcached:11211',
}, },
} }
else:
CACHES = {
'default': {
'BACKEND': 'django.core.cache.backends.memcached.MemcachedCache',
'LOCATION': 'memcached:11211',
},
}
# Social Auth configuration. # Social Auth configuration.
SOCIAL_AUTH_STRATEGY = 'social_django.strategy.DjangoStrategy' SOCIAL_AUTH_STRATEGY = 'social_django.strategy.DjangoStrategy'
@@ -1005,6 +1003,9 @@ LOGGING = {
'require_debug_true_or_test': { 'require_debug_true_or_test': {
'()': 'awx.main.utils.RequireDebugTrueOrTest', '()': 'awx.main.utils.RequireDebugTrueOrTest',
}, },
'external_log_enabled': {
'()': 'awx.main.utils.filters.ExternalLoggerEnabled'
},
}, },
'formatters': { 'formatters': {
'simple': { 'simple': {
@@ -1038,11 +1039,10 @@ LOGGING = {
'class': 'logging.NullHandler', 'class': 'logging.NullHandler',
'formatter': 'simple', 'formatter': 'simple',
}, },
'http_receiver': { 'external_logger': {
'class': 'awx.main.utils.handlers.HTTPSNullHandler', 'class': 'awx.main.utils.handlers.AWXProxyHandler',
'level': 'DEBUG',
'formatter': 'json', 'formatter': 'json',
'host': '', 'filters': ['external_log_enabled'],
}, },
'mail_admins': { 'mail_admins': {
'level': 'ERROR', 'level': 'ERROR',
@@ -1135,7 +1135,7 @@ LOGGING = {
'handlers': ['console'], 'handlers': ['console'],
}, },
'awx': { 'awx': {
'handlers': ['console', 'file', 'tower_warnings'], 'handlers': ['console', 'file', 'tower_warnings', 'external_logger'],
'level': 'DEBUG', 'level': 'DEBUG',
}, },
'awx.conf': { 'awx.conf': {
@@ -1160,16 +1160,13 @@ LOGGING = {
'propagate': False 'propagate': False
}, },
'awx.main.tasks': { 'awx.main.tasks': {
'handlers': ['task_system'], 'handlers': ['task_system', 'external_logger'],
'propagate': False 'propagate': False
}, },
'awx.main.scheduler': { 'awx.main.scheduler': {
'handlers': ['task_system'], 'handlers': ['task_system', 'external_logger'],
'propagate': False 'propagate': False
}, },
'awx.main.consumers': {
'handlers': ['null']
},
'awx.main.access': { 'awx.main.access': {
'handlers': ['null'], 'handlers': ['null'],
'propagate': False, 'propagate': False,
@@ -1183,7 +1180,7 @@ LOGGING = {
'propagate': False, 'propagate': False,
}, },
'awx.analytics': { 'awx.analytics': {
'handlers': ['http_receiver'], 'handlers': ['external_logger'],
'level': 'INFO', 'level': 'INFO',
'propagate': False 'propagate': False
}, },

View File

@@ -9,6 +9,7 @@ import socket
import copy import copy
import sys import sys
import traceback import traceback
import uuid
# Centos-7 doesn't include the svg mime type # Centos-7 doesn't include the svg mime type
# /usr/lib64/python/mimetypes.py # /usr/lib64/python/mimetypes.py
@@ -20,6 +21,15 @@ from split_settings.tools import optional, include
# Load default settings. # Load default settings.
from defaults import * # NOQA from defaults import * # NOQA
# don't use memcache when running tests
if "pytest" in sys.modules:
CACHES = {
'default': {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
'LOCATION': 'unique-{}'.format(str(uuid.uuid4())),
},
}
# awx-manage shell_plus --notebook # awx-manage shell_plus --notebook
NOTEBOOK_ARGUMENTS = [ NOTEBOOK_ARGUMENTS = [
'--NotebookApp.token=', '--NotebookApp.token=',
@@ -103,13 +113,6 @@ if 'django_jenkins' in INSTALLED_APPS:
INSTALLED_APPS += ('rest_framework_swagger',) INSTALLED_APPS += ('rest_framework_swagger',)
# Much faster than the default
# https://docs.djangoproject.com/en/1.6/topics/auth/passwords/#how-django-stores-passwords
PASSWORD_HASHERS = (
'django.contrib.auth.hashers.MD5PasswordHasher',
'django.contrib.auth.hashers.PBKDF2PasswordHasher',
)
# Configure a default UUID for development only. # Configure a default UUID for development only.
SYSTEM_UUID = '00000000-0000-0000-0000-000000000000' SYSTEM_UUID = '00000000-0000-0000-0000-000000000000'
@@ -149,8 +152,6 @@ SERVICE_NAME_DICT = {
"uwsgi": "uwsgi", "uwsgi": "uwsgi",
"daphne": "daphne", "daphne": "daphne",
"nginx": "nginx"} "nginx": "nginx"}
# Used for sending commands in automatic restart
UWSGI_FIFO_LOCATION = '/awxfifo'
try: try:
socket.gethostbyname('docker.for.mac.internal') socket.gethostbyname('docker.for.mac.internal')

View File

@@ -13,6 +13,7 @@
############################################################################### ###############################################################################
import os import os
import urllib import urllib
import sys
def patch_broken_pipe_error(): def patch_broken_pipe_error():
@@ -66,7 +67,7 @@ DATABASES = {
# Use SQLite for unit tests instead of PostgreSQL. If the lines below are # Use SQLite for unit tests instead of PostgreSQL. If the lines below are
# commented out, Django will create the test_awx-dev database in PostgreSQL to # commented out, Django will create the test_awx-dev database in PostgreSQL to
# run unit tests. # run unit tests.
if is_testing(sys.argv): if "pytest" in sys.modules:
DATABASES = { DATABASES = {
'default': { 'default': {
'ENGINE': 'django.db.backends.sqlite3', 'ENGINE': 'django.db.backends.sqlite3',
@@ -195,11 +196,10 @@ LOGGING['handlers']['syslog'] = {
LOGGING['loggers']['django.request']['handlers'] = ['console'] LOGGING['loggers']['django.request']['handlers'] = ['console']
LOGGING['loggers']['rest_framework.request']['handlers'] = ['console'] LOGGING['loggers']['rest_framework.request']['handlers'] = ['console']
LOGGING['loggers']['awx']['handlers'] = ['console'] LOGGING['loggers']['awx']['handlers'] = ['console', 'external_logger']
LOGGING['loggers']['awx.main.commands.run_callback_receiver']['handlers'] = ['console'] LOGGING['loggers']['awx.main.commands.run_callback_receiver']['handlers'] = ['console']
LOGGING['loggers']['awx.main.commands.inventory_import']['handlers'] = ['console'] LOGGING['loggers']['awx.main.tasks']['handlers'] = ['console', 'external_logger']
LOGGING['loggers']['awx.main.tasks']['handlers'] = ['console'] LOGGING['loggers']['awx.main.scheduler']['handlers'] = ['console', 'external_logger']
LOGGING['loggers']['awx.main.scheduler']['handlers'] = ['console']
LOGGING['loggers']['django_auth_ldap']['handlers'] = ['console'] LOGGING['loggers']['django_auth_ldap']['handlers'] = ['console']
LOGGING['loggers']['social']['handlers'] = ['console'] LOGGING['loggers']['social']['handlers'] = ['console']
LOGGING['loggers']['system_tracking_migrations']['handlers'] = ['console'] LOGGING['loggers']['system_tracking_migrations']['handlers'] = ['console']

View File

@@ -68,8 +68,6 @@ SERVICE_NAME_DICT = {
"channels": "awx-channels-worker", "channels": "awx-channels-worker",
"uwsgi": "awx-uwsgi", "uwsgi": "awx-uwsgi",
"daphne": "awx-daphne"} "daphne": "awx-daphne"}
# Used for sending commands in automatic restart
UWSGI_FIFO_LOCATION = '/var/lib/awx/awxfifo'
# Store a snapshot of default settings at this point before loading any # Store a snapshot of default settings at this point before loading any
# customizable config files. # customizable config files.

View File

@@ -0,0 +1,6 @@
# Ensure that our autouse overwrites are working
def test_cache(settings):
assert settings.CACHES['default']['BACKEND'] == 'django.core.cache.backends.locmem.LocMemCache'
assert settings.CACHES['default']['LOCATION'].startswith('unique-')

View File

@@ -157,8 +157,8 @@
</div> </div>
<div class="response-info" aria-label="{% trans "response info" %}"> <div class="response-info" aria-label="{% trans "response info" %}">
<pre class="prettyprint"><span class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %}{% for key, val in response_headers|items %} <pre class="prettyprint"><span class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %}{% if response_headers %}{% for key, val in response_headers|items %}
<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span>{% endfor %} <b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span>{% endfor %}{% endif %}
{# Original line below had the side effect of also escaping content: #} {# Original line below had the side effect of also escaping content: #}
{# </span>{{ content|urlize_quoted_links }}</pre>{% endautoescape %} #} {# </span>{{ content|urlize_quoted_links }}</pre>{% endautoescape %} #}
{# For AWX, disable automatic URL creation and move content outside of autoescape off block. #} {# For AWX, disable automatic URL creation and move content outside of autoescape off block. #}

View File

@@ -54,6 +54,18 @@ module.exports = {
'no-multiple-empty-lines': ['error', { max: 1 }], 'no-multiple-empty-lines': ['error', { max: 1 }],
'object-curly-newline': 'off', 'object-curly-newline': 'off',
'space-before-function-paren': ['error', 'always'], 'space-before-function-paren': ['error', 'always'],
'no-trailing-spaces': ['error'] 'no-trailing-spaces': ['error'],
} 'prefer-destructuring': ['error', {
'VariableDeclarator': {
'array': false,
'object': true
},
'AssignmentExpression': {
'array': false,
'object': true
}
}, {
'enforceForRenamedProperties': false
}]
}
}; };

View File

@@ -1,3 +1,2 @@
@import 'credentials/_index';
@import 'output/_index'; @import 'output/_index';
@import 'users/tokens/_index'; @import 'users/tokens/_index';

View File

@@ -3,10 +3,8 @@ function AddApplicationsController (models, $state, strings) {
const { application, me, organization } = models; const { application, me, organization } = models;
const omit = [ const omit = [
'authorization_grant_type',
'client_id', 'client_id',
'client_secret', 'client_secret',
'client_type',
'created', 'created',
'modified', 'modified',
'related', 'related',
@@ -44,19 +42,16 @@ function AddApplicationsController (models, $state, strings) {
vm.form.organization._resource = 'organization'; vm.form.organization._resource = 'organization';
vm.form.organization._route = 'applications.add.organization'; vm.form.organization._route = 'applications.add.organization';
vm.form.organization._model = organization; vm.form.organization._model = organization;
vm.form.organization._placeholder = strings.get('SELECT AN ORGANIZATION'); vm.form.organization._placeholder = strings.get('inputs.ORGANIZATION_PLACEHOLDER');
vm.form.name.required = true; vm.form.name.required = true;
vm.form.organization.required = true; vm.form.organization.required = true;
vm.form.redirect_uris.required = true;
delete vm.form.name.help_text; delete vm.form.name.help_text;
vm.form.save = data => { vm.form.save = data => {
const hiddenData = { const hiddenData = {
authorization_grant_type: 'implicit', user: me.get('id')
user: me.get('id'),
client_type: 'public'
}; };
const payload = _.merge(data, hiddenData); const payload = _.merge(data, hiddenData);

View File

@@ -14,7 +14,9 @@
<at-input-text col="4" tab="2" state="vm.form.description"></at-input-text> <at-input-text col="4" tab="2" state="vm.form.description"></at-input-text>
<at-input-lookup col="4" tab="3" state="vm.form.organization"></at-input-lookup> <at-input-lookup col="4" tab="3" state="vm.form.organization"></at-input-lookup>
<at-divider></at-divider> <at-divider></at-divider>
<at-input-text col="4" tab="4" state="vm.form.redirect_uris"></at-input-text> <at-input-select col="4" tab="4" state="vm.form.authorization_grant_type"></at-input-select>
<at-input-text col="4" tab="5" state="vm.form.redirect_uris"></at-input-text>
<at-input-select col="4" tab="6" state="vm.form.client_type"></at-input-select>
<at-action-group col="12" pos="right"> <at-action-group col="12" pos="right">
<at-form-action type="cancel" to="applications"></at-form-action> <at-form-action type="cancel" to="applications"></at-form-action>

View File

@@ -16,6 +16,10 @@ function ApplicationsStrings (BaseString) {
USERS: t.s('Tokens') USERS: t.s('Tokens')
}; };
ns.tooltips = {
ADD: t.s('Create a new Application')
};
ns.add = { ns.add = {
PANEL_TITLE: t.s('NEW APPLICATION') PANEL_TITLE: t.s('NEW APPLICATION')
}; };
@@ -25,6 +29,10 @@ function ApplicationsStrings (BaseString) {
ROW_ITEM_LABEL_ORGANIZATION: t.s('ORG'), ROW_ITEM_LABEL_ORGANIZATION: t.s('ORG'),
ROW_ITEM_LABEL_MODIFIED: t.s('LAST MODIFIED') ROW_ITEM_LABEL_MODIFIED: t.s('LAST MODIFIED')
}; };
ns.inputs = {
ORGANIZATION_PLACEHOLDER: t.s('SELECT AN ORGANIZATION')
};
} }
ApplicationsStrings.$inject = ['BaseStringService']; ApplicationsStrings.$inject = ['BaseStringService'];

View File

@@ -4,10 +4,8 @@ function EditApplicationsController (models, $state, strings, $scope) {
const { me, application, organization } = models; const { me, application, organization } = models;
const omit = [ const omit = [
'authorization_grant_type',
'client_id', 'client_id',
'client_secret', 'client_secret',
'client_type',
'created', 'created',
'modified', 'modified',
'related', 'related',
@@ -54,45 +52,30 @@ function EditApplicationsController (models, $state, strings, $scope) {
vm.form.disabled = !isEditable; vm.form.disabled = !isEditable;
vm.form.name.required = true;
const isOrgAdmin = _.some(me.get('related.admin_of_organizations.results'), (org) => org.id === organization.get('id')); const isOrgAdmin = _.some(me.get('related.admin_of_organizations.results'), (org) => org.id === organization.get('id'));
const isSuperuser = me.get('is_superuser'); const isSuperuser = me.get('is_superuser');
const isCurrentAuthor = Boolean(application.get('summary_fields.created_by.id') === me.get('id')); const isCurrentAuthor = Boolean(application.get('summary_fields.created_by.id') === me.get('id'));
vm.form.organization = {
type: 'field',
label: 'Organization',
id: 'organization'
};
vm.form.description = {
type: 'String',
label: 'Description',
id: 'description'
};
vm.form.organization._resource = 'organization';
vm.form.organization._route = 'applications.edit.organization';
vm.form.organization._model = organization;
vm.form.organization._placeholder = strings.get('SELECT AN ORGANIZATION');
// TODO: org not returned via api endpoint, check on this
vm.form.organization._value = application.get('organization');
vm.form.organization._disabled = true; vm.form.organization._disabled = true;
if (isSuperuser || isOrgAdmin || (application.get('organization') === null && isCurrentAuthor)) { if (isSuperuser || isOrgAdmin || (application.get('organization') === null && isCurrentAuthor)) {
vm.form.organization._disabled = false; vm.form.organization._disabled = false;
} }
vm.form.name.required = true; vm.form.organization._resource = 'organization';
vm.form.organization._model = organization;
vm.form.organization._route = 'applications.edit.organization';
vm.form.organization._value = application.get('summary_fields.organization.id');
vm.form.organization._displayValue = application.get('summary_fields.organization.name');
vm.form.organization._placeholder = strings.get('inputs.ORGANIZATION_PLACEHOLDER');
vm.form.organization.required = true; vm.form.organization.required = true;
vm.form.redirect_uris.required = true;
delete vm.form.name.help_text; delete vm.form.name.help_text;
vm.form.save = data => { vm.form.save = data => {
const hiddenData = { const hiddenData = {
authorization_grant_type: 'implicit', user: me.get('id')
user: me.get('id'),
client_type: 'public'
}; };
const payload = _.merge(data, hiddenData); const payload = _.merge(data, hiddenData);

View File

@@ -62,8 +62,7 @@ function ApplicationsRun ($stateExtender, strings) {
}, },
data: { data: {
activityStream: true, activityStream: true,
// TODO: double-check activity stream works activityStreamTarget: 'o_auth2_application'
activityStreamTarget: 'application'
}, },
views: { views: {
'@': { '@': {
@@ -111,8 +110,7 @@ function ApplicationsRun ($stateExtender, strings) {
}, },
data: { data: {
activityStream: true, activityStream: true,
// TODO: double-check activity stream works activityStreamTarget: 'o_auth2_application'
activityStreamTarget: 'application'
}, },
views: { views: {
'add@applications': { 'add@applications': {
@@ -134,7 +132,7 @@ function ApplicationsRun ($stateExtender, strings) {
}, },
data: { data: {
activityStream: true, activityStream: true,
activityStreamTarget: 'application', activityStreamTarget: 'o_auth2_application',
activityStreamId: 'application_id' activityStreamId: 'application_id'
}, },
views: { views: {
@@ -264,8 +262,7 @@ function ApplicationsRun ($stateExtender, strings) {
}, },
data: { data: {
activityStream: true, activityStream: true,
// TODO: double-check activity stream works activityStreamTarget: 'o_auth2_application'
activityStreamTarget: 'application'
}, },
views: { views: {
'userList@applications.edit': { 'userList@applications.edit': {

View File

@@ -38,6 +38,10 @@ function ListApplicationsController (
vm.applicationsCount = dataset.count; vm.applicationsCount = dataset.count;
}); });
vm.tooltips = {
add: strings.get('tooltips.ADD')
};
vm.getModified = app => { vm.getModified = app => {
const modified = _.get(app, 'modified'); const modified = _.get(app, 'modified');
@@ -74,7 +78,7 @@ function ListApplicationsController (
} }
if (parseInt($state.params.application_id, 10) === app.id) { if (parseInt($state.params.application_id, 10) === app.id) {
$state.go('^', reloadListStateParams, { reload: true }); $state.go('applications', reloadListStateParams, { reload: true });
} else { } else {
$state.go('.', reloadListStateParams, { reload: true }); $state.go('.', reloadListStateParams, { reload: true });
} }

View File

@@ -23,6 +23,9 @@
type="button" type="button"
ui-sref="applications.add" ui-sref="applications.add"
class="at-Button--add" class="at-Button--add"
id="button-add"
aw-tool-tip="{{vm.tooltips.add}}"
data-placement="top"
aria-haspopup="true" aria-haspopup="true"
aria-expanded="false"> aria-expanded="false">
</button> </button>

View File

@@ -1,3 +0,0 @@
.at-CredentialsPermissions {
margin-top: 50px;
}

View File

@@ -69,8 +69,8 @@ function LegacyCredentialsService () {
ngClick: '$state.go(\'.add\')', ngClick: '$state.go(\'.add\')',
label: 'Add', label: 'Add',
awToolTip: N_('Add a permission'), awToolTip: N_('Add a permission'),
actionClass: 'btn List-buttonSubmit', actionClass: 'at-Button--add',
buttonContent: `&#43; ${N_('ADD')}`, actionId: 'button-add',
ngShow: '(credential_obj.summary_fields.user_capabilities.edit || canAdd)' ngShow: '(credential_obj.summary_fields.user_capabilities.edit || canAdd)'
} }
}, },

Some files were not shown because too many files have changed in this diff Show More