diff --git a/awx/main/tasks/signals.py b/awx/main/tasks/signals.py index 95610548b9..7b4e4ba47a 100644 --- a/awx/main/tasks/signals.py +++ b/awx/main/tasks/signals.py @@ -16,7 +16,9 @@ class SignalExit(Exception): class SignalState: def reset(self): self.sigterm_flag = False - self.is_active = False + self.sigint_flag = False + + self.is_active = False # for nested context managers self.original_sigterm = None self.original_sigint = None self.raise_exception = False @@ -24,23 +26,36 @@ class SignalState: def __init__(self): self.reset() - def set_flag(self, *args): - """Method to pass into the python signal.signal method to receive signals""" - self.sigterm_flag = True + def raise_if_needed(self): if self.raise_exception: self.raise_exception = False # so it is not raised a second time in error handling raise SignalExit() + def set_sigterm_flag(self, *args): + self.sigterm_flag = True + self.raise_if_needed() + + def set_sigint_flag(self, *args): + self.sigint_flag = True + self.raise_if_needed() + def connect_signals(self): self.original_sigterm = signal.getsignal(signal.SIGTERM) self.original_sigint = signal.getsignal(signal.SIGINT) - signal.signal(signal.SIGTERM, self.set_flag) - signal.signal(signal.SIGINT, self.set_flag) + signal.signal(signal.SIGTERM, self.set_sigterm_flag) + signal.signal(signal.SIGINT, self.set_sigint_flag) self.is_active = True def restore_signals(self): signal.signal(signal.SIGTERM, self.original_sigterm) signal.signal(signal.SIGINT, self.original_sigint) + # if we got a signal while context manager was active, call parent methods. + if self.sigterm_flag: + if callable(self.original_sigterm): + self.original_sigterm() + if self.sigint_flag: + if callable(self.original_sigint): + self.original_sigint() self.reset() @@ -48,7 +63,7 @@ signal_state = SignalState() def signal_callback(): - return signal_state.sigterm_flag + return bool(signal_state.sigterm_flag or signal_state.sigint_flag) def with_signal_handling(f): diff --git a/awx/main/tests/unit/tasks/test_signals.py b/awx/main/tests/unit/tasks/test_signals.py index a435b8a660..75915504c5 100644 --- a/awx/main/tests/unit/tasks/test_signals.py +++ b/awx/main/tests/unit/tasks/test_signals.py @@ -1,8 +1,43 @@ import signal +import functools from awx.main.tasks.signals import signal_state, signal_callback, with_signal_handling +def pytest_sigint(): + pytest_sigint.called_count += 1 + + +def pytest_sigterm(): + pytest_sigterm.called_count += 1 + + +def tmp_signals_for_test(func): + """ + When we run our internal signal handlers, it will call the original signal + handlers when its own work is finished. + This would crash the test runners normally, because those methods will + shut down the process. + So this is a decorator to safely replace existing signal handlers + with new signal handlers that do nothing so that tests do not crash. + """ + + @functools.wraps(func) + def wrapper(): + original_sigterm = signal.getsignal(signal.SIGTERM) + original_sigint = signal.getsignal(signal.SIGINT) + signal.signal(signal.SIGTERM, pytest_sigterm) + signal.signal(signal.SIGINT, pytest_sigint) + pytest_sigterm.called_count = 0 + pytest_sigint.called_count = 0 + func() + signal.signal(signal.SIGTERM, original_sigterm) + signal.signal(signal.SIGINT, original_sigint) + + return wrapper + + +@tmp_signals_for_test def test_outer_inner_signal_handling(): """ Even if the flag is set in the outer context, its value should persist in the inner context @@ -15,17 +50,22 @@ def test_outer_inner_signal_handling(): @with_signal_handling def f1(): assert signal_callback() is False - signal_state.set_flag() + signal_state.set_sigterm_flag() assert signal_callback() f2() original_sigterm = signal.getsignal(signal.SIGTERM) assert signal_callback() is False + assert pytest_sigterm.called_count == 0 + assert pytest_sigint.called_count == 0 f1() assert signal_callback() is False assert signal.getsignal(signal.SIGTERM) is original_sigterm + assert pytest_sigterm.called_count == 1 + assert pytest_sigint.called_count == 0 +@tmp_signals_for_test def test_inner_outer_signal_handling(): """ Even if the flag is set in the inner context, its value should persist in the outer context @@ -34,7 +74,7 @@ def test_inner_outer_signal_handling(): @with_signal_handling def f2(): assert signal_callback() is False - signal_state.set_flag() + signal_state.set_sigint_flag() assert signal_callback() @with_signal_handling @@ -45,6 +85,10 @@ def test_inner_outer_signal_handling(): original_sigterm = signal.getsignal(signal.SIGTERM) assert signal_callback() is False + assert pytest_sigterm.called_count == 0 + assert pytest_sigint.called_count == 0 f1() assert signal_callback() is False assert signal.getsignal(signal.SIGTERM) is original_sigterm + assert pytest_sigterm.called_count == 0 + assert pytest_sigint.called_count == 1