Refactor the SSO serializer fields to follow the DRF idioms more closely

and fix the tests to handle the newer nested validation checks properly.
This commit is contained in:
Jeff Bradberry
2019-05-30 18:17:37 -04:00
parent 76d4de24df
commit 2a81643308
2 changed files with 167 additions and 184 deletions

View File

@@ -1,4 +1,5 @@
import collections import collections
import copy
import inspect import inspect
import json import json
import re import re
@@ -8,8 +9,8 @@ import ldap
import awx import awx
# Django # Django
from django.utils import six
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.core.exceptions import ValidationError
# Django Auth LDAP # Django Auth LDAP
import django_auth_ldap.config import django_auth_ldap.config
@@ -18,7 +19,8 @@ from django_auth_ldap.config import (
LDAPSearchUnion, LDAPSearchUnion,
) )
from rest_framework.fields import empty from rest_framework.exceptions import ValidationError
from rest_framework.fields import empty, Field, SkipField
# This must be imported so get_subclasses picks it up # This must be imported so get_subclasses picks it up
from awx.sso.ldap_group_types import PosixUIDGroupType # noqa from awx.sso.ldap_group_types import PosixUIDGroupType # noqa
@@ -74,6 +76,71 @@ class DependsOnMixin():
return res return res
class _Forbidden(Field):
default_error_messages = {
'invalid': _('Invalid field.'),
}
def run_validation(self, value):
self.fail('invalid')
class HybridDictField(fields.DictField):
"""A DictField, but with defined fixed Fields for certain keys.
"""
def __init__(self, *args, **kwargs):
self.allow_blank = kwargs.pop('allow_blank', False)
fields = [
(field_name, obj)
for field_name, obj in self.__class__.__dict__.items()
if isinstance(obj, Field) and field_name != 'child'
]
fields.sort(key=lambda x: x[1]._creation_counter)
self._declared_fields = collections.OrderedDict(fields)
super().__init__(*args, **kwargs)
def to_representation(self, value):
fields = copy.deepcopy(self._declared_fields)
return {
key: field.to_representation(val) if val is not None else None
for key, val, field in (
(six.text_type(key), val, fields.get(key, self.child))
for key, val in value.items()
)
if not field.write_only
}
def run_child_validation(self, data):
result = {}
if not data and self.allow_blank:
return result
errors = collections.OrderedDict()
fields = copy.deepcopy(self._declared_fields)
keys = set(fields.keys()) | set(data.keys())
for key in keys:
value = data.get(key, empty)
key = six.text_type(key)
field = fields.get(key, self.child)
try:
if field.read_only:
continue # Ignore read_only fields, as Serializer seems to do.
result[key] = field.run_validation(value)
except ValidationError as e:
errors[key] = e.detail
except SkipField:
pass
if not errors:
return result
raise ValidationError(errors)
class AuthenticationBackendsField(fields.StringListField): class AuthenticationBackendsField(fields.StringListField):
# Mapping of settings that must be set in order to enable each # Mapping of settings that must be set in order to enable each
@@ -459,70 +526,14 @@ class LDAPDNMapField(fields.StringListBooleanField):
child = LDAPDNField() child = LDAPDNField()
class BaseDictWithChildField(fields.DictField): class LDAPSingleOrganizationMapField(HybridDictField):
default_error_messages = { admins = LDAPDNMapField(allow_null=True, required=False)
'missing_keys': _('Missing key(s): {missing_keys}.'), users = LDAPDNMapField(allow_null=True, required=False)
'invalid_keys': _('Invalid key(s): {invalid_keys}.'), remove_admins = fields.BooleanField(required=False)
} remove_users = fields.BooleanField(required=False)
child_fields = {
# 'key': fields.ChildField(),
}
allow_unknown_keys = False
def __init__(self, *args, **kwargs): child = _Forbidden()
self.allow_blank = kwargs.pop('allow_blank', False)
super(BaseDictWithChildField, self).__init__(*args, **kwargs)
def to_representation(self, value):
value = super(BaseDictWithChildField, self).to_representation(value)
for k, v in value.items():
child_field = self.child_fields.get(k, None)
if child_field:
value[k] = child_field.to_representation(v)
elif self.allow_unknown_keys:
value[k] = v
return value
def to_internal_value(self, data):
data = super(BaseDictWithChildField, self).to_internal_value(data)
missing_keys = set()
for key, child_field in self.child_fields.items():
if not child_field.required:
continue
elif key not in data:
missing_keys.add(key)
missing_keys = sorted(list(missing_keys))
if missing_keys and (data or not self.allow_blank):
missing_keys = sorted(list(missing_keys))
keys_display = json.dumps(missing_keys).lstrip('[').rstrip(']')
self.fail('missing_keys', missing_keys=keys_display)
if not self.allow_unknown_keys:
invalid_keys = set(data.keys()) - set(self.child_fields.keys())
if invalid_keys:
invalid_keys = sorted(list(invalid_keys))
keys_display = json.dumps(invalid_keys).lstrip('[').rstrip(']')
self.fail('invalid_keys', invalid_keys=keys_display)
for k, v in data.items():
child_field = self.child_fields.get(k, None)
if child_field:
data[k] = child_field.run_validation(v)
elif self.allow_unknown_keys:
data[k] = v
return data
class LDAPSingleOrganizationMapField(BaseDictWithChildField):
default_error_messages = {
'invalid_keys': _('Invalid key(s) for organization map: {invalid_keys}.'),
}
child_fields = {
'admins': LDAPDNMapField(allow_null=True, required=False),
'users': LDAPDNMapField(allow_null=True, required=False),
'remove_admins': fields.BooleanField(required=False),
'remove_users': fields.BooleanField(required=False),
}
class LDAPOrganizationMapField(fields.DictField): class LDAPOrganizationMapField(fields.DictField):
@@ -530,17 +541,13 @@ class LDAPOrganizationMapField(fields.DictField):
child = LDAPSingleOrganizationMapField() child = LDAPSingleOrganizationMapField()
class LDAPSingleTeamMapField(BaseDictWithChildField): class LDAPSingleTeamMapField(HybridDictField):
default_error_messages = { organization = fields.CharField()
'missing_keys': _('Missing required key for team map: {invalid_keys}.'), users = LDAPDNMapField(allow_null=True, required=False)
'invalid_keys': _('Invalid key(s) for team map: {invalid_keys}.'), remove = fields.BooleanField(required=False)
}
child_fields = { child = _Forbidden()
'organization': fields.CharField(),
'users': LDAPDNMapField(allow_null=True, required=False),
'remove': fields.BooleanField(required=False),
}
class LDAPTeamMapField(fields.DictField): class LDAPTeamMapField(fields.DictField):
@@ -614,17 +621,14 @@ class SocialMapField(fields.ListField):
self.fail('type_error', input_type=type(data)) self.fail('type_error', input_type=type(data))
class SocialSingleOrganizationMapField(BaseDictWithChildField): class SocialSingleOrganizationMapField(HybridDictField):
default_error_messages = { admins = SocialMapField(allow_null=True, required=False)
'invalid_keys': _('Invalid key(s) for organization map: {invalid_keys}.'), users = SocialMapField(allow_null=True, required=False)
} remove_admins = fields.BooleanField(required=False)
child_fields = { remove_users = fields.BooleanField(required=False)
'admins': SocialMapField(allow_null=True, required=False),
'users': SocialMapField(allow_null=True, required=False), child = _Forbidden()
'remove_admins': fields.BooleanField(required=False),
'remove_users': fields.BooleanField(required=False),
}
class SocialOrganizationMapField(fields.DictField): class SocialOrganizationMapField(fields.DictField):
@@ -632,17 +636,13 @@ class SocialOrganizationMapField(fields.DictField):
child = SocialSingleOrganizationMapField() child = SocialSingleOrganizationMapField()
class SocialSingleTeamMapField(BaseDictWithChildField): class SocialSingleTeamMapField(HybridDictField):
default_error_messages = { organization = fields.CharField()
'missing_keys': _('Missing required key for team map: {missing_keys}.'), users = SocialMapField(allow_null=True, required=False)
'invalid_keys': _('Invalid key(s) for team map: {invalid_keys}.'), remove = fields.BooleanField(required=False)
}
child_fields = { child = _Forbidden()
'organization': fields.CharField(),
'users': SocialMapField(allow_null=True, required=False),
'remove': fields.BooleanField(required=False),
}
class SocialTeamMapField(fields.DictField): class SocialTeamMapField(fields.DictField):
@@ -650,17 +650,11 @@ class SocialTeamMapField(fields.DictField):
child = SocialSingleTeamMapField() child = SocialSingleTeamMapField()
class SAMLOrgInfoValueField(BaseDictWithChildField): class SAMLOrgInfoValueField(HybridDictField):
default_error_messages = { name = fields.CharField()
'missing_keys': _('Missing required key(s) for org info record: {missing_keys}.'), displayname = fields.CharField()
} url = fields.URLField()
child_fields = {
'name': fields.CharField(),
'displayname': fields.CharField(),
'url': fields.URLField(),
}
allow_unknown_keys = True
class SAMLOrgInfoField(fields.DictField): class SAMLOrgInfoField(fields.DictField):
@@ -683,34 +677,22 @@ class SAMLOrgInfoField(fields.DictField):
return data return data
class SAMLContactField(BaseDictWithChildField): class SAMLContactField(HybridDictField):
default_error_messages = { givenName = fields.CharField()
'missing_keys': _('Missing required key(s) for contact: {missing_keys}.'), emailAddress = fields.EmailField()
}
child_fields = {
'givenName': fields.CharField(),
'emailAddress': fields.EmailField(),
}
allow_unknown_keys = True
class SAMLIdPField(BaseDictWithChildField): class SAMLIdPField(HybridDictField):
default_error_messages = { entity_id = fields.CharField()
'missing_keys': _('Missing required key(s) for IdP: {missing_keys}.'), url = fields.URLField()
} x509cert = fields.CharField(validators=[validate_certificate])
child_fields = { attr_user_permanent_id = fields.CharField(required=False)
'entity_id': fields.CharField(), attr_first_name = fields.CharField(required=False)
'url': fields.URLField(), attr_last_name = fields.CharField(required=False)
'x509cert': fields.CharField(validators=[validate_certificate]), attr_username = fields.CharField(required=False)
'attr_user_permanent_id': fields.CharField(required=False), attr_email = fields.CharField(required=False)
'attr_first_name': fields.CharField(required=False),
'attr_last_name': fields.CharField(required=False),
'attr_username': fields.CharField(required=False),
'attr_email': fields.CharField(required=False),
}
allow_unknown_keys = True
class SAMLEnabledIdPsField(fields.DictField): class SAMLEnabledIdPsField(fields.DictField):
@@ -718,52 +700,49 @@ class SAMLEnabledIdPsField(fields.DictField):
child = SAMLIdPField() child = SAMLIdPField()
class SAMLSecurityField(BaseDictWithChildField): class SAMLSecurityField(HybridDictField):
child_fields = { nameIdEncrypted = fields.BooleanField(required=False)
'nameIdEncrypted': fields.BooleanField(required=False), authnRequestsSigned = fields.BooleanField(required=False)
'authnRequestsSigned': fields.BooleanField(required=False), logoutRequestSigned = fields.BooleanField(required=False)
'logoutRequestSigned': fields.BooleanField(required=False), logoutResponseSigned = fields.BooleanField(required=False)
'logoutResponseSigned': fields.BooleanField(required=False), signMetadata = fields.BooleanField(required=False)
'signMetadata': fields.BooleanField(required=False), wantMessagesSigned = fields.BooleanField(required=False)
'wantMessagesSigned': fields.BooleanField(required=False), wantAssertionsSigned = fields.BooleanField(required=False)
'wantAssertionsSigned': fields.BooleanField(required=False), wantAssertionsEncrypted = fields.BooleanField(required=False)
'wantAssertionsEncrypted': fields.BooleanField(required=False), wantNameId = fields.BooleanField(required=False)
'wantNameId': fields.BooleanField(required=False), wantNameIdEncrypted = fields.BooleanField(required=False)
'wantNameIdEncrypted': fields.BooleanField(required=False), wantAttributeStatement = fields.BooleanField(required=False)
'wantAttributeStatement': fields.BooleanField(required=False), requestedAuthnContext = fields.StringListBooleanField(required=False)
'requestedAuthnContext': fields.StringListBooleanField(required=False), requestedAuthnContextComparison = fields.CharField(required=False)
'requestedAuthnContextComparison': fields.CharField(required=False), metadataValidUntil = fields.CharField(allow_null=True, required=False)
'metadataValidUntil': fields.CharField(allow_null=True, required=False), metadataCacheDuration = fields.CharField(allow_null=True, required=False)
'metadataCacheDuration': fields.CharField(allow_null=True, required=False), signatureAlgorithm = fields.CharField(allow_null=True, required=False)
'signatureAlgorithm': fields.CharField(allow_null=True, required=False), digestAlgorithm = fields.CharField(allow_null=True, required=False)
'digestAlgorithm': fields.CharField(allow_null=True, required=False),
}
allow_unknown_keys = True
class SAMLOrgAttrField(BaseDictWithChildField): class SAMLOrgAttrField(HybridDictField):
child_fields = { remove = fields.BooleanField(required=False)
'remove': fields.BooleanField(required=False), saml_attr = fields.CharField(required=False, allow_null=True)
'saml_attr': fields.CharField(required=False, allow_null=True), remove_admins = fields.BooleanField(required=False)
'remove_admins': fields.BooleanField(required=False), saml_admin_attr = fields.CharField(required=False, allow_null=True)
'saml_admin_attr': fields.CharField(required=False, allow_null=True),
} child = _Forbidden()
class SAMLTeamAttrTeamOrgMapField(BaseDictWithChildField): class SAMLTeamAttrTeamOrgMapField(HybridDictField):
child_fields = { team = fields.CharField(required=True, allow_null=False)
'team': fields.CharField(required=True, allow_null=False), organization = fields.CharField(required=True, allow_null=False)
'organization': fields.CharField(required=True, allow_null=False),
} child = _Forbidden()
class SAMLTeamAttrField(BaseDictWithChildField): class SAMLTeamAttrField(HybridDictField):
child_fields = { team_org_map = fields.ListField(required=False, child=SAMLTeamAttrTeamOrgMapField(), allow_null=True)
'team_org_map': fields.ListField(required=False, child=SAMLTeamAttrTeamOrgMapField(), allow_null=True), remove = fields.BooleanField(required=False)
'remove': fields.BooleanField(required=False), saml_attr = fields.CharField(required=False, allow_null=True)
'saml_attr': fields.CharField(required=False, allow_null=True),
} child = _Forbidden()

