Files
awx/awx/main/tasks/signals.py
Alan Rominger fc0b58fd04 Fix bug that prevented dispatcher exit with downed DB (#14469)
* Separate handling of original sitTERM and sigINT
2023-10-26 14:34:25 -04:00

87 lines
2.4 KiB
Python

import signal
import functools
import logging
logger = logging.getLogger('awx.main.tasks.signals')
__all__ = ['with_signal_handling', 'signal_callback']
class SignalExit(Exception):
pass
class SignalState:
def reset(self):
self.sigterm_flag = False
self.sigint_flag = False
self.is_active = False # for nested context managers
self.original_sigterm = None
self.original_sigint = None
self.raise_exception = False
def __init__(self):
self.reset()
def raise_if_needed(self):
if self.raise_exception:
self.raise_exception = False # so it is not raised a second time in error handling
raise SignalExit()
def set_sigterm_flag(self, *args):
self.sigterm_flag = True
self.raise_if_needed()
def set_sigint_flag(self, *args):
self.sigint_flag = True
self.raise_if_needed()
def connect_signals(self):
self.original_sigterm = signal.getsignal(signal.SIGTERM)
self.original_sigint = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGTERM, self.set_sigterm_flag)
signal.signal(signal.SIGINT, self.set_sigint_flag)
self.is_active = True
def restore_signals(self):
signal.signal(signal.SIGTERM, self.original_sigterm)
signal.signal(signal.SIGINT, self.original_sigint)
# if we got a signal while context manager was active, call parent methods.
if self.sigterm_flag:
if callable(self.original_sigterm):
self.original_sigterm()
if self.sigint_flag:
if callable(self.original_sigint):
self.original_sigint()
self.reset()
signal_state = SignalState()
def signal_callback():
return bool(signal_state.sigterm_flag or signal_state.sigint_flag)
def with_signal_handling(f):
"""
Change signal handling to make signal_callback return True in event of SIGTERM or SIGINT.
"""
@functools.wraps(f)
def _wrapped(*args, **kwargs):
try:
this_is_outermost_caller = False
if not signal_state.is_active:
signal_state.connect_signals()
this_is_outermost_caller = True
return f(*args, **kwargs)
finally:
if this_is_outermost_caller:
signal_state.restore_signals()
return _wrapped