Fix wsrelay connection leak (#15113)

- when re-establishing connection to db close old connection
- re-initialize WebSocketRelayManager when restarting asyncio.run
- log and ignore error in cleanup_offline_host (this might come back to bite us)
- cleanup connection when WebSocketRelayManager crash
This commit is contained in:
Hao Liu
2024-04-16 14:54:36 -04:00
committed by GitHub
parent 672f1eb745
commit e873bb1304
2 changed files with 76 additions and 53 deletions

View File

@@ -165,11 +165,10 @@ class Command(BaseCommand):
return return
WebsocketsMetricsServer().start() WebsocketsMetricsServer().start()
websocket_relay_manager = WebSocketRelayManager()
while True: while True:
try: try:
asyncio.run(websocket_relay_manager.run()) asyncio.run(WebSocketRelayManager().run())
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info('Shutting down Websocket Relayer') logger.info('Shutting down Websocket Relayer')
break break

View File

@@ -285,6 +285,8 @@ class WebSocketRelayManager(object):
except asyncio.CancelledError: except asyncio.CancelledError:
# Handle the case where the task was already cancelled by the time we got here. # Handle the case where the task was already cancelled by the time we got here.
pass pass
except Exception as e:
logger.warning(f"Failed to cancel relay connection for {hostname}: {e}")
del self.relay_connections[hostname] del self.relay_connections[hostname]
@@ -295,6 +297,8 @@ class WebSocketRelayManager(object):
self.stats_mgr.delete_remote_host_stats(hostname) self.stats_mgr.delete_remote_host_stats(hostname)
except KeyError: except KeyError:
pass pass
except Exception as e:
logger.warning(f"Failed to delete stats for {hostname}: {e}")
async def run(self): async def run(self):
event_loop = asyncio.get_running_loop() event_loop = asyncio.get_running_loop()
@@ -316,10 +320,22 @@ class WebSocketRelayManager(object):
task = None task = None
# Managing the async_conn here so that we can close it if we need to restart the connection
async_conn = None
# Establishes a websocket connection to /websocket/relay on all API servers # Establishes a websocket connection to /websocket/relay on all API servers
try:
while True: while True:
if not task or task.done(): if not task or task.done():
try: try:
# Try to close the connection if it's open
if async_conn:
try:
await async_conn.close()
except Exception as e:
logger.warning(f"Failed to close connection to database for pg_notify: {e}")
# and re-establish the connection
async_conn = await psycopg.AsyncConnection.connect( async_conn = await psycopg.AsyncConnection.connect(
dbname=database_conf['NAME'], dbname=database_conf['NAME'],
host=database_conf['HOST'], host=database_conf['HOST'],
@@ -329,6 +345,7 @@ class WebSocketRelayManager(object):
) )
await async_conn.set_autocommit(True) await async_conn.set_autocommit(True)
# before creating the task that uses the connection
task = event_loop.create_task(self.on_ws_heartbeat(async_conn), name="on_ws_heartbeat") task = event_loop.create_task(self.on_ws_heartbeat(async_conn), name="on_ws_heartbeat")
logger.info("Creating `on_ws_heartbeat` task in event loop.") logger.info("Creating `on_ws_heartbeat` task in event loop.")
@@ -370,3 +387,10 @@ class WebSocketRelayManager(object):
self.relay_connections[h] = relay_connection self.relay_connections[h] = relay_connection
await asyncio.sleep(settings.BROADCAST_WEBSOCKET_NEW_INSTANCE_POLL_RATE_SECONDS) await asyncio.sleep(settings.BROADCAST_WEBSOCKET_NEW_INSTANCE_POLL_RATE_SECONDS)
finally:
if async_conn:
logger.info("Shutting down db connection for wsrelay.")
try:
await async_conn.close()
except Exception as e:
logger.info(f"Failed to close connection to database for pg_notify: {e}")