diff --git a/awx/main/routing.py b/awx/main/routing.py index 05edd2a802..bf2c06cfe3 100644 --- a/awx/main/routing.py +++ b/awx/main/routing.py @@ -12,6 +12,7 @@ from . import consumers logger = logging.getLogger('awx.main.routing') +_application = None class AWXProtocolTypeRouter(ProtocolTypeRouter): @@ -66,11 +67,52 @@ websocket_relay_urlpatterns = [ re_path(r'websocket/relay/$', consumers.RelayConsumer.as_asgi()), ] -application = AWXProtocolTypeRouter( - { - 'websocket': MultipleURLRouterAdapter( - URLRouter(websocket_relay_urlpatterns), - DrfAuthMiddlewareStack(URLRouter(websocket_urlpatterns)), - ) - } -) + +def application_func(cls=AWXProtocolTypeRouter) -> ProtocolTypeRouter: + return cls( + { + 'websocket': MultipleURLRouterAdapter( + URLRouter(websocket_relay_urlpatterns), + DrfAuthMiddlewareStack(URLRouter(websocket_urlpatterns)), + ) + } + ) + + +def __getattr__(name: str) -> ProtocolTypeRouter: + """ + Defer instantiating application. + For testing, we just need it to NOT run on import. + + https://peps.python.org/pep-0562/#specification + + Normally, someone would get application from this module via: + from awx.main.routing import application + + and do something with the application: + application.do_something() + + What does the callstack look like when the import runs? + ... + awx.main.routing.__getattribute__(...) # <-- we don't define this so NOOP as far as we are concerned + if '__getattr__' in awx.main.routing.__dict__: # <-- this triggers the function we are in + return awx.main.routing.__dict__.__getattr__("application") + + Why isn't this function simply implemented as: + def __getattr__(name): + if not _application: + _application = application_func() + return _application + + It could. I manually tested it and it passes test_routing.py. + + But my understanding after reading the PEP-0562 specification link above is that + performance would be a bit worse due to the extra __getattribute__ calls when + we reference non-global variables. + """ + if name == "application": + globs = globals() + if not globs['_application']: + globs['_application'] = application_func() + return globs['_application'] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/awx/main/tests/functional/test_routing.py b/awx/main/tests/functional/test_routing.py new file mode 100644 index 0000000000..a9d758da2b --- /dev/null +++ b/awx/main/tests/functional/test_routing.py @@ -0,0 +1,90 @@ +import pytest + +from django.contrib.auth.models import AnonymousUser + +from channels.routing import ProtocolTypeRouter +from channels.testing.websocket import WebsocketCommunicator + + +from awx.main.consumers import WebsocketSecretAuthHelper + + +@pytest.fixture +def application(): + # code in routing hits the db on import because .. settings cache + from awx.main.routing import application_func + + yield application_func(ProtocolTypeRouter) + + +@pytest.fixture +def websocket_server_generator(application): + def fn(endpoint): + return WebsocketCommunicator(application, endpoint) + + return fn + + +@pytest.mark.asyncio +@pytest.mark.django_db +class TestWebsocketRelay: + @pytest.fixture + def websocket_relay_secret_generator(self, settings): + def fn(secret, set_broadcast_websocket_secret=False): + secret_backup = settings.BROADCAST_WEBSOCKET_SECRET + settings.BROADCAST_WEBSOCKET_SECRET = 'foobar' + res = ('secret'.encode('utf-8'), WebsocketSecretAuthHelper.construct_secret().encode('utf-8')) + if set_broadcast_websocket_secret is False: + settings.BROADCAST_WEBSOCKET_SECRET = secret_backup + return res + + return fn + + @pytest.fixture + def websocket_relay_secret(self, settings, websocket_relay_secret_generator): + return websocket_relay_secret_generator('foobar', set_broadcast_websocket_secret=True) + + async def test_authorized(self, websocket_server_generator, websocket_relay_secret): + server = websocket_server_generator('/websocket/relay/') + + server.scope['headers'] = (websocket_relay_secret,) + connected, _ = await server.connect() + assert connected is True + + async def test_not_authorized(self, websocket_server_generator): + server = websocket_server_generator('/websocket/relay/') + connected, _ = await server.connect() + assert connected is False, "Connection to the relay websocket without auth. We expected the client to be denied." + + async def test_wrong_secret(self, websocket_server_generator, websocket_relay_secret_generator): + server = websocket_server_generator('/websocket/relay/') + + server.scope['headers'] = (websocket_relay_secret_generator('foobar', set_broadcast_websocket_secret=False),) + connected, _ = await server.connect() + assert connected is False + + +@pytest.mark.asyncio +@pytest.mark.django_db +class TestWebsocketEventConsumer: + async def test_unauthorized_anonymous(self, websocket_server_generator): + server = websocket_server_generator('/websocket/') + + server.scope['user'] = AnonymousUser() + connected, _ = await server.connect() + assert connected is False, "Anonymous user should NOT be allowed to login." + + @pytest.mark.skip(reason="Ran out of coding time.") + async def test_authorized(self, websocket_server_generator, application, admin): + server = websocket_server_generator('/websocket/') + + """ + I ran out of time. Here is what I was thinking ... + Inject a valid session into the cookies in the header + + server.scope['headers'] = ( + (b'cookie', ...), + ) + """ + connected, _ = await server.connect() + assert connected is True, "User should be allowed in via cookies auth via a session key in the cookies"