AAP-57817 Add Redis connection retry using redis-py 7.0+ built-in (#16176)

* AAP-57817 Add Redis connection retry using redis-py 7.0+ built-in mechanism

* Refactor Redis client helpers to use settings and eliminate code duplication

* Create awx/main/utils/redis.py and move Redis client functions to avoid circular imports

* Fix subsystem_metrics to share Redis connection pool between
  client and pipeline

* Cache Redis clients in RelayConsumer and RelayWebsocketStatsManager to avoid creating new connection pools on every call

* Add cap and base config

* Add Redis retry logic with exponential backoff to handle connection failures during long-running operations

* Add REDIS_BACKOFF_CAP and REDIS_BACKOFF_BASE settings to allow
  adjustment of retry timing in worst-case scenarios without code changes

* Simplify Redis retry tests by removing unnecessary reload logic
This commit is contained in:
Lila Yasin 2025-12-01 09:08:47 -05:00 committed by GitHub
parent 0d86874d5d
commit 4f41b50a09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 264 additions and 24 deletions

View File

@ -1,8 +1,6 @@
import datetime
import asyncio
import logging
import redis
import redis.asyncio
import re
from prometheus_client import (
@ -15,6 +13,7 @@ from prometheus_client import (
)
from django.conf import settings
from awx.main.utils.redis import get_redis_client, get_redis_client_async
BROADCAST_WEBSOCKET_REDIS_KEY_NAME = 'broadcast_websocket_stats'
@ -66,6 +65,8 @@ class FixedSlidingWindow:
class RelayWebsocketStatsManager:
_redis_client = None # Cached Redis client for get_stats_sync()
def __init__(self, local_hostname):
self._local_hostname = local_hostname
self._stats = dict()
@ -80,7 +81,7 @@ class RelayWebsocketStatsManager:
async def run_loop(self):
try:
redis_conn = await redis.asyncio.Redis.from_url(settings.BROKER_URL)
redis_conn = get_redis_client_async()
while True:
stats_data_str = ''.join(stat.serialize() for stat in self._stats.values())
await redis_conn.set(self._redis_key, stats_data_str)
@ -103,8 +104,10 @@ class RelayWebsocketStatsManager:
"""
Stringified verion of all the stats
"""
redis_conn = redis.Redis.from_url(settings.BROKER_URL)
stats_str = redis_conn.get(BROADCAST_WEBSOCKET_REDIS_KEY_NAME) or b''
# Reuse cached Redis client to avoid creating new connection pools on every call
if cls._redis_client is None:
cls._redis_client = get_redis_client()
stats_str = cls._redis_client.get(BROADCAST_WEBSOCKET_REDIS_KEY_NAME) or b''
return parser.text_string_to_metric_families(stats_str.decode('UTF-8'))

View File

@ -14,6 +14,7 @@ from rest_framework.request import Request
from awx.main.consumers import emit_channel_notification
from awx.main.utils import is_testing
from awx.main.utils.redis import get_redis_client
root_key = settings.SUBSYSTEM_METRICS_REDIS_KEY_PREFIX
logger = logging.getLogger('awx.main.analytics')
@ -198,8 +199,8 @@ class Metrics(MetricsNamespace):
def __init__(self, namespace, auto_pipe_execute=False, instance_name=None, metrics_have_changed=True, **kwargs):
MetricsNamespace.__init__(self, namespace)
self.pipe = redis.Redis.from_url(settings.BROKER_URL).pipeline()
self.conn = redis.Redis.from_url(settings.BROKER_URL)
self.conn = get_redis_client()
self.pipe = self.conn.pipeline()
self.last_pipe_execute = time.time()
# track if metrics have been modified since last saved to redis
# start with True so that we get an initial save to redis

View File

@ -3,7 +3,6 @@ import logging
import time
import hmac
import asyncio
import redis
from django.core.serializers.json import DjangoJSONEncoder
from django.conf import settings
@ -14,6 +13,8 @@ from channels.generic.websocket import AsyncJsonWebsocketConsumer
from channels.layers import get_channel_layer
from channels.db import database_sync_to_async
from awx.main.utils.redis import get_redis_client_async
logger = logging.getLogger('awx.main.consumers')
XRF_KEY = '_auth_user_xrf'
@ -94,6 +95,9 @@ class RelayConsumer(AsyncJsonWebsocketConsumer):
await self.channel_layer.group_add(settings.BROADCAST_WEBSOCKET_GROUP_NAME, self.channel_name)
logger.info(f"client '{self.channel_name}' joined the broadcast group.")
# Initialize Redis client once for reuse across all message handling
self._redis_conn = get_redis_client_async()
async def disconnect(self, code):
logger.info(f"client '{self.channel_name}' disconnected from the broadcast group.")
await self.channel_layer.group_discard(settings.BROADCAST_WEBSOCKET_GROUP_NAME, self.channel_name)
@ -105,8 +109,9 @@ class RelayConsumer(AsyncJsonWebsocketConsumer):
(group, message) = unwrap_broadcast_msg(data)
if group == "metrics":
message = json.loads(message['text'])
conn = redis.Redis.from_url(settings.BROKER_URL)
conn.set(settings.SUBSYSTEM_METRICS_REDIS_KEY_PREFIX + "-" + message['metrics_namespace'] + "_instance_" + message['instance'], message['metrics'])
await self._redis_conn.set(
settings.SUBSYSTEM_METRICS_REDIS_KEY_PREFIX + "-" + message['metrics_namespace'] + "_instance_" + message['instance'], message['metrics']
)
else:
await self.channel_layer.group_send(group, message)

View File

@ -2,11 +2,10 @@ import logging
import uuid
import json
from django.conf import settings
from django.db import connection
import redis
from awx.main.dispatch import get_task_queuename
from awx.main.utils.redis import get_redis_client
from . import pg_bus_conn
@ -24,7 +23,7 @@ class Control(object):
self.queuename = host or get_task_queuename()
def status(self, *args, **kwargs):
r = redis.Redis.from_url(settings.BROKER_URL)
r = get_redis_client()
if self.service == 'dispatcher':
stats = r.get(f'awx_{self.service}_statistics') or b''
return stats.decode('utf-8')

View File

@ -19,6 +19,7 @@ import redis.exceptions
from ansible_base.lib.logging.runtime import log_excess_runtime
from awx.main.utils.redis import get_redis_client
from awx.main.dispatch.pool import WorkerPool
from awx.main.dispatch.periodic import Scheduler
from awx.main.dispatch import pg_bus_conn
@ -59,7 +60,7 @@ class AWXConsumerBase(object):
if pool is None:
self.pool = WorkerPool()
self.pool.init_workers(self.worker.work_loop)
self.redis = redis.Redis.from_url(settings.BROKER_URL)
self.redis = get_redis_client()
@property
def listening_on(self):

View File

@ -15,6 +15,7 @@ import psutil
import redis
from awx.main.utils.redis import get_redis_client
from awx.main.consumers import emit_channel_notification
from awx.main.models import JobEvent, AdHocCommandEvent, ProjectUpdateEvent, InventoryUpdateEvent, SystemJobEvent, UnifiedJob
from awx.main.constants import ACTIVE_STATES
@ -72,7 +73,7 @@ class CallbackBrokerWorker(BaseWorker):
def __init__(self):
self.buff = {}
self.redis = redis.Redis.from_url(settings.BROKER_URL)
self.redis = get_redis_client()
self.subsystem_metrics = s_metrics.CallbackReceiverMetrics(auto_pipe_execute=False)
self.queue_pop = 0
self.queue_name = settings.CALLBACK_QUEUE

View File

@ -33,6 +33,7 @@ from awx.main.models.rbac import (
)
from awx.main.models.unified_jobs import UnifiedJob
from awx.main.utils.common import get_corrected_cpu, get_cpu_effective_capacity, get_corrected_memory, get_mem_effective_capacity
from awx.main.utils.redis import get_redis_client
from awx.main.models.mixins import RelatedJobsMixin, ResourceMixin
from awx.main.models.receptor_address import ReceptorAddress
@ -397,7 +398,7 @@ class Instance(HasPolicyEditsMixin, BaseModel):
try:
# if redis is down for some reason, that means we can't persist
# playbook event data; we should consider this a zero capacity event
redis.Redis.from_url(settings.BROKER_URL).ping()
get_redis_client().ping()
except redis.ConnectionError:
errors = _('Failed to connect to Redis')

View File

@ -4,11 +4,13 @@
# Python
import json
import logging
import redis
# Django
from django.conf import settings
# AWX
from awx.main.utils.redis import get_redis_client
__all__ = ['CallbackQueueDispatcher']
@ -26,7 +28,7 @@ class CallbackQueueDispatcher(object):
def __init__(self):
self.queue = getattr(settings, 'CALLBACK_QUEUE', '')
self.logger = logging.getLogger('awx.main.queue.CallbackQueueDispatcher')
self.connection = redis.Redis.from_url(settings.BROKER_URL)
self.connection = get_redis_client()
def dispatch(self, obj):
self.connection.rpush(self.queue, json.dumps(obj, cls=AnsibleJSONEncoder))

View File

@ -8,6 +8,7 @@ from channels.routing import ProtocolTypeRouter, URLRouter
from ansible_base.lib.channels.middleware import DrfAuthMiddlewareStack
from awx.main.utils.redis import get_redis_client
from . import consumers
@ -18,7 +19,7 @@ _application = None
class AWXProtocolTypeRouter(ProtocolTypeRouter):
def __init__(self, *args, **kwargs):
try:
r = redis.Redis.from_url(settings.BROKER_URL)
r = get_redis_client()
for k in r.scan_iter('asgi:*', 500):
logger.debug(f"cleaning up Redis key {k}")
r.delete(k)

View File

@ -77,15 +77,34 @@ def swagger_autogen(requests=__SWAGGER_REQUESTS__):
class FakeRedis:
def __init__(self, *args, **kwargs):
# Accept and ignore all arguments to match redis.Redis signature
pass
def keys(self, *args, **kwargs):
return []
def set(self):
def set(self, *args, **kwargs):
pass
def get(self):
def get(self, *args, **kwargs):
return None
def rpush(self, *args, **kwargs):
return 1
def blpop(self, *args, **kwargs):
return None
def delete(self, *args, **kwargs):
pass
def llen(self, *args, **kwargs):
return 0
def scan_iter(self, *args, **kwargs):
return iter([])
@classmethod
def from_url(cls, *args, **kwargs):
return cls()

View File

@ -1,6 +1,7 @@
import datetime
from unittest.mock import Mock, patch
from awx.main.analytics.broadcast_websocket import FixedSlidingWindow
from awx.main.analytics.broadcast_websocket import FixedSlidingWindow, RelayWebsocketStatsManager
from awx.main.analytics.broadcast_websocket import dt_to_seconds
@ -59,3 +60,70 @@ class TestFixedSlidingWindow:
assert 20 - i == fsw.render(self.ts(minute=1, second=i, microsecond=0)), "E. Sliding window where 1 record() should drop from the results each time"
assert 0 == fsw.render(self.ts(minute=1, second=20, microsecond=0)), "F. First second one minute after all record() calls"
class TestRelayWebsocketStatsManager:
"""Test Redis client caching in RelayWebsocketStatsManager."""
def test_get_stats_sync_caches_redis_client(self):
"""Verify get_stats_sync caches Redis client to avoid creating new connection pools."""
# Reset class variable
RelayWebsocketStatsManager._redis_client = None
mock_redis = Mock()
mock_redis.get.return_value = b''
with patch('awx.main.analytics.broadcast_websocket.get_redis_client', return_value=mock_redis) as mock_get_client:
# First call should create client
RelayWebsocketStatsManager.get_stats_sync()
assert mock_get_client.call_count == 1
# Second call should reuse cached client
RelayWebsocketStatsManager.get_stats_sync()
assert mock_get_client.call_count == 1 # Still 1, not called again
# Third call should still reuse cached client
RelayWebsocketStatsManager.get_stats_sync()
assert mock_get_client.call_count == 1
# Cleanup
RelayWebsocketStatsManager._redis_client = None
def test_get_stats_sync_returns_parsed_metrics(self):
"""Verify get_stats_sync returns parsed metric families from Redis."""
# Reset class variable
RelayWebsocketStatsManager._redis_client = None
# Sample Prometheus metrics format
sample_metrics = b'# HELP test_metric A test metric\n# TYPE test_metric gauge\ntest_metric 42\n'
mock_redis = Mock()
mock_redis.get.return_value = sample_metrics
with patch('awx.main.analytics.broadcast_websocket.get_redis_client', return_value=mock_redis):
result = list(RelayWebsocketStatsManager.get_stats_sync())
# Should return parsed metric families
assert len(result) > 0
assert mock_redis.get.called
# Cleanup
RelayWebsocketStatsManager._redis_client = None
def test_get_stats_sync_handles_empty_redis_data(self):
"""Verify get_stats_sync handles empty data from Redis gracefully."""
# Reset class variable
RelayWebsocketStatsManager._redis_client = None
mock_redis = Mock()
mock_redis.get.return_value = None # Redis returns None when key doesn't exist
with patch('awx.main.analytics.broadcast_websocket.get_redis_client', return_value=mock_redis):
result = list(RelayWebsocketStatsManager.get_stats_sync())
# Should handle empty data gracefully
assert result == []
assert mock_redis.get.called
# Cleanup
RelayWebsocketStatsManager._redis_client = None

View File

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2025 Ansible, Inc.
# All Rights Reserved
from django.test.utils import override_settings
from awx.main.utils.redis import get_redis_client, get_redis_client_async
from redis.exceptions import BusyLoadingError, ConnectionError, TimeoutError
from redis.backoff import ExponentialBackoff
class TestRedisRetryConfiguration:
"""Verify Redis retry configuration is applied to connection objects."""
def test_retry_configuration_applied_to_client(self, settings):
"""Verify all retry settings are applied to the connection pool."""
# Test sync client
client = get_redis_client()
retry = client.connection_pool.connection_kwargs['retry']
backoff = retry._backoff
retry_errors = client.connection_pool.connection_kwargs['retry_on_error']
# Assert provided values match values on the object
assert retry._retries == settings.REDIS_RETRY_COUNT == 3
assert isinstance(backoff, ExponentialBackoff)
assert backoff._base == settings.REDIS_BACKOFF_BASE == 0.5
assert backoff._cap == settings.REDIS_BACKOFF_CAP == 1.0
assert BusyLoadingError in retry_errors
assert ConnectionError in retry_errors
assert TimeoutError in retry_errors
# Test async client has same config
client_async = get_redis_client_async()
retry_async = client_async.connection_pool.connection_kwargs['retry']
backoff_async = retry_async._backoff
retry_errors_async = client_async.connection_pool.connection_kwargs['retry_on_error']
assert retry_async._retries == settings.REDIS_RETRY_COUNT
assert backoff_async._base == settings.REDIS_BACKOFF_BASE
assert backoff_async._cap == settings.REDIS_BACKOFF_CAP
assert ConnectionError in retry_errors_async
@override_settings(REDIS_RETRY_COUNT=5)
def test_override_settings_applied_to_client(self):
"""Verify override_settings changes are applied to client object."""
client = get_redis_client()
retry = client.connection_pool.connection_kwargs['retry']
assert retry._retries == 5
@override_settings(REDIS_BACKOFF_CAP=2.0, REDIS_BACKOFF_BASE=1.0)
def test_override_backoff_settings_applied_to_client(self):
"""Verify override_settings for backoff parameters are applied to client object."""
client = get_redis_client()
retry = client.connection_pool.connection_kwargs['retry']
backoff = retry._backoff
# Assert provided values match values on object
assert backoff._cap == 2.0
assert backoff._base == 1.0

View File

@ -3,6 +3,7 @@
# AWX
from awx.main.utils.common import * # noqa
from awx.main.utils.redis import get_redis_client, get_redis_client_async # noqa
from awx.main.utils.encryption import ( # noqa
get_encryption_key,
encrypt_field,

74
awx/main/utils/redis.py Normal file
View File

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2025 Ansible, Inc.
# All Rights Reserved
"""Redis client utilities with automatic retry on connection errors."""
import redis
import redis.asyncio
from django.conf import settings
from redis.backoff import ExponentialBackoff
from redis.retry import Retry
from redis.exceptions import BusyLoadingError, ConnectionError, TimeoutError
def _get_redis_pool_kwargs():
"""
Get common Redis connection pool kwargs with retry configuration.
Returns:
dict: Keyword arguments for redis.ConnectionPool.from_url()
"""
retry = Retry(ExponentialBackoff(cap=settings.REDIS_BACKOFF_CAP, base=settings.REDIS_BACKOFF_BASE), retries=settings.REDIS_RETRY_COUNT)
return {
'retry': retry,
'retry_on_error': [BusyLoadingError, ConnectionError, TimeoutError],
}
def get_redis_client():
"""
Create a Redis client with automatic retry on connection errors.
This function creates a Redis connection with built-in retry logic to handle
transient connection failures (like broken pipes, timeouts, etc.) that can occur
during long-running operations.
Based on PR feedback: https://github.com/ansible/awx/pull/16158#issuecomment-3486839154
Uses redis-py's built-in retry mechanism instead of custom retry logic.
Returns:
redis.Redis: A Redis client instance configured with retry logic
Notes:
- Uses exponential backoff with configurable retries (REDIS_RETRY_COUNT setting)
- Retries on BusyLoadingError, ConnectionError, and TimeoutError
- Requires redis-py 7.0+
"""
pool = redis.ConnectionPool.from_url(
settings.BROKER_URL,
**_get_redis_pool_kwargs(),
)
return redis.Redis(connection_pool=pool)
def get_redis_client_async():
"""
Create an async Redis client with automatic retry on connection errors.
This is the async version of get_redis_client() for use with asyncio code.
Returns:
redis.asyncio.Redis: An async Redis client instance configured with retry logic
Notes:
- Uses exponential backoff with configurable retries (REDIS_RETRY_COUNT setting)
- Retries on BusyLoadingError, ConnectionError, and TimeoutError
- Requires redis-py 7.0+
"""
pool = redis.asyncio.ConnectionPool.from_url(
settings.BROKER_URL,
**_get_redis_pool_kwargs(),
)
return redis.asyncio.Redis(connection_pool=pool)

View File

@ -424,6 +424,9 @@ DISPATCHER_MOCK_PUBLISH = False
DISPATCHERD_DEBUGGING_SOCKFILE = os.path.join(BASE_DIR, 'dispatcherd.sock')
BROKER_URL = 'unix:///var/run/redis/redis.sock'
REDIS_RETRY_COUNT = 3 # Number of retries for Redis connection errors
REDIS_BACKOFF_CAP = 1.0 # Maximum backoff delay in seconds for Redis retries
REDIS_BACKOFF_BASE = 0.5 # Base for exponential backoff calculation for Redis retries
CELERYBEAT_SCHEDULE = {
'tower_scheduler': {'task': 'awx.main.tasks.system.awx_periodic_scheduler', 'schedule': timedelta(seconds=30), 'options': {'expires': 20}},
'cluster_heartbeat': {

View File

@ -56,7 +56,7 @@ pyyaml>=6.0.2 # require packing fix for cython 3 or higher
pyzstd # otel collector log file compression library
receptorctl
sqlparse>=0.4.4 # Required by django https://github.com/ansible/awx/security/dependabot/96
redis[hiredis]
redis[hiredis]>=7.0 # requires 7.0+ for retry functionality on connection errors
requests
slack-sdk
twilio

View File

@ -433,7 +433,7 @@ pyzstd==0.18.0
# via -r /awx_devel/requirements/requirements.in
receptorctl==1.6.0
# via -r /awx_devel/requirements/requirements.in
redis[hiredis]==6.4.0
redis[hiredis]==7.0.1
# via
# -r /awx_devel/requirements/requirements.in
# channels-redis