Add support for x-trusted-proxy

* Increase the surface area of the set of headers that the proxy list
  feature looks at for the remote proxy IF x-trusted-proxy is valid.
This commit is contained in:
Chris Meyers
2024-06-17 14:45:47 -04:00
committed by Chris Meyers
parent b93aa20362
commit 8645fe5c57
5 changed files with 171 additions and 38 deletions

View File

@@ -36,6 +36,7 @@ from ansible_base.lib.utils.models import get_all_field_names
from ansible_base.lib.utils.requests import get_remote_host from ansible_base.lib.utils.requests import get_remote_host
from ansible_base.rbac.models import RoleEvaluation, RoleDefinition from ansible_base.rbac.models import RoleEvaluation, RoleDefinition
from ansible_base.rbac.permission_registry import permission_registry from ansible_base.rbac.permission_registry import permission_registry
from ansible_base.jwt_consumer.common.util import validate_x_trusted_proxy_header
# AWX # AWX
from awx.main.models import UnifiedJob, UnifiedJobTemplate, User, Role, Credential, WorkflowJobTemplateNode, WorkflowApprovalTemplate from awx.main.models import UnifiedJob, UnifiedJobTemplate, User, Role, Credential, WorkflowJobTemplateNode, WorkflowApprovalTemplate
@@ -43,6 +44,7 @@ from awx.main.models.rbac import give_creator_permissions
from awx.main.access import optimize_queryset from awx.main.access import optimize_queryset
from awx.main.utils import camelcase_to_underscore, get_search_fields, getattrd, get_object_or_400, decrypt_field, get_awx_version from awx.main.utils import camelcase_to_underscore, get_search_fields, getattrd, get_object_or_400, decrypt_field, get_awx_version
from awx.main.utils.licensing import server_product_name from awx.main.utils.licensing import server_product_name
from awx.main.utils.proxy import is_proxy_in_headers, delete_headers
from awx.main.views import ApiErrorView from awx.main.views import ApiErrorView
from awx.api.serializers import ResourceAccessListElementSerializer, CopySerializer from awx.api.serializers import ResourceAccessListElementSerializer, CopySerializer
from awx.api.versioning import URLPathVersioning from awx.api.versioning import URLPathVersioning
@@ -153,22 +155,23 @@ class APIView(views.APIView):
Store the Django REST Framework Request object as an attribute on the Store the Django REST Framework Request object as an attribute on the
normal Django request, store time the request started. normal Django request, store time the request started.
""" """
remote_headers = ['REMOTE_ADDR', 'REMOTE_HOST']
self.time_started = time.time() self.time_started = time.time()
if getattr(settings, 'SQL_DEBUG', False): if getattr(settings, 'SQL_DEBUG', False):
self.queries_before = len(connection.queries) self.queries_before = len(connection.queries)
if 'HTTP_X_TRUSTED_PROXY' in request.environ:
if validate_x_trusted_proxy_header(request.environ['HTTP_X_TRUSTED_PROXY']):
remote_headers = settings.REMOTE_HOST_HEADERS
else:
logger.warning("Request appeared to be a trusted upstream proxy but failed to provide a matching shared secret.")
# If there are any custom headers in REMOTE_HOST_HEADERS, make sure # If there are any custom headers in REMOTE_HOST_HEADERS, make sure
# they respect the allowed proxy list # they respect the allowed proxy list
if all( if settings.PROXY_IP_ALLOWED_LIST:
[ if not is_proxy_in_headers(self.request, settings.PROXY_IP_ALLOWED_LIST, remote_headers):
settings.PROXY_IP_ALLOWED_LIST, delete_headers(request, settings.REMOTE_HOST_HEADERS)
request.environ.get('REMOTE_ADDR') not in settings.PROXY_IP_ALLOWED_LIST,
request.environ.get('REMOTE_HOST') not in settings.PROXY_IP_ALLOWED_LIST,
]
):
for custom_header in settings.REMOTE_HOST_HEADERS:
if custom_header.startswith('HTTP_'):
request.environ.pop(custom_header, None)
drf_request = super(APIView, self).initialize_request(request, *args, **kwargs) drf_request = super(APIView, self).initialize_request(request, *args, **kwargs)
request.drf_request = drf_request request.drf_request = drf_request

