diff --git a/docs/how_to_guides/index.rst b/docs/how_to_guides/index.rst index bccdd541..9f4bf2d7 100644 --- a/docs/how_to_guides/index.rst +++ b/docs/how_to_guides/index.rst @@ -11,6 +11,7 @@ How to guides dispatch_apps.rst http_https_redirect.rst logging.rst + proxy_fix.rst server_names.rst statsd.rst wsgi_apps.rst diff --git a/docs/how_to_guides/proxy_fix.rst b/docs/how_to_guides/proxy_fix.rst new file mode 100644 index 00000000..dd8d080f --- /dev/null +++ b/docs/how_to_guides/proxy_fix.rst @@ -0,0 +1,33 @@ +Fixing proxy headers +==================== + +If you are serving Hypercorn behind a proxy e.g. a load balancer the +client-address, scheme, and host-header will match that of the +connection between the proxy and Hypercorn rather than the user-agent +(client). However, most proxies provide headers with the original +user-agent (client) values which can be used to "fix" the headers to +these values. + +Modern proxies should provide this information via a ``Forwarded`` +header from `RFC 7239 +`_. However, this is +rare in practice with legacy proxies using a combination of +``X-Forwarded-For``, ``X-Forwarded-Proto`` and +``X-Forwarded-Host``. It is important that you chose the correct mode +(legacy, or modern) based on the proxy you use. + +To use the proxy fix middleware behind a single legacy proxy simply +wrap your app and serve the wrapped app, + +.. code-block:: python + + from hypercorn.middleware import ProxyFixMiddleware + + fixed_app = ProxyFixMiddleware(app, mode="legacy", trusted_hops=1) + +.. warning:: + + The mode and number of trusted hops must match your setup or the + user-agent (client) may be trusted and hence able to set + alternative for, proto, and host values. This can, depending on + your usage in the app, lead to security vulnerabilities. diff --git a/src/hypercorn/middleware/__init__.py b/src/hypercorn/middleware/__init__.py index 83ea29c7..e7f017c1 100644 --- a/src/hypercorn/middleware/__init__.py +++ b/src/hypercorn/middleware/__init__.py @@ -2,11 +2,13 @@ from .dispatcher import DispatcherMiddleware from .http_to_https import HTTPToHTTPSRedirectMiddleware +from .proxy_fix import ProxyFixMiddleware from .wsgi import AsyncioWSGIMiddleware, TrioWSGIMiddleware __all__ = ( "AsyncioWSGIMiddleware", "DispatcherMiddleware", "HTTPToHTTPSRedirectMiddleware", + "ProxyFixMiddleware", "TrioWSGIMiddleware", ) diff --git a/src/hypercorn/middleware/proxy_fix.py b/src/hypercorn/middleware/proxy_fix.py new file mode 100644 index 00000000..509941cf --- /dev/null +++ b/src/hypercorn/middleware/proxy_fix.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Callable, Iterable, Literal, Optional, Tuple + +from ..typing import ASGIFramework, Scope + + +class ProxyFixMiddleware: + def __init__( + self, + app: ASGIFramework, + mode: Literal["legacy", "modern"] = "legacy", + trusted_hops: int = 1, + ) -> None: + self.app = app + self.mode = mode + self.trusted_hops = trusted_hops + + async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None: + if scope["type"] in {"http", "websocket"}: + scope = deepcopy(scope) + headers = scope["headers"] # type: ignore + client: Optional[str] = None + scheme: Optional[str] = None + host: Optional[str] = None + + if ( + self.mode == "modern" + and (value := _get_trusted_value(b"forwarded", headers, self.trusted_hops)) + is not None + ): + for part in value.split(";"): + if part.startswith("for="): + client = part[4:].strip() + elif part.startswith("host="): + host = part[5:].strip() + elif part.startswith("proto="): + scheme = part[6:].strip() + + else: + client = _get_trusted_value(b"x-forwarded-for", headers, self.trusted_hops) + scheme = _get_trusted_value(b"x-forwarded-proto", headers, self.trusted_hops) + host = _get_trusted_value(b"x-forwarded-host", headers, self.trusted_hops) + + if client is not None: + scope["client"] = (client, 0) # type: ignore + + if scheme is not None: + scope["scheme"] = scheme # type: ignore + + if host is not None: + headers = [ + (name, header_value) + for name, header_value in headers + if name.lower() != b"host" + ] + headers.append((b"host", host)) + scope["headers"] = headers # type: ignore + + await self.app(scope, receive, send) + + +def _get_trusted_value( + name: bytes, headers: Iterable[Tuple[bytes, bytes]], trusted_hops: int +) -> Optional[str]: + if trusted_hops == 0: + return None + + values = [] + for header_name, header_value in headers: + if header_name.lower() == name: + values.extend([value.decode("latin1").strip() for value in header_value.split(b",")]) + + if len(values) >= trusted_hops: + return values[-trusted_hops] + + return None diff --git a/tests/middleware/test_proxy_fix.py b/tests/middleware/test_proxy_fix.py new file mode 100644 index 00000000..2a43b589 --- /dev/null +++ b/tests/middleware/test_proxy_fix.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from hypercorn.middleware import ProxyFixMiddleware +from hypercorn.typing import HTTPScope + + +@pytest.mark.asyncio +async def test_proxy_fix_legacy() -> None: + mock = AsyncMock() + app = ProxyFixMiddleware(mock) + scope: HTTPScope = { + "type": "http", + "asgi": {}, + "http_version": "2", + "method": "GET", + "scheme": "http", + "path": "/", + "raw_path": b"/", + "query_string": b"", + "root_path": "", + "headers": [ + (b"x-forwarded-for", b"127.0.0.1"), + (b"x-forwarded-for", b"127.0.0.2"), + (b"x-forwarded-proto", b"http,https"), + ], + "client": ("127.0.0.3", 80), + "server": None, + "extensions": {}, + } + await app(scope, None, None) + mock.assert_called() + assert mock.call_args[0][0]["client"] == ("127.0.0.2", 0) + assert mock.call_args[0][0]["scheme"] == "https" + + +@pytest.mark.asyncio +async def test_proxy_fix_modern() -> None: + mock = AsyncMock() + app = ProxyFixMiddleware(mock, mode="modern") + scope: HTTPScope = { + "type": "http", + "asgi": {}, + "http_version": "2", + "method": "GET", + "scheme": "http", + "path": "/", + "raw_path": b"/", + "query_string": b"", + "root_path": "", + "headers": [ + (b"forwarded", b"for=127.0.0.1;proto=http,for=127.0.0.2;proto=https"), + ], + "client": ("127.0.0.3", 80), + "server": None, + "extensions": {}, + } + await app(scope, None, None) + mock.assert_called() + assert mock.call_args[0][0]["client"] == ("127.0.0.2", 0) + assert mock.call_args[0][0]["scheme"] == "https"