Refactor the SSO serializer fields to follow the DRF idioms more closely

and fix the tests to handle the newer nested validation checks properly.
This commit is contained in:
Jeff Bradberry 2019-05-30 18:17:37 -04:00
parent 76d4de24df
commit 2a81643308
2 changed files with 167 additions and 184 deletions

View File

@ -1,4 +1,5 @@
import collections
import copy
import inspect
import json
import re
@ -8,8 +9,8 @@ import ldap
import awx
# Django
from django.utils import six
from django.utils.translation import ugettext_lazy as _
from django.core.exceptions import ValidationError
# Django Auth LDAP
import django_auth_ldap.config
@ -18,7 +19,8 @@ from django_auth_ldap.config import (
LDAPSearchUnion,
)
from rest_framework.fields import empty
from rest_framework.exceptions import ValidationError
from rest_framework.fields import empty, Field, SkipField
# This must be imported so get_subclasses picks it up
from awx.sso.ldap_group_types import PosixUIDGroupType # noqa
@ -74,6 +76,71 @@ class DependsOnMixin():
return res
class _Forbidden(Field):
default_error_messages = {
'invalid': _('Invalid field.'),
}
def run_validation(self, value):
self.fail('invalid')
class HybridDictField(fields.DictField):
"""A DictField, but with defined fixed Fields for certain keys.
"""
def __init__(self, *args, **kwargs):
self.allow_blank = kwargs.pop('allow_blank', False)
fields = [
(field_name, obj)
for field_name, obj in self.__class__.__dict__.items()
if isinstance(obj, Field) and field_name != 'child'
]
fields.sort(key=lambda x: x[1]._creation_counter)
self._declared_fields = collections.OrderedDict(fields)
super().__init__(*args, **kwargs)
def to_representation(self, value):
fields = copy.deepcopy(self._declared_fields)
return {
key: field.to_representation(val) if val is not None else None
for key, val, field in (
(six.text_type(key), val, fields.get(key, self.child))
for key, val in value.items()
)
if not field.write_only
}
def run_child_validation(self, data):
result = {}
if not data and self.allow_blank:
return result
errors = collections.OrderedDict()
fields = copy.deepcopy(self._declared_fields)
keys = set(fields.keys()) | set(data.keys())
for key in keys:
value = data.get(key, empty)
key = six.text_type(key)
field = fields.get(key, self.child)
try:
if field.read_only:
continue # Ignore read_only fields, as Serializer seems to do.
result[key] = field.run_validation(value)
except ValidationError as e:
errors[key] = e.detail
except SkipField:
pass
if not errors:
return result
raise ValidationError(errors)
class AuthenticationBackendsField(fields.StringListField):
# Mapping of settings that must be set in order to enable each
@ -459,70 +526,14 @@ class LDAPDNMapField(fields.StringListBooleanField):
child = LDAPDNField()
class BaseDictWithChildField(fields.DictField):
class LDAPSingleOrganizationMapField(HybridDictField):
default_error_messages = {
'missing_keys': _('Missing key(s): {missing_keys}.'),
'invalid_keys': _('Invalid key(s): {invalid_keys}.'),
}
child_fields = {
# 'key': fields.ChildField(),
}
allow_unknown_keys = False
admins = LDAPDNMapField(allow_null=True, required=False)
users = LDAPDNMapField(allow_null=True, required=False)
remove_admins = fields.BooleanField(required=False)
remove_users = fields.BooleanField(required=False)
def __init__(self, *args, **kwargs):
self.allow_blank = kwargs.pop('allow_blank', False)
super(BaseDictWithChildField, self).__init__(*args, **kwargs)
def to_representation(self, value):
value = super(BaseDictWithChildField, self).to_representation(value)
for k, v in value.items():
child_field = self.child_fields.get(k, None)
if child_field:
value[k] = child_field.to_representation(v)
elif self.allow_unknown_keys:
value[k] = v
return value
def to_internal_value(self, data):
data = super(BaseDictWithChildField, self).to_internal_value(data)
missing_keys = set()
for key, child_field in self.child_fields.items():
if not child_field.required:
continue
elif key not in data:
missing_keys.add(key)
missing_keys = sorted(list(missing_keys))
if missing_keys and (data or not self.allow_blank):
missing_keys = sorted(list(missing_keys))
keys_display = json.dumps(missing_keys).lstrip('[').rstrip(']')
self.fail('missing_keys', missing_keys=keys_display)
if not self.allow_unknown_keys:
invalid_keys = set(data.keys()) - set(self.child_fields.keys())
if invalid_keys:
invalid_keys = sorted(list(invalid_keys))
keys_display = json.dumps(invalid_keys).lstrip('[').rstrip(']')
self.fail('invalid_keys', invalid_keys=keys_display)
for k, v in data.items():
child_field = self.child_fields.get(k, None)
if child_field:
data[k] = child_field.run_validation(v)
elif self.allow_unknown_keys:
data[k] = v
return data
class LDAPSingleOrganizationMapField(BaseDictWithChildField):
default_error_messages = {
'invalid_keys': _('Invalid key(s) for organization map: {invalid_keys}.'),
}
child_fields = {
'admins': LDAPDNMapField(allow_null=True, required=False),
'users': LDAPDNMapField(allow_null=True, required=False),
'remove_admins': fields.BooleanField(required=False),
'remove_users': fields.BooleanField(required=False),
}
child = _Forbidden()
class LDAPOrganizationMapField(fields.DictField):
@ -530,17 +541,13 @@ class LDAPOrganizationMapField(fields.DictField):
child = LDAPSingleOrganizationMapField()
class LDAPSingleTeamMapField(BaseDictWithChildField):
class LDAPSingleTeamMapField(HybridDictField):
default_error_messages = {
'missing_keys': _('Missing required key for team map: {invalid_keys}.'),
'invalid_keys': _('Invalid key(s) for team map: {invalid_keys}.'),
}
child_fields = {
'organization': fields.CharField(),
'users': LDAPDNMapField(allow_null=True, required=False),
'remove': fields.BooleanField(required=False),
}
organization = fields.CharField()
users = LDAPDNMapField(allow_null=True, required=False)
remove = fields.BooleanField(required=False)
child = _Forbidden()
class LDAPTeamMapField(fields.DictField):
@ -614,17 +621,14 @@ class SocialMapField(fields.ListField):
self.fail('type_error', input_type=type(data))
class SocialSingleOrganizationMapField(BaseDictWithChildField):
class SocialSingleOrganizationMapField(HybridDictField):
default_error_messages = {
'invalid_keys': _('Invalid key(s) for organization map: {invalid_keys}.'),
}
child_fields = {
'admins': SocialMapField(allow_null=True, required=False),
'users': SocialMapField(allow_null=True, required=False),
'remove_admins': fields.BooleanField(required=False),
'remove_users': fields.BooleanField(required=False),
}
admins = SocialMapField(allow_null=True, required=False)
users = SocialMapField(allow_null=True, required=False)
remove_admins = fields.BooleanField(required=False)
remove_users = fields.BooleanField(required=False)
child = _Forbidden()
class SocialOrganizationMapField(fields.DictField):
@ -632,17 +636,13 @@ class SocialOrganizationMapField(fields.DictField):
child = SocialSingleOrganizationMapField()
class SocialSingleTeamMapField(BaseDictWithChildField):
class SocialSingleTeamMapField(HybridDictField):
default_error_messages = {
'missing_keys': _('Missing required key for team map: {missing_keys}.'),
'invalid_keys': _('Invalid key(s) for team map: {invalid_keys}.'),
}
child_fields = {
'organization': fields.CharField(),
'users': SocialMapField(allow_null=True, required=False),
'remove': fields.BooleanField(required=False),
}
organization = fields.CharField()
users = SocialMapField(allow_null=True, required=False)
remove = fields.BooleanField(required=False)
child = _Forbidden()
class SocialTeamMapField(fields.DictField):
@ -650,17 +650,11 @@ class SocialTeamMapField(fields.DictField):
child = SocialSingleTeamMapField()
class SAMLOrgInfoValueField(BaseDictWithChildField):
class SAMLOrgInfoValueField(HybridDictField):
default_error_messages = {
'missing_keys': _('Missing required key(s) for org info record: {missing_keys}.'),
}
child_fields = {
'name': fields.CharField(),
'displayname': fields.CharField(),
'url': fields.URLField(),
}
allow_unknown_keys = True
name = fields.CharField()
displayname = fields.CharField()
url = fields.URLField()
class SAMLOrgInfoField(fields.DictField):
@ -683,34 +677,22 @@ class SAMLOrgInfoField(fields.DictField):
return data
class SAMLContactField(BaseDictWithChildField):
class SAMLContactField(HybridDictField):
default_error_messages = {
'missing_keys': _('Missing required key(s) for contact: {missing_keys}.'),
}
child_fields = {
'givenName': fields.CharField(),
'emailAddress': fields.EmailField(),
}
allow_unknown_keys = True
givenName = fields.CharField()
emailAddress = fields.EmailField()
class SAMLIdPField(BaseDictWithChildField):
class SAMLIdPField(HybridDictField):
default_error_messages = {
'missing_keys': _('Missing required key(s) for IdP: {missing_keys}.'),
}
child_fields = {
'entity_id': fields.CharField(),
'url': fields.URLField(),
'x509cert': fields.CharField(validators=[validate_certificate]),
'attr_user_permanent_id': fields.CharField(required=False),
'attr_first_name': fields.CharField(required=False),
'attr_last_name': fields.CharField(required=False),
'attr_username': fields.CharField(required=False),
'attr_email': fields.CharField(required=False),
}
allow_unknown_keys = True
entity_id = fields.CharField()
url = fields.URLField()
x509cert = fields.CharField(validators=[validate_certificate])
attr_user_permanent_id = fields.CharField(required=False)
attr_first_name = fields.CharField(required=False)
attr_last_name = fields.CharField(required=False)
attr_username = fields.CharField(required=False)
attr_email = fields.CharField(required=False)
class SAMLEnabledIdPsField(fields.DictField):
@ -718,52 +700,49 @@ class SAMLEnabledIdPsField(fields.DictField):
child = SAMLIdPField()
class SAMLSecurityField(BaseDictWithChildField):
class SAMLSecurityField(HybridDictField):
child_fields = {
'nameIdEncrypted': fields.BooleanField(required=False),
'authnRequestsSigned': fields.BooleanField(required=False),
'logoutRequestSigned': fields.BooleanField(required=False),
'logoutResponseSigned': fields.BooleanField(required=False),
'signMetadata': fields.BooleanField(required=False),
'wantMessagesSigned': fields.BooleanField(required=False),
'wantAssertionsSigned': fields.BooleanField(required=False),
'wantAssertionsEncrypted': fields.BooleanField(required=False),
'wantNameId': fields.BooleanField(required=False),
'wantNameIdEncrypted': fields.BooleanField(required=False),
'wantAttributeStatement': fields.BooleanField(required=False),
'requestedAuthnContext': fields.StringListBooleanField(required=False),
'requestedAuthnContextComparison': fields.CharField(required=False),
'metadataValidUntil': fields.CharField(allow_null=True, required=False),
'metadataCacheDuration': fields.CharField(allow_null=True, required=False),
'signatureAlgorithm': fields.CharField(allow_null=True, required=False),
'digestAlgorithm': fields.CharField(allow_null=True, required=False),
}
allow_unknown_keys = True
nameIdEncrypted = fields.BooleanField(required=False)
authnRequestsSigned = fields.BooleanField(required=False)
logoutRequestSigned = fields.BooleanField(required=False)
logoutResponseSigned = fields.BooleanField(required=False)
signMetadata = fields.BooleanField(required=False)
wantMessagesSigned = fields.BooleanField(required=False)
wantAssertionsSigned = fields.BooleanField(required=False)
wantAssertionsEncrypted = fields.BooleanField(required=False)
wantNameId = fields.BooleanField(required=False)
wantNameIdEncrypted = fields.BooleanField(required=False)
wantAttributeStatement = fields.BooleanField(required=False)
requestedAuthnContext = fields.StringListBooleanField(required=False)
requestedAuthnContextComparison = fields.CharField(required=False)
metadataValidUntil = fields.CharField(allow_null=True, required=False)
metadataCacheDuration = fields.CharField(allow_null=True, required=False)
signatureAlgorithm = fields.CharField(allow_null=True, required=False)
digestAlgorithm = fields.CharField(allow_null=True, required=False)
class SAMLOrgAttrField(BaseDictWithChildField):
class SAMLOrgAttrField(HybridDictField):
child_fields = {
'remove': fields.BooleanField(required=False),
'saml_attr': fields.CharField(required=False, allow_null=True),
'remove_admins': fields.BooleanField(required=False),
'saml_admin_attr': fields.CharField(required=False, allow_null=True),
}
remove = fields.BooleanField(required=False)
saml_attr = fields.CharField(required=False, allow_null=True)
remove_admins = fields.BooleanField(required=False)
saml_admin_attr = fields.CharField(required=False, allow_null=True)
child = _Forbidden()
class SAMLTeamAttrTeamOrgMapField(BaseDictWithChildField):
class SAMLTeamAttrTeamOrgMapField(HybridDictField):
child_fields = {
'team': fields.CharField(required=True, allow_null=False),
'organization': fields.CharField(required=True, allow_null=False),
}
team = fields.CharField(required=True, allow_null=False)
organization = fields.CharField(required=True, allow_null=False)
child = _Forbidden()
class SAMLTeamAttrField(BaseDictWithChildField):
class SAMLTeamAttrField(HybridDictField):
child_fields = {
'team_org_map': fields.ListField(required=False, child=SAMLTeamAttrTeamOrgMapField(), allow_null=True),
'remove': fields.BooleanField(required=False),
'saml_attr': fields.CharField(required=False, allow_null=True),
}
team_org_map = fields.ListField(required=False, child=SAMLTeamAttrTeamOrgMapField(), allow_null=True)
remove = fields.BooleanField(required=False)
saml_attr = fields.CharField(required=False, allow_null=True)
child = _Forbidden()

