From e873bb13045e0f488061ce675d89831198f9a405 Mon Sep 17 00:00:00 2001 From: Hao Liu <44379968+TheRealHaoLiu@users.noreply.github.com> Date: Tue, 16 Apr 2024 14:54:36 -0400 Subject: [PATCH] Fix wsrelay connection leak (#15113) - when re-establishing connection to db close old connection - re-initialize WebSocketRelayManager when restarting asyncio.run - log and ignore error in cleanup_offline_host (this might come back to bite us) - cleanup connection when WebSocketRelayManager crash --- awx/main/management/commands/run_wsrelay.py | 3 +- awx/main/wsrelay.py | 126 ++++++++++++-------- 2 files changed, 76 insertions(+), 53 deletions(-) diff --git a/awx/main/management/commands/run_wsrelay.py b/awx/main/management/commands/run_wsrelay.py index a3165cd669..ee7cbca682 100644 --- a/awx/main/management/commands/run_wsrelay.py +++ b/awx/main/management/commands/run_wsrelay.py @@ -165,11 +165,10 @@ class Command(BaseCommand): return WebsocketsMetricsServer().start() - websocket_relay_manager = WebSocketRelayManager() while True: try: - asyncio.run(websocket_relay_manager.run()) + asyncio.run(WebSocketRelayManager().run()) except KeyboardInterrupt: logger.info('Shutting down Websocket Relayer') break diff --git a/awx/main/wsrelay.py b/awx/main/wsrelay.py index fdf6d0e454..a4c94bd7e6 100644 --- a/awx/main/wsrelay.py +++ b/awx/main/wsrelay.py @@ -285,6 +285,8 @@ class WebSocketRelayManager(object): except asyncio.CancelledError: # Handle the case where the task was already cancelled by the time we got here. pass + except Exception as e: + logger.warning(f"Failed to cancel relay connection for {hostname}: {e}") del self.relay_connections[hostname] @@ -295,6 +297,8 @@ class WebSocketRelayManager(object): self.stats_mgr.delete_remote_host_stats(hostname) except KeyError: pass + except Exception as e: + logger.warning(f"Failed to delete stats for {hostname}: {e}") async def run(self): event_loop = asyncio.get_running_loop() @@ -316,57 +320,77 @@ class WebSocketRelayManager(object): task = None + # Managing the async_conn here so that we can close it if we need to restart the connection + async_conn = None + # Establishes a websocket connection to /websocket/relay on all API servers - while True: - if not task or task.done(): + try: + while True: + if not task or task.done(): + try: + # Try to close the connection if it's open + if async_conn: + try: + await async_conn.close() + except Exception as e: + logger.warning(f"Failed to close connection to database for pg_notify: {e}") + + # and re-establish the connection + async_conn = await psycopg.AsyncConnection.connect( + dbname=database_conf['NAME'], + host=database_conf['HOST'], + user=database_conf['USER'], + port=database_conf['PORT'], + **database_conf.get("OPTIONS", {}), + ) + await async_conn.set_autocommit(True) + + # before creating the task that uses the connection + task = event_loop.create_task(self.on_ws_heartbeat(async_conn), name="on_ws_heartbeat") + logger.info("Creating `on_ws_heartbeat` task in event loop.") + + except Exception as e: + logger.warning(f"Failed to connect to database for pg_notify: {e}") + + future_remote_hosts = self.known_hosts.keys() + current_remote_hosts = self.relay_connections.keys() + deleted_remote_hosts = set(current_remote_hosts) - set(future_remote_hosts) + new_remote_hosts = set(future_remote_hosts) - set(current_remote_hosts) + + # This loop handles if we get an advertisement from a host we already know about but + # the advertisement has a different IP than we are currently connected to. + for hostname, address in self.known_hosts.items(): + if hostname not in self.relay_connections: + # We've picked up a new hostname that we don't know about yet. + continue + + if address != self.relay_connections[hostname].remote_host: + deleted_remote_hosts.add(hostname) + new_remote_hosts.add(hostname) + + # Delete any hosts with closed connections + for hostname, relay_conn in self.relay_connections.items(): + if not relay_conn.connected: + deleted_remote_hosts.add(hostname) + + if deleted_remote_hosts: + logger.info(f"Removing {deleted_remote_hosts} from websocket broadcast list") + await asyncio.gather(*[self.cleanup_offline_host(h) for h in deleted_remote_hosts]) + + if new_remote_hosts: + logger.info(f"Adding {new_remote_hosts} to websocket broadcast list") + + for h in new_remote_hosts: + stats = self.stats_mgr.new_remote_host_stats(h) + relay_connection = WebsocketRelayConnection(name=self.local_hostname, stats=stats, remote_host=self.known_hosts[h]) + relay_connection.start() + self.relay_connections[h] = relay_connection + + await asyncio.sleep(settings.BROADCAST_WEBSOCKET_NEW_INSTANCE_POLL_RATE_SECONDS) + finally: + if async_conn: + logger.info("Shutting down db connection for wsrelay.") try: - async_conn = await psycopg.AsyncConnection.connect( - dbname=database_conf['NAME'], - host=database_conf['HOST'], - user=database_conf['USER'], - port=database_conf['PORT'], - **database_conf.get("OPTIONS", {}), - ) - await async_conn.set_autocommit(True) - - task = event_loop.create_task(self.on_ws_heartbeat(async_conn), name="on_ws_heartbeat") - logger.info("Creating `on_ws_heartbeat` task in event loop.") - + await async_conn.close() except Exception as e: - logger.warning(f"Failed to connect to database for pg_notify: {e}") - - future_remote_hosts = self.known_hosts.keys() - current_remote_hosts = self.relay_connections.keys() - deleted_remote_hosts = set(current_remote_hosts) - set(future_remote_hosts) - new_remote_hosts = set(future_remote_hosts) - set(current_remote_hosts) - - # This loop handles if we get an advertisement from a host we already know about but - # the advertisement has a different IP than we are currently connected to. - for hostname, address in self.known_hosts.items(): - if hostname not in self.relay_connections: - # We've picked up a new hostname that we don't know about yet. - continue - - if address != self.relay_connections[hostname].remote_host: - deleted_remote_hosts.add(hostname) - new_remote_hosts.add(hostname) - - # Delete any hosts with closed connections - for hostname, relay_conn in self.relay_connections.items(): - if not relay_conn.connected: - deleted_remote_hosts.add(hostname) - - if deleted_remote_hosts: - logger.info(f"Removing {deleted_remote_hosts} from websocket broadcast list") - await asyncio.gather(*[self.cleanup_offline_host(h) for h in deleted_remote_hosts]) - - if new_remote_hosts: - logger.info(f"Adding {new_remote_hosts} to websocket broadcast list") - - for h in new_remote_hosts: - stats = self.stats_mgr.new_remote_host_stats(h) - relay_connection = WebsocketRelayConnection(name=self.local_hostname, stats=stats, remote_host=self.known_hosts[h]) - relay_connection.start() - self.relay_connections[h] = relay_connection - - await asyncio.sleep(settings.BROADCAST_WEBSOCKET_NEW_INSTANCE_POLL_RATE_SECONDS) + logger.info(f"Failed to close connection to database for pg_notify: {e}")