diff --git a/awx/main/access.py b/awx/main/access.py index c7e0ad2a07..97d3131bb5 100644 --- a/awx/main/access.py +++ b/awx/main/access.py @@ -25,7 +25,7 @@ from awx.main.task_engine import TaskEnhancer from awx.conf.license import LicenseForbids __all__ = ['get_user_queryset', 'check_user_access', 'check_user_access_with_errors', - 'user_accessible_objects', + 'user_accessible_objects', 'consumer_access', 'user_admin_role', 'StateConflict',] PERMISSION_TYPES = [ @@ -164,6 +164,17 @@ def check_superuser(func): return wrapper +def consumer_access(group_name): + ''' + consumer_access returns the proper Access class based on group_name + for a channels consumer. + ''' + class_map = {'job_events': JobAccess, + 'workflow_events': WorkflowJobAccess, + 'ad_hoc_command_events': AdHocCommandAccess} + return class_map.get(group_name) + + class BaseAccess(object): ''' Base class for checking user access to a given model. Subclasses should diff --git a/awx/main/consumers.py b/awx/main/consumers.py index 2cb1f450f2..6483efb8e9 100644 --- a/awx/main/consumers.py +++ b/awx/main/consumers.py @@ -1,12 +1,14 @@ import json -import urlparse import logging +import urllib from channels import Group -from channels.sessions import channel_session +from channels.sessions import channel_session, http_session +from channels.handler import AsgiRequest + +from django.core.serializers.json import DjangoJSONEncoder from django.contrib.auth.models import User -from django.core.serializers.json import DjangoJSONEncoder from awx.main.models.organization import AuthToken @@ -19,31 +21,25 @@ def discard_groups(message): Group(group).discard(message.reply_channel) -def validate_token(token): - try: - auth_token = AuthToken.objects.get(key=token) - if not auth_token.in_valid_tokens: - return None - except AuthToken.DoesNotExist: - return None - return auth_token - - -def user_from_token(auth_token): - try: - return User.objects.get(pk=auth_token.user_id) - except User.DoesNotExist: - return None - - +@http_session @channel_session def ws_connect(message): - token = None - qs = urlparse.parse_qs(message['query_string']) - if 'token' in qs: - if len(qs['token']) > 0: - token = qs['token'].pop() - message.channel_session['token'] = token + connect_text = {'accept':False, 'user':None} + + if message.http_session: + request = AsgiRequest(message) + token = request.COOKIES.get('token', None) + if token is not None: + token = urllib.unquote(token).strip('"') + try: + auth_token = AuthToken.objects.get(key=token) + if auth_token.in_valid_tokens: + message.channel_session['user_id'] = auth_token.user_id + connect_text['accept'] = True + connect_text['user'] = auth_token.user_id + except AuthToken.DoesNotExist: + logger.error("auth_token provided was invalid.") + message.reply_channel.send({"text": json.dumps(connect_text)}) @channel_session @@ -53,20 +49,15 @@ def ws_disconnect(message): @channel_session def ws_receive(message): - token = message.channel_session.get('token') + from awx.main.access import consumer_access - auth_token = validate_token(token) - if auth_token is None: - logger.error("Authentication Failure validating user") - message.reply_channel.send({"text": json.dumps({"error": "invalid auth token"})}) - return None - - user = user_from_token(auth_token) - if user is None: - logger.error("No valid user corresponding to submitted auth_token") + user_id = message.channel_session.get('user_id', None) + if user_id is None: + logger.error("No valid user found for websocket.") message.reply_channel.send({"text": json.dumps({"error": "no valid user"})}) return None + user = User.objects.get(pk=user_id) raw_data = message.content['text'] data = json.loads(raw_data) @@ -78,6 +69,12 @@ def ws_receive(message): if type(v) is list: for oid in v: name = '{}-{}'.format(group_name, oid) + access_cls = consumer_access(group_name) + if access_cls is not None: + user_access = access_cls(user) + if not user_access.get_queryset().filter(pk=oid).exists(): + message.reply_channel.send({"text": json.dumps({"error": "access denied to channel {0} for resource id {1}".format(group_name, oid)})}) + continue current_groups.append(name) Group(name).add(message.reply_channel) else: