diff --git a/awxkit/awxkit/awx/utils.py b/awxkit/awxkit/awx/utils.py index 845f95d4ad..6700a37d8f 100644 --- a/awxkit/awxkit/awx/utils.py +++ b/awxkit/awxkit/awx/utils.py @@ -85,15 +85,23 @@ def as_user(v, username, password=None): if config.use_sessions: session_id = None domain = None + cookie_name = connection.session_cookie_name # requests doesn't provide interface for retrieving # domain segregated cookies other than iterating. for cookie in connection.session.cookies: - if cookie.name == connection.session_cookie_name: + if cookie.name == cookie_name: session_id = cookie.value domain = cookie.domain break + if session_id is None and cookie_name != 'gateway_sessionid': + for cookie in connection.session.cookies: + if cookie.name == 'gateway_sessionid': + session_id = cookie.value + domain = cookie.domain + cookie_name = 'gateway_sessionid' + break if session_id: - del connection.session.cookies[connection.session_cookie_name] + del connection.session.cookies[cookie_name] kwargs = connection.get_session_requirements() else: previous_auth = connection.session.auth @@ -102,9 +110,11 @@ def as_user(v, username, password=None): yield finally: if config.use_sessions: - del connection.session.cookies[connection.session_cookie_name] + for name in {connection.session_cookie_name, cookie_name}: + with suppress(KeyError): + del connection.session.cookies[name] if session_id: - connection.session.cookies.set(connection.session_cookie_name, session_id, domain=domain) + connection.session.cookies.set(cookie_name, session_id, domain=domain) else: connection.session.auth = previous_auth diff --git a/awxkit/test/test_as_user.py b/awxkit/test/test_as_user.py new file mode 100644 index 0000000000..64d79e8583 --- /dev/null +++ b/awxkit/test/test_as_user.py @@ -0,0 +1,151 @@ +from http.cookiejar import Cookie +from unittest import mock + +import pytest + +from awxkit.api.client import Connection +from awxkit.awx.utils import as_user +from awxkit.config import config + + +def _make_cookie(name, value, domain='.example.com'): + return Cookie( + version=0, + name=name, + value=value, + port=None, + port_specified=False, + domain=domain, + domain_specified=True, + domain_initial_dot=True, + path='/', + path_specified=True, + secure=False, + expires=None, + discard=True, + comment=None, + comment_url=None, + rest={}, + rfc2109=False, + ) + + +class FakeCookieJar: + def __init__(self): + self._cookies = {} + + def __iter__(self): + return iter(self._cookies.values()) + + def get(self, name, default=None): + c = self._cookies.get(name) + return c.value if c else default + + def set(self, name, value, domain='.example.com'): + self._cookies[name] = _make_cookie(name, value, domain) + + def __delitem__(self, name): + del self._cookies[name] + + +@pytest.fixture +def connection(): + conn = mock.MagicMock(spec=Connection) + conn.session = mock.MagicMock() + conn.session.cookies = FakeCookieJar() + conn.session_cookie_name = 'sessionid' + conn.get_session_requirements.return_value = {'next': '/api/controller/'} + yield conn + + +class TestAsUserSessionAuth: + """Tests for as_user() with session-based authentication.""" + + def setup_method(self): + self._orig = config.use_sessions + config.use_sessions = True + + def teardown_method(self): + config.use_sessions = self._orig + + def test_swaps_sessionid_cookie(self, connection): + connection.session.cookies.set('sessionid', 'admin_session') + + with as_user(connection, 'testuser', 'testpass'): + connection.login.assert_called_once_with('testuser', 'testpass', next='/api/controller/') + + assert connection.session.cookies.get('sessionid') == 'admin_session' + + def test_gateway_sessionid_fallback(self, connection): + """When session_cookie_name is 'sessionid' but actual cookie is 'gateway_sessionid', + as_user() should find and swap the gateway cookie.""" + connection.session.cookies.set('gateway_sessionid', 'admin_gw_session') + + with as_user(connection, 'testuser', 'testpass'): + connection.login.assert_called_once_with('testuser', 'testpass', next='/api/controller/') + assert connection.session.cookies.get('gateway_sessionid') is None + + assert connection.session.cookies.get('gateway_sessionid') == 'admin_gw_session' + + def test_gateway_fallback_not_triggered_when_sessionid_exists(self, connection): + """When sessionid cookie exists, gateway_sessionid fallback should not trigger.""" + connection.session.cookies.set('sessionid', 'admin_session') + connection.session.cookies.set('gateway_sessionid', 'admin_gw_session') + + with as_user(connection, 'testuser', 'testpass'): + pass + + assert connection.session.cookies.get('sessionid') == 'admin_session' + assert connection.session.cookies.get('gateway_sessionid') == 'admin_gw_session' + + def test_accepts_user_object(self, connection): + from awxkit.api import User + + user = mock.MagicMock(spec=User) + user.username = 'bob' + user.password = 'secret' + connection.session.cookies.set('sessionid', 'admin_session') + + with as_user(connection, user): + connection.login.assert_called_once_with('bob', 'secret', next='/api/controller/') + + def test_restores_gateway_cookie_after_exception(self, connection): + connection.session.cookies.set('gateway_sessionid', 'admin_gw_session') + + with pytest.raises(RuntimeError): + with as_user(connection, 'testuser', 'testpass'): + raise RuntimeError('boom') + + assert connection.session.cookies.get('gateway_sessionid') == 'admin_gw_session' + + def test_no_session_cookie_at_all(self, connection): + with as_user(connection, 'testuser', 'testpass'): + connection.login.assert_called_once() + + +class TestAsUserBasicAuth: + """Tests for as_user() with basic authentication.""" + + def setup_method(self): + self._orig = config.use_sessions + config.use_sessions = False + + def teardown_method(self): + config.use_sessions = self._orig + + def test_swaps_basic_auth(self, connection): + connection.session.auth = ('admin', 'adminpass') + + with as_user(connection, 'testuser', 'testpass'): + connection.login.assert_called_once_with('testuser', 'testpass') + + assert connection.session.auth == ('admin', 'adminpass') + + def test_restores_basic_auth_after_exception(self, connection): + connection.session.auth = ('admin', 'adminpass') + + with pytest.raises(RuntimeError): + with as_user(connection, 'testuser', 'testpass'): + raise RuntimeError('boom') + + assert connection.session.auth == ('admin', 'adminpass')