Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add middleware for checking Cross-Site Request Forgery (CSRF) when trusted origins are specified via environment variable #4916

Merged
merged 11 commits into from
Oct 9, 2024
Merged
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ dependencies = [
"litellm>=1.0.3",
"openai>=1.0.0",
"tenacity",
"protobuf==3.20", # version minimum (for tests)
"protobuf==3.20.2", # version minimum (for tests)
"grpc-interceptor[testing]",
"responses",
"tiktoken",
Expand Down
21 changes: 21 additions & 0 deletions src/phoenix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@
"""
The duration, in minutes, before password reset tokens expire.
"""
ENV_PHOENIX_CSRF_TRUSTED_ORIGINS = "PHOENIX_CSRF_TRUSTED_ORIGINS"
"""
A comma-separated list of origins that are allowed to bypass Cross-Site Request Forgery (CSRF)
protection.
"""

# SMTP settings
ENV_PHOENIX_SMTP_HOSTNAME = "PHOENIX_SMTP_HOSTNAME"
Expand Down Expand Up @@ -321,6 +326,22 @@ def get_env_refresh_token_expiry() -> timedelta:
return timedelta(minutes=minutes)


def get_env_csrf_trusted_origins() -> List[str]:
origins: List[str] = []
if not (csrf_trusted_origins := os.getenv(ENV_PHOENIX_CSRF_TRUSTED_ORIGINS)):
return origins
for origin in csrf_trusted_origins.split(","):
if not origin:
continue
if not urlparse(origin).hostname:
raise ValueError(
f"The environment variable `{ENV_PHOENIX_CSRF_TRUSTED_ORIGINS}` contains a url "
f"with missing hostname. Please ensure that each url has a valid hostname."
)
origins.append(origin)
return sorted(set(origins))


def get_env_smtp_username() -> str:
return os.getenv(ENV_PHOENIX_SMTP_USERNAME) or ""

Expand Down
31 changes: 31 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Union,
cast,
)
from urllib.parse import urlparse

import strawberry
from fastapi import APIRouter, Depends, FastAPI
Expand All @@ -42,6 +43,7 @@
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.staticfiles import StaticFiles
from starlette.status import HTTP_401_UNAUTHORIZED
from starlette.templating import Jinja2Templates
from starlette.types import Scope, StatefulLifespan
from strawberry.extensions import SchemaExtension
Expand All @@ -53,8 +55,10 @@
import phoenix.trace.v1 as pb
from phoenix.config import (
DEFAULT_PROJECT_NAME,
ENV_PHOENIX_CSRF_TRUSTED_ORIGINS,
SERVER_DIR,
OAuth2ClientConfig,
get_env_csrf_trusted_origins,
get_env_host,
get_env_port,
server_instrumentation_is_enabled,
Expand Down Expand Up @@ -226,6 +230,25 @@
return response


class RequestOriginHostnameValidator(BaseHTTPMiddleware):
def __init__(self, trusted_hostnames: List[str], *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._trusted_hostnames = trusted_hostnames

async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
headers = request.headers
for key in "origin", "referer":
if not (url := headers.get(key)):
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
continue
if urlparse(url).hostname not in self._trusted_hostnames:
return Response(f"untrusted {key}", status_code=HTTP_401_UNAUTHORIZED)
return await call_next(request)


class HeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
Expand Down Expand Up @@ -660,6 +683,14 @@
)
last_updated_at = LastUpdatedAt()
middlewares: List[Middleware] = [Middleware(HeadersMiddleware)]
if origins := get_env_csrf_trusted_origins():
trusted_hostnames = [h for o in origins if o and (h := urlparse(o).hostname)]
middlewares.append(Middleware(RequestOriginHostnameValidator, trusted_hostnames))
elif email_sender or oauth2_client_configs:
logger.warning(
"CSRF protection can be enabled by listing trusted origins via "
f"the `{ENV_PHOENIX_CSRF_TRUSTED_ORIGINS}` environment variable."
Fixed Show fixed Hide fixed
)
if authentication_enabled and secret:
token_store = JwtStore(db, secret)
middlewares.append(
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/auth/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from faker import Faker
from phoenix.auth import DEFAULT_SECRET_LENGTH
from phoenix.config import (
ENV_PHOENIX_CSRF_TRUSTED_ORIGINS,
ENV_PHOENIX_DISABLE_RATE_LIMIT,
ENV_PHOENIX_ENABLE_AUTH,
ENV_PHOENIX_SECRET,
Expand Down Expand Up @@ -52,6 +53,7 @@ def _app(
(ENV_PHOENIX_SMTP_PASSWORD, "test"),
(ENV_PHOENIX_SMTP_MAIL_FROM, _fake.email()),
(ENV_PHOENIX_SMTP_VALIDATE_CERTS, "false"),
(ENV_PHOENIX_CSRF_TRUSTED_ORIGINS, ",http://localhost,"),
)
with ExitStack() as stack:
stack.enter_context(mock.patch.dict(os.environ, values))
Expand Down
25 changes: 25 additions & 0 deletions tests/integration/auth/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime, timedelta, timezone
from functools import partial
from typing import (
Any,
ContextManager,
DefaultDict,
Dict,
Expand Down Expand Up @@ -53,6 +54,7 @@
_grpc_span_exporter,
_Headers,
_http_span_exporter,
_httpx_client,
_initiate_password_reset,
_log_in,
_log_out,
Expand All @@ -73,6 +75,29 @@
_TokenT = TypeVar("_TokenT", _AccessToken, _RefreshToken)


class TestOriginAndReferer:
@pytest.mark.parametrize(
"headers,expectation",
[
[dict(), _OK],
[dict(origin="http://localhost"), _OK],
[dict(referer="http://localhost/xyz"), _OK],
[dict(origin="http://xyz.com"), _EXPECTATION_401],
[dict(referer="http://xyz.com/xyz"), _EXPECTATION_401],
[dict(origin="http://xyz.com", referer="http://localhost/xyz"), _EXPECTATION_401],
[dict(origin="http://localhost", referer="http://xyz.com/xyz"), _EXPECTATION_401],
],
)
def test_csrf_origin_validation(
self,
headers: Dict[str, str],
expectation: ContextManager[Any],
) -> None:
resp = _httpx_client(headers=headers).get("/healthz")
with expectation:
resp.raise_for_status()


class TestLogIn:
@pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _DEFAULT_ADMIN])
def test_can_log_in(
Expand Down
Loading