diff --git a/src/phoenix/auth.py b/src/phoenix/auth.py index 4f3e0751809..999801c2577 100644 --- a/src/phoenix/auth.py +++ b/src/phoenix/auth.py @@ -71,31 +71,53 @@ def validate_password_format(password: str) -> None: def set_access_token_cookie( *, response: ResponseType, access_token: str, max_age: timedelta ) -> ResponseType: - return _set_token_cookie( + return _set_cookie( response=response, cookie_name=PHOENIX_ACCESS_TOKEN_COOKIE_NAME, cookie_max_age=max_age, - token=access_token, + value=access_token, ) def set_refresh_token_cookie( *, response: ResponseType, refresh_token: str, max_age: timedelta ) -> ResponseType: - return _set_token_cookie( + return _set_cookie( response=response, cookie_name=PHOENIX_REFRESH_TOKEN_COOKIE_NAME, cookie_max_age=max_age, - token=refresh_token, + value=refresh_token, ) -def _set_token_cookie( - response: ResponseType, cookie_name: str, cookie_max_age: timedelta, token: str +def set_oauth2_state_cookie( + *, response: ResponseType, state: str, max_age: timedelta +) -> ResponseType: + return _set_cookie( + response=response, + cookie_name=PHOENIX_OAUTH2_STATE_COOKIE_NAME, + cookie_max_age=max_age, + value=state, + ) + + +def set_oauth2_nonce_cookie( + *, response: ResponseType, nonce: str, max_age: timedelta +) -> ResponseType: + return _set_cookie( + response=response, + cookie_name=PHOENIX_OAUTH2_NONCE_COOKIE_NAME, + cookie_max_age=max_age, + value=nonce, + ) + + +def _set_cookie( + response: ResponseType, cookie_name: str, cookie_max_age: timedelta, value: str ) -> ResponseType: response.set_cookie( key=cookie_name, - value=token, + value=value, secure=get_env_phoenix_use_secure_cookies(), httponly=True, samesite="strict", @@ -104,16 +126,26 @@ def _set_token_cookie( return response -def delete_access_token_cookie(response: Response) -> Response: +def delete_access_token_cookie(response: ResponseType) -> ResponseType: response.delete_cookie(key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME) return response -def delete_refresh_token_cookie(response: Response) -> Response: +def delete_refresh_token_cookie(response: ResponseType) -> ResponseType: response.delete_cookie(key=PHOENIX_REFRESH_TOKEN_COOKIE_NAME) return response +def delete_oauth2_state_cookie(response: ResponseType) -> ResponseType: + response.delete_cookie(key=PHOENIX_OAUTH2_STATE_COOKIE_NAME) + return response + + +def delete_oauth2_nonce_cookie(response: ResponseType) -> ResponseType: + response.delete_cookie(key=PHOENIX_OAUTH2_NONCE_COOKIE_NAME) + return response + + @dataclass(frozen=True) class _PasswordRequirements: """ @@ -206,6 +238,16 @@ def validate( """The name of the cookie that stores the Phoenix access token.""" PHOENIX_REFRESH_TOKEN_COOKIE_NAME = "phoenix-refresh-token" """The name of the cookie that stores the Phoenix refresh token.""" +PHOENIX_OAUTH2_STATE_COOKIE_NAME = "phoenix-oauth2-state" +"""The name of the cookie that stores the state used for the OAuth2 authorization code flow.""" +PHOENIX_OAUTH2_NONCE_COOKIE_NAME = "phoenix-oauth2-nonce" +"""The name of the cookie that stores the nonce used for the OAuth2 authorization code flow.""" +DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES = 15 +""" +The default amount of time in minutes that can elapse between the initial +redirect to the IDP and the invocation of the callback URL during the OAuth2 +authorization code flow. +""" class Token(str): ... diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index a1b158b991a..b859cf1bbf3 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -12,6 +12,8 @@ PHOENIX_REFRESH_TOKEN_COOKIE_NAME, Token, delete_access_token_cookie, + delete_oauth2_nonce_cookie, + delete_oauth2_state_cookie, delete_refresh_token_cookie, is_valid_password, set_access_token_cookie, @@ -97,6 +99,8 @@ async def logout( response = Response(status_code=HTTP_204_NO_CONTENT) response = delete_access_token_cookie(response) response = delete_refresh_token_cookie(response) + response = delete_oauth2_state_cookie(response) + response = delete_oauth2_nonce_cookie(response) return response diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 33b2f3e1c60..1107069ba71 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -2,28 +2,36 @@ from datetime import timedelta from typing import Any, Dict, Optional +from authlib.common.security import generate_token from authlib.integrations.starlette_client import OAuthError -from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client -from fastapi import APIRouter, Path, Request +from fastapi import APIRouter, Cookie, Path, Query, Request from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from starlette.datastructures import URL from starlette.responses import RedirectResponse +from starlette.status import HTTP_302_FOUND from typing_extensions import Annotated from phoenix.auth import ( + DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES, + PHOENIX_OAUTH2_NONCE_COOKIE_NAME, + PHOENIX_OAUTH2_STATE_COOKIE_NAME, + delete_oauth2_nonce_cookie, + delete_oauth2_state_cookie, set_access_token_cookie, + set_oauth2_nonce_cookie, + set_oauth2_state_cookie, set_refresh_token_cookie, ) from phoenix.db import models from phoenix.db.enums import UserRole from phoenix.server.bearer_auth import create_access_and_refresh_tokens from phoenix.server.jwt_store import JwtStore +from phoenix.server.oauth2 import OAuth2Client _LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+" - router = APIRouter(prefix="/oauth2", include_in_schema=False) @@ -36,8 +44,24 @@ async def login( oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client ): return _redirect_to_login(error=f"Unknown IDP: {idp_name}.") - redirect_uri = request.url_for("create_tokens", idp_name=idp_name) - response: RedirectResponse = await oauth2_client.authorize_redirect(request, redirect_uri) + authorization_url_data = await oauth2_client.create_authorization_url( + redirect_uri=_get_create_tokens_endpoint(request=request, idp_name=idp_name), + state=generate_token(), + ) + assert isinstance(authorization_url := authorization_url_data.get("url"), str) + assert isinstance(state := authorization_url_data.get("state"), str) + assert isinstance(nonce := authorization_url_data.get("nonce"), str) + response = RedirectResponse(url=authorization_url, status_code=HTTP_302_FOUND) + response = set_oauth2_state_cookie( + response=response, + state=state, + max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES), + ) + response = set_oauth2_nonce_cookie( + response=response, + nonce=nonce, + max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES), + ) return response @@ -45,7 +69,18 @@ async def login( async def create_tokens( request: Request, idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)], + state: str = Query(), + authorization_code: str = Query(alias="code"), + stored_state: str = Cookie(alias=PHOENIX_OAUTH2_STATE_COOKIE_NAME), + stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME), ) -> RedirectResponse: + if state != stored_state: + return _redirect_to_login( + error=( + "Received invalid state parameter during " + "OAuth2 authorization code flow for IDP {idp_name}." + ) + ) assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta) assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta) token_store: JwtStore = request.app.state.get_token_store() @@ -54,13 +89,20 @@ async def create_tokens( ): return _redirect_to_login(error=f"Unknown IDP: {idp_name}.") try: - token_data = await oauth2_client.authorize_access_token(request) + token_data = await oauth2_client.fetch_access_token( + state=state, + code=authorization_code, + redirect_uri=_get_create_tokens_endpoint(request=request, idp_name=idp_name), + ) except OAuthError as error: return _redirect_to_login(error=str(error)) - if (user_info := _get_user_info(token_data)) is None: + _validate_token_data(token_data) + if "id_token" not in token_data: return _redirect_to_login( error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect." ) + user_info = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce) + user_info = _parse_user_info(user_info) try: async with request.app.state.db() as session: user = await _ensure_user_exists_and_is_up_to_date( @@ -76,13 +118,15 @@ async def create_tokens( access_token_expiry=access_token_expiry, refresh_token_expiry=refresh_token_expiry, ) - response = RedirectResponse(url="/") # todo: sanitize a return url + response = RedirectResponse(url="/", status_code=HTTP_302_FOUND) # todo: sanitize a return url response = set_access_token_cookie( response=response, access_token=access_token, max_age=access_token_expiry ) response = set_refresh_token_cookie( response=response, refresh_token=refresh_token, max_age=refresh_token_expiry ) + response = delete_oauth2_state_cookie(response) + response = delete_oauth2_nonce_cookie(response) return response @@ -94,15 +138,19 @@ class UserInfo: profile_picture_url: Optional[str] -def _get_user_info(token_data: Dict[str, Any]) -> Optional[UserInfo]: +def _validate_token_data(token_data: Dict[str, Any]) -> None: """ - Parses token data and extracts user info if available. + Performs basic validations on the token data returned by the IDP. """ assert isinstance(token_data.get("access_token"), str) assert isinstance(token_type := token_data.get("token_type"), str) assert token_type.lower() == "bearer" - if (user_info := token_data.get("userinfo")) is None: - return None + + +def _parse_user_info(user_info: Dict[str, Any]) -> UserInfo: + """ + Parses user info from the IDP's ID token. + """ assert isinstance(subject := user_info.get("sub"), (str, int)) idp_user_id = str(subject) assert isinstance(email := user_info.get("email"), str) @@ -278,4 +326,15 @@ def _redirect_to_login(*, error: str) -> RedirectResponse: """ Creates a RedirectResponse to the login page to display an error message. """ - return RedirectResponse(url=URL("/login").include_query_params(error=error)) + url = URL("/login").include_query_params(error=error) + response = RedirectResponse(url=url) + response = delete_oauth2_state_cookie(response) + response = delete_oauth2_nonce_cookie(response) + return response + + +def _get_create_tokens_endpoint(*, request: Request, idp_name: str) -> str: + """ + Gets the endpoint for create tokens route. + """ + return str(request.url_for(create_tokens.__name__, idp_name=idp_name)) diff --git a/src/phoenix/server/oauth2.py b/src/phoenix/server/oauth2.py index 8bff39929b5..3ab521a3e6e 100644 --- a/src/phoenix/server/oauth2.py +++ b/src/phoenix/server/oauth2.py @@ -1,23 +1,35 @@ -from datetime import datetime, timedelta -from typing import Any, Dict, Generic, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable -from authlib.integrations.starlette_client import OAuth -from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client -from typing_extensions import TypeAlias, TypeVar +from authlib.integrations.base_client import BaseApp +from authlib.integrations.base_client.async_app import AsyncOAuth2Mixin +from authlib.integrations.base_client.async_openid import AsyncOpenIDMixin +from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncHttpxOAuth2Client from phoenix.config import OAuth2ClientConfig +class OAuth2Client(AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp): # type:ignore[misc] + """ + An OAuth2 client class that supports OpenID Connect. Adapted from authlib's + `StarletteOAuth2App` to be useable without integration with Starlette. + + https://github.com/lepture/authlib/blob/904d66bebd79bf39fb8814353a22bab7d3e092c4/authlib/integrations/starlette_client/apps.py#L58 + """ + + client_cls = AsyncHttpxOAuth2Client + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(framework=None, *args, **kwargs) + + class OAuth2Clients: def __init__(self) -> None: self._clients: Dict[str, OAuth2Client] = {} - self._oauth = OAuth(cache=_OAuth2ClientTTLCache[str, Any]()) def add_client(self, config: OAuth2ClientConfig) -> None: if (idp_name := config.idp_name) in self._clients: raise ValueError(f"oauth client already registered: {idp_name}") - client = self._oauth.register( - idp_name, + client = OAuth2Client( client_id=config.client_id, client_secret=config.client_secret, server_metadata_url=config.server_metadata_url, @@ -37,63 +49,3 @@ def from_configs(cls, configs: Iterable[OAuth2ClientConfig]) -> "OAuth2Clients": for config in configs: oauth2_clients.add_client(config) return oauth2_clients - - -_CacheKey = TypeVar("_CacheKey") -_CacheValue = TypeVar("_CacheValue") -_Expiry: TypeAlias = datetime -_MINUTE = timedelta(minutes=1) - - -class _OAuth2ClientTTLCache(Generic[_CacheKey, _CacheValue]): - """ - A TTL cache satisfying the interface required by the Authlib Starlette - integration. Provides an alternative to starlette session middleware. - """ - - def __init__(self, cleanup_interval: timedelta = 1 * _MINUTE) -> None: - self._data: Dict[_CacheKey, Tuple[_CacheValue, _Expiry]] = {} - self._last_cleanup_time = datetime.now() - self._cleanup_interval = cleanup_interval - - async def get(self, key: _CacheKey) -> Optional[_CacheValue]: - """ - Retrieves the value associated with the given key if it exists and has - not expired, otherwise, returns None. - """ - self._remove_expired_keys_if_cleanup_interval_exceeded() - if (value_and_expiry := self._data.get(key)) is None: - return None - value, expiry = value_and_expiry - if datetime.now() < expiry: - return value - self._data.pop(key, None) - return None - - async def set(self, key: _CacheKey, value: _CacheValue, expires: int) -> None: - """ - Sets the value associated with the given key to the provided value with - the given expiry time in seconds. - """ - self._remove_expired_keys_if_cleanup_interval_exceeded() - expiry = datetime.now() + timedelta(seconds=expires) - self._data[key] = (value, expiry) - - async def delete(self, key: _CacheKey) -> None: - """ - Removes the value associated with the given key if it exists. - """ - self._remove_expired_keys_if_cleanup_interval_exceeded() - self._data.pop(key, None) - - def _remove_expired_keys_if_cleanup_interval_exceeded(self) -> None: - time_since_last_cleanup = datetime.now() - self._last_cleanup_time - if time_since_last_cleanup > self._cleanup_interval: - self._remove_expired_keys() - - def _remove_expired_keys(self) -> None: - current_time = datetime.now() - delete_keys = [key for key, (_, expiry) in self._data.items() if expiry <= current_time] - for key in delete_keys: - self._data.pop(key, None) - self._last_cleanup_time = current_time