diff --git a/awx/main/wsrelay.py b/awx/main/wsrelay.py index 1f97690c6d..735386a48a 100644 --- a/awx/main/wsrelay.py +++ b/awx/main/wsrelay.py @@ -12,6 +12,8 @@ from channels.db import database_sync_to_async from django.conf import settings from django.apps import apps +import psycopg + from awx.main.analytics.broadcast_websocket import ( RelayWebsocketStats, RelayWebsocketStatsManager, @@ -20,27 +22,12 @@ import awx.main.analytics.subsystem_metrics as s_metrics logger = logging.getLogger('awx.main.wsrelay') - def wrap_broadcast_msg(group, message: str): # TODO: Maybe wrap as "group","message" so that we don't need to # encode/decode as json. return dict(group=group, message=message) -@sync_to_async -def get_broadcast_hosts(): - Instance = apps.get_model('main', 'Instance') - instances = ( - Instance.objects.exclude(hostname=Instance.objects.my_hostname()) - .exclude(node_type='execution') - .exclude(node_type='hop') - .order_by('hostname') - .values('hostname', 'ip_address') - .distinct() - ) - return {i['hostname']: i['ip_address'] or i['hostname'] for i in instances} - - def get_local_host(): Instance = apps.get_model('main', 'Instance') return Instance.objects.my_hostname() @@ -198,41 +185,89 @@ class WebsocketRelayConnection: class WebSocketRelayManager(object): def __init__(self): - self.relay_connections = dict() self.local_hostname = get_local_host() - self.event_loop = asyncio.get_event_loop() - self.stats_mgr = RelayWebsocketStatsManager(self.event_loop, self.local_hostname) + self.relay_connections = dict() + # hostname -> ip + self.known_hosts: Dict[str, str] = dict() + + async def pg_consumer(self, conn): + try: + await conn.execute("LISTEN wsrelay_rx_from_web") + async for notif in conn.notifies(): + if notif is not None and notif.channel == "wsrelay_rx_from_web": + try: + payload = json.loads(notif.payload) + except json.JSONDecodeError: + logmsg = "Failed to decode message from pg_notify channel `wsrelay_rx_from_web`" + if logger.isEnabledFor(logging.DEBUG): + logmsg = "{} {}".format(logmsg, payload) + logger.warning(logmsg) + continue + + if payload.get("action") == "online": + hostname = payload["hostname"] + ip = payload["ip"] + self.known_hosts[hostname] = ip + logger.info(f"Web host {hostname} ({ip}) is online.") + elif payload.get("action") == "offline": + hostname = payload["hostname"] + del self.known_hosts[hostname] + logger.info(f"Web host {host} ({ip}) is offline.") + except Exception as e: + # This catch-all is the same as the one above. asyncio will NOT log exceptions anywhere, so we need to do so ourselves. + logger.exception(f"pg_consumer exception") async def run(self): - self.stats_mgr.start() + event_loop = asyncio.get_running_loop() + + stats_mgr = RelayWebsocketStatsManager(event_loop, self.local_hostname) + stats_mgr.start() + + # Set up a pg_notify consumer for allowing web nodes to "provision" and "deprovision" themselves gracefully. + database_conf = settings.DATABASES['default'] + async_conn = await psycopg.AsyncConnection.connect( + dbname=database_conf['NAME'], + host=database_conf['HOST'], + user=database_conf['USER'], + password=database_conf['PASSWORD'], + port=database_conf['PORT'], + **database_conf.get("OPTIONS", {}), + ) + await async_conn.set_autocommit(True) + event_loop.create_task(self.pg_consumer(async_conn)) # Establishes a websocket connection to /websocket/relay on all API servers while True: - known_hosts = await get_broadcast_hosts() - future_remote_hosts = known_hosts.keys() + logger.info("Current known hosts: {}".format(self.known_hosts)) + future_remote_hosts = self.known_hosts.keys() current_remote_hosts = self.relay_connections.keys() deleted_remote_hosts = set(current_remote_hosts) - set(future_remote_hosts) new_remote_hosts = set(future_remote_hosts) - set(current_remote_hosts) remote_addresses = {k: v.remote_host for k, v in self.relay_connections.items()} - for hostname, address in known_hosts.items(): + for hostname, address in self.known_hosts.items(): if hostname in self.relay_connections and address != remote_addresses[hostname]: deleted_remote_hosts.add(hostname) new_remote_hosts.add(hostname) if deleted_remote_hosts: - logger.warning(f"Removing {deleted_remote_hosts} from websocket broadcast list") + logger.info(f"Removing {deleted_remote_hosts} from websocket broadcast list") + if new_remote_hosts: - logger.warning(f"Adding {new_remote_hosts} to websocket broadcast list") + logger.info(f"Adding {new_remote_hosts} to websocket broadcast list") for h in deleted_remote_hosts: self.relay_connections[h].cancel() del self.relay_connections[h] - self.stats_mgr.delete_remote_host_stats(h) + stats_mgr.delete_remote_host_stats(h) + logger.error(f"New remote hosts: {new_remote_hosts}") for h in new_remote_hosts: - stats = self.stats_mgr.new_remote_host_stats(h) - relay_connection = WebsocketRelayConnection(name=self.local_hostname, stats=stats, remote_host=known_hosts[h]) + logger.error("we are here once") + stats = stats_mgr.new_remote_host_stats(h) + logger.error("but now we are not?") + logger.info(f"Starting relay connection to {h}") + relay_connection = WebsocketRelayConnection(name=self.local_hostname, stats=stats, remote_host=self.known_hosts[h]) relay_connection.start() self.relay_connections[h] = relay_connection diff --git a/requirements/requirements.in b/requirements/requirements.in index e66ce702cc..e7d7722f28 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -36,6 +36,7 @@ openshift pexpect==4.7.0 # see library notes prometheus_client psycopg2 +psycopg psutil pygerduty pyparsing==2.4.6 # Upgrading to v3 of pyparsing introduce errors on smart host filtering: Expected 'or' term, found 'or' (at char 15), (line:1, col:16) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index a8bd03a801..f45e0acb87 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -265,6 +265,8 @@ prometheus-client==0.15.0 # via -r /awx_devel/requirements/requirements.in psutil==5.9.4 # via -r /awx_devel/requirements/requirements.in +psycopg==3.1.4 + # via -r /awx_devel/requirements/requirements.in psycopg2==2.9.5 # via -r /awx_devel/requirements/requirements.in ptyprocess==0.7.0 @@ -425,7 +427,7 @@ txaio==22.2.1 typing-extensions==4.4.0 # via # azure-core - # pydantic + # psycopg # setuptools-rust # setuptools-scm # twisted