mirror of
https://github.com/ansible/awx.git
synced 2026-01-13 11:00:03 -03:30
Refactor the SSO serializer fields to follow the DRF idioms more closely
and fix the tests to handle the newer nested validation checks properly.
This commit is contained in:
parent
76d4de24df
commit
2a81643308
@ -1,4 +1,5 @@
|
||||
import collections
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
@ -8,8 +9,8 @@ import ldap
|
||||
import awx
|
||||
|
||||
# Django
|
||||
from django.utils import six
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
# Django Auth LDAP
|
||||
import django_auth_ldap.config
|
||||
@ -18,7 +19,8 @@ from django_auth_ldap.config import (
|
||||
LDAPSearchUnion,
|
||||
)
|
||||
|
||||
from rest_framework.fields import empty
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.fields import empty, Field, SkipField
|
||||
|
||||
# This must be imported so get_subclasses picks it up
|
||||
from awx.sso.ldap_group_types import PosixUIDGroupType # noqa
|
||||
@ -74,6 +76,71 @@ class DependsOnMixin():
|
||||
return res
|
||||
|
||||
|
||||
class _Forbidden(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _('Invalid field.'),
|
||||
}
|
||||
|
||||
def run_validation(self, value):
|
||||
self.fail('invalid')
|
||||
|
||||
|
||||
class HybridDictField(fields.DictField):
|
||||
"""A DictField, but with defined fixed Fields for certain keys.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.allow_blank = kwargs.pop('allow_blank', False)
|
||||
|
||||
fields = [
|
||||
(field_name, obj)
|
||||
for field_name, obj in self.__class__.__dict__.items()
|
||||
if isinstance(obj, Field) and field_name != 'child'
|
||||
]
|
||||
fields.sort(key=lambda x: x[1]._creation_counter)
|
||||
self._declared_fields = collections.OrderedDict(fields)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def to_representation(self, value):
|
||||
fields = copy.deepcopy(self._declared_fields)
|
||||
return {
|
||||
key: field.to_representation(val) if val is not None else None
|
||||
for key, val, field in (
|
||||
(six.text_type(key), val, fields.get(key, self.child))
|
||||
for key, val in value.items()
|
||||
)
|
||||
if not field.write_only
|
||||
}
|
||||
|
||||
def run_child_validation(self, data):
|
||||
result = {}
|
||||
|
||||
if not data and self.allow_blank:
|
||||
return result
|
||||
|
||||
errors = collections.OrderedDict()
|
||||
fields = copy.deepcopy(self._declared_fields)
|
||||
keys = set(fields.keys()) | set(data.keys())
|
||||
|
||||
for key in keys:
|
||||
value = data.get(key, empty)
|
||||
key = six.text_type(key)
|
||||
field = fields.get(key, self.child)
|
||||
try:
|
||||
if field.read_only:
|
||||
continue # Ignore read_only fields, as Serializer seems to do.
|
||||
result[key] = field.run_validation(value)
|
||||
except ValidationError as e:
|
||||
errors[key] = e.detail
|
||||
except SkipField:
|
||||
pass
|
||||
|
||||
if not errors:
|
||||
return result
|
||||
raise ValidationError(errors)
|
||||
|
||||
|
||||
class AuthenticationBackendsField(fields.StringListField):
|
||||
|
||||
# Mapping of settings that must be set in order to enable each
|
||||
@ -459,70 +526,14 @@ class LDAPDNMapField(fields.StringListBooleanField):
|
||||
child = LDAPDNField()
|
||||
|
||||
|
||||
class BaseDictWithChildField(fields.DictField):
|
||||
class LDAPSingleOrganizationMapField(HybridDictField):
|
||||
|
||||
default_error_messages = {
|
||||
'missing_keys': _('Missing key(s): {missing_keys}.'),
|
||||
'invalid_keys': _('Invalid key(s): {invalid_keys}.'),
|
||||
}
|
||||
child_fields = {
|
||||
# 'key': fields.ChildField(),
|
||||
}
|
||||
allow_unknown_keys = False
|
||||
admins = LDAPDNMapField(allow_null=True, required=False)
|
||||
users = LDAPDNMapField(allow_null=True, required=False)
|
||||
remove_admins = fields.BooleanField(required=False)
|
||||
remove_users = fields.BooleanField(required=False)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.allow_blank = kwargs.pop('allow_blank', False)
|
||||
super(BaseDictWithChildField, self).__init__(*args, **kwargs)
|
||||
|
||||
def to_representation(self, value):
|
||||
value = super(BaseDictWithChildField, self).to_representation(value)
|
||||
for k, v in value.items():
|
||||
child_field = self.child_fields.get(k, None)
|
||||
if child_field:
|
||||
value[k] = child_field.to_representation(v)
|
||||
elif self.allow_unknown_keys:
|
||||
value[k] = v
|
||||
return value
|
||||
|
||||
def to_internal_value(self, data):
|
||||
data = super(BaseDictWithChildField, self).to_internal_value(data)
|
||||
missing_keys = set()
|
||||
for key, child_field in self.child_fields.items():
|
||||
if not child_field.required:
|
||||
continue
|
||||
elif key not in data:
|
||||
missing_keys.add(key)
|
||||
missing_keys = sorted(list(missing_keys))
|
||||
if missing_keys and (data or not self.allow_blank):
|
||||
missing_keys = sorted(list(missing_keys))
|
||||
keys_display = json.dumps(missing_keys).lstrip('[').rstrip(']')
|
||||
self.fail('missing_keys', missing_keys=keys_display)
|
||||
if not self.allow_unknown_keys:
|
||||
invalid_keys = set(data.keys()) - set(self.child_fields.keys())
|
||||
if invalid_keys:
|
||||
invalid_keys = sorted(list(invalid_keys))
|
||||
keys_display = json.dumps(invalid_keys).lstrip('[').rstrip(']')
|
||||
self.fail('invalid_keys', invalid_keys=keys_display)
|
||||
for k, v in data.items():
|
||||
child_field = self.child_fields.get(k, None)
|
||||
if child_field:
|
||||
data[k] = child_field.run_validation(v)
|
||||
elif self.allow_unknown_keys:
|
||||
data[k] = v
|
||||
return data
|
||||
|
||||
|
||||
class LDAPSingleOrganizationMapField(BaseDictWithChildField):
|
||||
|
||||
default_error_messages = {
|
||||
'invalid_keys': _('Invalid key(s) for organization map: {invalid_keys}.'),
|
||||
}
|
||||
child_fields = {
|
||||
'admins': LDAPDNMapField(allow_null=True, required=False),
|
||||
'users': LDAPDNMapField(allow_null=True, required=False),
|
||||
'remove_admins': fields.BooleanField(required=False),
|
||||
'remove_users': fields.BooleanField(required=False),
|
||||
}
|
||||
child = _Forbidden()
|
||||
|
||||
|
||||
class LDAPOrganizationMapField(fields.DictField):
|
||||
@ -530,17 +541,13 @@ class LDAPOrganizationMapField(fields.DictField):
|
||||
child = LDAPSingleOrganizationMapField()
|
||||
|
||||
|
||||
class LDAPSingleTeamMapField(BaseDictWithChildField):
|
||||
class LDAPSingleTeamMapField(HybridDictField):
|
||||
|
||||
default_error_messages = {
|
||||
'missing_keys': _('Missing required key for team map: {invalid_keys}.'),
|
||||
'invalid_keys': _('Invalid key(s) for team map: {invalid_keys}.'),
|
||||
}
|
||||
child_fields = {
|
||||
'organization': fields.CharField(),
|
||||
'users': LDAPDNMapField(allow_null=True, required=False),
|
||||
'remove': fields.BooleanField(required=False),
|
||||
}
|
||||
organization = fields.CharField()
|
||||
users = LDAPDNMapField(allow_null=True, required=False)
|
||||
remove = fields.BooleanField(required=False)
|
||||
|
||||
child = _Forbidden()
|
||||
|
||||
|
||||
class LDAPTeamMapField(fields.DictField):
|
||||
@ -614,17 +621,14 @@ class SocialMapField(fields.ListField):
|
||||
self.fail('type_error', input_type=type(data))
|
||||
|
||||
|
||||
class SocialSingleOrganizationMapField(BaseDictWithChildField):
|
||||
class SocialSingleOrganizationMapField(HybridDictField):
|
||||
|
||||
default_error_messages = {
|
||||
'invalid_keys': _('Invalid key(s) for organization map: {invalid_keys}.'),
|
||||
}
|
||||
child_fields = {
|
||||
'admins': SocialMapField(allow_null=True, required=False),
|
||||
'users': SocialMapField(allow_null=True, required=False),
|
||||
'remove_admins': fields.BooleanField(required=False),
|
||||
'remove_users': fields.BooleanField(required=False),
|
||||
}
|
||||
admins = SocialMapField(allow_null=True, required=False)
|
||||
users = SocialMapField(allow_null=True, required=False)
|
||||
remove_admins = fields.BooleanField(required=False)
|
||||
remove_users = fields.BooleanField(required=False)
|
||||
|
||||
child = _Forbidden()
|
||||
|
||||
|
||||
class SocialOrganizationMapField(fields.DictField):
|
||||
@ -632,17 +636,13 @@ class SocialOrganizationMapField(fields.DictField):
|
||||
child = SocialSingleOrganizationMapField()
|
||||
|
||||
|
||||
class SocialSingleTeamMapField(BaseDictWithChildField):
|
||||
class SocialSingleTeamMapField(HybridDictField):
|
||||
|
||||
default_error_messages = {
|
||||
'missing_keys': _('Missing required key for team map: {missing_keys}.'),
|
||||
'invalid_keys': _('Invalid key(s) for team map: {invalid_keys}.'),
|
||||
}
|
||||
child_fields = {
|
||||
'organization': fields.CharField(),
|
||||
'users': SocialMapField(allow_null=True, required=False),
|
||||
'remove': fields.BooleanField(required=False),
|
||||
}
|
||||
organization = fields.CharField()
|
||||
users = SocialMapField(allow_null=True, required=False)
|
||||
remove = fields.BooleanField(required=False)
|
||||
|
||||
child = _Forbidden()
|
||||
|
||||
|
||||
class SocialTeamMapField(fields.DictField):
|
||||
@ -650,17 +650,11 @@ class SocialTeamMapField(fields.DictField):
|
||||
child = SocialSingleTeamMapField()
|
||||
|
||||
|
||||
class SAMLOrgInfoValueField(BaseDictWithChildField):
|
||||
class SAMLOrgInfoValueField(HybridDictField):
|
||||
|
||||
default_error_messages = {
|
||||
'missing_keys': _('Missing required key(s) for org info record: {missing_keys}.'),
|
||||
}
|
||||
child_fields = {
|
||||
'name': fields.CharField(),
|
||||
'displayname': fields.CharField(),
|
||||
'url': fields.URLField(),
|
||||
}
|
||||
allow_unknown_keys = True
|
||||
name = fields.CharField()
|
||||
displayname = fields.CharField()
|
||||
url = fields.URLField()
|
||||
|
||||
|
||||
class SAMLOrgInfoField(fields.DictField):
|
||||
@ -683,34 +677,22 @@ class SAMLOrgInfoField(fields.DictField):
|
||||
return data
|
||||
|
||||
|
||||
class SAMLContactField(BaseDictWithChildField):
|
||||
class SAMLContactField(HybridDictField):
|
||||
|
||||
default_error_messages = {
|
||||
'missing_keys': _('Missing required key(s) for contact: {missing_keys}.'),
|
||||
}
|
||||
child_fields = {
|
||||
'givenName': fields.CharField(),
|
||||
'emailAddress': fields.EmailField(),
|
||||
}
|
||||
allow_unknown_keys = True
|
||||
givenName = fields.CharField()
|
||||
emailAddress = fields.EmailField()
|
||||
|
||||
|
||||
class SAMLIdPField(BaseDictWithChildField):
|
||||
class SAMLIdPField(HybridDictField):
|
||||
|
||||
default_error_messages = {
|
||||
'missing_keys': _('Missing required key(s) for IdP: {missing_keys}.'),
|
||||
}
|
||||
child_fields = {
|
||||
'entity_id': fields.CharField(),
|
||||
'url': fields.URLField(),
|
||||
'x509cert': fields.CharField(validators=[validate_certificate]),
|
||||
'attr_user_permanent_id': fields.CharField(required=False),
|
||||
'attr_first_name': fields.CharField(required=False),
|
||||
'attr_last_name': fields.CharField(required=False),
|
||||
'attr_username': fields.CharField(required=False),
|
||||
'attr_email': fields.CharField(required=False),
|
||||
}
|
||||
allow_unknown_keys = True
|
||||
entity_id = fields.CharField()
|
||||
url = fields.URLField()
|
||||
x509cert = fields.CharField(validators=[validate_certificate])
|
||||
attr_user_permanent_id = fields.CharField(required=False)
|
||||
attr_first_name = fields.CharField(required=False)
|
||||
attr_last_name = fields.CharField(required=False)
|
||||
attr_username = fields.CharField(required=False)
|
||||
attr_email = fields.CharField(required=False)
|
||||
|
||||
|
||||
class SAMLEnabledIdPsField(fields.DictField):
|
||||
@ -718,52 +700,49 @@ class SAMLEnabledIdPsField(fields.DictField):
|
||||
child = SAMLIdPField()
|
||||
|
||||
|
||||
class SAMLSecurityField(BaseDictWithChildField):
|
||||
class SAMLSecurityField(HybridDictField):
|
||||
|
||||
child_fields = {
|
||||
'nameIdEncrypted': fields.BooleanField(required=False),
|
||||
'authnRequestsSigned': fields.BooleanField(required=False),
|
||||
'logoutRequestSigned': fields.BooleanField(required=False),
|
||||
'logoutResponseSigned': fields.BooleanField(required=False),
|
||||
'signMetadata': fields.BooleanField(required=False),
|
||||
'wantMessagesSigned': fields.BooleanField(required=False),
|
||||
'wantAssertionsSigned': fields.BooleanField(required=False),
|
||||
'wantAssertionsEncrypted': fields.BooleanField(required=False),
|
||||
'wantNameId': fields.BooleanField(required=False),
|
||||
'wantNameIdEncrypted': fields.BooleanField(required=False),
|
||||
'wantAttributeStatement': fields.BooleanField(required=False),
|
||||
'requestedAuthnContext': fields.StringListBooleanField(required=False),
|
||||
'requestedAuthnContextComparison': fields.CharField(required=False),
|
||||
'metadataValidUntil': fields.CharField(allow_null=True, required=False),
|
||||
'metadataCacheDuration': fields.CharField(allow_null=True, required=False),
|
||||
'signatureAlgorithm': fields.CharField(allow_null=True, required=False),
|
||||
'digestAlgorithm': fields.CharField(allow_null=True, required=False),
|
||||
}
|
||||
allow_unknown_keys = True
|
||||
nameIdEncrypted = fields.BooleanField(required=False)
|
||||
authnRequestsSigned = fields.BooleanField(required=False)
|
||||
logoutRequestSigned = fields.BooleanField(required=False)
|
||||
logoutResponseSigned = fields.BooleanField(required=False)
|
||||
signMetadata = fields.BooleanField(required=False)
|
||||
wantMessagesSigned = fields.BooleanField(required=False)
|
||||
wantAssertionsSigned = fields.BooleanField(required=False)
|
||||
wantAssertionsEncrypted = fields.BooleanField(required=False)
|
||||
wantNameId = fields.BooleanField(required=False)
|
||||
wantNameIdEncrypted = fields.BooleanField(required=False)
|
||||
wantAttributeStatement = fields.BooleanField(required=False)
|
||||
requestedAuthnContext = fields.StringListBooleanField(required=False)
|
||||
requestedAuthnContextComparison = fields.CharField(required=False)
|
||||
metadataValidUntil = fields.CharField(allow_null=True, required=False)
|
||||
metadataCacheDuration = fields.CharField(allow_null=True, required=False)
|
||||
signatureAlgorithm = fields.CharField(allow_null=True, required=False)
|
||||
digestAlgorithm = fields.CharField(allow_null=True, required=False)
|
||||
|
||||
|
||||
class SAMLOrgAttrField(BaseDictWithChildField):
|
||||
class SAMLOrgAttrField(HybridDictField):
|
||||
|
||||
child_fields = {
|
||||
'remove': fields.BooleanField(required=False),
|
||||
'saml_attr': fields.CharField(required=False, allow_null=True),
|
||||
'remove_admins': fields.BooleanField(required=False),
|
||||
'saml_admin_attr': fields.CharField(required=False, allow_null=True),
|
||||
}
|
||||
remove = fields.BooleanField(required=False)
|
||||
saml_attr = fields.CharField(required=False, allow_null=True)
|
||||
remove_admins = fields.BooleanField(required=False)
|
||||
saml_admin_attr = fields.CharField(required=False, allow_null=True)
|
||||
|
||||
child = _Forbidden()
|
||||
|
||||
|
||||
class SAMLTeamAttrTeamOrgMapField(BaseDictWithChildField):
|
||||
class SAMLTeamAttrTeamOrgMapField(HybridDictField):
|
||||
|
||||
child_fields = {
|
||||
'team': fields.CharField(required=True, allow_null=False),
|
||||
'organization': fields.CharField(required=True, allow_null=False),
|
||||
}
|
||||
team = fields.CharField(required=True, allow_null=False)
|
||||
organization = fields.CharField(required=True, allow_null=False)
|
||||
|
||||
child = _Forbidden()
|
||||
|
||||
|
||||
class SAMLTeamAttrField(BaseDictWithChildField):
|
||||
class SAMLTeamAttrField(HybridDictField):
|
||||
|
||||
child_fields = {
|
||||
'team_org_map': fields.ListField(required=False, child=SAMLTeamAttrTeamOrgMapField(), allow_null=True),
|
||||
'remove': fields.BooleanField(required=False),
|
||||
'saml_attr': fields.CharField(required=False, allow_null=True),
|
||||
}
|
||||
team_org_map = fields.ListField(required=False, child=SAMLTeamAttrTeamOrgMapField(), allow_null=True)
|
||||
remove = fields.BooleanField(required=False)
|
||||
saml_attr = fields.CharField(required=False, allow_null=True)
|
||||
|
||||
child = _Forbidden()
|
||||
|
||||
@ -33,21 +33,23 @@ class TestSAMLOrgAttrField():
|
||||
|
||||
@pytest.mark.parametrize("data, expected", [
|
||||
({'remove': 'blah', 'saml_attr': 'foobar'},
|
||||
ValidationError('Must be a valid boolean.')),
|
||||
{'remove': ['Must be a valid boolean.']}),
|
||||
({'remove': True, 'saml_attr': False},
|
||||
ValidationError('Not a valid string.')),
|
||||
{'saml_attr': ['Not a valid string.']}),
|
||||
({'remove': True, 'saml_attr': False, 'foo': 'bar', 'gig': 'ity'},
|
||||
ValidationError('Invalid key(s): "foo", "gig".')),
|
||||
{'saml_attr': ['Not a valid string.'],
|
||||
'foo': ['Invalid field.'],
|
||||
'gig': ['Invalid field.']}),
|
||||
({'remove_admins': True, 'saml_admin_attr': False},
|
||||
ValidationError('Not a valid string.')),
|
||||
{'saml_admin_attr': ['Not a valid string.']}),
|
||||
({'remove_admins': 'blah', 'saml_admin_attr': 'foobar'},
|
||||
ValidationError('Must be a valid boolean.')),
|
||||
{'remove_admins': ['Must be a valid boolean.']}),
|
||||
])
|
||||
def test_internal_value_invalid(self, data, expected):
|
||||
field = SAMLOrgAttrField()
|
||||
with pytest.raises(type(expected)) as e:
|
||||
with pytest.raises(ValidationError) as e:
|
||||
field.to_internal_value(data)
|
||||
assert str(e.value) == str(expected)
|
||||
assert e.value.detail == expected
|
||||
|
||||
|
||||
class TestSAMLTeamAttrField():
|
||||
@ -77,36 +79,38 @@ class TestSAMLTeamAttrField():
|
||||
@pytest.mark.parametrize("data, expected", [
|
||||
({'remove': True, 'saml_attr': 'foobar', 'team_org_map': [
|
||||
{'team': 'foobar', 'not_a_valid_key': 'blah', 'organization': 'Ansible'},
|
||||
]}, ValidationError('Invalid key(s): "not_a_valid_key".')),
|
||||
]}, {'team_org_map': {0: {'not_a_valid_key': ['Invalid field.']}}}),
|
||||
({'remove': False, 'saml_attr': 'foobar', 'team_org_map': [
|
||||
{'organization': 'Ansible'},
|
||||
]}, ValidationError('Missing key(s): "team".')),
|
||||
]}, {'team_org_map': {0: {'team': ['This field is required.']}}}),
|
||||
({'remove': False, 'saml_attr': 'foobar', 'team_org_map': [
|
||||
{},
|
||||
]}, ValidationError('Missing key(s): "organization", "team".')),
|
||||
]}, {'team_org_map': {
|
||||
0: {'organization': ['This field is required.'],
|
||||
'team': ['This field is required.']}}}),
|
||||
])
|
||||
def test_internal_value_invalid(self, data, expected):
|
||||
field = SAMLTeamAttrField()
|
||||
with pytest.raises(type(expected)) as e:
|
||||
with pytest.raises(ValidationError) as e:
|
||||
field.to_internal_value(data)
|
||||
assert str(e.value) == str(expected)
|
||||
assert e.value.detail == expected
|
||||
|
||||
|
||||
class TestLDAPGroupTypeParamsField():
|
||||
|
||||
@pytest.mark.parametrize("group_type, data, expected", [
|
||||
('LDAPGroupType', {'name_attr': 'user', 'bob': ['a', 'b'], 'scooter': 'hello'},
|
||||
ValidationError('Invalid key(s): "bob", "scooter".')),
|
||||
['Invalid key(s): "bob", "scooter".']),
|
||||
('MemberDNGroupType', {'name_attr': 'user', 'member_attr': 'west', 'bob': ['a', 'b'], 'scooter': 'hello'},
|
||||
ValidationError('Invalid key(s): "bob", "scooter".')),
|
||||
['Invalid key(s): "bob", "scooter".']),
|
||||
('PosixUIDGroupType', {'name_attr': 'user', 'member_attr': 'west', 'ldap_group_user_attr': 'legacyThing',
|
||||
'bob': ['a', 'b'], 'scooter': 'hello'},
|
||||
ValidationError('Invalid key(s): "bob", "member_attr", "scooter".')),
|
||||
['Invalid key(s): "bob", "member_attr", "scooter".']),
|
||||
])
|
||||
def test_internal_value_invalid(self, group_type, data, expected):
|
||||
field = LDAPGroupTypeParamsField()
|
||||
field.get_depends_on = mock.MagicMock(return_value=group_type)
|
||||
|
||||
with pytest.raises(type(expected)) as e:
|
||||
with pytest.raises(ValidationError) as e:
|
||||
field.to_internal_value(data)
|
||||
assert str(e.value) == str(expected)
|
||||
assert e.value.detail == expected
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user