Use Django's own logic to invalidate sessions of users when changing passwords

The key is django.contrib.auth.update_session_auth_hash(), which knows
how to inject a recalculated session hash back into the session if the
requesting user is changing their own password, in order to keep that
user logged in.
This commit is contained in:
Jeff Bradberry
2019-03-26 17:22:16 -04:00
parent 2129f12085
commit f2be4de544
5 changed files with 22 additions and 19 deletions

View File

@@ -16,6 +16,7 @@ from oauthlib.common import generate_token
# Django # Django
from django.conf import settings from django.conf import settings
from django.contrib.auth import update_session_auth_hash
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist, ValidationError as DjangoValidationError from django.core.exceptions import ObjectDoesNotExist, ValidationError as DjangoValidationError
@@ -933,8 +934,12 @@ class UserSerializer(BaseSerializer):
if new_password: if new_password:
obj.set_password(new_password) obj.set_password(new_password)
obj.save(update_fields=['password']) obj.save(update_fields=['password'])
if self.context['request'].user != obj:
UserSessionMembership.clear_session_for_user(obj) # Cycle the session key, but if the requesting user is the same
# as the modified user then inject a session key derived from
# the updated user to prevent logout. This is the logic used by
# the Django admin's own user_change_password view.
update_session_auth_hash(self.context['request'], obj)
elif not obj.password: elif not obj.password:
obj.set_unusable_password() obj.set_unusable_password()
obj.save(update_fields=['password']) obj.save(update_fields=['password'])

View File

@@ -29,9 +29,9 @@ class Command(BaseCommand):
# with consideration for timezones. # with consideration for timezones.
start = timezone.now() start = timezone.now()
sessions = Session.objects.filter(expire_date__gte=start).iterator() sessions = Session.objects.filter(expire_date__gte=start).iterator()
request = HttpRequest()
for session in sessions: for session in sessions:
user_id = session.get_decoded().get('_auth_user_id') user_id = session.get_decoded().get('_auth_user_id')
if (user is None) or (user_id and user.id == int(user_id)): if (user is None) or (user_id and user.id == int(user_id)):
request.session = import_module(settings.SESSION_ENGINE).SessionStore(session.session_key) session = import_module(settings.SESSION_ENGINE).SessionStore(session.session_key)
logout(request) # Log out the session, but without the need for a request object.
session.flush()

View File

@@ -127,8 +127,8 @@ class SessionTimeoutMiddleware(object):
def process_response(self, request, response): def process_response(self, request, response):
should_skip = 'HTTP_X_WS_SESSION_QUIET' in request.META should_skip = 'HTTP_X_WS_SESSION_QUIET' in request.META
req_session = getattr(request, 'session', None) # Only update the session if it hasn't been flushed by being forced to log out.
if req_session and not req_session.is_empty() and should_skip is False: if request.session and not request.session.is_empty() and not should_skip:
expiry = int(settings.SESSION_COOKIE_AGE) expiry = int(settings.SESSION_COOKIE_AGE)
request.session.set_expiry(expiry) request.session.set_expiry(expiry)
response['Session-Timeout'] = expiry response['Session-Timeout'] = expiry

View File

@@ -183,12 +183,6 @@ class UserSessionMembership(BaseModel):
non_expire_memberships = [x for x in query_set if x.session.expire_date > now] non_expire_memberships = [x for x in query_set if x.session.expire_date > now]
return non_expire_memberships[settings.SESSIONS_PER_USER:] return non_expire_memberships[settings.SESSIONS_PER_USER:]
@staticmethod
def clear_session_for_user(user):
query_set = UserSessionMembership.objects.select_related('session').filter(user=user)
sessions_to_delete = [obj.session.pk for obj in query_set]
Session.objects.filter(pk__in=sessions_to_delete).delete()
# Add get_absolute_url method to User model if not present. # Add get_absolute_url method to User model if not present.
if not hasattr(User, 'get_absolute_url'): if not hasattr(User, 'get_absolute_url'):

View File

@@ -1,5 +1,8 @@
import pytest import pytest
from django.contrib.sessions.middleware import SessionMiddleware
from django.test import Client
from awx.api.versioning import reverse from awx.api.versioning import reverse
@@ -19,7 +22,7 @@ EXAMPLE_USER_DATA = {
@pytest.mark.django_db @pytest.mark.django_db
def test_user_create(post, admin): def test_user_create(post, admin):
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin) response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
assert response.status_code == 201 assert response.status_code == 201
assert not response.data['is_superuser'] assert not response.data['is_superuser']
assert not response.data['is_system_auditor'] assert not response.data['is_system_auditor']
@@ -27,21 +30,22 @@ def test_user_create(post, admin):
@pytest.mark.django_db @pytest.mark.django_db
def test_fail_double_create_user(post, admin): def test_fail_double_create_user(post, admin):
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin) response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
assert response.status_code == 201 assert response.status_code == 201
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin) response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
assert response.status_code == 400 assert response.status_code == 400
@pytest.mark.django_db @pytest.mark.django_db
def test_create_delete_create_user(post, delete, admin): def test_create_delete_create_user(post, delete, admin):
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin) response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
assert response.status_code == 201 assert response.status_code == 201
response = delete(reverse('api:user_detail', kwargs={'pk': response.data['id']}), admin) response = delete(reverse('api:user_detail', kwargs={'pk': response.data['id']}), admin,
middleware=SessionMiddleware())
assert response.status_code == 204 assert response.status_code == 204
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin) response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
print(response.data) print(response.data)
assert response.status_code == 201 assert response.status_code == 201