Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(auth): ensure consistent origin url #4748

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 54 additions & 33 deletions src/phoenix/server/api/routers/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from datetime import timedelta
from random import randrange
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, TypedDict
from urllib.parse import unquote, urlparse

from authlib.common.security import generate_token
Expand All @@ -18,7 +18,7 @@
from starlette.responses import RedirectResponse
from starlette.routing import Router
from starlette.status import HTTP_302_FOUND
from typing_extensions import Annotated
from typing_extensions import Annotated, NotRequired, TypeGuard

from phoenix.auth import (
DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES,
Expand Down Expand Up @@ -87,10 +87,13 @@ async def login(
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
):
return _redirect_to_login(error=f"Unknown IDP: {idp_name}.")
origin_url = _get_origin_url(request)
authorization_url_data = await oauth2_client.create_authorization_url(
redirect_uri=_get_create_tokens_endpoint(request=request, idp_name=idp_name),
redirect_uri=_get_create_tokens_endpoint(
request=request, origin_url=origin_url, idp_name=idp_name
),
state=_generate_state_for_oauth2_authorization_code_flow(
secret=secret, return_url=return_url
secret=secret, origin_url=origin_url, return_url=return_url
),
)
assert isinstance(authorization_url := authorization_url_data.get("url"), str)
Expand Down Expand Up @@ -122,12 +125,13 @@ async def create_tokens(
secret = request.app.state.get_secret()
if state != stored_state:
return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE)
signature_is_valid, return_url = _validate_signature_and_parse_return_url(
secret=secret, state=state
)
if not signature_is_valid:
try:
payload = _parse_state_payload(secret=secret, state=state)
except JoseError:
return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE)
if return_url is not None and not _is_relative_url(unquote(return_url)):
if (return_url := payload.get("return_url")) is not None and not _is_relative_url(
unquote(return_url)
):
return _redirect_to_login(error="Attempting login with unsafe return URL.")
assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
Expand All @@ -140,7 +144,9 @@ async def create_tokens(
token_data = await oauth2_client.fetch_access_token(
state=state,
code=authorization_code,
redirect_uri=_get_create_tokens_endpoint(request=request, idp_name=idp_name),
redirect_uri=_get_create_tokens_endpoint(
request=request, origin_url=payload["origin_url"], idp_name=idp_name
),
)
except OAuthError as error:
return _redirect_to_login(error=str(error))
Expand Down Expand Up @@ -357,17 +363,17 @@ def _redirect_to_login(*, error: str) -> RedirectResponse:
return response


def _get_create_tokens_endpoint(*, request: Request, idp_name: str) -> str:
def _get_create_tokens_endpoint(*, request: Request, origin_url: str, idp_name: str) -> str:
"""
Gets the endpoint for create tokens route.
"""
router: Router = request.scope["router"]
url_path = router.url_path_for(create_tokens.__name__, idp_name=idp_name)
return str(url_path.make_absolute_url(base_url=_get_origin_url(request)))
return str(url_path.make_absolute_url(base_url=origin_url))


def _generate_state_for_oauth2_authorization_code_flow(
*, secret: str, return_url: Optional[str]
*, secret: str, origin_url: str, return_url: Optional[str]
) -> str:
"""
Generates a JWT whose payload contains both an OAuth2 state (generated using
Expand All @@ -377,30 +383,34 @@ def _generate_state_for_oauth2_authorization_code_flow(
maintain state.
"""
header = {"alg": _JWT_ALGORITHM}
payload = {"state": generate_token()}
payload = _OAuth2StatePayload(
random=generate_token(),
origin_url=origin_url,
)
if return_url is not None:
payload[_RETURN_URL] = return_url
payload["return_url"] = return_url
jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=secret)
return jwt_bytes.decode()


def _validate_signature_and_parse_return_url(
*, secret: str, state: str
) -> Tuple[bool, Optional[str]]:
class _OAuth2StatePayload(TypedDict):
"""
Represents the OAuth2 state payload.
"""

random: str
origin_url: str
return_url: NotRequired[str]


def _parse_state_payload(*, secret: str, state: str) -> _OAuth2StatePayload:
"""
Validates the JWT signature and parses the return URL from the OAuth2 state.
"""
signature_is_valid: bool
return_url: Optional[str]
try:
payload = jwt.decode(s=state, key=secret)
signature_is_valid = True
return_url = payload.get(_RETURN_URL)
assert isinstance(return_url, str) or return_url is None
except JoseError:
signature_is_valid = False
return_url = None
return signature_is_valid, return_url
payload = jwt.decode(s=state, key=secret)
if _is_oauth2_state_payload(payload):
return payload
raise ValueError("Invalid OAuth2 state payload.")


def _is_relative_url(url: str) -> bool:
Expand All @@ -417,17 +427,28 @@ def _with_random_suffix(string: str) -> str:
return f"{string}-{randrange(10_000, 100_000)}"


def _get_origin_url(request: Request) -> URL:
def _get_origin_url(request: Request) -> str:
"""
Infers the origin URL from the request.
"""
if (referer := request.headers.get("referer")) is None:
return request.base_url
return str(request.base_url)
parsed_url = urlparse(referer)
return URL(f"{parsed_url.scheme}://{parsed_url.netloc}")
return f"{parsed_url.scheme}://{parsed_url.netloc}"


def _is_oauth2_state_payload(maybe_state_payload: Any) -> TypeGuard[_OAuth2StatePayload]:
"""
Determines whether the given object is an OAuth2 state payload.
"""

return (
isinstance(maybe_state_payload, dict)
and {"random", "origin_url"}.issubset((keys := set(maybe_state_payload.keys())))
and keys.issubset({"random", "origin_url", "return_url"})
)


_RETURN_URL = "return_url"
_JWT_ALGORITHM = "HS256"
_INVALID_OAUTH2_STATE_MESSAGE = (
"Received invalid state parameter during OAuth2 authorization code flow for IDP {idp_name}."
Expand Down
Loading