diff --git a/awx/api/generics.py b/awx/api/generics.py index 6618263742..da9ea6dfa2 100644 --- a/awx/api/generics.py +++ b/awx/api/generics.py @@ -348,7 +348,7 @@ class SubListCreateAPIView(SubListAPIView, ListCreateAPIView): # object deserialized obj = serializer.save() serializer = self.get_serializer(instance=obj) - + headers = {'Location': obj.get_absolute_url()} return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) @@ -359,7 +359,7 @@ class SubListCreateAttachDetachAPIView(SubListCreateAPIView): def attach(self, request, *args, **kwargs): created = False parent = self.get_parent_object() - relationship = getattr(parent, self.relationship) + relationship = getattrd(parent, self.relationship) sub_id = request.data.get('id', None) data = request.data @@ -378,7 +378,7 @@ class SubListCreateAttachDetachAPIView(SubListCreateAPIView): # Retrive the sub object (whether created or by ID). sub = get_object_or_400(self.model, pk=sub_id) - + # Verify we have permission to attach. if not request.user.can_access(self.parent_model, 'attach', parent, sub, self.relationship, data, @@ -405,7 +405,7 @@ class SubListCreateAttachDetachAPIView(SubListCreateAPIView): parent = self.get_parent_object() parent_key = getattr(self, 'parent_key', None) - relationship = getattr(parent, self.relationship) + relationship = getattrd(parent, self.relationship) sub = get_object_or_400(self.model, pk=sub_id) if not request.user.can_access(self.parent_model, 'unattach', parent, diff --git a/awx/main/utils.py b/awx/main/utils.py index 5bd00c2da6..00bfc74608 100644 --- a/awx/main/utils.py +++ b/awx/main/utils.py @@ -30,7 +30,7 @@ __all__ = ['get_object_or_400', 'get_object_or_403', 'camelcase_to_underscore', 'get_ansible_version', 'get_ssh_version', 'get_awx_version', 'update_scm_url', 'get_type_for_model', 'get_model_for_type', 'to_python_boolean', 'ignore_inventory_computed_fields', 'ignore_inventory_group_removal', - '_inventory_updates', 'get_pk_from_dict'] + '_inventory_updates', 'get_pk_from_dict', 'getattrd', 'NoDefaultProvided'] def get_object_or_400(klass, *args, **kwargs): @@ -521,3 +521,21 @@ def timedelta_total_seconds(timedelta): timedelta.microseconds + 0.0 + (timedelta.seconds + timedelta.days * 24 * 3600) * 10 ** 6) / 10 ** 6 + +class NoDefaultProvided(object): + pass + +def getattrd(obj, name, default=NoDefaultProvided): + """ + Same as getattr(), but allows dot notation lookup + Discussed in: + http://stackoverflow.com/questions/11975781 + """ + + try: + return reduce(getattr, name.split("."), obj) + except AttributeError: + if default != NoDefaultProvided: + return default + raise +