mirror of
https://github.com/ansible/awx.git
synced 2026-02-01 17:48:10 -03:30
306 lines
14 KiB
Python
306 lines
14 KiB
Python
import json
|
|
import logging
|
|
import asyncio
|
|
from typing import Dict
|
|
|
|
import aiohttp
|
|
from aiohttp import client_exceptions
|
|
|
|
from channels.layers import get_channel_layer
|
|
|
|
from django.conf import settings
|
|
from django.apps import apps
|
|
|
|
import psycopg
|
|
|
|
from awx.main.analytics.broadcast_websocket import (
|
|
RelayWebsocketStats,
|
|
RelayWebsocketStatsManager,
|
|
)
|
|
import awx.main.analytics.subsystem_metrics as s_metrics
|
|
|
|
logger = logging.getLogger('awx.main.wsrelay')
|
|
|
|
|
|
def wrap_broadcast_msg(group, message: str):
|
|
# TODO: Maybe wrap as "group","message" so that we don't need to
|
|
# encode/decode as json.
|
|
return dict(group=group, message=message)
|
|
|
|
|
|
def get_local_host():
|
|
Instance = apps.get_model('main', 'Instance')
|
|
return Instance.objects.my_hostname()
|
|
|
|
|
|
class WebsocketRelayConnection:
|
|
def __init__(
|
|
self,
|
|
name,
|
|
stats: RelayWebsocketStats,
|
|
remote_host: str,
|
|
remote_port: int = settings.BROADCAST_WEBSOCKET_PORT,
|
|
protocol: str = settings.BROADCAST_WEBSOCKET_PROTOCOL,
|
|
verify_ssl: bool = settings.BROADCAST_WEBSOCKET_VERIFY_CERT,
|
|
):
|
|
self.name = name
|
|
self.event_loop = asyncio.get_event_loop()
|
|
self.stats = stats
|
|
self.remote_host = remote_host
|
|
self.remote_port = remote_port
|
|
self.protocol = protocol
|
|
self.verify_ssl = verify_ssl
|
|
self.channel_layer = None
|
|
self.subsystem_metrics = s_metrics.Metrics(instance_name=name)
|
|
self.producers = dict()
|
|
self.connected = False
|
|
|
|
async def run_loop(self, websocket: aiohttp.ClientWebSocketResponse):
|
|
raise RuntimeError("Implement me")
|
|
|
|
async def connect(self):
|
|
from awx.main.consumers import WebsocketSecretAuthHelper # noqa
|
|
|
|
logger.debug(f"Connection attempt from {self.name} to {self.remote_host}")
|
|
|
|
'''
|
|
Can not put get_channel_layer() in the init code because it is in the init
|
|
path of channel layers i.e. RedisChannelLayer() calls our init code.
|
|
'''
|
|
if not self.channel_layer:
|
|
self.channel_layer = get_channel_layer()
|
|
|
|
uri = f"{self.protocol}://{self.remote_host}:{self.remote_port}/websocket/relay/"
|
|
timeout = aiohttp.ClientTimeout(total=10)
|
|
|
|
secret_val = WebsocketSecretAuthHelper.construct_secret()
|
|
try:
|
|
async with aiohttp.ClientSession(headers={'secret': secret_val}, timeout=timeout) as session:
|
|
async with session.ws_connect(uri, ssl=self.verify_ssl, heartbeat=20) as websocket:
|
|
logger.info(f"Connection from {self.name} to {self.remote_host} established.")
|
|
self.stats.record_connection_established()
|
|
self.connected = True
|
|
await self.run_connection(websocket)
|
|
except asyncio.CancelledError:
|
|
# TODO: Check if connected and disconnect
|
|
# Possibly use run_until_complete() if disconnect is async
|
|
logger.warning(f"Connection from {self.name} to {self.remote_host} cancelled.")
|
|
except client_exceptions.ClientConnectorError as e:
|
|
logger.warning(f"Connection from {self.name} to {self.remote_host} failed: '{e}'.", exc_info=True)
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Connection from {self.name} to {self.remote_host} timed out.")
|
|
except Exception as e:
|
|
# Early on, this is our canary. I'm not sure what exceptions we can really encounter.
|
|
logger.warning(f"Connection from {self.name} to {self.remote_host} failed for unknown reason: '{e}'.", exc_info=True)
|
|
else:
|
|
logger.debug(f"Connection from {self.name} to {self.remote_host} lost, but no exception was raised.")
|
|
finally:
|
|
self.connected = False
|
|
self.stats.record_connection_lost()
|
|
|
|
def start(self):
|
|
self.async_task = self.event_loop.create_task(self.connect())
|
|
return self.async_task
|
|
|
|
def cancel(self):
|
|
self.async_task.cancel()
|
|
|
|
async def run_connection(self, websocket: aiohttp.ClientWebSocketResponse):
|
|
# create a dedicated subsystem metric producer to handle local subsystem
|
|
# metrics messages
|
|
# the "metrics" group is not subscribed to in the typical fashion, so we
|
|
# just explicitly create it
|
|
producer = self.event_loop.create_task(self.run_producer("metrics", websocket, "metrics"))
|
|
self.producers["metrics"] = {"task": producer, "subscriptions": {"metrics"}}
|
|
async for msg in websocket:
|
|
self.stats.record_message_received()
|
|
|
|
if msg.type == aiohttp.WSMsgType.ERROR:
|
|
break
|
|
elif msg.type == aiohttp.WSMsgType.TEXT:
|
|
try:
|
|
payload = json.loads(msg.data)
|
|
except json.JSONDecodeError:
|
|
logmsg = "Failed to decode message from web node"
|
|
if logger.isEnabledFor(logging.DEBUG):
|
|
logmsg = "{} {}".format(logmsg, payload)
|
|
logger.warning(logmsg)
|
|
continue
|
|
|
|
if payload.get("type") == "consumer.subscribe":
|
|
for group in payload['groups']:
|
|
name = f"{self.remote_host}-{group}"
|
|
origin_channel = payload['origin_channel']
|
|
if not self.producers.get(name):
|
|
producer = self.event_loop.create_task(self.run_producer(name, websocket, group))
|
|
self.producers[name] = {"task": producer, "subscriptions": {origin_channel}}
|
|
logger.debug(f"Producer {name} started.")
|
|
else:
|
|
self.producers[name]["subscriptions"].add(origin_channel)
|
|
logger.debug(f"Connection from {self.name} to {self.remote_host} added subscription to {group}.")
|
|
|
|
if payload.get("type") == "consumer.unsubscribe":
|
|
for group in payload['groups']:
|
|
name = f"{self.remote_host}-{group}"
|
|
origin_channel = payload['origin_channel']
|
|
try:
|
|
self.producers[name]["subscriptions"].remove(origin_channel)
|
|
logger.debug(f"Unsubscribed {origin_channel} from {name}")
|
|
except KeyError:
|
|
logger.warning(f"Producer {name} not found.")
|
|
|
|
async def run_producer(self, name, websocket, group):
|
|
try:
|
|
logger.info(f"Starting producer for {name}")
|
|
|
|
consumer_channel = await self.channel_layer.new_channel()
|
|
await self.channel_layer.group_add(group, consumer_channel)
|
|
logger.debug(f"Producer {name} added to group {group} and is now awaiting messages.")
|
|
|
|
while True:
|
|
try:
|
|
msg = await asyncio.wait_for(self.channel_layer.receive(consumer_channel), timeout=10)
|
|
if not msg.get("needs_relay"):
|
|
# This is added in by emit_channel_notification(). It prevents us from looping
|
|
# in the event that we are sharing a redis with a web instance. We'll see the
|
|
# message once (it'll have needs_relay=True), we'll delete that, and then forward
|
|
# the message along. The web instance will add it back to the same channels group,
|
|
# but it won't have needs_relay=True, so we'll ignore it.
|
|
continue
|
|
|
|
# We need to copy the message because we're going to delete the needs_relay key
|
|
# and we don't want to modify the original message because other producers may
|
|
# still need to act on it. It seems weird, but it's necessary.
|
|
msg = dict(msg)
|
|
del msg["needs_relay"]
|
|
except asyncio.TimeoutError:
|
|
current_subscriptions = self.producers[name]["subscriptions"]
|
|
if len(current_subscriptions) == 0:
|
|
logger.info(f"Producer {name} has no subscribers, shutting down.")
|
|
return
|
|
|
|
continue
|
|
|
|
await websocket.send_json(wrap_broadcast_msg(group, msg))
|
|
except ConnectionResetError:
|
|
# This can be hit when a web node is scaling down and we try to write to it.
|
|
# There's really nothing to do in this case and it's a fairly typical thing to happen.
|
|
# We'll log it as debug, but it's not really a problem.
|
|
logger.debug(f"Producer {name} connection reset.")
|
|
pass
|
|
except Exception:
|
|
# Note, this is very intentional and important since we do not otherwise
|
|
# ever check the result of this future. Without this line you will not see an error if
|
|
# something goes wrong in here.
|
|
logger.exception(f"Event relay producer {name} crashed")
|
|
finally:
|
|
await self.channel_layer.group_discard(group, consumer_channel)
|
|
del self.producers[name]
|
|
|
|
|
|
class WebSocketRelayManager(object):
|
|
def __init__(self):
|
|
self.local_hostname = get_local_host()
|
|
self.relay_connections = dict()
|
|
# hostname -> ip
|
|
self.known_hosts: Dict[str, str] = dict()
|
|
|
|
async def pg_consumer(self, conn):
|
|
try:
|
|
await conn.execute("LISTEN web_heartbeet")
|
|
async for notif in conn.notifies():
|
|
if notif is not None and notif.channel == "web_heartbeet":
|
|
try:
|
|
payload = json.loads(notif.payload)
|
|
except json.JSONDecodeError:
|
|
logmsg = "Failed to decode message from pg_notify channel `web_heartbeet`"
|
|
if logger.isEnabledFor(logging.DEBUG):
|
|
logmsg = "{} {}".format(logmsg, payload)
|
|
logger.warning(logmsg)
|
|
continue
|
|
|
|
# Skip if the message comes from the same host we are running on
|
|
# In this case, we'll be sharing a redis, no need to relay.
|
|
if payload.get("hostname") == self.local_hostname:
|
|
continue
|
|
|
|
if payload.get("action") == "online":
|
|
hostname = payload["hostname"]
|
|
ip = payload["ip"]
|
|
if ip is None:
|
|
# If we don't get an IP, just try the hostname, maybe it resolves
|
|
ip = hostname
|
|
self.known_hosts[hostname] = ip
|
|
logger.debug(f"Web host {hostname} ({ip}) online heartbeat received.")
|
|
elif payload.get("action") == "offline":
|
|
hostname = payload["hostname"]
|
|
del self.known_hosts[hostname]
|
|
logger.debug(f"Web host {hostname} ({ip}) offline heartbeat received.")
|
|
except Exception as e:
|
|
# This catch-all is the same as the one above. asyncio will eat the exception
|
|
# but we want to know about it.
|
|
logger.exception(f"pg_consumer exception: {e}")
|
|
|
|
async def run(self):
|
|
event_loop = asyncio.get_running_loop()
|
|
|
|
stats_mgr = RelayWebsocketStatsManager(event_loop, self.local_hostname)
|
|
stats_mgr.start()
|
|
|
|
# Set up a pg_notify consumer for allowing web nodes to "provision" and "deprovision" themselves gracefully.
|
|
database_conf = settings.DATABASES['default']
|
|
async_conn = await psycopg.AsyncConnection.connect(
|
|
dbname=database_conf['NAME'],
|
|
host=database_conf['HOST'],
|
|
user=database_conf['USER'],
|
|
password=database_conf['PASSWORD'],
|
|
port=database_conf['PORT'],
|
|
**database_conf.get("OPTIONS", {}),
|
|
)
|
|
await async_conn.set_autocommit(True)
|
|
event_loop.create_task(self.pg_consumer(async_conn))
|
|
|
|
# Establishes a websocket connection to /websocket/relay on all API servers
|
|
while True:
|
|
# logger.info("Current known hosts: {}".format(self.known_hosts))
|
|
future_remote_hosts = self.known_hosts.keys()
|
|
current_remote_hosts = self.relay_connections.keys()
|
|
deleted_remote_hosts = set(current_remote_hosts) - set(future_remote_hosts)
|
|
new_remote_hosts = set(future_remote_hosts) - set(current_remote_hosts)
|
|
|
|
# This loop handles if we get an advertisement from a host we already know about but
|
|
# the advertisement has a different IP than we are currently connected to.
|
|
for hostname, address in self.known_hosts.items():
|
|
if hostname not in self.relay_connections:
|
|
# We've picked up a new hostname that we don't know about yet.
|
|
continue
|
|
|
|
if address != self.relay_connections[hostname].remote_host:
|
|
deleted_remote_hosts.add(hostname)
|
|
new_remote_hosts.add(hostname)
|
|
|
|
# Delete any hosts with closed connections
|
|
for hostname, relay_conn in self.relay_connections.items():
|
|
if not relay_conn.connected:
|
|
deleted_remote_hosts.add(hostname)
|
|
|
|
if deleted_remote_hosts:
|
|
logger.info(f"Removing {deleted_remote_hosts} from websocket broadcast list")
|
|
|
|
if new_remote_hosts:
|
|
logger.info(f"Adding {new_remote_hosts} to websocket broadcast list")
|
|
|
|
for h in deleted_remote_hosts:
|
|
self.relay_connections[h].cancel()
|
|
del self.relay_connections[h]
|
|
del self.known_hosts[h]
|
|
stats_mgr.delete_remote_host_stats(h)
|
|
|
|
for h in new_remote_hosts:
|
|
stats = stats_mgr.new_remote_host_stats(h)
|
|
relay_connection = WebsocketRelayConnection(name=self.local_hostname, stats=stats, remote_host=self.known_hosts[h])
|
|
relay_connection.start()
|
|
self.relay_connections[h] = relay_connection
|
|
|
|
await asyncio.sleep(settings.BROADCAST_WEBSOCKET_NEW_INSTANCE_POLL_RATE_SECONDS)
|