diff --git a/awx/main/consumers.py b/awx/main/consumers.py index ff55507939..39020099d1 100644 --- a/awx/main/consumers.py +++ b/awx/main/consumers.py @@ -1,16 +1,14 @@ import json import logging -import urllib from channels import Group, channel_layers -from channels.sessions import channel_session -from channels.handler import AsgiRequest +from channels.sessions import enforce_ordering, channel_session, channel_and_http_session from django.conf import settings from django.core.serializers.json import DjangoJSONEncoder from django.contrib.auth.models import User -from awx.main.models.organization import AuthToken +from django.contrib.sessions.models import Session logger = logging.getLogger('awx.main.consumers') @@ -22,24 +20,21 @@ def discard_groups(message): Group(group).discard(message.reply_channel) -@channel_session +@channel_and_http_session def ws_connect(message): - connect_text = {'accept':False, 'user':None} + if message.http_session.session_key is None: + raise ValueError('No valid session key to get auth from') - message.content['method'] = 'FAKE' - request = AsgiRequest(message) - token = request.COOKIES.get('token', None) - if token is not None: - token = urllib.unquote(token).strip('"') - try: - auth_token = AuthToken.objects.get(key=token) - if auth_token.in_valid_tokens: - message.channel_session['user_id'] = auth_token.user_id - connect_text['accept'] = True - connect_text['user'] = auth_token.user_id - except AuthToken.DoesNotExist: - logger.error("auth_token provided was invalid.") - message.reply_channel.send({"text": json.dumps(connect_text)}) + session = Session.objects.get(session_key=message.http_session.session_key) + session_data = session.get_decoded() + + try: + user = User.objects.get(pk=session_data['_auth_user_id']) + except User.DoesNotExist: + raise ValueError('No valid user for the session key') + + message.channel_session['user_id'] = user.pk + message.reply_channel.send({"text": json.dumps({'accept': True, 'user': user.pk})}) @channel_session @@ -47,6 +42,7 @@ def ws_disconnect(message): discard_groups(message) +@enforce_ordering @channel_session def ws_receive(message): from awx.main.access import consumer_access