diff --git a/awx/main/tests/functional/api/test_settings.py b/awx/main/tests/functional/api/test_settings.py index 97effe0fa3..4e6852ce85 100644 --- a/awx/main/tests/functional/api/test_settings.py +++ b/awx/main/tests/functional/api/test_settings.py @@ -101,6 +101,42 @@ def test_ldap_settings(get, put, patch, delete, admin): patch(url, user=admin, data={'AUTH_LDAP_BIND_DN': u'cn=暴力膜,dc=大新闻,dc=真的粉丝'}, expect=200) +@pytest.mark.django_db +@pytest.mark.parametrize('value', [ + None, '', 'INVALID', 1, [1], ['INVALID'], +]) +def test_ldap_user_flags_by_group_invalid_dn(get, patch, admin, value): + url = reverse('api:setting_singleton_detail', kwargs={'category_slug': 'ldap'}) + patch(url, user=admin, + data={'AUTH_LDAP_USER_FLAGS_BY_GROUP': {'is_superuser': value}}, + expect=400) + + +@pytest.mark.django_db +def test_ldap_user_flags_by_group_string(get, patch, admin): + expected = 'CN=Admins,OU=Groups,DC=example,DC=com' + url = reverse('api:setting_singleton_detail', kwargs={'category_slug': 'ldap'}) + patch(url, user=admin, + data={'AUTH_LDAP_USER_FLAGS_BY_GROUP': {'is_superuser': expected}}, + expect=200) + resp = get(url, user=admin) + assert resp.data['AUTH_LDAP_USER_FLAGS_BY_GROUP']['is_superuser'] == [expected] + + +@pytest.mark.django_db +def test_ldap_user_flags_by_group_list(get, patch, admin): + expected = [ + 'CN=Admins,OU=Groups,DC=example,DC=com', + 'CN=Superadmins,OU=Groups,DC=example,DC=com' + ] + url = reverse('api:setting_singleton_detail', kwargs={'category_slug': 'ldap'}) + patch(url, user=admin, + data={'AUTH_LDAP_USER_FLAGS_BY_GROUP': {'is_superuser': expected}}, + expect=200) + resp = get(url, user=admin) + assert resp.data['AUTH_LDAP_USER_FLAGS_BY_GROUP']['is_superuser'] == expected + + @pytest.mark.parametrize('setting', [ 'AUTH_LDAP_USER_DN_TEMPLATE', 'AUTH_LDAP_REQUIRE_GROUP', diff --git a/awx/sso/fields.py b/awx/sso/fields.py index 0e7434f443..c31591beb7 100644 --- a/awx/sso/fields.py +++ b/awx/sso/fields.py @@ -220,6 +220,18 @@ class LDAPDNField(fields.CharField): return None if value == '' else value +class LDAPDNListField(fields.StringListField): + + def __init__(self, **kwargs): + super(LDAPDNListField, self).__init__(**kwargs) + self.validators.append(lambda dn: map(validate_ldap_dn, dn)) + + def run_validation(self, data=empty): + if not isinstance(data, (list, tuple)): + data = [data] + return super(LDAPDNListField, self).run_validation(data) + + class LDAPDNWithUserField(fields.CharField): def __init__(self, **kwargs): @@ -431,7 +443,7 @@ class LDAPUserFlagsField(fields.DictField): 'invalid_flag': _('Invalid user flag: "{invalid_flag}".'), } valid_user_flags = {'is_superuser', 'is_system_auditor'} - child = LDAPDNField() + child = LDAPDNListField() def to_internal_value(self, data): data = super(LDAPUserFlagsField, self).to_internal_value(data)