diff --git a/awx/api/generics.py b/awx/api/generics.py index 2597e2452a..024b33b79d 100644 --- a/awx/api/generics.py +++ b/awx/api/generics.py @@ -36,6 +36,7 @@ from ansible_base.lib.utils.models import get_all_field_names from ansible_base.lib.utils.requests import get_remote_host from ansible_base.rbac.models import RoleEvaluation, RoleDefinition from ansible_base.rbac.permission_registry import permission_registry +from ansible_base.jwt_consumer.common.util import validate_x_trusted_proxy_header # AWX from awx.main.models import UnifiedJob, UnifiedJobTemplate, User, Role, Credential, WorkflowJobTemplateNode, WorkflowApprovalTemplate @@ -43,6 +44,7 @@ from awx.main.models.rbac import give_creator_permissions from awx.main.access import optimize_queryset from awx.main.utils import camelcase_to_underscore, get_search_fields, getattrd, get_object_or_400, decrypt_field, get_awx_version from awx.main.utils.licensing import server_product_name +from awx.main.utils.proxy import is_proxy_in_headers, delete_headers from awx.main.views import ApiErrorView from awx.api.serializers import ResourceAccessListElementSerializer, CopySerializer from awx.api.versioning import URLPathVersioning @@ -153,22 +155,23 @@ class APIView(views.APIView): Store the Django REST Framework Request object as an attribute on the normal Django request, store time the request started. """ + remote_headers = ['REMOTE_ADDR', 'REMOTE_HOST'] + self.time_started = time.time() if getattr(settings, 'SQL_DEBUG', False): self.queries_before = len(connection.queries) + if 'HTTP_X_TRUSTED_PROXY' in request.environ: + if validate_x_trusted_proxy_header(request.environ['HTTP_X_TRUSTED_PROXY']): + remote_headers = settings.REMOTE_HOST_HEADERS + else: + logger.warning("Request appeared to be a trusted upstream proxy but failed to provide a matching shared secret.") + # If there are any custom headers in REMOTE_HOST_HEADERS, make sure # they respect the allowed proxy list - if all( - [ - settings.PROXY_IP_ALLOWED_LIST, - request.environ.get('REMOTE_ADDR') not in settings.PROXY_IP_ALLOWED_LIST, - request.environ.get('REMOTE_HOST') not in settings.PROXY_IP_ALLOWED_LIST, - ] - ): - for custom_header in settings.REMOTE_HOST_HEADERS: - if custom_header.startswith('HTTP_'): - request.environ.pop(custom_header, None) + if settings.PROXY_IP_ALLOWED_LIST: + if not is_proxy_in_headers(self.request, settings.PROXY_IP_ALLOWED_LIST, remote_headers): + delete_headers(request, settings.REMOTE_HOST_HEADERS) drf_request = super(APIView, self).initialize_request(request, *args, **kwargs) request.drf_request = drf_request diff --git a/awx/api/views/__init__.py b/awx/api/views/__init__.py index 92db0f87f2..fe116afb98 100644 --- a/awx/api/views/__init__.py +++ b/awx/api/views/__init__.py @@ -61,6 +61,7 @@ import pytz from wsgiref.util import FileWrapper # django-ansible-base +from ansible_base.lib.utils.requests import get_remote_hosts from ansible_base.rbac.models import RoleEvaluation, ObjectRole from ansible_base.resource_registry.shared_types import OrganizationType, TeamType, UserType @@ -2770,12 +2771,7 @@ class JobTemplateCallback(GenericAPIView): host for the current request. """ # Find the list of remote host names/IPs to check. - remote_hosts = set() - for header in settings.REMOTE_HOST_HEADERS: - for value in self.request.META.get(header, '').split(','): - value = value.strip() - if value: - remote_hosts.add(value) + remote_hosts = set(get_remote_hosts(self.request)) # Add the reverse lookup of IP addresses. for rh in list(remote_hosts): try: diff --git a/awx/main/tests/functional/api/test_generic.py b/awx/main/tests/functional/api/test_generic.py index 814ee4d1bc..87571c8eea 100644 --- a/awx/main/tests/functional/api/test_generic.py +++ b/awx/main/tests/functional/api/test_generic.py @@ -1,4 +1,5 @@ import pytest +from unittest import mock from awx.api.versioning import reverse @@ -52,21 +53,6 @@ def test_proxy_ip_allowed(get, patch, admin): assert middleware.environ['HTTP_X_FROM_THE_LOAD_BALANCER'] == 'some-actual-ip' -@pytest.mark.django_db -def test_x_trusted_proxy(get, patch, admin, rsa_keypair): # noqa: F811 - url = reverse('api:setting_singleton_detail', kwargs={'category_slug': 'system'}) - patch(url, user=admin, data={'REMOTE_HOST_HEADERS': ['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST']}) - - # Invalid x_trusted_proxy value SHOULD result in sensitive headers deleted - middleware = HeaderTrackingMiddleware() - headers = { - 'HTTP_X_TRUSTED_PROXY': generate_x_trusted_proxy_header(rsa_keypair.private), - 'HTTP_X_FROM_THE_LOAD_BALANCER': 'some-actual-ip', - } - get(url, user=admin, middleware=middleware, **headers) - assert middleware.environ['HTTP_X_FROM_THE_LOAD_BALANCER'] == 'some-actual-ip' - - @pytest.mark.django_db class TestTrustedProxyAllowListIntegration: @pytest.fixture @@ -81,23 +67,25 @@ class TestTrustedProxyAllowListIntegration: return HeaderTrackingMiddleware() def test_x_trusted_proxy_valid_signature(self, get, admin, rsa_keypair, url, middleware): # noqa: F811 - # Invalid x_trusted_proxy value SHOULD result in sensitive headers deleted + # Headers should NOT get deleted headers = { 'HTTP_X_TRUSTED_PROXY': generate_x_trusted_proxy_header(rsa_keypair.private), 'HTTP_X_FROM_THE_LOAD_BALANCER': 'some-actual-ip', } - with override_settings(ANSIBLE_BASE_JWT_KEY=rsa_keypair.public): - get(url, user=admin, middleware=middleware, **headers) + with mock.patch('ansible_base.jwt_consumer.common.cache.JWTCache.get_key_from_cache', lambda self: None): + with override_settings(ANSIBLE_BASE_JWT_KEY=rsa_keypair.public, PROXY_IP_ALLOWED_LIST=[]): + get(url, user=admin, middleware=middleware, **headers) assert middleware.environ['HTTP_X_FROM_THE_LOAD_BALANCER'] == 'some-actual-ip' - def test_x_trusted_proxy_invalid_signature(self, get, admin, url, middleware): - # Invalid x_trusted_proxy value SHOULD result in sensitive headers deleted + def test_x_trusted_proxy_invalid_signature(self, get, admin, url, patch, middleware): + # Headers should NOT get deleted headers = { 'HTTP_X_TRUSTED_PROXY': 'DEAD-BEEF', 'HTTP_X_FROM_THE_LOAD_BALANCER': 'some-actual-ip', } - get(url, user=admin, middleware=middleware, **headers) - assert 'HTTP_X_FROM_THE_LOAD_BALANCER' not in middleware.environ + with override_settings(PROXY_IP_ALLOWED_LIST=[]): + get(url, user=admin, middleware=middleware, **headers) + assert middleware.environ['HTTP_X_FROM_THE_LOAD_BALANCER'] == 'some-actual-ip' def test_x_trusted_proxy_invalid_signature_valid_proxy(self, get, admin, url, middleware): # A valid explicit proxy SHOULD result in sensitive headers NOT being deleted, regardless of the trusted proxy signature results diff --git a/awx/main/tests/functional/api/test_job_template.py b/awx/main/tests/functional/api/test_job_template.py index 3e154766f8..4cdf43d37b 100644 --- a/awx/main/tests/functional/api/test_job_template.py +++ b/awx/main/tests/functional/api/test_job_template.py @@ -1,4 +1,5 @@ import pytest +from unittest import mock # AWX from awx.api.serializers import JobTemplateSerializer @@ -8,10 +9,15 @@ from awx.main.migrations import _save_password_keys as save_password_keys # Django from django.apps import apps +from django.test.utils import override_settings # DRF from rest_framework.exceptions import ValidationError +# DAB +from ansible_base.jwt_consumer.common.util import generate_x_trusted_proxy_header +from ansible_base.lib.testing.fixtures import rsa_keypair_factory, rsa_keypair # noqa: F401; pylint: disable=unused-import + @pytest.mark.django_db @pytest.mark.parametrize( @@ -369,3 +375,113 @@ def test_job_template_missing_inventory(project, inventory, admin_user, post): ) assert r.status_code == 400 assert "Cannot start automatically, an inventory is required." in str(r.data) + + +@pytest.mark.django_db +class TestJobTemplateCallbackProxyIntegration: + """ + Test the interaction of provision job template callback feature and: + settings.PROXY_IP_ALLOWED_LIST + x-trusted-proxy http header + """ + + @pytest.fixture + def job_template(self, inventory, project): + jt = JobTemplate.objects.create(name='test-jt', inventory=inventory, project=project, playbook='helloworld.yml', host_config_key='abcd') + return jt + + @override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=['my.proxy.example.org']) + def test_host_not_found(self, job_template, admin_user, post, rsa_keypair): # noqa: F811 + job_template.inventory.hosts.create(name='foobar') + + headers = { + 'HTTP_X_FROM_THE_LOAD_BALANCER': 'baz', + 'REMOTE_HOST': 'baz', + 'REMOTE_ADDR': 'baz', + } + r = post( + url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), data={'host_config_key': 'abcd'}, user=admin_user, expect=400, **headers + ) + assert r.data['msg'] == 'No matching host could be found!' + + @pytest.mark.parametrize( + 'headers, expected', + ( + pytest.param( + { + 'HTTP_X_FROM_THE_LOAD_BALANCER': 'foobar', + 'REMOTE_HOST': 'my.proxy.example.org', + }, + 201, + ), + pytest.param( + { + 'HTTP_X_FROM_THE_LOAD_BALANCER': 'foobar', + 'REMOTE_HOST': 'not-my-proxy.org', + }, + 400, + ), + ), + ) + @override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=['my.proxy.example.org']) + def test_proxy_ip_allowed_list(self, job_template, admin_user, post, headers, expected): # noqa: F811 + job_template.inventory.hosts.create(name='my.proxy.example.org') + + post( + url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), + data={'host_config_key': 'abcd'}, + user=admin_user, + expect=expected, + **headers + ) + + @override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=[]) + def test_no_proxy_trust_all_headers(self, job_template, admin_user, post): + job_template.inventory.hosts.create(name='foobar') + + headers = { + 'HTTP_X_FROM_THE_LOAD_BALANCER': 'foobar', + 'REMOTE_ADDR': 'bar', + 'REMOTE_HOST': 'baz', + } + post(url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), data={'host_config_key': 'abcd'}, user=admin_user, expect=201, **headers) + + @override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=['my.proxy.example.org']) + def test_trusted_proxy(self, job_template, admin_user, post, rsa_keypair): # noqa: F811 + job_template.inventory.hosts.create(name='foobar') + + headers = { + 'HTTP_X_TRUSTED_PROXY': generate_x_trusted_proxy_header(rsa_keypair.private), + 'HTTP_X_FROM_THE_LOAD_BALANCER': 'foobar, my.proxy.example.org', + } + + with mock.patch('ansible_base.jwt_consumer.common.cache.JWTCache.get_key_from_cache', lambda self: None): + with override_settings(ANSIBLE_BASE_JWT_KEY=rsa_keypair.public): + post( + url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), + data={'host_config_key': 'abcd'}, + user=admin_user, + expect=201, + **headers + ) + + @override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=['my.proxy.example.org']) + def test_trusted_proxy_host_not_found(self, job_template, admin_user, post, rsa_keypair): # noqa: F811 + job_template.inventory.hosts.create(name='foobar') + + headers = { + 'HTTP_X_TRUSTED_PROXY': generate_x_trusted_proxy_header(rsa_keypair.private), + 'HTTP_X_FROM_THE_LOAD_BALANCER': 'baz, my.proxy.example.org', + 'REMOTE_ADDR': 'bar', + 'REMOTE_HOST': 'baz', + } + + with mock.patch('ansible_base.jwt_consumer.common.cache.JWTCache.get_key_from_cache', lambda self: None): + with override_settings(ANSIBLE_BASE_JWT_KEY=rsa_keypair.public): + post( + url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), + data={'host_config_key': 'abcd'}, + user=admin_user, + expect=400, + **headers + ) diff --git a/awx/main/utils/proxy.py b/awx/main/utils/proxy.py new file mode 100644 index 0000000000..1f96455ed0 --- /dev/null +++ b/awx/main/utils/proxy.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Ansible, Inc. +# All Rights Reserved. + + +# DRF +from rest_framework.request import Request + + +""" +Note that these methods operate on request.environ. This data is from uwsgi. +It is the source data from which request.headers (read-only) is constructed. +""" + + +def is_proxy_in_headers(request: Request, proxy_list: list[str], headers: list[str]) -> bool: + remote_hosts = set() + + for header in headers: + for value in request.environ.get(header, '').split(','): + value = value.strip() + if value: + remote_hosts.add(value) + + return bool(remote_hosts.intersection(set(proxy_list))) + + +def delete_headers(request: Request, headers: list[str]): + for header in headers: + if header.startswith('HTTP_'): + request.environ.pop(header, None)