mirror of
https://github.com/ansible/awx.git
synced 2026-05-19 14:57:39 -02:30
prevent cross site request forgery in websockets w/ the CSRF token
now that we have the CSRF middleware, we have a reliable token available to us which we can use to verify individual ws_receive payloads; this is _simpler_ than making sure you've properly configured trusted origins, and it's also more secure than Origin header checks see: https://github.com/ansible/tower/issues/2661
This commit is contained in:
@@ -3,15 +3,13 @@ import logging
|
||||
|
||||
from channels import Group
|
||||
from channels.auth import channel_session_user_from_http, channel_session_user
|
||||
from channels.exceptions import DenyConnection
|
||||
from six.moves.urllib.parse import urlparse
|
||||
|
||||
from django.utils.http import is_same_domain
|
||||
from django.conf import settings
|
||||
from django.http.cookie import parse_cookie
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
|
||||
|
||||
logger = logging.getLogger('awx.main.consumers')
|
||||
XRF_KEY = '_auth_user_xrf'
|
||||
|
||||
|
||||
def discard_groups(message):
|
||||
@@ -20,47 +18,22 @@ def discard_groups(message):
|
||||
Group(group).discard(message.reply_channel)
|
||||
|
||||
|
||||
def origin_is_valid(message, trusted_values):
|
||||
origin = dict(message.content.get('headers', {})).get('origin', '')
|
||||
for trusted in trusted_values:
|
||||
try:
|
||||
client = urlparse(origin)
|
||||
trusted = urlparse(trusted)
|
||||
except (AttributeError, ValueError):
|
||||
# if we can't parse a hostname, consider it invalid and try the
|
||||
# next one
|
||||
pass
|
||||
else:
|
||||
# if we _can_ parse the origin header, verify that it's trusted
|
||||
if (
|
||||
trusted.scheme == client.scheme and
|
||||
is_same_domain(client.netloc, trusted.netloc)
|
||||
):
|
||||
# the provided Origin matches at least _one_ whitelisted host,
|
||||
# return True
|
||||
return True
|
||||
logger.error((
|
||||
"ws:// origin header mismatch {} not in {}; consider adding {} to "
|
||||
"settings.WEBSOCKET_ORIGIN_WHITELIST if it's a trusted host."
|
||||
).format(origin, trusted_values, origin))
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@channel_session_user_from_http
|
||||
def ws_connect(message):
|
||||
if not origin_is_valid(
|
||||
message,
|
||||
[settings.TOWER_URL_BASE] + settings.WEBSOCKET_ORIGIN_WHITELIST
|
||||
):
|
||||
raise DenyConnection()
|
||||
|
||||
headers = dict(message.content.get('headers', ''))
|
||||
message.reply_channel.send({"accept": True})
|
||||
message.content['method'] = 'FAKE'
|
||||
if message.user.is_authenticated():
|
||||
message.reply_channel.send(
|
||||
{"text": json.dumps({"accept": True, "user": message.user.id})}
|
||||
)
|
||||
# store the valid CSRF token from the cookie so we can compare it later
|
||||
# on ws_receive
|
||||
cookie_token = parse_cookie(
|
||||
headers.get('cookie')
|
||||
).get('csrftoken')
|
||||
if cookie_token:
|
||||
message.channel_session[XRF_KEY] = cookie_token
|
||||
else:
|
||||
logger.error("Request user is not authenticated to use websocket.")
|
||||
message.reply_channel.send({"close": True})
|
||||
@@ -79,6 +52,20 @@ def ws_receive(message):
|
||||
raw_data = message.content['text']
|
||||
data = json.loads(raw_data)
|
||||
|
||||
xrftoken = data.get('xrftoken')
|
||||
if (
|
||||
not xrftoken or
|
||||
XRF_KEY not in message.channel_session or
|
||||
xrftoken != message.channel_session[XRF_KEY]
|
||||
):
|
||||
logger.error(
|
||||
"access denied to channel, XRF mismatch for {}".format(user.username)
|
||||
)
|
||||
message.reply_channel.send({
|
||||
"text": json.dumps({"error": "access denied to channel"})
|
||||
})
|
||||
return
|
||||
|
||||
if 'groups' in data:
|
||||
discard_groups(message)
|
||||
groups = data['groups']
|
||||
|
||||
Reference in New Issue
Block a user