diff --git a/awx/api/generics.py b/awx/api/generics.py index 2597e2452a..12ee6eb7b7 100644 --- a/awx/api/generics.py +++ b/awx/api/generics.py @@ -36,6 +36,7 @@ from ansible_base.lib.utils.models import get_all_field_names from ansible_base.lib.utils.requests import get_remote_host from ansible_base.rbac.models import RoleEvaluation, RoleDefinition from ansible_base.rbac.permission_registry import permission_registry +from ansible_base.jwt_consumer.common.util import validate_x_trusted_proxy_header # AWX from awx.main.models import UnifiedJob, UnifiedJobTemplate, User, Role, Credential, WorkflowJobTemplateNode, WorkflowApprovalTemplate @@ -153,13 +154,21 @@ class APIView(views.APIView): Store the Django REST Framework Request object as an attribute on the normal Django request, store time the request started. """ + is_trusted_proxy = False + self.time_started = time.time() if getattr(settings, 'SQL_DEBUG', False): self.queries_before = len(connection.queries) + if 'HTTP_X_TRUSTED_PROXY' in request.META: + if validate_x_trusted_proxy_header(request.META['HTTP_X_TRUSTED_PROXY']): + is_trusted_proxy = True + else: + logger.warning("Request appeared to be a trusted upstream proxy but failed to provide a matching shared secret.") + # If there are any custom headers in REMOTE_HOST_HEADERS, make sure # they respect the allowed proxy list - if all( + if not is_trusted_proxy and all( [ settings.PROXY_IP_ALLOWED_LIST, request.environ.get('REMOTE_ADDR') not in settings.PROXY_IP_ALLOWED_LIST, diff --git a/awx/api/views/__init__.py b/awx/api/views/__init__.py index 92db0f87f2..844287c9f2 100644 --- a/awx/api/views/__init__.py +++ b/awx/api/views/__init__.py @@ -61,6 +61,7 @@ import pytz from wsgiref.util import FileWrapper # django-ansible-base +from ansible_base.lib.utils.requests import get_remote_hosts from ansible_base.rbac.models import RoleEvaluation, ObjectRole from ansible_base.resource_registry.shared_types import OrganizationType, TeamType, UserType @@ -2770,12 +2771,8 @@ class JobTemplateCallback(GenericAPIView): host for the current request. """ # Find the list of remote host names/IPs to check. - remote_hosts = set() - for header in settings.REMOTE_HOST_HEADERS: - for value in self.request.META.get(header, '').split(','): - value = value.strip() - if value: - remote_hosts.add(value) + + remote_hosts = set(get_remote_hosts(self.request)) # Add the reverse lookup of IP addresses. for rh in list(remote_hosts): try: