move code linting to a stricter pep8-esque auto-formatting tool, black

This commit is contained in:
Ryan Petrello
2021-03-19 12:44:51 -04:00
parent 9b702e46fe
commit c2ef0a6500
671 changed files with 20538 additions and 21924 deletions

View File

@@ -40,11 +40,7 @@ logger = logging.getLogger('awx.sso.backends')
class LDAPSettings(BaseLDAPSettings):
defaults = dict(list(BaseLDAPSettings.defaults.items()) + list({
'ORGANIZATION_MAP': {},
'TEAM_MAP': {},
'GROUP_TYPE_PARAMS': {},
}.items()))
defaults = dict(list(BaseLDAPSettings.defaults.items()) + list({'ORGANIZATION_MAP': {}, 'TEAM_MAP': {}, 'GROUP_TYPE_PARAMS': {}}.items()))
def __init__(self, prefix='AUTH_LDAP_', defaults={}):
super(LDAPSettings, self).__init__(prefix, defaults)
@@ -72,9 +68,9 @@ class LDAPSettings(BaseLDAPSettings):
class LDAPBackend(BaseLDAPBackend):
'''
"""
Custom LDAP backend for AWX.
'''
"""
settings_prefix = 'AUTH_LDAP_'
@@ -117,14 +113,9 @@ class LDAPBackend(BaseLDAPBackend):
pass
try:
for setting_name, type_ in [
('GROUP_SEARCH', 'LDAPSearch'),
('GROUP_TYPE', 'LDAPGroupType'),
]:
for setting_name, type_ in [('GROUP_SEARCH', 'LDAPSearch'), ('GROUP_TYPE', 'LDAPGroupType')]:
if getattr(self.settings, setting_name) is None:
raise ImproperlyConfigured(
"{} must be an {} instance.".format(setting_name, type_)
)
raise ImproperlyConfigured("{} must be an {} instance.".format(setting_name, type_))
return super(LDAPBackend, self).authenticate(request, username, password)
except Exception:
logger.exception("Encountered an error authenticating to LDAP")
@@ -184,8 +175,7 @@ def _get_or_set_enterprise_user(username, password, provider):
except User.DoesNotExist:
user = User(username=username)
enterprise_auth = _decorate_enterprise_user(user, provider)
logger.debug("Created enterprise user %s via %s backend." %
(username, enterprise_auth.get_provider_display()))
logger.debug("Created enterprise user %s via %s backend." % (username, enterprise_auth.get_provider_display()))
created = True
if created or user.is_in_enterprise_category(provider):
return user
@@ -193,9 +183,9 @@ def _get_or_set_enterprise_user(username, password, provider):
class RADIUSBackend(BaseRADIUSBackend):
'''
"""
Custom Radius backend to verify license status
'''
"""
def authenticate(self, request, username, password):
if not django_settings.RADIUS_SERVER:
@@ -214,9 +204,9 @@ class RADIUSBackend(BaseRADIUSBackend):
class TACACSPlusBackend(object):
'''
"""
Custom TACACS+ auth backend for AWX
'''
"""
def authenticate(self, request, username, password):
if not django_settings.TACACSPLUS_HOST:
@@ -228,10 +218,7 @@ class TACACSPlusBackend(object):
django_settings.TACACSPLUS_PORT,
django_settings.TACACSPLUS_SECRET,
timeout=django_settings.TACACSPLUS_SESSION_TIMEOUT,
).authenticate(
username, password,
authen_type=tacacs_plus.TAC_PLUS_AUTHEN_TYPES[django_settings.TACACSPLUS_AUTH_PROTOCOL],
)
).authenticate(username, password, authen_type=tacacs_plus.TAC_PLUS_AUTHEN_TYPES[django_settings.TACACSPLUS_AUTH_PROTOCOL])
except Exception as e:
logger.exception("TACACS+ Authentication Error: %s" % str(e))
return None
@@ -248,9 +235,9 @@ class TACACSPlusBackend(object):
class TowerSAMLIdentityProvider(BaseSAMLIdentityProvider):
'''
"""
Custom Identity Provider to make attributes to what we expect.
'''
"""
def get_user_permanent_id(self, attributes):
uid = attributes[self.conf.get('attr_user_permanent_id', OID_USERID)]
@@ -270,26 +257,37 @@ class TowerSAMLIdentityProvider(BaseSAMLIdentityProvider):
if isinstance(value, (list, tuple)):
value = value[0]
if conf_key in ('attr_first_name', 'attr_last_name', 'attr_username', 'attr_email') and value is None:
logger.warn("Could not map user detail '%s' from SAML attribute '%s'; "
"update SOCIAL_AUTH_SAML_ENABLED_IDPS['%s']['%s'] with the correct SAML attribute.",
conf_key[5:], key, self.name, conf_key)
logger.warn(
"Could not map user detail '%s' from SAML attribute '%s'; " "update SOCIAL_AUTH_SAML_ENABLED_IDPS['%s']['%s'] with the correct SAML attribute.",
conf_key[5:],
key,
self.name,
conf_key,
)
return str(value) if value is not None else value
class SAMLAuth(BaseSAMLAuth):
'''
"""
Custom SAMLAuth backend to verify license status
'''
"""
def get_idp(self, idp_name):
idp_config = self.setting('ENABLED_IDPS')[idp_name]
return TowerSAMLIdentityProvider(idp_name, **idp_config)
def authenticate(self, request, *args, **kwargs):
if not all([django_settings.SOCIAL_AUTH_SAML_SP_ENTITY_ID, django_settings.SOCIAL_AUTH_SAML_SP_PUBLIC_CERT,
django_settings.SOCIAL_AUTH_SAML_SP_PRIVATE_KEY, django_settings.SOCIAL_AUTH_SAML_ORG_INFO,
django_settings.SOCIAL_AUTH_SAML_TECHNICAL_CONTACT, django_settings.SOCIAL_AUTH_SAML_SUPPORT_CONTACT,
django_settings.SOCIAL_AUTH_SAML_ENABLED_IDPS]):
if not all(
[
django_settings.SOCIAL_AUTH_SAML_SP_ENTITY_ID,
django_settings.SOCIAL_AUTH_SAML_SP_PUBLIC_CERT,
django_settings.SOCIAL_AUTH_SAML_SP_PRIVATE_KEY,
django_settings.SOCIAL_AUTH_SAML_ORG_INFO,
django_settings.SOCIAL_AUTH_SAML_TECHNICAL_CONTACT,
django_settings.SOCIAL_AUTH_SAML_SUPPORT_CONTACT,
django_settings.SOCIAL_AUTH_SAML_ENABLED_IDPS,
]
):
return None
user = super(SAMLAuth, self).authenticate(request, *args, **kwargs)
# Comes from https://github.com/omab/python-social-auth/blob/v0.2.21/social/backends/base.py#L91
@@ -300,18 +298,25 @@ class SAMLAuth(BaseSAMLAuth):
return user
def get_user(self, user_id):
if not all([django_settings.SOCIAL_AUTH_SAML_SP_ENTITY_ID, django_settings.SOCIAL_AUTH_SAML_SP_PUBLIC_CERT,
django_settings.SOCIAL_AUTH_SAML_SP_PRIVATE_KEY, django_settings.SOCIAL_AUTH_SAML_ORG_INFO,
django_settings.SOCIAL_AUTH_SAML_TECHNICAL_CONTACT, django_settings.SOCIAL_AUTH_SAML_SUPPORT_CONTACT,
django_settings.SOCIAL_AUTH_SAML_ENABLED_IDPS]):
if not all(
[
django_settings.SOCIAL_AUTH_SAML_SP_ENTITY_ID,
django_settings.SOCIAL_AUTH_SAML_SP_PUBLIC_CERT,
django_settings.SOCIAL_AUTH_SAML_SP_PRIVATE_KEY,
django_settings.SOCIAL_AUTH_SAML_ORG_INFO,
django_settings.SOCIAL_AUTH_SAML_TECHNICAL_CONTACT,
django_settings.SOCIAL_AUTH_SAML_SUPPORT_CONTACT,
django_settings.SOCIAL_AUTH_SAML_ENABLED_IDPS,
]
):
return None
return super(SAMLAuth, self).get_user(user_id)
def _update_m2m_from_groups(user, ldap_user, related, opts, remove=True):
'''
"""
Hepler function to update m2m relationship based on LDAP group membership.
'''
"""
should_add = False
if opts is None:
return
@@ -337,11 +342,12 @@ def _update_m2m_from_groups(user, ldap_user, related, opts, remove=True):
@receiver(populate_user, dispatch_uid='populate-ldap-user')
def on_populate_user(sender, **kwargs):
'''
"""
Handle signal from LDAP backend to populate the user object. Update user
organization/team memberships according to their LDAP groups.
'''
"""
from awx.main.models import Organization, Team
user = kwargs['user']
ldap_user = kwargs['ldap_user']
backend = ldap_user.backend
@@ -356,9 +362,7 @@ def on_populate_user(sender, **kwargs):
field_len = len(getattr(user, field))
if field_len > max_len:
setattr(user, field, getattr(user, field)[:max_len])
logger.warn(
'LDAP user {} has {} > max {} characters'.format(user.username, field, max_len)
)
logger.warn('LDAP user {} has {} > max {} characters'.format(user.username, field, max_len))
# Update organization membership based on group memberships.
org_map = getattr(backend.settings, 'ORGANIZATION_MAP', {})
@@ -367,16 +371,13 @@ def on_populate_user(sender, **kwargs):
remove = bool(org_opts.get('remove', True))
admins_opts = org_opts.get('admins', None)
remove_admins = bool(org_opts.get('remove_admins', remove))
_update_m2m_from_groups(user, ldap_user, org.admin_role.members, admins_opts,
remove_admins)
_update_m2m_from_groups(user, ldap_user, org.admin_role.members, admins_opts, remove_admins)
auditors_opts = org_opts.get('auditors', None)
remove_auditors = bool(org_opts.get('remove_auditors', remove))
_update_m2m_from_groups(user, ldap_user, org.auditor_role.members, auditors_opts,
remove_auditors)
_update_m2m_from_groups(user, ldap_user, org.auditor_role.members, auditors_opts, remove_auditors)
users_opts = org_opts.get('users', None)
remove_users = bool(org_opts.get('remove_users', remove))
_update_m2m_from_groups(user, ldap_user, org.member_role.members, users_opts,
remove_users)
_update_m2m_from_groups(user, ldap_user, org.member_role.members, users_opts, remove_users)
# Update team membership based on group memberships.
team_map = getattr(backend.settings, 'TEAM_MAP', {})
@@ -387,8 +388,7 @@ def on_populate_user(sender, **kwargs):
team, created = Team.objects.get_or_create(name=team_name, organization=org)
users_opts = team_opts.get('users', None)
remove = bool(team_opts.get('remove', True))
_update_m2m_from_groups(user, ldap_user, team.member_role.members, users_opts,
remove)
_update_m2m_from_groups(user, ldap_user, team.member_role.members, users_opts, remove)
# Update user profile to store LDAP DN.
user.save()

File diff suppressed because it is too large Load Diff

View File

@@ -14,10 +14,7 @@ from django.utils.translation import ugettext_lazy as _
# Django Auth LDAP
import django_auth_ldap.config
from django_auth_ldap.config import (
LDAPSearch,
LDAPSearchUnion,
)
from django_auth_ldap.config import LDAPSearch, LDAPSearchUnion
from rest_framework.exceptions import ValidationError
from rest_framework.fields import empty, Field, SkipField
@@ -46,9 +43,9 @@ def get_subclasses(cls):
def find_class_in_modules(class_name):
'''
"""
Used to find ldap subclasses by string
'''
"""
module_search_space = [django_auth_ldap.config, awx.sso.ldap_group_types]
for m in module_search_space:
cls = getattr(m, class_name, None)
@@ -57,7 +54,7 @@ def find_class_in_modules(class_name):
return None
class DependsOnMixin():
class DependsOnMixin:
def get_depends_on(self):
"""
Get the value of the dependent field.
@@ -65,38 +62,34 @@ class DependsOnMixin():
Then fall back to the raw value from the setting in the DB.
"""
from django.conf import settings
dependent_key = next(iter(self.depends_on))
if self.context:
request = self.context.get('request', None)
if request and request.data and \
request.data.get(dependent_key, None):
if request and request.data and request.data.get(dependent_key, None):
return request.data.get(dependent_key)
res = settings._get_local(dependent_key, validate=False)
return res
class _Forbidden(Field):
default_error_messages = {
'invalid': _('Invalid field.'),
}
default_error_messages = {'invalid': _('Invalid field.')}
def run_validation(self, value):
self.fail('invalid')
class HybridDictField(fields.DictField):
"""A DictField, but with defined fixed Fields for certain keys.
"""
"""A DictField, but with defined fixed Fields for certain keys."""
def __init__(self, *args, **kwargs):
self.allow_blank = kwargs.pop('allow_blank', False)
fields = [
sorted(
((field_name, obj) for field_name, obj in cls.__dict__.items()
if isinstance(obj, Field) and field_name != 'child'),
key=lambda x: x[1]._creation_counter
((field_name, obj) for field_name, obj in cls.__dict__.items() if isinstance(obj, Field) and field_name != 'child'),
key=lambda x: x[1]._creation_counter,
)
for cls in reversed(self.__class__.__mro__)
]
@@ -108,10 +101,7 @@ class HybridDictField(fields.DictField):
fields = copy.deepcopy(self._declared_fields)
return {
key: field.to_representation(val) if val is not None else None
for key, val, field in (
(six.text_type(key), val, fields.get(key, self.child))
for key, val in value.items()
)
for key, val, field in ((six.text_type(key), val, fields.get(key, self.child)) for key, val in value.items())
if not field.write_only
}
@@ -147,81 +137,67 @@ class AuthenticationBackendsField(fields.StringListField):
# Mapping of settings that must be set in order to enable each
# authentication backend.
REQUIRED_BACKEND_SETTINGS = collections.OrderedDict([
('awx.sso.backends.LDAPBackend', [
'AUTH_LDAP_SERVER_URI',
]),
('awx.sso.backends.LDAPBackend1', [
'AUTH_LDAP_1_SERVER_URI',
]),
('awx.sso.backends.LDAPBackend2', [
'AUTH_LDAP_2_SERVER_URI',
]),
('awx.sso.backends.LDAPBackend3', [
'AUTH_LDAP_3_SERVER_URI',
]),
('awx.sso.backends.LDAPBackend4', [
'AUTH_LDAP_4_SERVER_URI',
]),
('awx.sso.backends.LDAPBackend5', [
'AUTH_LDAP_5_SERVER_URI',
]),
('awx.sso.backends.RADIUSBackend', [
'RADIUS_SERVER',
]),
('social_core.backends.google.GoogleOAuth2', [
'SOCIAL_AUTH_GOOGLE_OAUTH2_KEY',
'SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET',
]),
('social_core.backends.github.GithubOAuth2', [
'SOCIAL_AUTH_GITHUB_KEY',
'SOCIAL_AUTH_GITHUB_SECRET',
]),
('social_core.backends.github.GithubOrganizationOAuth2', [
'SOCIAL_AUTH_GITHUB_ORG_KEY',
'SOCIAL_AUTH_GITHUB_ORG_SECRET',
'SOCIAL_AUTH_GITHUB_ORG_NAME',
]),
('social_core.backends.github.GithubTeamOAuth2', [
'SOCIAL_AUTH_GITHUB_TEAM_KEY',
'SOCIAL_AUTH_GITHUB_TEAM_SECRET',
'SOCIAL_AUTH_GITHUB_TEAM_ID',
]),
('social_core.backends.github_enterprise.GithubEnterpriseOAuth2', [
'SOCIAL_AUTH_GITHUB_ENTERPRISE_URL',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_API_URL',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_KEY',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_SECRET',
]),
('social_core.backends.github_enterprise.GithubEnterpriseOrganizationOAuth2', [
'SOCIAL_AUTH_GITHUB_ENTERPRISE_ORG_URL',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_ORG_API_URL',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_ORG_KEY',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_ORG_SECRET',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_ORG_NAME',
]),
('social_core.backends.github_enterprise.GithubEnterpriseTeamOAuth2', [
'SOCIAL_AUTH_GITHUB_ENTERPRISE_TEAM_URL',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_TEAM_API_URL',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_TEAM_KEY',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_TEAM_SECRET',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_TEAM_ID',
]),
('social_core.backends.azuread.AzureADOAuth2', [
'SOCIAL_AUTH_AZUREAD_OAUTH2_KEY',
'SOCIAL_AUTH_AZUREAD_OAUTH2_SECRET',
]),
('awx.sso.backends.SAMLAuth', [
'SOCIAL_AUTH_SAML_SP_ENTITY_ID',
'SOCIAL_AUTH_SAML_SP_PUBLIC_CERT',
'SOCIAL_AUTH_SAML_SP_PRIVATE_KEY',
'SOCIAL_AUTH_SAML_ORG_INFO',
'SOCIAL_AUTH_SAML_TECHNICAL_CONTACT',
'SOCIAL_AUTH_SAML_SUPPORT_CONTACT',
'SOCIAL_AUTH_SAML_ENABLED_IDPS',
]),
('django.contrib.auth.backends.ModelBackend', []),
])
REQUIRED_BACKEND_SETTINGS = collections.OrderedDict(
[
('awx.sso.backends.LDAPBackend', ['AUTH_LDAP_SERVER_URI']),
('awx.sso.backends.LDAPBackend1', ['AUTH_LDAP_1_SERVER_URI']),
('awx.sso.backends.LDAPBackend2', ['AUTH_LDAP_2_SERVER_URI']),
('awx.sso.backends.LDAPBackend3', ['AUTH_LDAP_3_SERVER_URI']),
('awx.sso.backends.LDAPBackend4', ['AUTH_LDAP_4_SERVER_URI']),
('awx.sso.backends.LDAPBackend5', ['AUTH_LDAP_5_SERVER_URI']),
('awx.sso.backends.RADIUSBackend', ['RADIUS_SERVER']),
('social_core.backends.google.GoogleOAuth2', ['SOCIAL_AUTH_GOOGLE_OAUTH2_KEY', 'SOCIAL_AUTH_GOOGLE_OAUTH2_SECRET']),
('social_core.backends.github.GithubOAuth2', ['SOCIAL_AUTH_GITHUB_KEY', 'SOCIAL_AUTH_GITHUB_SECRET']),
(
'social_core.backends.github.GithubOrganizationOAuth2',
['SOCIAL_AUTH_GITHUB_ORG_KEY', 'SOCIAL_AUTH_GITHUB_ORG_SECRET', 'SOCIAL_AUTH_GITHUB_ORG_NAME'],
),
('social_core.backends.github.GithubTeamOAuth2', ['SOCIAL_AUTH_GITHUB_TEAM_KEY', 'SOCIAL_AUTH_GITHUB_TEAM_SECRET', 'SOCIAL_AUTH_GITHUB_TEAM_ID']),
(
'social_core.backends.github_enterprise.GithubEnterpriseOAuth2',
[
'SOCIAL_AUTH_GITHUB_ENTERPRISE_URL',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_API_URL',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_KEY',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_SECRET',
],
),
(
'social_core.backends.github_enterprise.GithubEnterpriseOrganizationOAuth2',
[
'SOCIAL_AUTH_GITHUB_ENTERPRISE_ORG_URL',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_ORG_API_URL',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_ORG_KEY',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_ORG_SECRET',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_ORG_NAME',
],
),
(
'social_core.backends.github_enterprise.GithubEnterpriseTeamOAuth2',
[
'SOCIAL_AUTH_GITHUB_ENTERPRISE_TEAM_URL',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_TEAM_API_URL',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_TEAM_KEY',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_TEAM_SECRET',
'SOCIAL_AUTH_GITHUB_ENTERPRISE_TEAM_ID',
],
),
('social_core.backends.azuread.AzureADOAuth2', ['SOCIAL_AUTH_AZUREAD_OAUTH2_KEY', 'SOCIAL_AUTH_AZUREAD_OAUTH2_SECRET']),
(
'awx.sso.backends.SAMLAuth',
[
'SOCIAL_AUTH_SAML_SP_ENTITY_ID',
'SOCIAL_AUTH_SAML_SP_PUBLIC_CERT',
'SOCIAL_AUTH_SAML_SP_PRIVATE_KEY',
'SOCIAL_AUTH_SAML_ORG_INFO',
'SOCIAL_AUTH_SAML_TECHNICAL_CONTACT',
'SOCIAL_AUTH_SAML_SUPPORT_CONTACT',
'SOCIAL_AUTH_SAML_ENABLED_IDPS',
],
),
('django.contrib.auth.backends.ModelBackend', []),
]
)
@classmethod
def get_all_required_settings(cls):
@@ -236,6 +212,7 @@ class AuthenticationBackendsField(fields.StringListField):
def _default_from_required_settings(self):
from django.conf import settings
try:
backends = settings._awx_conf_settings._get_default('AUTHENTICATION_BACKENDS')
except AttributeError:
@@ -252,7 +229,6 @@ class AuthenticationBackendsField(fields.StringListField):
class LDAPServerURIField(fields.URLField):
def __init__(self, **kwargs):
kwargs.setdefault('schemes', ('ldap', 'ldaps'))
kwargs.setdefault('allow_plain_hostname', True)
@@ -266,9 +242,7 @@ class LDAPServerURIField(fields.URLField):
class LDAPConnectionOptionsField(fields.DictField):
default_error_messages = {
'invalid_options': _('Invalid connection option(s): {invalid_options}.'),
}
default_error_messages = {'invalid_options': _('Invalid connection option(s): {invalid_options}.')}
def to_representation(self, value):
value = value or {}
@@ -296,7 +270,6 @@ class LDAPConnectionOptionsField(fields.DictField):
class LDAPDNField(fields.CharField):
def __init__(self, **kwargs):
super(LDAPDNField, self).__init__(**kwargs)
self.validators.append(validate_ldap_dn)
@@ -309,7 +282,6 @@ class LDAPDNField(fields.CharField):
class LDAPDNListField(fields.StringListField):
def __init__(self, **kwargs):
super(LDAPDNListField, self).__init__(**kwargs)
self.validators.append(lambda dn: list(map(validate_ldap_dn, dn)))
@@ -321,7 +293,6 @@ class LDAPDNListField(fields.StringListField):
class LDAPDNWithUserField(fields.CharField):
def __init__(self, **kwargs):
super(LDAPDNWithUserField, self).__init__(**kwargs)
self.validators.append(validate_ldap_dn_with_user)
@@ -334,27 +305,20 @@ class LDAPDNWithUserField(fields.CharField):
class LDAPFilterField(fields.CharField):
def __init__(self, **kwargs):
super(LDAPFilterField, self).__init__(**kwargs)
self.validators.append(validate_ldap_filter)
class LDAPFilterWithUserField(fields.CharField):
def __init__(self, **kwargs):
super(LDAPFilterWithUserField, self).__init__(**kwargs)
self.validators.append(validate_ldap_filter_with_user)
class LDAPScopeField(fields.ChoiceField):
def __init__(self, choices=None, **kwargs):
choices = choices or [
('SCOPE_BASE', _('Base')),
('SCOPE_ONELEVEL', _('One Level')),
('SCOPE_SUBTREE', _('Subtree')),
]
choices = choices or [('SCOPE_BASE', _('Base')), ('SCOPE_ONELEVEL', _('One Level')), ('SCOPE_SUBTREE', _('Subtree'))]
super(LDAPScopeField, self).__init__(choices, **kwargs)
def to_representation(self, value):
@@ -394,9 +358,7 @@ class LDAPSearchField(fields.ListField):
if len(data) != 3:
self.fail('invalid_length', length=len(data))
return LDAPSearch(
LDAPDNField().run_validation(data[0]),
LDAPScopeField().run_validation(data[1]),
self.ldap_filter_field_class().run_validation(data[2]),
LDAPDNField().run_validation(data[0]), LDAPScopeField().run_validation(data[1]), self.ldap_filter_field_class().run_validation(data[2])
)
@@ -407,9 +369,7 @@ class LDAPSearchWithUserField(LDAPSearchField):
class LDAPSearchUnionField(fields.ListField):
default_error_messages = {
'type_error': _('Expected an instance of LDAPSearch or LDAPSearchUnion but got {input_type} instead.'),
}
default_error_messages = {'type_error': _('Expected an instance of LDAPSearch or LDAPSearchUnion but got {input_type} instead.')}
ldap_search_field_class = LDAPSearchWithUserField
def to_representation(self, value):
@@ -432,8 +392,7 @@ class LDAPSearchUnionField(fields.ListField):
search_args = []
for i in range(len(data)):
if not isinstance(data[i], list):
raise ValidationError('In order to ultilize LDAP Union, input element No. %d'
' should be a search query array.' % (i + 1))
raise ValidationError('In order to ultilize LDAP Union, input element No. %d' ' should be a search query array.' % (i + 1))
try:
search_args.append(self.ldap_search_field_class().run_validation(data[i]))
except Exception as e:
@@ -445,15 +404,13 @@ class LDAPSearchUnionField(fields.ListField):
class LDAPUserAttrMapField(fields.DictField):
default_error_messages = {
'invalid_attrs': _('Invalid user attribute(s): {invalid_attrs}.'),
}
default_error_messages = {'invalid_attrs': _('Invalid user attribute(s): {invalid_attrs}.')}
valid_user_attrs = {'first_name', 'last_name', 'email'}
child = fields.CharField()
def to_internal_value(self, data):
data = super(LDAPUserAttrMapField, self).to_internal_value(data)
invalid_attrs = (set(data.keys()) - self.valid_user_attrs)
invalid_attrs = set(data.keys()) - self.valid_user_attrs
if invalid_attrs:
invalid_attrs = sorted(list(invalid_attrs))
attrs_display = json.dumps(invalid_attrs).lstrip('[').rstrip(']')
@@ -466,7 +423,7 @@ class LDAPGroupTypeField(fields.ChoiceField, DependsOnMixin):
default_error_messages = {
'type_error': _('Expected an instance of LDAPGroupType but got {input_type} instead.'),
'missing_parameters': _('Missing required parameters in {dependency}.'),
'invalid_parameters': _('Invalid group_type parameters. Expected instance of dict but got {parameters_type} instead.')
'invalid_parameters': _('Invalid group_type parameters. Expected instance of dict but got {parameters_type} instead.'),
}
def __init__(self, choices=None, **kwargs):
@@ -515,9 +472,7 @@ class LDAPGroupTypeField(fields.ChoiceField, DependsOnMixin):
class LDAPGroupTypeParamsField(fields.DictField, DependsOnMixin):
default_error_messages = {
'invalid_keys': _('Invalid key(s): {invalid_keys}.'),
}
default_error_messages = {'invalid_keys': _('Invalid key(s): {invalid_keys}.')}
def to_internal_value(self, value):
value = super(LDAPGroupTypeParamsField, self).to_internal_value(value)
@@ -541,15 +496,13 @@ class LDAPGroupTypeParamsField(fields.DictField, DependsOnMixin):
class LDAPUserFlagsField(fields.DictField):
default_error_messages = {
'invalid_flag': _('Invalid user flag: "{invalid_flag}".'),
}
default_error_messages = {'invalid_flag': _('Invalid user flag: "{invalid_flag}".')}
valid_user_flags = {'is_superuser', 'is_system_auditor'}
child = LDAPDNListField()
def to_internal_value(self, data):
data = super(LDAPUserFlagsField, self).to_internal_value(data)
invalid_flags = (set(data.keys()) - self.valid_user_flags)
invalid_flags = set(data.keys()) - self.valid_user_flags
if invalid_flags:
self.fail('invalid_flag', invalid_flag=list(invalid_flags)[0])
return data
@@ -592,7 +545,6 @@ class LDAPTeamMapField(fields.DictField):
class SocialMapStringRegexField(fields.CharField):
def to_representation(self, value):
if isinstance(value, type(re.compile(''))):
flags = []
@@ -623,9 +575,7 @@ class SocialMapStringRegexField(fields.CharField):
class SocialMapField(fields.ListField):
default_error_messages = {
'type_error': _('Expected None, True, False, a string or list of strings but got {input_type} instead.'),
}
default_error_messages = {'type_error': _('Expected None, True, False, a string or list of strings but got {input_type} instead.')}
child = SocialMapStringRegexField()
def to_representation(self, value):
@@ -695,9 +645,7 @@ class SAMLOrgInfoValueField(HybridDictField):
class SAMLOrgInfoField(fields.DictField):
default_error_messages = {
'invalid_lang_code': _('Invalid language code(s) for org info: {invalid_lang_codes}.'),
}
default_error_messages = {'invalid_lang_code': _('Invalid language code(s) for org info: {invalid_lang_codes}.')}
child = SAMLOrgInfoValueField()
def to_internal_value(self, data):

View File

@@ -12,7 +12,6 @@ from django_auth_ldap.config import LDAPGroupType
class PosixUIDGroupType(LDAPGroupType):
def __init__(self, name_attr='cn', ldap_group_user_attr='uid'):
self.ldap_group_user_attr = ldap_group_user_attr
super(PosixUIDGroupType, self).__init__(name_attr)
@@ -20,6 +19,7 @@ class PosixUIDGroupType(LDAPGroupType):
"""
An LDAPGroupType subclass that handles non-standard DS.
"""
def user_groups(self, ldap_user, group_search):
"""
Searches for any group that is either the user's primary or contains the
@@ -34,12 +34,10 @@ class PosixUIDGroupType(LDAPGroupType):
user_gid = ldap_user.attrs['gidNumber'][0]
filterstr = u'(|(gidNumber=%s)(memberUid=%s))' % (
self.ldap.filter.escape_filter_chars(user_gid),
self.ldap.filter.escape_filter_chars(user_uid)
)
else:
filterstr = u'(memberUid=%s)' % (
self.ldap.filter.escape_filter_chars(user_uid),
)
else:
filterstr = u'(memberUid=%s)' % (self.ldap.filter.escape_filter_chars(user_uid),)
search = group_search.search_with_additional_term_string(filterstr)
search.attrlist = [str(self.name_attr)]

View File

@@ -17,7 +17,6 @@ from social_django.middleware import SocialAuthExceptionMiddleware
class SocialAuthMiddleware(SocialAuthExceptionMiddleware):
def process_request(self, request):
if request.path.startswith('/sso'):
# See upgrade blocker note in requirements/README.md

View File

@@ -7,9 +7,7 @@ from django.conf import settings
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
dependencies = [migrations.swappable_dependency(settings.AUTH_USER_MODEL)]
operations = [
migrations.CreateModel(
@@ -20,8 +18,5 @@ class Migration(migrations.Migration):
('user', models.ForeignKey(related_name='enterprise_auth', on_delete=models.CASCADE, to=settings.AUTH_USER_MODEL)),
],
),
migrations.AlterUniqueTogether(
name='userenterpriseauth',
unique_together=set([('user', 'provider')]),
),
migrations.AlterUniqueTogether(name='userenterpriseauth', unique_together=set([('user', 'provider')])),
]

View File

@@ -6,14 +6,12 @@ from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('sso', '0001_initial'),
]
dependencies = [('sso', '0001_initial')]
operations = [
migrations.AlterField(
model_name='userenterpriseauth',
name='provider',
field=models.CharField(max_length=32, choices=[('radius', 'RADIUS'), ('tacacs+', 'TACACS+'), ('saml', 'SAML')]),
),
)
]

View File

@@ -10,18 +10,10 @@ from django.utils.translation import ugettext_lazy as _
class UserEnterpriseAuth(models.Model):
"""Tower Enterprise Auth association model"""
PROVIDER_CHOICES = (
('radius', _('RADIUS')),
('tacacs+', _('TACACS+')),
('saml', _('SAML')),
)
PROVIDER_CHOICES = (('radius', _('RADIUS')), ('tacacs+', _('TACACS+')), ('saml', _('SAML')))
class Meta:
unique_together = ('user', 'provider')
user = models.ForeignKey(
User, related_name='enterprise_auth', on_delete=models.CASCADE
)
provider = models.CharField(
max_length=32, choices=PROVIDER_CHOICES
)
user = models.ForeignKey(User, related_name='enterprise_auth', on_delete=models.CASCADE)
provider = models.CharField(max_length=32, choices=PROVIDER_CHOICES)

View File

@@ -19,7 +19,6 @@ logger = logging.getLogger('awx.sso.pipeline')
class AuthNotFound(AuthException):
def __init__(self, backend, email_or_uid, *args, **kwargs):
self.email_or_uid = email_or_uid
super(AuthNotFound, self).__init__(backend, *args, **kwargs)
@@ -29,7 +28,6 @@ class AuthNotFound(AuthException):
class AuthInactive(AuthException):
def __str__(self):
return _('Your account is inactive')
@@ -52,10 +50,10 @@ def prevent_inactive_login(backend, details, user=None, *args, **kwargs):
def _update_m2m_from_expression(user, related, expr, remove=True):
'''
"""
Helper function to update m2m relationship based on user matching one or
more expressions.
'''
"""
should_add = False
if expr is None:
return
@@ -98,31 +96,28 @@ def _update_org_from_attr(user, related, attr, remove, remove_admins, remove_aud
getattr(org, related).members.add(user)
if remove:
[o.member_role.members.remove(user) for o in
Organization.objects.filter(Q(member_role__members=user) & ~Q(id__in=org_ids))]
[o.member_role.members.remove(user) for o in Organization.objects.filter(Q(member_role__members=user) & ~Q(id__in=org_ids))]
if remove_admins:
[o.admin_role.members.remove(user) for o in
Organization.objects.filter(Q(admin_role__members=user) & ~Q(id__in=org_ids))]
[o.admin_role.members.remove(user) for o in Organization.objects.filter(Q(admin_role__members=user) & ~Q(id__in=org_ids))]
if remove_auditors:
[o.auditor_role.members.remove(user) for o in
Organization.objects.filter(Q(auditor_role__members=user) & ~Q(id__in=org_ids))]
[o.auditor_role.members.remove(user) for o in Organization.objects.filter(Q(auditor_role__members=user) & ~Q(id__in=org_ids))]
def update_user_orgs(backend, details, user=None, *args, **kwargs):
'''
"""
Update organization memberships for the given user based on mapping rules
defined in settings.
'''
"""
if not user:
return
from awx.main.models import Organization
org_map = backend.setting('ORGANIZATION_MAP') or {}
for org_name, org_opts in org_map.items():
org = Organization.objects.get_or_create(name=org_name)[0]
# Update org admins from expression(s).
remove = bool(org_opts.get('remove', True))
admins_expr = org_opts.get('admins', None)
@@ -136,13 +131,14 @@ def update_user_orgs(backend, details, user=None, *args, **kwargs):
def update_user_teams(backend, details, user=None, *args, **kwargs):
'''
"""
Update team memberships for the given user based on mapping rules defined
in settings.
'''
"""
if not user:
return
from awx.main.models import Organization, Team
team_map = backend.setting('TEAM_MAP') or {}
for team_name, team_opts in team_map.items():
# Get or create the org to update.
@@ -150,7 +146,6 @@ def update_user_teams(backend, details, user=None, *args, **kwargs):
continue
org = Organization.objects.get_or_create(name=team_opts['organization'])[0]
# Update team members from expression(s).
team = Team.objects.get_or_create(name=team_name, organization=org)[0]
users_expr = team_opts.get('users', None)
@@ -162,6 +157,7 @@ def update_user_orgs_by_saml_attr(backend, details, user=None, *args, **kwargs):
if not user:
return
from django.conf import settings
org_map = settings.SOCIAL_AUTH_SAML_ORGANIZATION_ATTR
if org_map.get('saml_attr') is None and org_map.get('saml_admin_attr') is None and org_map.get('saml_auditor_attr') is None:
return
@@ -184,14 +180,12 @@ def update_user_teams_by_saml_attr(backend, details, user=None, *args, **kwargs)
return
from awx.main.models import Organization, Team
from django.conf import settings
team_map = settings.SOCIAL_AUTH_SAML_TEAM_ATTR
if team_map.get('saml_attr') is None:
return
saml_team_names = set(kwargs
.get('response', {})
.get('attributes', {})
.get(team_map['saml_attr'], []))
saml_team_names = set(kwargs.get('response', {}).get('attributes', {}).get(team_map['saml_attr'], []))
team_ids = []
for team_name_map in team_map.get('team_org_map', []):
@@ -230,5 +224,4 @@ def update_user_teams_by_saml_attr(backend, details, user=None, *args, **kwargs)
team.member_role.members.add(user)
if team_map.get('remove', True):
[t.member_role.members.remove(user) for t in
Team.objects.filter(Q(member_role__members=user) & ~Q(id__in=team_ids))]
[t.member_role.members.remove(user) for t in Team.objects.filter(Q(member_role__members=user) & ~Q(id__in=team_ids))]

View File

@@ -19,9 +19,7 @@ def test_fetch_user_if_exist(existing_tacacsplus_user):
def test_create_user_if_not_exist(existing_tacacsplus_user):
with mock.patch('awx.sso.backends.logger') as mocked_logger:
new_user = _get_or_set_enterprise_user("bar", "password", "tacacs+")
mocked_logger.debug.assert_called_once_with(
u'Created enterprise user bar via TACACS+ backend.'
)
mocked_logger.debug.assert_called_once_with(u'Created enterprise user bar via TACACS+ backend.')
assert new_user != existing_tacacsplus_user
@@ -35,7 +33,5 @@ def test_created_user_has_no_usable_password():
def test_non_enterprise_user_does_not_get_pass(existing_normal_user):
with mock.patch('awx.sso.backends.logger') as mocked_logger:
new_user = _get_or_set_enterprise_user("alice", "password", "tacacs+")
mocked_logger.warn.assert_called_once_with(
u'Enterprise user alice already defined in Tower.'
)
mocked_logger.warn.assert_called_once_with(u'Enterprise user alice already defined in Tower.')
assert new_user is None

View File

@@ -5,20 +5,15 @@ import pytest
from awx.sso.backends import LDAPSettings
@override_settings(AUTH_LDAP_CONNECTION_OPTIONS = {ldap.OPT_NETWORK_TIMEOUT: 60})
@override_settings(AUTH_LDAP_CONNECTION_OPTIONS={ldap.OPT_NETWORK_TIMEOUT: 60})
@pytest.mark.django_db
def test_ldap_with_custom_timeout():
settings = LDAPSettings()
assert settings.CONNECTION_OPTIONS == {
ldap.OPT_NETWORK_TIMEOUT: 60
}
assert settings.CONNECTION_OPTIONS == {ldap.OPT_NETWORK_TIMEOUT: 60}
@override_settings(AUTH_LDAP_CONNECTION_OPTIONS = {ldap.OPT_REFERRALS: 0})
@override_settings(AUTH_LDAP_CONNECTION_OPTIONS={ldap.OPT_REFERRALS: 0})
@pytest.mark.django_db
def test_ldap_with_missing_timeout():
settings = LDAPSettings()
assert settings.CONNECTION_OPTIONS == {
ldap.OPT_REFERRALS: 0,
ldap.OPT_NETWORK_TIMEOUT: 30
}
assert settings.CONNECTION_OPTIONS == {ldap.OPT_REFERRALS: 0, ldap.OPT_NETWORK_TIMEOUT: 30}

View File

@@ -1,20 +1,10 @@
import pytest
import re
from unittest import mock
from awx.sso.pipeline import (
update_user_orgs,
update_user_teams,
update_user_orgs_by_saml_attr,
update_user_teams_by_saml_attr,
)
from awx.sso.pipeline import update_user_orgs, update_user_teams, update_user_orgs_by_saml_attr, update_user_teams_by_saml_attr
from awx.main.models import (
User,
Team,
Organization
)
from awx.main.models import User, Team, Organization
@pytest.fixture
@@ -26,33 +16,13 @@ def users():
@pytest.mark.django_db
class TestSAMLMap():
class TestSAMLMap:
@pytest.fixture
def backend(self):
class Backend:
s = {
'ORGANIZATION_MAP': {
'Default': {
'remove': True,
'admins': 'foobar',
'remove_admins': True,
'users': 'foo',
'remove_users': True,
}
},
'TEAM_MAP': {
'Blue': {
'organization': 'Default',
'remove': True,
'users': '',
},
'Red': {
'organization': 'Default',
'remove': True,
'users': '',
}
}
'ORGANIZATION_MAP': {'Default': {'remove': True, 'admins': 'foobar', 'remove_admins': True, 'users': 'foo', 'remove_users': True}},
'TEAM_MAP': {'Blue': {'organization': 'Default', 'remove': True, 'users': ''}, 'Red': {'organization': 'Default', 'remove': True, 'users': ''}},
}
def setting(self, key):
@@ -132,17 +102,13 @@ class TestSAMLMap():
@pytest.mark.django_db
class TestSAMLAttr():
class TestSAMLAttr:
@pytest.fixture
def kwargs(self):
return {
'username': u'cmeyers@redhat.com',
'uid': 'idp:cmeyers@redhat.com',
'request': {
u'SAMLResponse': [],
u'RelayState': [u'idp']
},
'request': {u'SAMLResponse': [], u'RelayState': [u'idp']},
'is_new': False,
'response': {
'session_index': '_0728f0e0-b766-0135-75fa-02842b07c044',
@@ -156,14 +122,14 @@ class TestSAMLAttr():
'User.LastName': ['Meyers'],
'name_id': 'cmeyers@redhat.com',
'User.FirstName': ['Chris'],
'PersonImmutableID': []
}
'PersonImmutableID': [],
},
},
#'social': <UserSocialAuth: cmeyers@redhat.com>,
'social': None,
#'strategy': <awx.sso.strategies.django_strategy.AWXDjangoStrategy object at 0x8523a10>,
'strategy': None,
'new_association': False
'new_association': False,
}
@pytest.fixture
@@ -181,7 +147,7 @@ class TestSAMLAttr():
else:
autocreate = True
class MockSettings():
class MockSettings:
SAML_AUTO_CREATE_OBJECTS = autocreate
SOCIAL_AUTH_SAML_ORGANIZATION_ATTR = {
'saml_attr': 'memberOf',
@@ -200,12 +166,10 @@ class TestSAMLAttr():
{'team': 'Red', 'organization': 'Default1'},
{'team': 'Green', 'organization': 'Default1'},
{'team': 'Green', 'organization': 'Default3'},
{
'team': 'Yellow', 'team_alias': 'Yellow_Alias',
'organization': 'Default4', 'organization_alias': 'Default4_Alias'
},
]
{'team': 'Yellow', 'team_alias': 'Yellow_Alias', 'organization': 'Default4', 'organization_alias': 'Default4_Alias'},
],
}
return MockSettings()
def test_update_user_orgs_by_saml_attr(self, orgs, users, kwargs, mock_settings):
@@ -308,8 +272,7 @@ class TestSAMLAttr():
assert Team.objects.filter(name='Yellow', organization__name='Default4').count() == 0
assert Team.objects.filter(name='Yellow_Alias', organization__name='Default4_Alias').count() == 1
assert Team.objects.get(
name='Yellow_Alias', organization__name='Default4_Alias').member_role.members.count() == 1
assert Team.objects.get(name='Yellow_Alias', organization__name='Default4_Alias').member_role.members.count() == 1
@pytest.mark.fixture_args(autocreate=False)
def test_autocreate_disabled(self, users, kwargs, mock_settings):

View File

@@ -1,5 +1,3 @@
# Ensure that our autouse overwrites are working
def test_cache(settings):
assert settings.CACHES['default']['BACKEND'] == 'django.core.cache.backends.locmem.LocMemCache'

View File

@@ -1,51 +1,48 @@
import pytest
from unittest import mock
from rest_framework.exceptions import ValidationError
from awx.sso.fields import (
SAMLOrgAttrField,
SAMLTeamAttrField,
LDAPGroupTypeParamsField,
LDAPServerURIField
)
from awx.sso.fields import SAMLOrgAttrField, SAMLTeamAttrField, LDAPGroupTypeParamsField, LDAPServerURIField
class TestSAMLOrgAttrField():
@pytest.mark.parametrize("data, expected", [
({}, {}),
({'remove': True, 'saml_attr': 'foobar'}, {'remove': True, 'saml_attr': 'foobar'}),
({'remove': True, 'saml_attr': 1234}, {'remove': True, 'saml_attr': '1234'}),
({'remove': True, 'saml_attr': 3.14}, {'remove': True, 'saml_attr': '3.14'}),
({'saml_attr': 'foobar'}, {'saml_attr': 'foobar'}),
({'remove': True}, {'remove': True}),
({'remove': True, 'saml_admin_attr': 'foobar'}, {'remove': True, 'saml_admin_attr': 'foobar'}),
({'saml_admin_attr': 'foobar'}, {'saml_admin_attr': 'foobar'}),
({'remove_admins': True, 'saml_admin_attr': 'foobar'}, {'remove_admins': True, 'saml_admin_attr': 'foobar'}),
({'remove': True, 'saml_attr': 'foo', 'remove_admins': True, 'saml_admin_attr': 'bar'},
{'remove': True, 'saml_attr': 'foo', 'remove_admins': True, 'saml_admin_attr': 'bar'}),
])
class TestSAMLOrgAttrField:
@pytest.mark.parametrize(
"data, expected",
[
({}, {}),
({'remove': True, 'saml_attr': 'foobar'}, {'remove': True, 'saml_attr': 'foobar'}),
({'remove': True, 'saml_attr': 1234}, {'remove': True, 'saml_attr': '1234'}),
({'remove': True, 'saml_attr': 3.14}, {'remove': True, 'saml_attr': '3.14'}),
({'saml_attr': 'foobar'}, {'saml_attr': 'foobar'}),
({'remove': True}, {'remove': True}),
({'remove': True, 'saml_admin_attr': 'foobar'}, {'remove': True, 'saml_admin_attr': 'foobar'}),
({'saml_admin_attr': 'foobar'}, {'saml_admin_attr': 'foobar'}),
({'remove_admins': True, 'saml_admin_attr': 'foobar'}, {'remove_admins': True, 'saml_admin_attr': 'foobar'}),
(
{'remove': True, 'saml_attr': 'foo', 'remove_admins': True, 'saml_admin_attr': 'bar'},
{'remove': True, 'saml_attr': 'foo', 'remove_admins': True, 'saml_admin_attr': 'bar'},
),
],
)
def test_internal_value_valid(self, data, expected):
field = SAMLOrgAttrField()
res = field.to_internal_value(data)
assert res == expected
@pytest.mark.parametrize("data, expected", [
({'remove': 'blah', 'saml_attr': 'foobar'},
{'remove': ['Must be a valid boolean.']}),
({'remove': True, 'saml_attr': False},
{'saml_attr': ['Not a valid string.']}),
({'remove': True, 'saml_attr': False, 'foo': 'bar', 'gig': 'ity'},
{'saml_attr': ['Not a valid string.'],
'foo': ['Invalid field.'],
'gig': ['Invalid field.']}),
({'remove_admins': True, 'saml_admin_attr': False},
{'saml_admin_attr': ['Not a valid string.']}),
({'remove_admins': 'blah', 'saml_admin_attr': 'foobar'},
{'remove_admins': ['Must be a valid boolean.']}),
])
@pytest.mark.parametrize(
"data, expected",
[
({'remove': 'blah', 'saml_attr': 'foobar'}, {'remove': ['Must be a valid boolean.']}),
({'remove': True, 'saml_attr': False}, {'saml_attr': ['Not a valid string.']}),
(
{'remove': True, 'saml_attr': False, 'foo': 'bar', 'gig': 'ity'},
{'saml_attr': ['Not a valid string.'], 'foo': ['Invalid field.'], 'gig': ['Invalid field.']},
),
({'remove_admins': True, 'saml_admin_attr': False}, {'saml_admin_attr': ['Not a valid string.']}),
({'remove_admins': 'blah', 'saml_admin_attr': 'foobar'}, {'remove_admins': ['Must be a valid boolean.']}),
],
)
def test_internal_value_invalid(self, data, expected):
field = SAMLOrgAttrField()
with pytest.raises(ValidationError) as e:
@@ -53,51 +50,64 @@ class TestSAMLOrgAttrField():
assert e.value.detail == expected
class TestSAMLTeamAttrField():
@pytest.mark.parametrize("data", [
{},
{'remove': True, 'saml_attr': 'foobar', 'team_org_map': []},
{'remove': True, 'saml_attr': 'foobar', 'team_org_map': [
{'team': 'Engineering', 'organization': 'Ansible'}
]},
{'remove': True, 'saml_attr': 'foobar', 'team_org_map': [
{'team': 'Engineering', 'organization': 'Ansible'},
{'team': 'Engineering', 'organization': 'Ansible2'},
{'team': 'Engineering2', 'organization': 'Ansible'},
]},
{'remove': True, 'saml_attr': 'foobar', 'team_org_map': [
{'team': 'Engineering', 'organization': 'Ansible'},
{'team': 'Engineering', 'organization': 'Ansible2'},
{'team': 'Engineering2', 'organization': 'Ansible'},
]},
{'remove': True, 'saml_attr': 'foobar', 'team_org_map': [
class TestSAMLTeamAttrField:
@pytest.mark.parametrize(
"data",
[
{},
{'remove': True, 'saml_attr': 'foobar', 'team_org_map': []},
{'remove': True, 'saml_attr': 'foobar', 'team_org_map': [{'team': 'Engineering', 'organization': 'Ansible'}]},
{
'team': 'Engineering', 'team_alias': 'Engineering Team',
'organization': 'Ansible', 'organization_alias': 'Awesome Org'
'remove': True,
'saml_attr': 'foobar',
'team_org_map': [
{'team': 'Engineering', 'organization': 'Ansible'},
{'team': 'Engineering', 'organization': 'Ansible2'},
{'team': 'Engineering2', 'organization': 'Ansible'},
],
},
{'team': 'Engineering', 'organization': 'Ansible2'},
{'team': 'Engineering2', 'organization': 'Ansible'},
]},
])
{
'remove': True,
'saml_attr': 'foobar',
'team_org_map': [
{'team': 'Engineering', 'organization': 'Ansible'},
{'team': 'Engineering', 'organization': 'Ansible2'},
{'team': 'Engineering2', 'organization': 'Ansible'},
],
},
{
'remove': True,
'saml_attr': 'foobar',
'team_org_map': [
{'team': 'Engineering', 'team_alias': 'Engineering Team', 'organization': 'Ansible', 'organization_alias': 'Awesome Org'},
{'team': 'Engineering', 'organization': 'Ansible2'},
{'team': 'Engineering2', 'organization': 'Ansible'},
],
},
],
)
def test_internal_value_valid(self, data):
field = SAMLTeamAttrField()
res = field.to_internal_value(data)
assert res == data
@pytest.mark.parametrize("data, expected", [
({'remove': True, 'saml_attr': 'foobar', 'team_org_map': [
{'team': 'foobar', 'not_a_valid_key': 'blah', 'organization': 'Ansible'},
]}, {'team_org_map': {0: {'not_a_valid_key': ['Invalid field.']}}}),
({'remove': False, 'saml_attr': 'foobar', 'team_org_map': [
{'organization': 'Ansible'},
]}, {'team_org_map': {0: {'team': ['This field is required.']}}}),
({'remove': False, 'saml_attr': 'foobar', 'team_org_map': [
{},
]}, {'team_org_map': {
0: {'organization': ['This field is required.'],
'team': ['This field is required.']}}}),
])
@pytest.mark.parametrize(
"data, expected",
[
(
{'remove': True, 'saml_attr': 'foobar', 'team_org_map': [{'team': 'foobar', 'not_a_valid_key': 'blah', 'organization': 'Ansible'}]},
{'team_org_map': {0: {'not_a_valid_key': ['Invalid field.']}}},
),
(
{'remove': False, 'saml_attr': 'foobar', 'team_org_map': [{'organization': 'Ansible'}]},
{'team_org_map': {0: {'team': ['This field is required.']}}},
),
(
{'remove': False, 'saml_attr': 'foobar', 'team_org_map': [{}]},
{'team_org_map': {0: {'organization': ['This field is required.'], 'team': ['This field is required.']}}},
),
],
)
def test_internal_value_invalid(self, data, expected):
field = SAMLTeamAttrField()
with pytest.raises(ValidationError) as e:
@@ -105,17 +115,19 @@ class TestSAMLTeamAttrField():
assert e.value.detail == expected
class TestLDAPGroupTypeParamsField():
@pytest.mark.parametrize("group_type, data, expected", [
('LDAPGroupType', {'name_attr': 'user', 'bob': ['a', 'b'], 'scooter': 'hello'},
['Invalid key(s): "bob", "scooter".']),
('MemberDNGroupType', {'name_attr': 'user', 'member_attr': 'west', 'bob': ['a', 'b'], 'scooter': 'hello'},
['Invalid key(s): "bob", "scooter".']),
('PosixUIDGroupType', {'name_attr': 'user', 'member_attr': 'west', 'ldap_group_user_attr': 'legacyThing',
'bob': ['a', 'b'], 'scooter': 'hello'},
['Invalid key(s): "bob", "member_attr", "scooter".']),
])
class TestLDAPGroupTypeParamsField:
@pytest.mark.parametrize(
"group_type, data, expected",
[
('LDAPGroupType', {'name_attr': 'user', 'bob': ['a', 'b'], 'scooter': 'hello'}, ['Invalid key(s): "bob", "scooter".']),
('MemberDNGroupType', {'name_attr': 'user', 'member_attr': 'west', 'bob': ['a', 'b'], 'scooter': 'hello'}, ['Invalid key(s): "bob", "scooter".']),
(
'PosixUIDGroupType',
{'name_attr': 'user', 'member_attr': 'west', 'ldap_group_user_attr': 'legacyThing', 'bob': ['a', 'b'], 'scooter': 'hello'},
['Invalid key(s): "bob", "member_attr", "scooter".'],
),
],
)
def test_internal_value_invalid(self, group_type, data, expected):
field = LDAPGroupTypeParamsField()
field.get_depends_on = mock.MagicMock(return_value=group_type)
@@ -125,14 +137,16 @@ class TestLDAPGroupTypeParamsField():
assert e.value.detail == expected
class TestLDAPServerURIField():
@pytest.mark.parametrize("ldap_uri, exception, expected", [
(r'ldap://servername.com:444', None, r'ldap://servername.com:444'),
(r'ldap://servername.so3:444', None, r'ldap://servername.so3:444'),
(r'ldaps://servername3.s300:344', None, r'ldaps://servername3.s300:344'),
(r'ldap://servername.-so3:444', ValidationError, None),
])
class TestLDAPServerURIField:
@pytest.mark.parametrize(
"ldap_uri, exception, expected",
[
(r'ldap://servername.com:444', None, r'ldap://servername.com:444'),
(r'ldap://servername.so3:444', None, r'ldap://servername.so3:444'),
(r'ldaps://servername3.s300:344', None, r'ldaps://servername3.s300:344'),
(r'ldap://servername.-so3:444', ValidationError, None),
],
)
def test_run_validators_valid(self, ldap_uri, exception, expected):
field = LDAPServerURIField()
if exception is None:

View File

@@ -10,17 +10,15 @@ def test_empty_host_fails_auth(tacacsplus_backend):
def test_client_raises_exception(tacacsplus_backend):
client = mock.MagicMock()
client.authenticate.side_effect=Exception("foo")
with mock.patch('awx.sso.backends.django_settings') as settings,\
mock.patch('awx.sso.backends.logger') as logger,\
mock.patch('tacacs_plus.TACACSClient', return_value=client):
client.authenticate.side_effect = Exception("foo")
with mock.patch('awx.sso.backends.django_settings') as settings, mock.patch('awx.sso.backends.logger') as logger, mock.patch(
'tacacs_plus.TACACSClient', return_value=client
):
settings.TACACSPLUS_HOST = 'localhost'
settings.TACACSPLUS_AUTH_PROTOCOL = 'ascii'
ret_user = tacacsplus_backend.authenticate(None, u"user", u"pass")
assert ret_user is None
logger.exception.assert_called_once_with(
"TACACS+ Authentication Error: foo"
)
logger.exception.assert_called_once_with("TACACS+ Authentication Error: foo")
def test_client_return_invalid_fails_auth(tacacsplus_backend):
@@ -28,8 +26,7 @@ def test_client_return_invalid_fails_auth(tacacsplus_backend):
auth.valid = False
client = mock.MagicMock()
client.authenticate.return_value = auth
with mock.patch('awx.sso.backends.django_settings') as settings,\
mock.patch('tacacs_plus.TACACSClient', return_value=client):
with mock.patch('awx.sso.backends.django_settings') as settings, mock.patch('tacacs_plus.TACACSClient', return_value=client):
settings.TACACSPLUS_HOST = 'localhost'
settings.TACACSPLUS_AUTH_PROTOCOL = 'ascii'
ret_user = tacacsplus_backend.authenticate(None, u"user", u"pass")
@@ -43,9 +40,9 @@ def test_client_return_valid_passes_auth(tacacsplus_backend):
client.authenticate.return_value = auth
user = mock.MagicMock()
user.has_usable_password = mock.MagicMock(return_value=False)
with mock.patch('awx.sso.backends.django_settings') as settings,\
mock.patch('tacacs_plus.TACACSClient', return_value=client),\
mock.patch('awx.sso.backends._get_or_set_enterprise_user', return_value=user):
with mock.patch('awx.sso.backends.django_settings') as settings, mock.patch('tacacs_plus.TACACSClient', return_value=client), mock.patch(
'awx.sso.backends._get_or_set_enterprise_user', return_value=user
):
settings.TACACSPLUS_HOST = 'localhost'
settings.TACACSPLUS_AUTH_PROTOCOL = 'ascii'
ret_user = tacacsplus_backend.authenticate(None, u"user", u"pass")

View File

@@ -2,12 +2,7 @@
# All Rights Reserved.
from django.conf.urls import url
from awx.sso.views import (
sso_complete,
sso_error,
sso_inactive,
saml_metadata,
)
from awx.sso.views import sso_complete, sso_error, sso_inactive, saml_metadata
app_name = 'sso'

View File

@@ -8,10 +8,14 @@ import ldap
from django.core.exceptions import ValidationError
from django.utils.translation import ugettext_lazy as _
__all__ = ['validate_ldap_dn', 'validate_ldap_dn_with_user',
'validate_ldap_bind_dn', 'validate_ldap_filter',
'validate_ldap_filter_with_user',
'validate_tacacsplus_disallow_nonascii']
__all__ = [
'validate_ldap_dn',
'validate_ldap_dn_with_user',
'validate_ldap_bind_dn',
'validate_ldap_filter',
'validate_ldap_filter_with_user',
'validate_tacacsplus_disallow_nonascii',
]
def validate_ldap_dn(value, with_user=False):
@@ -32,8 +36,9 @@ def validate_ldap_dn_with_user(value):
def validate_ldap_bind_dn(value):
if not re.match(r'^[A-Za-z][A-Za-z0-9._-]*?\\[A-Za-z0-9 ._-]+?$', value.strip()) and \
not re.match(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$', value.strip()):
if not re.match(r'^[A-Za-z][A-Za-z0-9._-]*?\\[A-Za-z0-9 ._-]+?$', value.strip()) and not re.match(
r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$', value.strip()
):
validate_ldap_dn(value)

View File

@@ -37,7 +37,6 @@ sso_inactive = BaseRedirectView.as_view()
class CompleteView(BaseRedirectView):
def dispatch(self, request, *args, **kwargs):
response = super(CompleteView, self).dispatch(request, *args, **kwargs)
if self.request.user and self.request.user.is_authenticated:
@@ -54,16 +53,12 @@ sso_complete = CompleteView.as_view()
class MetadataView(View):
def get(self, request, *args, **kwargs):
from social_django.utils import load_backend, load_strategy
complete_url = reverse('social:complete', args=('saml', ))
complete_url = reverse('social:complete', args=('saml',))
try:
saml_backend = load_backend(
load_strategy(request),
'saml',
redirect_uri=complete_url,
)
saml_backend = load_backend(load_strategy(request), 'saml', redirect_uri=complete_url)
metadata, errors = saml_backend.generate_metadata_xml()
except Exception as e:
logger.exception('unable to generate SAML metadata')