View File

@@ -61,6 +61,7 @@ import pytz
from wsgiref.util import FileWrapper from wsgiref.util import FileWrapper
# django-ansible-base # django-ansible-base
from ansible_base.lib.utils.requests import get_remote_hosts
from ansible_base.rbac.models import RoleEvaluation, ObjectRole from ansible_base.rbac.models import RoleEvaluation, ObjectRole
from ansible_base.resource_registry.shared_types import OrganizationType, TeamType, UserType from ansible_base.resource_registry.shared_types import OrganizationType, TeamType, UserType
@@ -2770,12 +2771,7 @@ class JobTemplateCallback(GenericAPIView):
host for the current request. host for the current request.
""" """
# Find the list of remote host names/IPs to check. # Find the list of remote host names/IPs to check.
remote_hosts = set() remote_hosts = set(get_remote_hosts(self.request))
for header in settings.REMOTE_HOST_HEADERS:
for value in self.request.META.get(header, '').split(','):
value = value.strip()
if value:
remote_hosts.add(value)
# Add the reverse lookup of IP addresses. # Add the reverse lookup of IP addresses.
for rh in list(remote_hosts): for rh in list(remote_hosts):
try: try:

View File

@@ -1,4 +1,5 @@
import pytest import pytest
from unittest import mock
from awx.api.versioning import reverse from awx.api.versioning import reverse
@@ -52,21 +53,6 @@ def test_proxy_ip_allowed(get, patch, admin):
assert middleware.environ['HTTP_X_FROM_THE_LOAD_BALANCER'] == 'some-actual-ip' assert middleware.environ['HTTP_X_FROM_THE_LOAD_BALANCER'] == 'some-actual-ip'
@pytest.mark.django_db
def test_x_trusted_proxy(get, patch, admin, rsa_keypair): # noqa: F811
url = reverse('api:setting_singleton_detail', kwargs={'category_slug': 'system'})
patch(url, user=admin, data={'REMOTE_HOST_HEADERS': ['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST']})
# Invalid x_trusted_proxy value SHOULD result in sensitive headers deleted
middleware = HeaderTrackingMiddleware()
headers = {
'HTTP_X_TRUSTED_PROXY': generate_x_trusted_proxy_header(rsa_keypair.private),
'HTTP_X_FROM_THE_LOAD_BALANCER': 'some-actual-ip',
}
get(url, user=admin, middleware=middleware, **headers)
assert middleware.environ['HTTP_X_FROM_THE_LOAD_BALANCER'] == 'some-actual-ip'
@pytest.mark.django_db @pytest.mark.django_db
class TestTrustedProxyAllowListIntegration: class TestTrustedProxyAllowListIntegration:
@pytest.fixture @pytest.fixture
@@ -81,23 +67,25 @@ class TestTrustedProxyAllowListIntegration:
return HeaderTrackingMiddleware() return HeaderTrackingMiddleware()
def test_x_trusted_proxy_valid_signature(self, get, admin, rsa_keypair, url, middleware): # noqa: F811 def test_x_trusted_proxy_valid_signature(self, get, admin, rsa_keypair, url, middleware): # noqa: F811
# Invalid x_trusted_proxy value SHOULD result in sensitive headers deleted # Headers should NOT get deleted
headers = { headers = {
'HTTP_X_TRUSTED_PROXY': generate_x_trusted_proxy_header(rsa_keypair.private), 'HTTP_X_TRUSTED_PROXY': generate_x_trusted_proxy_header(rsa_keypair.private),
'HTTP_X_FROM_THE_LOAD_BALANCER': 'some-actual-ip', 'HTTP_X_FROM_THE_LOAD_BALANCER': 'some-actual-ip',
} }
with override_settings(ANSIBLE_BASE_JWT_KEY=rsa_keypair.public): with mock.patch('ansible_base.jwt_consumer.common.cache.JWTCache.get_key_from_cache', lambda self: None):
get(url, user=admin, middleware=middleware, **headers) with override_settings(ANSIBLE_BASE_JWT_KEY=rsa_keypair.public, PROXY_IP_ALLOWED_LIST=[]):
get(url, user=admin, middleware=middleware, **headers)
assert middleware.environ['HTTP_X_FROM_THE_LOAD_BALANCER'] == 'some-actual-ip' assert middleware.environ['HTTP_X_FROM_THE_LOAD_BALANCER'] == 'some-actual-ip'
def test_x_trusted_proxy_invalid_signature(self, get, admin, url, middleware): def test_x_trusted_proxy_invalid_signature(self, get, admin, url, patch, middleware):
# Invalid x_trusted_proxy value SHOULD result in sensitive headers deleted # Headers should NOT get deleted
headers = { headers = {
'HTTP_X_TRUSTED_PROXY': 'DEAD-BEEF', 'HTTP_X_TRUSTED_PROXY': 'DEAD-BEEF',
'HTTP_X_FROM_THE_LOAD_BALANCER': 'some-actual-ip', 'HTTP_X_FROM_THE_LOAD_BALANCER': 'some-actual-ip',
} }
get(url, user=admin, middleware=middleware, **headers) with override_settings(PROXY_IP_ALLOWED_LIST=[]):
assert 'HTTP_X_FROM_THE_LOAD_BALANCER' not in middleware.environ get(url, user=admin, middleware=middleware, **headers)
assert middleware.environ['HTTP_X_FROM_THE_LOAD_BALANCER'] == 'some-actual-ip'
def test_x_trusted_proxy_invalid_signature_valid_proxy(self, get, admin, url, middleware): def test_x_trusted_proxy_invalid_signature_valid_proxy(self, get, admin, url, middleware):
# A valid explicit proxy SHOULD result in sensitive headers NOT being deleted, regardless of the trusted proxy signature results # A valid explicit proxy SHOULD result in sensitive headers NOT being deleted, regardless of the trusted proxy signature results

