validate group type params

This commit is contained in:
chris meyers
2018-03-19 11:32:40 -04:00
parent 17795f82e8
commit 1c578cdd74
3 changed files with 79 additions and 10 deletions

View File

@@ -305,7 +305,7 @@ class SettingsWrapper(UserSettingsHolder):
settings_to_cache['_awx_conf_preload_expires'] = self._awx_conf_preload_expires settings_to_cache['_awx_conf_preload_expires'] = self._awx_conf_preload_expires
self.cache.set_many(settings_to_cache, timeout=SETTING_CACHE_TIMEOUT) self.cache.set_many(settings_to_cache, timeout=SETTING_CACHE_TIMEOUT)
def _get_local(self, name): def _get_local(self, name, validate=True):
self._preload_cache() self._preload_cache()
cache_key = Setting.get_cache_key(name) cache_key = Setting.get_cache_key(name)
try: try:
@@ -368,7 +368,10 @@ class SettingsWrapper(UserSettingsHolder):
field.run_validators(internal_value) field.run_validators(internal_value)
return internal_value return internal_value
else: else:
return field.run_validation(value) if validate:
return field.run_validation(value)
else:
return value
except Exception: except Exception:
logger.warning( logger.warning(
'The current value "%r" for setting "%s" is invalid.', 'The current value "%r" for setting "%s" is invalid.',

View File

@@ -8,7 +8,11 @@ from django.core.exceptions import ValidationError
# Django Auth LDAP # Django Auth LDAP
import django_auth_ldap.config import django_auth_ldap.config
from django_auth_ldap.config import LDAPSearch, LDAPSearchUnion from django_auth_ldap.config import (
LDAPSearch,
LDAPSearchUnion,
LDAPGroupType,
)
# This must be imported so get_subclasses picks it up # This must be imported so get_subclasses picks it up
from awx.sso.ldap_group_types import PosixUIDGroupType # noqa from awx.sso.ldap_group_types import PosixUIDGroupType # noqa
@@ -28,6 +32,25 @@ def get_subclasses(cls):
yield subclass yield subclass
class DependsOnMixin():
def get_depends_on(self):
"""
Get the value of the dependent field.
First try to find the value in the request.
Then fall back to the raw value from the setting in the DB.
"""
from django.conf import settings
dependent_key = iter(self.depends_on).next()
if self.context:
request = self.context.get('request', None)
if request and request.data and \
request.data.get(dependent_key, None):
return request.data.get(dependent_key)
res = settings._get_local(dependent_key, validate=False)
return res
class AuthenticationBackendsField(fields.StringListField): class AuthenticationBackendsField(fields.StringListField):
# Mapping of settings that must be set in order to enable each # Mapping of settings that must be set in order to enable each
@@ -326,7 +349,15 @@ class LDAPUserAttrMapField(fields.DictField):
return data return data
class LDAPGroupTypeField(fields.ChoiceField): VALID_GROUP_TYPE_PARAMS_MAP = {
'LDAPGroupType': ['name_attr'],
'MemberDNGroupType': ['name_attr', 'member_attr'],
'PosixUIDGroupType': ['name_attr', 'ldap_group_user_attr'],
}
class LDAPGroupTypeField(fields.ChoiceField, DependsOnMixin):
default_error_messages = { default_error_messages = {
'type_error': _('Expected an instance of LDAPGroupType but got {input_type} instead.'), 'type_error': _('Expected an instance of LDAPGroupType but got {input_type} instead.'),
@@ -357,8 +388,7 @@ class LDAPGroupTypeField(fields.ChoiceField):
if not data: if not data:
return None return None
from django.conf import settings params = self.get_depends_on() or {}
params = getattr(settings, iter(self.depends_on).next(), None) or {}
cls = find_class_in_modules(data) cls = find_class_in_modules(data)
if not cls: if not cls:
return None return None
@@ -370,8 +400,9 @@ class LDAPGroupTypeField(fields.ChoiceField):
# took a parameter. # took a parameter.
params_sanitized = dict() params_sanitized = dict()
if isinstance(cls, LDAPGroupType): if isinstance(cls, LDAPGroupType):
if 'name_attr' in params: for k in VALID_GROUP_TYPE_PARAMS_MAP['LDAPGroupType']:
params_sanitized['name_attr'] = params['name_attr'] if k in params:
params_sanitized['name_attr'] = params['name_attr']
if data.endswith('MemberDNGroupType'): if data.endswith('MemberDNGroupType'):
params.setdefault('member_attr', 'member') params.setdefault('member_attr', 'member')
@@ -383,8 +414,22 @@ class LDAPGroupTypeField(fields.ChoiceField):
return cls(**params_sanitized) return cls(**params_sanitized)
class LDAPGroupTypeParamsField(fields.DictField): class LDAPGroupTypeParamsField(fields.DictField, DependsOnMixin):
pass default_error_messages = {
'invalid_keys': _('Invalid key(s): {invalid_keys}.'),
}
def to_internal_value(self, value):
value = super(LDAPGroupTypeParamsField, self).to_internal_value(value)
if not value:
return value
group_type_str = self.get_depends_on()
group_type_str = group_type_str or ''
invalid_keys = (set(value.keys()) - set(VALID_GROUP_TYPE_PARAMS_MAP.get(group_type_str, 'LDAPGroupType')))
if invalid_keys:
keys_display = json.dumps(list(invalid_keys)).lstrip('[').rstrip(']')
self.fail('invalid_keys', invalid_keys=keys_display)
return value
class LDAPUserFlagsField(fields.DictField): class LDAPUserFlagsField(fields.DictField):

View File

@@ -1,11 +1,13 @@
import pytest import pytest
import mock
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from awx.sso.fields import ( from awx.sso.fields import (
SAMLOrgAttrField, SAMLOrgAttrField,
SAMLTeamAttrField, SAMLTeamAttrField,
LDAPGroupTypeParamsField,
) )
@@ -80,3 +82,22 @@ class TestSAMLTeamAttrField():
field.to_internal_value(data) field.to_internal_value(data)
assert str(e.value) == str(expected) assert str(e.value) == str(expected)
class TestLDAPGroupTypeParamsField():
@pytest.mark.parametrize("group_type, data, expected", [
('LDAPGroupType', {'name_attr': 'user', 'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "scooter".')),
('MemberDNGroupType', {'name_attr': 'user', 'member_attr': 'west', 'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "scooter".')),
('PosixUIDGroupType', {'name_attr': 'user', 'member_attr': 'west', 'ldap_group_user_attr': 'legacyThing',
'bob': ['a', 'b'], 'scooter': 'hello'},
ValidationError('Invalid key(s): "bob", "member_attr", "scooter".')),
])
def test_internal_value_invalid(self, group_type, data, expected):
field = LDAPGroupTypeParamsField()
field.get_depends_on = mock.MagicMock(return_value=group_type)
with pytest.raises(type(expected)) as e:
field.to_internal_value(data)
assert str(e.value) == str(expected)