View File

@@ -33,21 +33,23 @@ class TestSAMLOrgAttrField():
@pytest.mark.parametrize("data, expected", [ @pytest.mark.parametrize("data, expected", [
({'remove': 'blah', 'saml_attr': 'foobar'}, ({'remove': 'blah', 'saml_attr': 'foobar'},
ValidationError('Must be a valid boolean.')), {'remove': ['Must be a valid boolean.']}),
({'remove': True, 'saml_attr': False}, ({'remove': True, 'saml_attr': False},
ValidationError('Not a valid string.')), {'saml_attr': ['Not a valid string.']}),
({'remove': True, 'saml_attr': False, 'foo': 'bar', 'gig': 'ity'}, ({'remove': True, 'saml_attr': False, 'foo': 'bar', 'gig': 'ity'},
ValidationError('Invalid key(s): "foo", "gig".')), {'saml_attr': ['Not a valid string.'],
'foo': ['Invalid field.'],
'gig': ['Invalid field.']}),
({'remove_admins': True, 'saml_admin_attr': False}, ({'remove_admins': True, 'saml_admin_attr': False},
ValidationError('Not a valid string.')), {'saml_admin_attr': ['Not a valid string.']}),
({'remove_admins': 'blah', 'saml_admin_attr': 'foobar'}, ({'remove_admins': 'blah', 'saml_admin_attr': 'foobar'},
ValidationError('Must be a valid boolean.')), {'remove_admins': ['Must be a valid boolean.']}),
]) ])
def test_internal_value_invalid(self, data, expected): def test_internal_value_invalid(self, data, expected):
field = SAMLOrgAttrField() field = SAMLOrgAttrField()
with pytest.raises(type(expected)) as e: with pytest.raises(ValidationError) as e:
field.to_internal_value(data) field.to_internal_value(data)
assert str(e.value) == str(expected) assert e.value.detail == expected
class TestSAMLTeamAttrField(): class TestSAMLTeamAttrField():
@@ -77,36 +79,38 @@ class TestSAMLTeamAttrField():
@pytest.mark.parametrize("data, expected", [ @pytest.mark.parametrize("data, expected", [
({'remove': True, 'saml_attr': 'foobar', 'team_org_map': [ ({'remove': True, 'saml_attr': 'foobar', 'team_org_map': [
{'team': 'foobar', 'not_a_valid_key': 'blah', 'organization': 'Ansible'}, {'team': 'foobar', 'not_a_valid_key': 'blah', 'organization': 'Ansible'},
]}, ValidationError('Invalid key(s): "not_a_valid_key".')), ]}, {'team_org_map': {0: {'not_a_valid_key': ['Invalid field.']}}}),
({'remove': False, 'saml_attr': 'foobar', 'team_org_map': [ ({'remove': False, 'saml_attr': 'foobar', 'team_org_map': [
{'organization': 'Ansible'}, {'organization': 'Ansible'},
]}, ValidationError('Missing key(s): "team".')), ]}, {'team_org_map': {0: {'team': ['This field is required.']}}}),
({'remove': False, 'saml_attr': 'foobar', 'team_org_map': [ ({'remove': False, 'saml_attr': 'foobar', 'team_org_map': [
{}, {},
]}, ValidationError('Missing key(s): "organization", "team".')), ]}, {'team_org_map': {
0: {'organization': ['This field is required.'],
'team': ['This field is required.']}}}),
]) ])
def test_internal_value_invalid(self, data, expected): def test_internal_value_invalid(self, data, expected):
field = SAMLTeamAttrField() field = SAMLTeamAttrField()
with pytest.raises(type(expected)) as e: with pytest.raises(ValidationError) as e:
field.to_internal_value(data) field.to_internal_value(data)
assert str(e.value) == str(expected) assert e.value.detail == expected
class TestLDAPGroupTypeParamsField(): class TestLDAPGroupTypeParamsField():
@pytest.mark.parametrize("group_type, data, expected", [ @pytest.mark.parametrize("group_type, data, expected", [
('LDAPGroupType', {'name_attr': 'user', 'bob': ['a', 'b'], 'scooter': 'hello'}, ('LDAPGroupType', {'name_attr': 'user', 'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "scooter".')), ['Invalid key(s): "bob", "scooter".']),
('MemberDNGroupType', {'name_attr': 'user', 'member_attr': 'west', 'bob': ['a', 'b'], 'scooter': 'hello'}, ('MemberDNGroupType', {'name_attr': 'user', 'member_attr': 'west', 'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "scooter".')), ['Invalid key(s): "bob", "scooter".']),
('PosixUIDGroupType', {'name_attr': 'user', 'member_attr': 'west', 'ldap_group_user_attr': 'legacyThing', ('PosixUIDGroupType', {'name_attr': 'user', 'member_attr': 'west', 'ldap_group_user_attr': 'legacyThing',
'bob': ['a', 'b'], 'scooter': 'hello'}, 'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "member_attr", "scooter".')), ['Invalid key(s): "bob", "member_attr", "scooter".']),
]) ])
def test_internal_value_invalid(self, group_type, data, expected): def test_internal_value_invalid(self, group_type, data, expected):
field = LDAPGroupTypeParamsField() field = LDAPGroupTypeParamsField()
field.get_depends_on = mock.MagicMock(return_value=group_type) field.get_depends_on = mock.MagicMock(return_value=group_type)
with pytest.raises(type(expected)) as e: with pytest.raises(ValidationError) as e:
field.to_internal_value(data) field.to_internal_value(data)
assert str(e.value) == str(expected) assert e.value.detail == expected