mirror of
https://github.com/ansible/awx.git
synced 2026-03-07 11:41:08 -03:30
Use Django's own logic to invalidate sessions of users when changing passwords
The key is django.contrib.auth.update_session_auth_hash(), which knows how to inject a recalculated session hash back into the session if the requesting user is changing their own password, in order to keep that user logged in.
This commit is contained in:
@@ -16,6 +16,7 @@ from oauthlib.common import generate_token
|
|||||||
|
|
||||||
# Django
|
# Django
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
from django.contrib.auth import update_session_auth_hash
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.core.exceptions import ObjectDoesNotExist, ValidationError as DjangoValidationError
|
from django.core.exceptions import ObjectDoesNotExist, ValidationError as DjangoValidationError
|
||||||
@@ -933,8 +934,12 @@ class UserSerializer(BaseSerializer):
|
|||||||
if new_password:
|
if new_password:
|
||||||
obj.set_password(new_password)
|
obj.set_password(new_password)
|
||||||
obj.save(update_fields=['password'])
|
obj.save(update_fields=['password'])
|
||||||
if self.context['request'].user != obj:
|
|
||||||
UserSessionMembership.clear_session_for_user(obj)
|
# Cycle the session key, but if the requesting user is the same
|
||||||
|
# as the modified user then inject a session key derived from
|
||||||
|
# the updated user to prevent logout. This is the logic used by
|
||||||
|
# the Django admin's own user_change_password view.
|
||||||
|
update_session_auth_hash(self.context['request'], obj)
|
||||||
elif not obj.password:
|
elif not obj.password:
|
||||||
obj.set_unusable_password()
|
obj.set_unusable_password()
|
||||||
obj.save(update_fields=['password'])
|
obj.save(update_fields=['password'])
|
||||||
|
|||||||
@@ -29,9 +29,9 @@ class Command(BaseCommand):
|
|||||||
# with consideration for timezones.
|
# with consideration for timezones.
|
||||||
start = timezone.now()
|
start = timezone.now()
|
||||||
sessions = Session.objects.filter(expire_date__gte=start).iterator()
|
sessions = Session.objects.filter(expire_date__gte=start).iterator()
|
||||||
request = HttpRequest()
|
|
||||||
for session in sessions:
|
for session in sessions:
|
||||||
user_id = session.get_decoded().get('_auth_user_id')
|
user_id = session.get_decoded().get('_auth_user_id')
|
||||||
if (user is None) or (user_id and user.id == int(user_id)):
|
if (user is None) or (user_id and user.id == int(user_id)):
|
||||||
request.session = import_module(settings.SESSION_ENGINE).SessionStore(session.session_key)
|
session = import_module(settings.SESSION_ENGINE).SessionStore(session.session_key)
|
||||||
logout(request)
|
# Log out the session, but without the need for a request object.
|
||||||
|
session.flush()
|
||||||
|
|||||||
@@ -127,8 +127,8 @@ class SessionTimeoutMiddleware(object):
|
|||||||
|
|
||||||
def process_response(self, request, response):
|
def process_response(self, request, response):
|
||||||
should_skip = 'HTTP_X_WS_SESSION_QUIET' in request.META
|
should_skip = 'HTTP_X_WS_SESSION_QUIET' in request.META
|
||||||
req_session = getattr(request, 'session', None)
|
# Only update the session if it hasn't been flushed by being forced to log out.
|
||||||
if req_session and not req_session.is_empty() and should_skip is False:
|
if request.session and not request.session.is_empty() and not should_skip:
|
||||||
expiry = int(settings.SESSION_COOKIE_AGE)
|
expiry = int(settings.SESSION_COOKIE_AGE)
|
||||||
request.session.set_expiry(expiry)
|
request.session.set_expiry(expiry)
|
||||||
response['Session-Timeout'] = expiry
|
response['Session-Timeout'] = expiry
|
||||||
|
|||||||
@@ -183,12 +183,6 @@ class UserSessionMembership(BaseModel):
|
|||||||
non_expire_memberships = [x for x in query_set if x.session.expire_date > now]
|
non_expire_memberships = [x for x in query_set if x.session.expire_date > now]
|
||||||
return non_expire_memberships[settings.SESSIONS_PER_USER:]
|
return non_expire_memberships[settings.SESSIONS_PER_USER:]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def clear_session_for_user(user):
|
|
||||||
query_set = UserSessionMembership.objects.select_related('session').filter(user=user)
|
|
||||||
sessions_to_delete = [obj.session.pk for obj in query_set]
|
|
||||||
Session.objects.filter(pk__in=sessions_to_delete).delete()
|
|
||||||
|
|
||||||
|
|
||||||
# Add get_absolute_url method to User model if not present.
|
# Add get_absolute_url method to User model if not present.
|
||||||
if not hasattr(User, 'get_absolute_url'):
|
if not hasattr(User, 'get_absolute_url'):
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from django.contrib.sessions.middleware import SessionMiddleware
|
||||||
|
from django.test import Client
|
||||||
|
|
||||||
from awx.api.versioning import reverse
|
from awx.api.versioning import reverse
|
||||||
|
|
||||||
|
|
||||||
@@ -19,7 +22,7 @@ EXAMPLE_USER_DATA = {
|
|||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_user_create(post, admin):
|
def test_user_create(post, admin):
|
||||||
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin)
|
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
|
||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
assert not response.data['is_superuser']
|
assert not response.data['is_superuser']
|
||||||
assert not response.data['is_system_auditor']
|
assert not response.data['is_system_auditor']
|
||||||
@@ -27,21 +30,22 @@ def test_user_create(post, admin):
|
|||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_fail_double_create_user(post, admin):
|
def test_fail_double_create_user(post, admin):
|
||||||
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin)
|
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
|
||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
|
|
||||||
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin)
|
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_create_delete_create_user(post, delete, admin):
|
def test_create_delete_create_user(post, delete, admin):
|
||||||
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin)
|
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
|
||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
|
|
||||||
response = delete(reverse('api:user_detail', kwargs={'pk': response.data['id']}), admin)
|
response = delete(reverse('api:user_detail', kwargs={'pk': response.data['id']}), admin,
|
||||||
|
middleware=SessionMiddleware())
|
||||||
assert response.status_code == 204
|
assert response.status_code == 204
|
||||||
|
|
||||||
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin)
|
response = post(reverse('api:user_list'), EXAMPLE_USER_DATA, admin, middleware=SessionMiddleware())
|
||||||
print(response.data)
|
print(response.data)
|
||||||
assert response.status_code == 201
|
assert response.status_code == 201
|
||||||
|
|||||||
Reference in New Issue
Block a user