Use Django's own logic to invalidate sessions of users when changing passwords

The key is django.contrib.auth.update_session_auth_hash(), which knows
how to inject a recalculated session hash back into the session if the
requesting user is changing their own password, in order to keep that
user logged in.
This commit is contained in:
Jeff Bradberry 2019-03-26 17:22:16 -04:00
parent 2129f12085
commit f2be4de544
5 changed files with 22 additions and 19 deletions

View File

@ -16,6 +16,7 @@ from oauthlib.common import generate_token
# Django
from django.conf import settings
from django.contrib.auth import update_session_auth_hash
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist, ValidationError as DjangoValidationError
@ -933,8 +934,12 @@ class UserSerializer(BaseSerializer):
if new_password:
obj.set_password(new_password)
obj.save(update_fields=['password'])
if self.context['request'].user != obj:
UserSessionMembership.clear_session_for_user(obj)
# Cycle the session key, but if the requesting user is the same
# as the modified user then inject a session key derived from
# the updated user to prevent logout. This is the logic used by
# the Django admin's own user_change_password view.
update_session_auth_hash(self.context['request'], obj)
elif not obj.password:
obj.set_unusable_password()
obj.save(update_fields=['password'])

View File

@ -29,9 +29,9 @@ class Command(BaseCommand):
# with consideration for timezones.
start = timezone.now()
sessions = Session.objects.filter(expire_date__gte=start).iterator()
request = HttpRequest()
for session in sessions:
user_id = session.get_decoded().get('_auth_user_id')
if (user is None) or (user_id and user.id == int(user_id)):
request.session = import_module(settings.SESSION_ENGINE).SessionStore(session.session_key)
logout(request)
session = import_module(settings.SESSION_ENGINE).SessionStore(session.session_key)
# Log out the session, but without the need for a request object.
session.flush()

View File

@ -127,8 +127,8 @@ class SessionTimeoutMiddleware(object):
def process_response(self, request, response):
should_skip = 'HTTP_X_WS_SESSION_QUIET' in request.META
req_session = getattr(request, 'session', None)
if req_session and not req_session.is_empty() and should_skip is False:
# Only update the session if it hasn't been flushed by being forced to log out.
if request.session and not request.session.is_empty() and not should_skip:
expiry = int(settings.SESSION_COOKIE_AGE)
request.session.set_expiry(expiry)
response['Session-Timeout'] = expiry

View File

@ -183,12 +183,6 @@ class UserSessionMembership(BaseModel):
non_expire_memberships = [x for x in query_set if x.session.expire_date > now]
return non_expire_memberships[settings.SESSIONS_PER_USER:]
@staticmethod
def clear_session_for_user(user):
query_set = UserSessionMembership.objects.select_related('session').filter(user=user)
sessions_to_delete = [obj.session.pk for obj in query_set]
Session.objects.filter(pk__in=sessions_to_delete).delete()
# Add get_absolute_url method to User model if not present.
if not hasattr(User, 'get_absolute_url'):

View File

@ -1,5 +1,8 @@
import pytest
from django.contrib.sessions.middleware import SessionMiddleware
from django.test import Client
from awx.api.versioning import reverse
@ -19,7 +22,7 @@ EXAMPLE_USER_DATA = {
@pytest.mark.django_db
def test_user_create(post, admin):
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin)
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
assert response.status_code == 201
assert not response.data['is_superuser']
assert not response.data['is_system_auditor']
@ -27,21 +30,22 @@ def test_user_create(post, admin):
@pytest.mark.django_db
def test_fail_double_create_user(post, admin):
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin)
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
assert response.status_code == 201
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin)
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
assert response.status_code == 400
@pytest.mark.django_db
def test_create_delete_create_user(post, delete, admin):
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin)
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
assert response.status_code == 201
response = delete(reverse('api:user_detail', kwargs={'pk': response.data['id']}), admin)
response = delete(reverse('api:user_detail', kwargs={'pk': response.data['id']}), admin,
middleware=SessionMiddleware())
assert response.status_code == 204
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin)
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
print(response.data)
assert response.status_code == 201