View File

@ -33,21 +33,23 @@ class TestSAMLOrgAttrField():
@pytest.mark.parametrize("data, expected", [
({'remove': 'blah', 'saml_attr': 'foobar'},
ValidationError('Must be a valid boolean.')),
{'remove': ['Must be a valid boolean.']}),
({'remove': True, 'saml_attr': False},
ValidationError('Not a valid string.')),
{'saml_attr': ['Not a valid string.']}),
({'remove': True, 'saml_attr': False, 'foo': 'bar', 'gig': 'ity'},
ValidationError('Invalid key(s): "foo", "gig".')),
{'saml_attr': ['Not a valid string.'],
'foo': ['Invalid field.'],
'gig': ['Invalid field.']}),
({'remove_admins': True, 'saml_admin_attr': False},
ValidationError('Not a valid string.')),
{'saml_admin_attr': ['Not a valid string.']}),
({'remove_admins': 'blah', 'saml_admin_attr': 'foobar'},
ValidationError('Must be a valid boolean.')),
{'remove_admins': ['Must be a valid boolean.']}),
])
def test_internal_value_invalid(self, data, expected):
field = SAMLOrgAttrField()
with pytest.raises(type(expected)) as e:
with pytest.raises(ValidationError) as e:
field.to_internal_value(data)
assert str(e.value) == str(expected)
assert e.value.detail == expected
class TestSAMLTeamAttrField():
@ -77,36 +79,38 @@ class TestSAMLTeamAttrField():
@pytest.mark.parametrize("data, expected", [
({'remove': True, 'saml_attr': 'foobar', 'team_org_map': [
{'team': 'foobar', 'not_a_valid_key': 'blah', 'organization': 'Ansible'},
]}, ValidationError('Invalid key(s): "not_a_valid_key".')),
]}, {'team_org_map': {0: {'not_a_valid_key': ['Invalid field.']}}}),
({'remove': False, 'saml_attr': 'foobar', 'team_org_map': [
{'organization': 'Ansible'},
]}, ValidationError('Missing key(s): "team".')),
]}, {'team_org_map': {0: {'team': ['This field is required.']}}}),
({'remove': False, 'saml_attr': 'foobar', 'team_org_map': [
{},
]}, ValidationError('Missing key(s): "organization", "team".')),
]}, {'team_org_map': {
0: {'organization': ['This field is required.'],
'team': ['This field is required.']}}}),
])
def test_internal_value_invalid(self, data, expected):
field = SAMLTeamAttrField()
with pytest.raises(type(expected)) as e:
with pytest.raises(ValidationError) as e:
field.to_internal_value(data)
assert str(e.value) == str(expected)
assert e.value.detail == expected
class TestLDAPGroupTypeParamsField():
@pytest.mark.parametrize("group_type, data, expected", [
('LDAPGroupType', {'name_attr': 'user', 'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "scooter".')),
['Invalid key(s): "bob", "scooter".']),
('MemberDNGroupType', {'name_attr': 'user', 'member_attr': 'west', 'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "scooter".')),
['Invalid key(s): "bob", "scooter".']),
('PosixUIDGroupType', {'name_attr': 'user', 'member_attr': 'west', 'ldap_group_user_attr': 'legacyThing',
'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "member_attr", "scooter".')),
['Invalid key(s): "bob", "member_attr", "scooter".']),
])
def test_internal_value_invalid(self, group_type, data, expected):
field = LDAPGroupTypeParamsField()
field.get_depends_on = mock.MagicMock(return_value=group_type)
with pytest.raises(type(expected)) as e:
with pytest.raises(ValidationError) as e:
field.to_internal_value(data)
assert str(e.value) == str(expected)
assert e.value.detail == expected