diff --git a/awx/sso/fields.py b/awx/sso/fields.py index 8220feed5b..09b87fff02 100644 --- a/awx/sso/fields.py +++ b/awx/sso/fields.py @@ -1,4 +1,5 @@ import collections +import copy import inspect import json import re @@ -8,8 +9,8 @@ import ldap import awx # Django +from django.utils import six from django.utils.translation import ugettext_lazy as _ -from django.core.exceptions import ValidationError # Django Auth LDAP import django_auth_ldap.config @@ -18,7 +19,8 @@ from django_auth_ldap.config import ( LDAPSearchUnion, ) -from rest_framework.fields import empty +from rest_framework.exceptions import ValidationError +from rest_framework.fields import empty, Field, SkipField # This must be imported so get_subclasses picks it up from awx.sso.ldap_group_types import PosixUIDGroupType # noqa @@ -74,6 +76,71 @@ class DependsOnMixin(): return res +class _Forbidden(Field): + default_error_messages = { + 'invalid': _('Invalid field.'), + } + + def run_validation(self, value): + self.fail('invalid') + + +class HybridDictField(fields.DictField): + """A DictField, but with defined fixed Fields for certain keys. + """ + + def __init__(self, *args, **kwargs): + self.allow_blank = kwargs.pop('allow_blank', False) + + fields = [ + (field_name, obj) + for field_name, obj in self.__class__.__dict__.items() + if isinstance(obj, Field) and field_name != 'child' + ] + fields.sort(key=lambda x: x[1]._creation_counter) + self._declared_fields = collections.OrderedDict(fields) + + super().__init__(*args, **kwargs) + + def to_representation(self, value): + fields = copy.deepcopy(self._declared_fields) + return { + key: field.to_representation(val) if val is not None else None + for key, val, field in ( + (six.text_type(key), val, fields.get(key, self.child)) + for key, val in value.items() + ) + if not field.write_only + } + + def run_child_validation(self, data): + result = {} + + if not data and self.allow_blank: + return result + + errors = collections.OrderedDict() + fields = copy.deepcopy(self._declared_fields) + keys = set(fields.keys()) | set(data.keys()) + + for key in keys: + value = data.get(key, empty) + key = six.text_type(key) + field = fields.get(key, self.child) + try: + if field.read_only: + continue # Ignore read_only fields, as Serializer seems to do. + result[key] = field.run_validation(value) + except ValidationError as e: + errors[key] = e.detail + except SkipField: + pass + + if not errors: + return result + raise ValidationError(errors) + + class AuthenticationBackendsField(fields.StringListField): # Mapping of settings that must be set in order to enable each @@ -459,70 +526,14 @@ class LDAPDNMapField(fields.StringListBooleanField): child = LDAPDNField() -class BaseDictWithChildField(fields.DictField): +class LDAPSingleOrganizationMapField(HybridDictField): - default_error_messages = { - 'missing_keys': _('Missing key(s): {missing_keys}.'), - 'invalid_keys': _('Invalid key(s): {invalid_keys}.'), - } - child_fields = { - # 'key': fields.ChildField(), - } - allow_unknown_keys = False + admins = LDAPDNMapField(allow_null=True, required=False) + users = LDAPDNMapField(allow_null=True, required=False) + remove_admins = fields.BooleanField(required=False) + remove_users = fields.BooleanField(required=False) - def __init__(self, *args, **kwargs): - self.allow_blank = kwargs.pop('allow_blank', False) - super(BaseDictWithChildField, self).__init__(*args, **kwargs) - - def to_representation(self, value): - value = super(BaseDictWithChildField, self).to_representation(value) - for k, v in value.items(): - child_field = self.child_fields.get(k, None) - if child_field: - value[k] = child_field.to_representation(v) - elif self.allow_unknown_keys: - value[k] = v - return value - - def to_internal_value(self, data): - data = super(BaseDictWithChildField, self).to_internal_value(data) - missing_keys = set() - for key, child_field in self.child_fields.items(): - if not child_field.required: - continue - elif key not in data: - missing_keys.add(key) - missing_keys = sorted(list(missing_keys)) - if missing_keys and (data or not self.allow_blank): - missing_keys = sorted(list(missing_keys)) - keys_display = json.dumps(missing_keys).lstrip('[').rstrip(']') - self.fail('missing_keys', missing_keys=keys_display) - if not self.allow_unknown_keys: - invalid_keys = set(data.keys()) - set(self.child_fields.keys()) - if invalid_keys: - invalid_keys = sorted(list(invalid_keys)) - keys_display = json.dumps(invalid_keys).lstrip('[').rstrip(']') - self.fail('invalid_keys', invalid_keys=keys_display) - for k, v in data.items(): - child_field = self.child_fields.get(k, None) - if child_field: - data[k] = child_field.run_validation(v) - elif self.allow_unknown_keys: - data[k] = v - return data - - -class LDAPSingleOrganizationMapField(BaseDictWithChildField): - - default_error_messages = { - 'invalid_keys': _('Invalid key(s) for organization map: {invalid_keys}.'), - } - child_fields = { - 'admins': LDAPDNMapField(allow_null=True, required=False), - 'users': LDAPDNMapField(allow_null=True, required=False), - 'remove_admins': fields.BooleanField(required=False), - 'remove_users': fields.BooleanField(required=False), - } + child = _Forbidden() class LDAPOrganizationMapField(fields.DictField): @@ -530,17 +541,13 @@ class LDAPOrganizationMapField(fields.DictField): child = LDAPSingleOrganizationMapField() -class LDAPSingleTeamMapField(BaseDictWithChildField): +class LDAPSingleTeamMapField(HybridDictField): - default_error_messages = { - 'missing_keys': _('Missing required key for team map: {invalid_keys}.'), - 'invalid_keys': _('Invalid key(s) for team map: {invalid_keys}.'), - } - child_fields = { - 'organization': fields.CharField(), - 'users': LDAPDNMapField(allow_null=True, required=False), - 'remove': fields.BooleanField(required=False), - } + organization = fields.CharField() + users = LDAPDNMapField(allow_null=True, required=False) + remove = fields.BooleanField(required=False) + + child = _Forbidden() class LDAPTeamMapField(fields.DictField): @@ -614,17 +621,14 @@ class SocialMapField(fields.ListField): self.fail('type_error', input_type=type(data)) -class SocialSingleOrganizationMapField(BaseDictWithChildField): +class SocialSingleOrganizationMapField(HybridDictField): - default_error_messages = { - 'invalid_keys': _('Invalid key(s) for organization map: {invalid_keys}.'), - } - child_fields = { - 'admins': SocialMapField(allow_null=True, required=False), - 'users': SocialMapField(allow_null=True, required=False), - 'remove_admins': fields.BooleanField(required=False), - 'remove_users': fields.BooleanField(required=False), - } + admins = SocialMapField(allow_null=True, required=False) + users = SocialMapField(allow_null=True, required=False) + remove_admins = fields.BooleanField(required=False) + remove_users = fields.BooleanField(required=False) + + child = _Forbidden() class SocialOrganizationMapField(fields.DictField): @@ -632,17 +636,13 @@ class SocialOrganizationMapField(fields.DictField): child = SocialSingleOrganizationMapField() -class SocialSingleTeamMapField(BaseDictWithChildField): +class SocialSingleTeamMapField(HybridDictField): - default_error_messages = { - 'missing_keys': _('Missing required key for team map: {missing_keys}.'), - 'invalid_keys': _('Invalid key(s) for team map: {invalid_keys}.'), - } - child_fields = { - 'organization': fields.CharField(), - 'users': SocialMapField(allow_null=True, required=False), - 'remove': fields.BooleanField(required=False), - } + organization = fields.CharField() + users = SocialMapField(allow_null=True, required=False) + remove = fields.BooleanField(required=False) + + child = _Forbidden() class SocialTeamMapField(fields.DictField): @@ -650,17 +650,11 @@ class SocialTeamMapField(fields.DictField): child = SocialSingleTeamMapField() -class SAMLOrgInfoValueField(BaseDictWithChildField): +class SAMLOrgInfoValueField(HybridDictField): - default_error_messages = { - 'missing_keys': _('Missing required key(s) for org info record: {missing_keys}.'), - } - child_fields = { - 'name': fields.CharField(), - 'displayname': fields.CharField(), - 'url': fields.URLField(), - } - allow_unknown_keys = True + name = fields.CharField() + displayname = fields.CharField() + url = fields.URLField() class SAMLOrgInfoField(fields.DictField): @@ -683,34 +677,22 @@ class SAMLOrgInfoField(fields.DictField): return data -class SAMLContactField(BaseDictWithChildField): +class SAMLContactField(HybridDictField): - default_error_messages = { - 'missing_keys': _('Missing required key(s) for contact: {missing_keys}.'), - } - child_fields = { - 'givenName': fields.CharField(), - 'emailAddress': fields.EmailField(), - } - allow_unknown_keys = True + givenName = fields.CharField() + emailAddress = fields.EmailField() -class SAMLIdPField(BaseDictWithChildField): +class SAMLIdPField(HybridDictField): - default_error_messages = { - 'missing_keys': _('Missing required key(s) for IdP: {missing_keys}.'), - } - child_fields = { - 'entity_id': fields.CharField(), - 'url': fields.URLField(), - 'x509cert': fields.CharField(validators=[validate_certificate]), - 'attr_user_permanent_id': fields.CharField(required=False), - 'attr_first_name': fields.CharField(required=False), - 'attr_last_name': fields.CharField(required=False), - 'attr_username': fields.CharField(required=False), - 'attr_email': fields.CharField(required=False), - } - allow_unknown_keys = True + entity_id = fields.CharField() + url = fields.URLField() + x509cert = fields.CharField(validators=[validate_certificate]) + attr_user_permanent_id = fields.CharField(required=False) + attr_first_name = fields.CharField(required=False) + attr_last_name = fields.CharField(required=False) + attr_username = fields.CharField(required=False) + attr_email = fields.CharField(required=False) class SAMLEnabledIdPsField(fields.DictField): @@ -718,52 +700,49 @@ class SAMLEnabledIdPsField(fields.DictField): child = SAMLIdPField() -class SAMLSecurityField(BaseDictWithChildField): +class SAMLSecurityField(HybridDictField): - child_fields = { - 'nameIdEncrypted': fields.BooleanField(required=False), - 'authnRequestsSigned': fields.BooleanField(required=False), - 'logoutRequestSigned': fields.BooleanField(required=False), - 'logoutResponseSigned': fields.BooleanField(required=False), - 'signMetadata': fields.BooleanField(required=False), - 'wantMessagesSigned': fields.BooleanField(required=False), - 'wantAssertionsSigned': fields.BooleanField(required=False), - 'wantAssertionsEncrypted': fields.BooleanField(required=False), - 'wantNameId': fields.BooleanField(required=False), - 'wantNameIdEncrypted': fields.BooleanField(required=False), - 'wantAttributeStatement': fields.BooleanField(required=False), - 'requestedAuthnContext': fields.StringListBooleanField(required=False), - 'requestedAuthnContextComparison': fields.CharField(required=False), - 'metadataValidUntil': fields.CharField(allow_null=True, required=False), - 'metadataCacheDuration': fields.CharField(allow_null=True, required=False), - 'signatureAlgorithm': fields.CharField(allow_null=True, required=False), - 'digestAlgorithm': fields.CharField(allow_null=True, required=False), - } - allow_unknown_keys = True + nameIdEncrypted = fields.BooleanField(required=False) + authnRequestsSigned = fields.BooleanField(required=False) + logoutRequestSigned = fields.BooleanField(required=False) + logoutResponseSigned = fields.BooleanField(required=False) + signMetadata = fields.BooleanField(required=False) + wantMessagesSigned = fields.BooleanField(required=False) + wantAssertionsSigned = fields.BooleanField(required=False) + wantAssertionsEncrypted = fields.BooleanField(required=False) + wantNameId = fields.BooleanField(required=False) + wantNameIdEncrypted = fields.BooleanField(required=False) + wantAttributeStatement = fields.BooleanField(required=False) + requestedAuthnContext = fields.StringListBooleanField(required=False) + requestedAuthnContextComparison = fields.CharField(required=False) + metadataValidUntil = fields.CharField(allow_null=True, required=False) + metadataCacheDuration = fields.CharField(allow_null=True, required=False) + signatureAlgorithm = fields.CharField(allow_null=True, required=False) + digestAlgorithm = fields.CharField(allow_null=True, required=False) -class SAMLOrgAttrField(BaseDictWithChildField): +class SAMLOrgAttrField(HybridDictField): - child_fields = { - 'remove': fields.BooleanField(required=False), - 'saml_attr': fields.CharField(required=False, allow_null=True), - 'remove_admins': fields.BooleanField(required=False), - 'saml_admin_attr': fields.CharField(required=False, allow_null=True), - } + remove = fields.BooleanField(required=False) + saml_attr = fields.CharField(required=False, allow_null=True) + remove_admins = fields.BooleanField(required=False) + saml_admin_attr = fields.CharField(required=False, allow_null=True) + + child = _Forbidden() -class SAMLTeamAttrTeamOrgMapField(BaseDictWithChildField): +class SAMLTeamAttrTeamOrgMapField(HybridDictField): - child_fields = { - 'team': fields.CharField(required=True, allow_null=False), - 'organization': fields.CharField(required=True, allow_null=False), - } + team = fields.CharField(required=True, allow_null=False) + organization = fields.CharField(required=True, allow_null=False) + + child = _Forbidden() -class SAMLTeamAttrField(BaseDictWithChildField): +class SAMLTeamAttrField(HybridDictField): - child_fields = { - 'team_org_map': fields.ListField(required=False, child=SAMLTeamAttrTeamOrgMapField(), allow_null=True), - 'remove': fields.BooleanField(required=False), - 'saml_attr': fields.CharField(required=False, allow_null=True), - } + team_org_map = fields.ListField(required=False, child=SAMLTeamAttrTeamOrgMapField(), allow_null=True) + remove = fields.BooleanField(required=False) + saml_attr = fields.CharField(required=False, allow_null=True) + + child = _Forbidden() diff --git a/awx/sso/tests/unit/test_fields.py b/awx/sso/tests/unit/test_fields.py index 18afab5d31..95cd774b98 100644 --- a/awx/sso/tests/unit/test_fields.py +++ b/awx/sso/tests/unit/test_fields.py @@ -33,21 +33,23 @@ class TestSAMLOrgAttrField(): @pytest.mark.parametrize("data, expected", [ ({'remove': 'blah', 'saml_attr': 'foobar'}, - ValidationError('Must be a valid boolean.')), + {'remove': ['Must be a valid boolean.']}), ({'remove': True, 'saml_attr': False}, - ValidationError('Not a valid string.')), + {'saml_attr': ['Not a valid string.']}), ({'remove': True, 'saml_attr': False, 'foo': 'bar', 'gig': 'ity'}, - ValidationError('Invalid key(s): "foo", "gig".')), + {'saml_attr': ['Not a valid string.'], + 'foo': ['Invalid field.'], + 'gig': ['Invalid field.']}), ({'remove_admins': True, 'saml_admin_attr': False}, - ValidationError('Not a valid string.')), + {'saml_admin_attr': ['Not a valid string.']}), ({'remove_admins': 'blah', 'saml_admin_attr': 'foobar'}, - ValidationError('Must be a valid boolean.')), + {'remove_admins': ['Must be a valid boolean.']}), ]) def test_internal_value_invalid(self, data, expected): field = SAMLOrgAttrField() - with pytest.raises(type(expected)) as e: + with pytest.raises(ValidationError) as e: field.to_internal_value(data) - assert str(e.value) == str(expected) + assert e.value.detail == expected class TestSAMLTeamAttrField(): @@ -77,36 +79,38 @@ class TestSAMLTeamAttrField(): @pytest.mark.parametrize("data, expected", [ ({'remove': True, 'saml_attr': 'foobar', 'team_org_map': [ {'team': 'foobar', 'not_a_valid_key': 'blah', 'organization': 'Ansible'}, - ]}, ValidationError('Invalid key(s): "not_a_valid_key".')), + ]}, {'team_org_map': {0: {'not_a_valid_key': ['Invalid field.']}}}), ({'remove': False, 'saml_attr': 'foobar', 'team_org_map': [ {'organization': 'Ansible'}, - ]}, ValidationError('Missing key(s): "team".')), + ]}, {'team_org_map': {0: {'team': ['This field is required.']}}}), ({'remove': False, 'saml_attr': 'foobar', 'team_org_map': [ {}, - ]}, ValidationError('Missing key(s): "organization", "team".')), + ]}, {'team_org_map': { + 0: {'organization': ['This field is required.'], + 'team': ['This field is required.']}}}), ]) def test_internal_value_invalid(self, data, expected): field = SAMLTeamAttrField() - with pytest.raises(type(expected)) as e: + with pytest.raises(ValidationError) as e: field.to_internal_value(data) - assert str(e.value) == str(expected) + assert e.value.detail == expected class TestLDAPGroupTypeParamsField(): @pytest.mark.parametrize("group_type, data, expected", [ ('LDAPGroupType', {'name_attr': 'user', 'bob': ['a', 'b'], 'scooter': 'hello'}, - ValidationError('Invalid key(s): "bob", "scooter".')), + ['Invalid key(s): "bob", "scooter".']), ('MemberDNGroupType', {'name_attr': 'user', 'member_attr': 'west', 'bob': ['a', 'b'], 'scooter': 'hello'}, - ValidationError('Invalid key(s): "bob", "scooter".')), + ['Invalid key(s): "bob", "scooter".']), ('PosixUIDGroupType', {'name_attr': 'user', 'member_attr': 'west', 'ldap_group_user_attr': 'legacyThing', 'bob': ['a', 'b'], 'scooter': 'hello'}, - ValidationError('Invalid key(s): "bob", "member_attr", "scooter".')), + ['Invalid key(s): "bob", "member_attr", "scooter".']), ]) def test_internal_value_invalid(self, group_type, data, expected): field = LDAPGroupTypeParamsField() field.get_depends_on = mock.MagicMock(return_value=group_type) - with pytest.raises(type(expected)) as e: + with pytest.raises(ValidationError) as e: field.to_internal_value(data) - assert str(e.value) == str(expected) + assert e.value.detail == expected