From 4873e2413f325c11bd6c9377a41e80f00eeea490 Mon Sep 17 00:00:00 2001 From: Chris Church Date: Tue, 16 Feb 2016 17:49:34 -0500 Subject: [PATCH] * Populate browsable API raw data form with submitted request data in response to an update. * Remove fields from browsable API raw data that are set implicitly based on URL / parent object. * Fix issue where a group/host could be assigned to a different inventory. * Update validation to load values from existing instance if not present in new data; allows PATCH requests to succeed. * Remove job_args, job_cwd, job_env, result_stdout and result_traceback fields from job listings. --- awx/api/generics.py | 28 +++++++++++- awx/api/parsers.py | 30 +++++++++++++ awx/api/renderers.py | 16 ++++++- awx/api/serializers.py | 93 +++++++++++++++++++--------------------- awx/api/views.py | 56 +++++++++++++++++++++++- awx/settings/defaults.py | 2 +- 6 files changed, 170 insertions(+), 55 deletions(-) create mode 100644 awx/api/parsers.py diff --git a/awx/api/generics.py b/awx/api/generics.py index 6618263742..605c824468 100644 --- a/awx/api/generics.py +++ b/awx/api/generics.py @@ -2,6 +2,7 @@ # All Rights Reserved. # Python +from collections import OrderedDict import inspect import logging import time @@ -155,6 +156,22 @@ class APIView(views.APIView): context = self.get_description_context() return render_to_string(template_list, context) + def update_raw_data(self, data): + # Remove the parent key if the view is a sublist, since it will be set + # automatically. + parent_key = getattr(self, 'parent_key', None) + if parent_key: + data.pop(parent_key, None) + + # Use request data as-is when original request is an update and the + # submitted data was rejected. + request_method = getattr(self, '_raw_data_request_method', None) + response_status = getattr(self, '_raw_data_response_status', 0) + if request_method in ('POST', 'PUT', 'PATCH') and response_status in xrange(400, 500): + return self.request.data.copy() + + return data + class GenericAPIView(generics.GenericAPIView, APIView): # Base class for all model-based views. @@ -166,11 +183,14 @@ class GenericAPIView(generics.GenericAPIView, APIView): def get_serializer(self, *args, **kwargs): serializer = super(GenericAPIView, self).get_serializer(*args, **kwargs) # Override when called from browsable API to generate raw data form; - # always remove read only fields from sample raw data. + # update serializer "validated" data to be displayed by the raw data + # form. if hasattr(self, '_raw_data_form_marker'): + # Always remove read only fields from serializer. for name, field in serializer.fields.items(): if getattr(field, 'read_only', None): - del serializer.fields[name] + del serializer.fields[name] + serializer._data = self.update_raw_data(serializer.data) return serializer def get_queryset(self): @@ -439,6 +459,10 @@ class RetrieveUpdateAPIView(RetrieveAPIView, generics.RetrieveUpdateAPIView): self.update_filter(request, *args, **kwargs) return super(RetrieveUpdateAPIView, self).update(request, *args, **kwargs) + def partial_update(self, request, *args, **kwargs): + self.update_filter(request, *args, **kwargs) + return super(RetrieveUpdateAPIView, self).partial_update(request, *args, **kwargs) + def update_filter(self, request, *args, **kwargs): ''' scrub any fields the user cannot/should not put/patch, based on user context. This runs after read-only serialization filtering ''' pass diff --git a/awx/api/parsers.py b/awx/api/parsers.py new file mode 100644 index 0000000000..94ddbec561 --- /dev/null +++ b/awx/api/parsers.py @@ -0,0 +1,30 @@ +# Python +from collections import OrderedDict +import json + +# Django +from django.conf import settings +from django.utils import six + +# Django REST Framework +from rest_framework import parsers +from rest_framework.exceptions import ParseError + + +class JSONParser(parsers.JSONParser): + """ + Parses JSON-serialized data, preserving order of dictionary keys. + """ + + def parse(self, stream, media_type=None, parser_context=None): + """ + Parses the incoming bytestream as JSON and returns the resulting data. + """ + parser_context = parser_context or {} + encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + + try: + data = stream.read().decode(encoding) + return json.loads(data, object_pairs_hook=OrderedDict) + except ValueError as exc: + raise ParseError('JSON parse error - %s' % six.text_type(exc)) diff --git a/awx/api/renderers.py b/awx/api/renderers.py index 1897028333..348a8220c4 100644 --- a/awx/api/renderers.py +++ b/awx/api/renderers.py @@ -17,17 +17,29 @@ class BrowsableAPIRenderer(renderers.BrowsableAPIRenderer): return renderers.JSONRenderer() return renderer + def get_context(self, data, accepted_media_type, renderer_context): + # Store the associated response status to know how to populate the raw + # data form. + try: + setattr(renderer_context['view'], '_raw_data_response_status', renderer_context['response'].status_code) + return super(BrowsableAPIRenderer, self).get_context(data, accepted_media_type, renderer_context) + finally: + delattr(renderer_context['view'], '_raw_data_response_status') + def get_raw_data_form(self, data, view, method, request): # Set a flag on the view to indiciate to the view/serializer that we're - # creating a raw data form for the browsable API. + # creating a raw data form for the browsable API. Store the original + # request method to determine how to populate the raw data form. try: setattr(view, '_raw_data_form_marker', True) + setattr(view, '_raw_data_request_method', request.method) return super(BrowsableAPIRenderer, self).get_raw_data_form(data, view, method, request) finally: delattr(view, '_raw_data_form_marker') + delattr(view, '_raw_data_request_method') def get_rendered_html_form(self, data, view, method, request): - '''Never show auto-generated form (only raw form).''' + # Never show auto-generated form (only raw form). obj = getattr(view, 'object', None) if not self.show_form_for_method(view, method, request, obj): return diff --git a/awx/api/serializers.py b/awx/api/serializers.py index 572e690ce6..1c8fd17bd8 100644 --- a/awx/api/serializers.py +++ b/awx/api/serializers.py @@ -463,24 +463,6 @@ class BaseSerializer(serializers.ModelSerializer): raise ValidationError(d) return attrs - def to_representation(self, obj): - # FIXME: Doesn't get called anymore for an new raw data form! - # When rendering the raw data form, create an instance of the model so - # that the model defaults will be filled in. - view = self.context.get('view', None) - parent_key = getattr(view, 'parent_key', None) - if not obj and hasattr(view, '_raw_data_form_marker'): - obj = self.Meta.model() - # FIXME: Would be nice to include any posted data for the raw data - # form, so that a submission with errors can be modified in place - # and resubmitted. - ret = super(BaseSerializer, self).to_representation(obj) - # Remove parent key from raw form data, since it will be automatically - # set by the sub list create view. - if parent_key and hasattr(view, '_raw_data_form_marker'): - ret.pop(parent_key, None) - return ret - class BaseFactSerializer(DocumentSerializer): @@ -611,6 +593,12 @@ class UnifiedJobListSerializer(UnifiedJobSerializer): class Meta: fields = ('*', '-job_args', '-job_cwd', '-job_env', '-result_traceback', '-result_stdout') + def get_field_names(self, declared_fields, info): + field_names = super(UnifiedJobListSerializer, self).get_field_names(declared_fields, info) + # Meta multiple inheritance and -field_name options don't seem to be + # taking effect above, so remove the undesired fields here. + return tuple(x for x in field_names if x not in ('job_args', 'job_cwd', 'job_env', 'result_traceback', 'result_stdout')) + def get_types(self): if type(self) is UnifiedJobListSerializer: return ['project_update', 'inventory_update', 'job', 'ad_hoc_command', 'system_job'] @@ -995,6 +983,14 @@ class HostSerializer(BaseSerializerWithVariables): 'last_job_host_summary') read_only_fields = ('last_job', 'last_job_host_summary') + def build_relational_field(self, field_name, relation_info): + field_class, field_kwargs = super(HostSerializer, self).build_relational_field(field_name, relation_info) + # Inventory is read-only unless creating a new host. + if self.instance and field_name == 'inventory': + field_kwargs['read_only'] = True + field_kwargs.pop('queryset', None) + return field_class, field_kwargs + def get_related(self, obj): res = super(HostSerializer, self).get_related(obj) res.update(dict( @@ -1053,15 +1049,12 @@ class HostSerializer(BaseSerializerWithVariables): return value def validate(self, attrs): - name = force_text(attrs.get('name', '')) + name = force_text(attrs.get('name', self.instance and self.instance.name or '')) host, port = self._get_host_port_from_name(name) if port: attrs['name'] = host - if self.instance: - variables = force_text(attrs.get('variables', self.instance.variables) or '') - else: - variables = force_text(attrs.get('variables', '')) + variables = force_text(attrs.get('variables', self.instance and self.instance.variables or '')) try: vars_dict = json.loads(variables.strip() or '{}') vars_dict['ansible_ssh_port'] = port @@ -1099,6 +1092,14 @@ class GroupSerializer(BaseSerializerWithVariables): 'total_hosts', 'hosts_with_active_failures', 'total_groups', 'groups_with_active_failures', 'has_inventory_sources') + def build_relational_field(self, field_name, relation_info): + field_class, field_kwargs = super(GroupSerializer, self).build_relational_field(field_name, relation_info) + # Inventory is read-only unless creating a new group. + if self.instance and field_name == 'inventory': + field_kwargs['read_only'] = True + field_kwargs.pop('queryset', None) + return field_class, field_kwargs + def get_related(self, obj): res = super(GroupSerializer, self).get_related(obj) res.update(dict( @@ -1247,9 +1248,10 @@ class InventorySourceOptionsSerializer(BaseSerializer): # TODO: Validate source, validate source_regions errors = {} - source_script = attrs.get('source_script', None) - if 'source' in attrs and attrs.get('source', '') == 'custom': - if source_script is None or source_script == '': + source = attrs.get('source', self.instance and self.instance.source or '') + source_script = attrs.get('source_script', self.instance and self.instance.source_script or '') + if source == 'custom': + if not source_script is None or source_script == '': errors['source_script'] = 'source_script must be provided' else: try: @@ -1403,15 +1405,19 @@ class PermissionSerializer(BaseSerializer): def validate(self, attrs): # Can only set either user or team. - if attrs.get('user', None) and attrs.get('team', None): + user = attrs.get('user', self.instance and self.instance.user or None) + team = attrs.get('team', self.instance and self.instance.team or None) + if user and team: raise serializers.ValidationError('permission can only be assigned' ' to a user OR a team, not both') # Cannot assign admit/read/write permissions for a project. - if attrs.get('permission_type', None) in ('admin', 'read', 'write') and attrs.get('project', None): + permission_type = attrs.get('permission_type', self.instance and self.instance.permission_type or None) + project = attrs.get('project', self.instance and self.instance.project or None) + if permission_type in ('admin', 'read', 'write') and project: raise serializers.ValidationError('project cannot be assigned for ' 'inventory-only permissions') # Project is required when setting deployment permissions. - if attrs.get('permission_type', None) in ('run', 'check') and not attrs.get('project', None): + if permission_type in ('run', 'check') and not project: raise serializers.ValidationError('project is required when ' 'assigning deployment permissions') @@ -1522,9 +1528,10 @@ class JobOptionsSerializer(BaseSerializer): def validate(self, attrs): if 'project' in self.fields and 'playbook' in self.fields: - project = attrs.get('project', None) - playbook = attrs.get('playbook', '') - if not project and attrs.get('job_type') != PERM_INVENTORY_SCAN: + project = attrs.get('project', self.instance and self.instance.project or None) + playbook = attrs.get('playbook', self.instance and self.instance.playbook or '') + job_type = attrs.get('job_type', self.instance and self.instance.job_type or None) + if not project and job_type != PERM_INVENTORY_SCAN: raise serializers.ValidationError({'project': 'This field is required.'}) if project and playbook and force_text(playbook) not in project.playbooks: raise serializers.ValidationError({'playbook': 'Playbook not found for project'}) @@ -1578,8 +1585,8 @@ class JobTemplateSerializer(UnifiedJobTemplateSerializer, JobOptionsSerializer): return d def validate(self, attrs): - survey_enabled = attrs.get('survey_enabled', False) - job_type = attrs.get('job_type', None) + survey_enabled = attrs.get('survey_enabled', self.instance and self.instance.survey_enabled or False) + job_type = attrs.get('job_type', self.instance and self.instance.job_type or None) if survey_enabled and job_type == PERM_INVENTORY_SCAN: raise serializers.ValidationError({'survey_enabled': 'Survey Enabled can not be used with scan jobs'}) @@ -1737,8 +1744,8 @@ class AdHocCommandSerializer(UnifiedJobSerializer): def get_field_names(self, declared_fields, info): field_names = super(AdHocCommandSerializer, self).get_field_names(declared_fields, info) - # Meta inheritance and -field_name options don't seem to be taking - # effect above, so remove the undesired fields here. + # Meta multiple inheritance and -field_name options don't seem to be + # taking effect above, so remove the undesired fields here. return tuple(x for x in field_names if x not in ('unified_job_template', 'description')) def build_standard_field(self, field_name, model_field): @@ -1770,19 +1777,7 @@ class AdHocCommandSerializer(UnifiedJobSerializer): return res def to_representation(self, obj): - # In raw data form, populate limit field from host/group name. - view = self.context.get('view', None) - parent_model = getattr(view, 'parent_model', None) - if not (obj and obj.pk) and view and hasattr(view, '_raw_data_form_marker'): - if not obj: - obj = self.Meta.model() ret = super(AdHocCommandSerializer, self).to_representation(obj) - # Hide inventory and limit fields from raw data, since they will be set - # automatically by sub list create view. - if not (obj and obj.pk) and view and hasattr(view, '_raw_data_form_marker'): - if parent_model in (Host, Group): - ret.pop('inventory', None) - ret.pop('limit', None) if 'inventory' in ret and (not obj.inventory or not obj.inventory.active): ret['inventory'] = None if 'credential' in ret and (not obj.credential or not obj.credential.active): @@ -1993,7 +1988,7 @@ class JobLaunchSerializer(BaseSerializer): obj = self.context.get('obj') data = self.context.get('data') - credential = attrs.get('credential', None) or (obj and obj.credential) + credential = attrs.get('credential', obj and obj.credential or None) if not credential or not credential.active: errors['credential'] = 'Credential not provided' diff --git a/awx/api/views.py b/awx/api/views.py index 9a41e779ea..aebe4e0a91 100644 --- a/awx/api/views.py +++ b/awx/api/views.py @@ -568,8 +568,21 @@ class AuthTokenView(APIView): serializer_class = AuthTokenSerializer model = AuthToken + def get_serializer(self, *args, **kwargs): + serializer = self.serializer_class(*args, **kwargs) + # Override when called from browsable API to generate raw data form; + # update serializer "validated" data to be displayed by the raw data + # form. + if hasattr(self, '_raw_data_form_marker'): + # Always remove read only fields from serializer. + for name, field in serializer.fields.items(): + if getattr(field, 'read_only', None): + del serializer.fields[name] + serializer._data = self.update_raw_data(serializer.data) + return serializer + def post(self, request): - serializer = self.serializer_class(data=request.data) + serializer = self.get_serializer(data=request.data) if serializer.is_valid(): request_hash = AuthToken.get_request_hash(self.request) try: @@ -1178,6 +1191,19 @@ class HostGroupsList(SubListCreateAttachDetachAPIView): parent_model = Host relationship = 'groups' + def update_raw_data(self, data): + data.pop('inventory', None) + return super(HostGroupsList, self).update_raw_data(data) + + def create(self, request, *args, **kwargs): + # Inject parent host inventory ID into new group data. + data = request.data + # HACK: Make request data mutable. + if getattr(data, '_mutable', None) is False: + data._mutable = True + data['inventory'] = self.get_parent_object().inventory_id + return super(HostGroupsList, self).create(request, *args, **kwargs) + class HostAllGroupsList(SubListAPIView): ''' the list of all groups of which the host is directly or indirectly a member ''' @@ -1334,6 +1360,19 @@ class GroupChildrenList(SubListCreateAttachDetachAPIView): parent_model = Group relationship = 'children' + def update_raw_data(self, data): + data.pop('inventory', None) + return super(GroupChildrenList, self).update_raw_data(data) + + def create(self, request, *args, **kwargs): + # Inject parent group inventory ID into new group data. + data = request.data + # HACK: Make request data mutable. + if getattr(data, '_mutable', None) is False: + data._mutable = True + data['inventory'] = self.get_parent_object().inventory_id + return super(GroupChildrenList, self).create(request, *args, **kwargs) + def unattach(self, request, *args, **kwargs): sub_id = request.data.get('id', None) if sub_id is not None: @@ -1394,8 +1433,14 @@ class GroupHostsList(SubListCreateAttachDetachAPIView): parent_model = Group relationship = 'hosts' + def update_raw_data(self, data): + data.pop('inventory', None) + return super(GroupHostsList, self).update_raw_data(data) + def create(self, request, *args, **kwargs): parent_group = Group.objects.get(id=self.kwargs['pk']) + # Inject parent group inventory ID into new host data. + request.data['inventory'] = parent_group.inventory_id existing_hosts = Host.objects.filter(inventory=parent_group.inventory, name=request.data['name']) if existing_hosts.count() > 0 and ('variables' not in request.data or request.data['variables'] == '' or @@ -2583,6 +2628,15 @@ class AdHocCommandList(ListCreateAPIView): def dispatch(self, *args, **kwargs): return super(AdHocCommandList, self).dispatch(*args, **kwargs) + def update_raw_data(self, data): + # Hide inventory and limit fields from raw data, since they will be set + # automatically by sub list create view. + parent_model = getattr(self, 'parent_model', None) + if parent_model in (Host, Group): + data.pop('inventory', None) + data.pop('limit', None) + return super(AdHocCommandList, self).update_raw_data(data) + def create(self, request, *args, **kwargs): # Inject inventory ID and limit if parent objects is a host/group. if hasattr(self, 'get_parent_object') and not getattr(self, 'parent_key', None): diff --git a/awx/settings/defaults.py b/awx/settings/defaults.py index 36f39ac3ec..d56c16fbef 100644 --- a/awx/settings/defaults.py +++ b/awx/settings/defaults.py @@ -216,7 +216,7 @@ REST_FRAMEWORK = { 'awx.api.filters.OrderByBackend', ), 'DEFAULT_PARSER_CLASSES': ( - 'rest_framework.parsers.JSONParser', + 'awx.api.parsers.JSONParser', ), 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer',