Skip to content

Commit eff9bb0

Browse files
committed
fink: enhance ProcessLinksMiddleware with base URL handling and link transformation
- Added `get_base_url` utility to reconstruct the client's base URL from forwarded headers. - Updated `ProcessLinksMiddleware` to utilize the new utility for transforming links in responses. - Improved link transformation logic to handle various scenarios, including different hostnames and ports. - Refactored tests for `ProcessLinksMiddleware` to cover new functionality and edge cases.
1 parent 9cea266 commit eff9bb0

File tree

4 files changed

+693
-106
lines changed

4 files changed

+693
-106
lines changed

src/stac_auth_proxy/middleware/ProcessLinksMiddleware.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from starlette.types import ASGIApp, Scope
1212

1313
from ..utils.middleware import JsonResponseMiddleware
14+
from ..utils.requests import get_base_url
1415
from ..utils.stac import get_links
1516

1617
logger = logging.getLogger(__name__)
@@ -40,6 +41,11 @@ def should_transform_response(self, request: Request, scope: Scope) -> bool:
4041

4142
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
4243
"""Update links in the response to include root_path."""
44+
# Get the client's actual base URL (accounting for load balancers/proxies)
45+
req_base_url = get_base_url(request)
46+
parsed_req_url = urlparse(req_base_url)
47+
parsed_upstream_url = urlparse(self.upstream_url)
48+
4349
for link in get_links(data):
4450
href = link.get("href")
4551
if not href:
@@ -48,12 +54,25 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
4854
try:
4955
parsed_link = urlparse(href)
5056

51-
# Ignore links that are not for this proxy
52-
if parsed_link.netloc != request.headers.get("host"):
57+
if parsed_link.netloc not in [
58+
parsed_req_url.netloc,
59+
parsed_upstream_url.netloc,
60+
]:
61+
logger.warning(
62+
"Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)",
63+
href,
64+
parsed_req_url.netloc,
65+
parsed_upstream_url.netloc,
66+
)
5367
continue
5468

55-
# Remove the upstream_url path from the link if it exists
56-
parsed_upstream_url = urlparse(self.upstream_url)
69+
if parsed_link.netloc == parsed_upstream_url.netloc:
70+
# Replace the upstream host with the client's host
71+
parsed_link = parsed_link._replace(
72+
netloc=parsed_req_url.netloc
73+
)._replace(scheme=parsed_req_url.scheme)
74+
75+
# Rewrite the link path
5776
if parsed_upstream_url.path != "/" and parsed_link.path.startswith(
5877
parsed_upstream_url.path
5978
):
@@ -68,6 +87,7 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
6887
)
6988