View File

@@ -1,4 +1,5 @@
import pytest import pytest
from unittest import mock
# AWX # AWX
from awx.api.serializers import JobTemplateSerializer from awx.api.serializers import JobTemplateSerializer
@@ -8,10 +9,15 @@ from awx.main.migrations import _save_password_keys as save_password_keys
# Django # Django
from django.apps import apps from django.apps import apps
from django.test.utils import override_settings
# DRF # DRF
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
# DAB
from ansible_base.jwt_consumer.common.util import generate_x_trusted_proxy_header
from ansible_base.lib.testing.fixtures import rsa_keypair_factory, rsa_keypair # noqa: F401; pylint: disable=unused-import
@pytest.mark.django_db @pytest.mark.django_db
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -369,3 +375,113 @@ def test_job_template_missing_inventory(project, inventory, admin_user, post):
) )
assert r.status_code == 400 assert r.status_code == 400
assert "Cannot start automatically, an inventory is required." in str(r.data) assert "Cannot start automatically, an inventory is required." in str(r.data)
@pytest.mark.django_db
class TestJobTemplateCallbackProxyIntegration:
"""
Test the interaction of provision job template callback feature and:
settings.PROXY_IP_ALLOWED_LIST
x-trusted-proxy http header
"""
@pytest.fixture
def job_template(self, inventory, project):
jt = JobTemplate.objects.create(name='test-jt', inventory=inventory, project=project, playbook='helloworld.yml', host_config_key='abcd')
return jt
@override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=['my.proxy.example.org'])
def test_host_not_found(self, job_template, admin_user, post, rsa_keypair): # noqa: F811
job_template.inventory.hosts.create(name='foobar')
headers = {
'HTTP_X_FROM_THE_LOAD_BALANCER': 'baz',
'REMOTE_HOST': 'baz',
'REMOTE_ADDR': 'baz',
}
r = post(
url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), data={'host_config_key': 'abcd'}, user=admin_user, expect=400, **headers
)
assert r.data['msg'] == 'No matching host could be found!'
@pytest.mark.parametrize(
'headers, expected',
(
pytest.param(
{
'HTTP_X_FROM_THE_LOAD_BALANCER': 'foobar',
'REMOTE_HOST': 'my.proxy.example.org',
},
201,
),
pytest.param(
{
'HTTP_X_FROM_THE_LOAD_BALANCER': 'foobar',
'REMOTE_HOST': 'not-my-proxy.org',
},
400,
),
),
)
@override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=['my.proxy.example.org'])
def test_proxy_ip_allowed_list(self, job_template, admin_user, post, headers, expected): # noqa: F811
job_template.inventory.hosts.create(name='my.proxy.example.org')
post(
url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}),
data={'host_config_key': 'abcd'},
user=admin_user,
expect=expected,
**headers
)
@override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=[])
def test_no_proxy_trust_all_headers(self, job_template, admin_user, post):
job_template.inventory.hosts.create(name='foobar')
headers = {
'HTTP_X_FROM_THE_LOAD_BALANCER': 'foobar',
'REMOTE_ADDR': 'bar',
'REMOTE_HOST': 'baz',
}
post(url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}), data={'host_config_key': 'abcd'}, user=admin_user, expect=201, **headers)
@override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=['my.proxy.example.org'])
def test_trusted_proxy(self, job_template, admin_user, post, rsa_keypair): # noqa: F811
job_template.inventory.hosts.create(name='foobar')
headers = {
'HTTP_X_TRUSTED_PROXY': generate_x_trusted_proxy_header(rsa_keypair.private),
'HTTP_X_FROM_THE_LOAD_BALANCER': 'foobar, my.proxy.example.org',
}
with mock.patch('ansible_base.jwt_consumer.common.cache.JWTCache.get_key_from_cache', lambda self: None):
with override_settings(ANSIBLE_BASE_JWT_KEY=rsa_keypair.public):
post(
url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}),
data={'host_config_key': 'abcd'},
user=admin_user,
expect=201,
**headers
)
@override_settings(REMOTE_HOST_HEADERS=['HTTP_X_FROM_THE_LOAD_BALANCER', 'REMOTE_ADDR', 'REMOTE_HOST'], PROXY_IP_ALLOWED_LIST=['my.proxy.example.org'])
def test_trusted_proxy_host_not_found(self, job_template, admin_user, post, rsa_keypair): # noqa: F811
job_template.inventory.hosts.create(name='foobar')
headers = {
'HTTP_X_TRUSTED_PROXY': generate_x_trusted_proxy_header(rsa_keypair.private),
'HTTP_X_FROM_THE_LOAD_BALANCER': 'baz, my.proxy.example.org',
'REMOTE_ADDR': 'bar',
'REMOTE_HOST': 'baz',
}
with mock.patch('ansible_base.jwt_consumer.common.cache.JWTCache.get_key_from_cache', lambda self: None):
with override_settings(ANSIBLE_BASE_JWT_KEY=rsa_keypair.public):
post(
url=reverse('api:job_template_callback', kwargs={'pk': job_template.pk}),
data={'host_config_key': 'abcd'},
user=admin_user,
expect=400,
**headers
)

30
awx/main/utils/proxy.py Normal file
View File

@@ -0,0 +1,30 @@
# Copyright (c) 2024 Ansible, Inc.
# All Rights Reserved.
# DRF
from rest_framework.request import Request
"""
Note that these methods operate on request.environ. This data is from uwsgi.
It is the source data from which request.headers (read-only) is constructed.
"""
def is_proxy_in_headers(request: Request, proxy_list: list[str], headers: list[str]) -> bool:
remote_hosts = set()
for header in headers:
for value in request.environ.get(header, '').split(','):
value = value.strip()
if value:
remote_hosts.add(value)
return bool(remote_hosts.intersection(set(proxy_list)))
def delete_headers(request: Request, headers: list[str]):
for header in headers:
if header.startswith('HTTP_'):
request.environ.pop(header, None)