7089
link["href"] = urlunparse(parsed_link)
90+
7191
except Exception as e:
7292
logger.error(
7393
"Failed to parse link href %r, (ignoring): %s", href, str(e)

src/stac_auth_proxy/utils/requests.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
"""Utility functions for working with HTTP requests."""
22

33
import json
4+
import logging
45
import re
56
from dataclasses import dataclass, field
6-
from typing import Sequence
7+
from typing import Dict, Sequence
78
from urllib.parse import urlparse
89

10+
from starlette.requests import Request
11+
912
from ..config import EndpointMethods
1013

14+
logger = logging.getLogger(__name__)
15+
1116

1217
def extract_variables(url: str) -> dict:
1318
"""
@@ -80,3 +85,110 @@ class MatchResult:
8085

8186
is_private: bool
8287
required_scopes: Sequence[str] = field(default_factory=list)
88+
89+
90+
def parse_forwarded_header(forwarded_header: str) -> Dict[str, str]:
91+
"""
92+
Parse the Forwarded header according to RFC 7239.
93+
94+
Args:
95+
forwarded_header: The Forwarded header value
96+
97+
Returns:
98+
Dictionary containing parsed forwarded information (proto, host, for, by, etc.)
99+
100+
Example:
101+
>>> parse_forwarded_header("for=192.0.2.43; by=203.0.113.60; proto=https; host=api.example.com")
102+
{'for': '192.0.2.43', 'by': '203.0.113.60', 'proto': 'https', 'host': 'api.example.com'}
103+
104+
"""
105+
# Forwarded header format: "for=192.0.2.43, for=198.51.100.17; by=203.0.113.60; proto=https; host=example.com"
106+
# The format is: for=value1, for=value2; by=value; proto=value; host=value
107+
# We need to parse all the key=value pairs, taking the first 'for' value
108+
forwarded_info = {}
109+
110+
try:
111+
# Parse all key=value pairs separated by semicolons
112+
for pair in forwarded_header.split(";"):
113+
pair = pair.strip()
114+
if "=" in pair:
115+
key, value = pair.split("=", 1)
116+
key = key.strip()
117+
value = value.strip().strip('"')
118+
119+
# For 'for' field, only take the first value if there are multiple
120+
if key == "for" and key not in forwarded_info:
121+
# Extract the first for value (before comma if present)
122+
first_for_value = value.split(",")[0].strip()
123+
forwarded_info[key] = first_for_value
124+
elif key != "for":
125+
# For other fields, just use the value as-is
126+
forwarded_info[key] = value
127+
except Exception as e:
128+
logger.warning(f"Failed to parse Forwarded header '{forwarded_header}': {e}")
129+
return {}
130+
131+
return forwarded_info
132+
133+
134+
def get_base_url(request: Request) -> str:
135+
"""
136+
Get the request's base URL, accounting for forwarded headers from load balancers/proxies.
137+
138+
This function handles both the standard Forwarded header (RFC 7239) and legacy
139+
X-Forwarded-* headers to reconstruct the original client URL when the service
140+
is deployed behind load balancers or reverse proxies.
141+
142+
Args:
143+
request: The Starlette request object
144+
145+
Returns:
146+
The reconstructed client base URL
147+
148+
Example:
149+
>>> # With Forwarded header
150+
>>> request.headers = {"Forwarded": "for=192.0.2.43; proto=https; host=api.example.com"}
151+
>>> get_base_url(request)
152+
"https://api.example.com/"
153+
154+
>>> # With X-Forwarded-* headers
155+
>>> request.headers = {"X-Forwarded-Host": "api.example.com", "X-Forwarded-Proto": "https"}
156+
>>> get_base_url(request)
157+
"https://api.example.com/"
158+
159+
"""
160+
# Check for standard Forwarded header first (RFC 7239)
161+
forwarded_header = request.headers.get("Forwarded")
162+
if forwarded_header:
163+
try:
164+
forwarded_info = parse_forwarded_header(forwarded_header)
165+
# Only use Forwarded header if we successfully parsed it and got useful info
166+
if forwarded_info and (
167+
"proto" in forwarded_info or "host" in forwarded_info
168+
):
169+
scheme = forwarded_info.get("proto", request.url.scheme)
170+
host = forwarded_info.get("host", request.url.netloc)
171+
# Note: Forwarded header doesn't include path, so we use request.base_url.path
172+
path = request.base_url.path
173+
return f"{scheme}://{host}{path}"
174+
except Exception as e:
175+
logger.warning(f"Failed to parse Forwarded header: {e}")
176+
177+
# Fall back to legacy X-Forwarded-* headers
178+
forwarded_host = request.headers.get("X-Forwarded-Host")
179+
forwarded_proto = request.headers.get("X-Forwarded-Proto")
180+
forwarded_path = request.headers.get("X-Forwarded-Path")
181+
182+
if forwarded_host:
183+
# Use forwarded headers to reconstruct the original client URL
184+
scheme = forwarded_proto or request.url.scheme
185+
netloc = forwarded_host
186+
# Use forwarded path if available, otherwise use request base URL path
187+
path = forwarded_path or request.base_url.path
188+
else:
189+
# Fall back to the request's base URL if no forwarded headers
190+
scheme = request.url.scheme
191+
netloc = request.url.netloc
192+
path = request.base_url.path
193+
194+
return f"{scheme}://{netloc}{path}"

0 commit comments

Comments
 (0)