diff --git a/app/src/pages/auth/LoginPage.tsx b/app/src/pages/auth/LoginPage.tsx index a66701ca67..1660d9d757 100644 --- a/app/src/pages/auth/LoginPage.tsx +++ b/app/src/pages/auth/LoginPage.tsx @@ -1,12 +1,14 @@ import React from "react"; +import { css } from "@emotion/react"; -import { Flex, View } from "@arizeai/components"; +import { Button, Flex, Form, View } from "@arizeai/components"; import { AuthLayout } from "./AuthLayout"; import { LoginForm } from "./LoginForm"; import { PhoenixLogo } from "./PhoenixLogo"; export function LoginPage() { + const oAuthIdps = window.Config.oAuthIdps; return ( @@ -15,6 +17,37 @@ export function LoginPage() { + {oAuthIdps.map((idp) => ( + + ))} ); } + +type OAuthLoginFormProps = { + idpId: string; + idpDisplayName: string; +}; +export function OAuthLoginForm({ idpId, idpDisplayName }: OAuthLoginFormProps) { + return ( +
+
+ +
+
+ ); +} diff --git a/app/src/pages/auth/oAuthCallbackLoader.ts b/app/src/pages/auth/oAuthCallbackLoader.ts new file mode 100644 index 0000000000..251ed76915 --- /dev/null +++ b/app/src/pages/auth/oAuthCallbackLoader.ts @@ -0,0 +1,41 @@ +// import { redirect } from "react-router"; +import { LoaderFunctionArgs } from "react-router-dom"; + +export async function oAuthCallbackLoader(args: LoaderFunctionArgs) { + const queryParameters = new URL(args.request.url).searchParams; + const authorizationCode = queryParameters.get("code"); + const state = queryParameters.get("state"); + const actualState = sessionStorage.getItem("oAuthState"); + sessionStorage.removeItem("oAuthState"); + if ( + authorizationCode == undefined || + state == undefined || + actualState == undefined || + state !== actualState + ) { + // todo: display error message + return null; + } + const origin = new URL(window.location.href).origin; + const redirectUri = `${origin}/oauth-callback`; + try { + const response = await fetch("/auth/oauth-tokens", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + authorization_code: authorizationCode, + redirect_uri: redirectUri, + }), + }); + if (!response.ok) { + // todo: parse response body and display error message + return null; + } + } catch (error) { + // todo: display error + } + // redirect("/"); + return null; +} diff --git a/app/src/window.d.ts b/app/src/window.d.ts index f50eca98f9..08ab1cb825 100644 --- a/app/src/window.d.ts +++ b/app/src/window.d.ts @@ -1,5 +1,10 @@ export {}; +type OAuthIdp = { + id: string; + displayName: string; +}; + declare global { interface Window { Config: { @@ -15,6 +20,7 @@ declare global { nSamples: number; }; authenticationEnabled: boolean; + oAuthIdps: OAuthIdp[]; }; } } diff --git a/pyproject.toml b/pyproject.toml index 1eaea37c10..a296055c8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ dependencies = [ "fastapi", "pydantic>=1.0,!=2.0.*,<3", # exclude 2.0.* since it does not support the `json_encoders` configuration setting "pyjwt", + "authlib", ] dynamic = ["version"] @@ -406,6 +407,7 @@ module = [ "grpc.*", "py_grpc_prometheus.*", "orjson", # suppress fastapi internal type errors + "authlib.*", ] ignore_missing_imports = true diff --git a/src/phoenix/config.py b/src/phoenix/config.py index 78f8cc626f..c51393868a 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -1,15 +1,21 @@ import os import re import tempfile +from dataclasses import dataclass from datetime import timedelta from logging import getLogger from pathlib import Path from typing import Dict, List, Optional, Tuple import pandas as pd +from typing_extensions import TypeAlias from phoenix.utilities.re import parse_env_headers +IdpId: TypeAlias = str +EnvVarName: TypeAlias = str +EnvVarValue: TypeAlias = str + logger = getLogger(__name__) # Phoenix environment variables @@ -201,6 +207,72 @@ def get_env_refresh_token_expiry() -> timedelta: ) +@dataclass(frozen=True) +class OAuthClientConfig: + idp_id: str + display_name: str + client_id: str + client_secret: str + server_metadata_url: Optional[str] = None + authorize_url: Optional[str] = None + access_token_url: Optional[str] = None + + @classmethod + def from_env(cls, idp_id: str) -> "OAuthClientConfig": + idp_id_upper = idp_id.upper() + if ( + client_id := os.getenv(client_id_env_var := f"PHOENIX_OAUTH_{idp_id_upper}_CLIENT_ID") + ) is None: + raise ValueError( + f"A client id must be set for the {idp_id} OAuth IDP " + f"via the {client_id_env_var} environment variable" + ) + if ( + client_secret := os.getenv( + client_secret_env_var := f"PHOENIX_OAUTH_{idp_id_upper}_CLIENT_SECRET" + ) + ) is None: + raise ValueError( + f"A client secret must be set for the {idp_id} OAuth IDP " + f"via the {client_secret_env_var} environment variable" + ) + return cls( + idp_id=idp_id, + display_name=os.getenv( + f"PHOENIX_OAUTH_{idp_id_upper}_DISPLAY_NAME", get_default_idp_display_name(idp_id) + ), + client_id=client_id, + client_secret=client_secret, + server_metadata_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_SERVER_METADATA_URL"), + access_token_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_ACCESS_TOKEN_URL"), + authorize_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_AUTHORIZE_URL"), + ) + + def __post_init__(self) -> None: + assert self.idp_id + if not self.display_name: + raise ValueError(f"OAuth display name for {self.idp_id} cannot be empty") + if not self.client_id: + raise ValueError(f"OAuth client id for {self.idp_id} cannot be empty") + if not self.client_secret: + raise ValueError(f"OAuth client secret for {self.idp_id} cannot be empty") + + +def get_env_oauth_settings() -> List[OAuthClientConfig]: + """ + Get OAuth settings from environment variables. + """ + + idp_ids = set() + pattern = re.compile( + r"^PHOENIX_OAUTH_(\w+)_(DISPLAY_NAME|CLIENT_ID|CLIENT_SECRET|SERVER_METADATA_URL|ACCESS_TOKEN_URL|AUTHORIZE_URL)$" + ) + for env_var in os.environ: + if (match := pattern.match(env_var)) is not None and (idp_id := match.group(1).lower()): + idp_ids.add(idp_id) + return [OAuthClientConfig.from_env(idp_id) for idp_id in sorted(idp_ids)] + + def _parse_duration(duration_str: str) -> timedelta: """ Parses a duration string into a timedelta object, assuming the duration is @@ -385,5 +457,9 @@ def get_web_base_url() -> str: return get_base_url() +def get_default_idp_display_name(ipd_id: IdpId) -> str: + return ipd_id.replace("_", " ").title() + + DEFAULT_PROJECT_NAME = "default" _KUBERNETES_PHOENIX_PORT_PATTERN = re.compile(r"^tcp://\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3}:\d+$") diff --git a/src/phoenix/server/api/routers/__init__.py b/src/phoenix/server/api/routers/__init__.py index 354d5d106d..8c65c0c768 100644 --- a/src/phoenix/server/api/routers/__init__.py +++ b/src/phoenix/server/api/routers/__init__.py @@ -1,9 +1,11 @@ from .auth import router as auth_router from .embeddings import create_embeddings_router +from .oauth import router as oauth_router from .v1 import create_v1_router __all__ = [ "auth_router", "create_embeddings_router", "create_v1_router", + "oauth_router", ] diff --git a/src/phoenix/server/api/routers/oauth.py b/src/phoenix/server/api/routers/oauth.py new file mode 100644 index 0000000000..db4e82897e --- /dev/null +++ b/src/phoenix/server/api/routers/oauth.py @@ -0,0 +1,39 @@ +from authlib.integrations.starlette_client import OAuthError +from authlib.integrations.starlette_client import StarletteOAuth2App as OAuthClient +from fastapi import APIRouter, Depends, HTTPException, Request +from starlette.responses import RedirectResponse +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND + +from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter + +rate_limiter = ServerRateLimiter( + per_second_rate_limit=0.2, + enforcement_window_seconds=30, + partition_seconds=60, + active_partitions=2, +) +login_rate_limiter = fastapi_rate_limiter(rate_limiter, paths=["/login"]) +router = APIRouter( + prefix="/oauth", include_in_schema=False, dependencies=[Depends(login_rate_limiter)] +) + + +@router.post("/{idp}/login") +async def login(request: Request, idp: str) -> RedirectResponse: + if not isinstance(oauth_client := request.app.state.oauth_clients.get_client(idp), OAuthClient): + raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp}") + redirect_uri = request.url_for("create_tokens", idp=idp) + response: RedirectResponse = await oauth_client.authorize_redirect(request, redirect_uri) + return response + + +@router.get("/{idp}/tokens") +async def create_tokens(request: Request, idp: str) -> RedirectResponse: + if not isinstance(oauth_client := request.app.state.oauth_clients.get_client(idp), OAuthClient): + raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp}") + try: + token = await oauth_client.authorize_access_token(request) + except OAuthError as error: + raise HTTPException(HTTP_401_UNAUTHORIZED, detail=str(error)) + print(f"{token=}") + return RedirectResponse(url="/") diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 2070e09f73..2338b578e8 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -20,7 +20,9 @@ List, NamedTuple, Optional, + Sequence, Tuple, + TypedDict, Union, cast, ) @@ -50,6 +52,7 @@ from phoenix.config import ( DEFAULT_PROJECT_NAME, SERVER_DIR, + OAuthClientConfig, get_env_host, get_env_port, server_instrumentation_is_enabled, @@ -90,7 +93,12 @@ UserRolesDataLoader, UsersDataLoader, ) -from phoenix.server.api.routers import auth_router, create_embeddings_router, create_v1_router +from phoenix.server.api.routers import ( + auth_router, + create_embeddings_router, + create_v1_router, + oauth_router, +) from phoenix.server.api.routers.v1 import REST_API_VERSION from phoenix.server.api.schema import schema from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated @@ -98,6 +106,7 @@ from phoenix.server.dml_event_handler import DmlEventHandler from phoenix.server.grpc_server import GrpcServer from phoenix.server.jwt_store import JwtStore +from phoenix.server.oauth import OAuthClients from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider from phoenix.server.types import ( CanGetLastUpdatedAt, @@ -143,6 +152,11 @@ _Callback: TypeAlias = Callable[[], Union[None, Awaitable[None]]] +class OAuthIdp(TypedDict): + id: str + displayName: str + + class AppConfig(NamedTuple): has_inferences: bool """ Whether the model has inferences (e.g. a primary dataset) """ @@ -154,6 +168,7 @@ class AppConfig(NamedTuple): web_manifest_path: Path authentication_enabled: bool """ Whether authentication is enabled """ + oauth_idps: Sequence[OAuthIdp] class Static(StaticFiles): @@ -202,6 +217,7 @@ async def get_response(self, path: str, scope: Scope) -> Response: "is_development": self._app_config.is_development, "manifest": self._web_manifest, "authentication_enabled": self._app_config.authentication_enabled, + "oauth_idps": self._app_config.oauth_idps, }, ) except Exception as e: @@ -608,6 +624,7 @@ def create_app( access_token_expiry: Optional[timedelta] = None, refresh_token_expiry: Optional[timedelta] = None, scaffolder_config: Optional[ScaffolderConfig] = None, + oauth_client_configs: Optional[List[OAuthClientConfig]] = None, ) -> FastAPI: startup_callbacks_list: List[_Callback] = list(startup_callbacks) shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks) @@ -720,9 +737,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: app.include_router(graphql_router) if authentication_enabled: app.include_router(auth_router) + app.include_router(oauth_router) app.add_middleware(GZipMiddleware) web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json" if serve_ui and web_manifest_path.is_file(): + oauth_idps = [ + OAuthIdp(id=config.idp_id, displayName=config.display_name) + for config in oauth_client_configs or [] + ] app.mount( "/", app=Static( @@ -736,6 +758,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: is_development=dev, authentication_enabled=authentication_enabled, web_manifest_path=web_manifest_path, + oauth_idps=oauth_idps, ), ), name="static", @@ -744,6 +767,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: app.state.export_path = export_path app.state.access_token_expiry = access_token_expiry app.state.refresh_token_expiry = refresh_token_expiry + app.state.oauth_clients = OAuthClients.from_configs(oauth_client_configs or []) app.state.db = db app = _add_get_secret_method(app=app, secret=secret) app = _add_get_token_store_method(app=app, token_store=token_store) diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 399e635b4b..6289ffbb06 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -25,6 +25,7 @@ get_env_grpc_port, get_env_host, get_env_host_root_path, + get_env_oauth_settings, get_env_port, get_env_refresh_token_expiry, get_pids_path, @@ -392,6 +393,7 @@ def _get_pid_file() -> Path: access_token_expiry=get_env_access_token_expiry(), refresh_token_expiry=get_env_refresh_token_expiry(), scaffolder_config=scaffolder_config, + oauth_client_configs=get_env_oauth_settings(), ) server = Server(config=Config(app, host=host, port=port, root_path=host_root_path)) # type: ignore Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start() diff --git a/src/phoenix/server/oauth.py b/src/phoenix/server/oauth.py new file mode 100644 index 0000000000..5ce659d7da --- /dev/null +++ b/src/phoenix/server/oauth.py @@ -0,0 +1,145 @@ +from dataclasses import asdict, dataclass +from datetime import datetime, timedelta +from types import MappingProxyType +from typing import Any, Dict, Generic, List, Optional, Tuple + +from authlib.integrations.starlette_client import OAuth +from authlib.integrations.starlette_client import StarletteOAuth2App as OAuthClient +from typing_extensions import TypeAlias, TypeVar + +from phoenix.config import OAuthClientConfig + +IdpId: TypeAlias = str + + +class OAuthClients: + def __init__(self) -> None: + self._clients: Dict[IdpId, OAuthClient] = {} + self._oauth = OAuth(cache=_OAuthClientTTLCache[str, Any]()) + + def add_client(self, config: OAuthClientConfig) -> None: + if (idp_id := config.idp_id) in self._clients: + raise ValueError(f"oauth client already registered: {idp_id}") + config = _apply_oauth_config_defaults(config) + server_metadata_url = config.server_metadata_url + authorize_url = config.authorize_url + access_token_url = config.access_token_url + if not (server_metadata_url or (authorize_url and access_token_url)): + raise ValueError( + f"{idp_id} OAuth client must have either a server metadata URL," + " or authorize and access token URLs" + ) + client = self._oauth.register( + idp_id, + client_id=config.client_id, + client_secret=config.client_secret, + server_metadata_url=server_metadata_url, + authorize_url=authorize_url, + access_token_url=access_token_url, + client_kwargs={"scope": "openid email profile"}, + ) + assert isinstance(client, OAuthClient) + self._clients[config.idp_id] = client + + def get_client(self, idp_id: IdpId) -> OAuthClient: + if (client := self._clients.get(idp_id)) is None: + raise ValueError(f"unknown or unregistered oauth client: {idp_id}") + return client + + @classmethod + def from_configs(cls, configs: List[OAuthClientConfig]) -> "OAuthClients": + oauth_clients = cls() + for config in configs: + oauth_clients.add_client(config) + return oauth_clients + + +@dataclass +class OAuthClientDefaultConfig: + idp_id: IdpId + display_name: Optional[str] = None + server_metadata_url: Optional[str] = None + authorize_url: Optional[str] = None + access_token_url: Optional[str] = None + + +def _apply_oauth_config_defaults(config: OAuthClientConfig) -> OAuthClientConfig: + if (default_config := _OAUTH_CLIENT_DEFAULT_CONFIGS.get(config.idp_id)) is None: + return config + return OAuthClientConfig( + **{ + **{k: v for k, v in asdict(default_config).items() if v is not None}, + **{k: v for k, v in asdict(config).items() if v is not None}, + } + ) + + +_OAUTH_CLIENT_DEFAULT_CONFIGS = MappingProxyType( + { + config.idp_id: config + for config in ( + OAuthClientDefaultConfig( + idp_id="google", + server_metadata_url="https://accounts.google.com/.well-known/openid-configuration", + ), + ) + } +) + +_CacheKey = TypeVar("_CacheKey") +_CacheValue = TypeVar("_CacheValue") +_Expiry: TypeAlias = datetime +_MINUTE = timedelta(minutes=1) + + +class _OAuthClientTTLCache(Generic[_CacheKey, _CacheValue]): + """ + A TTL cache satisfying the interface required by the Authlib Starlette + integration. Provides an alternative to starlette session middleware. + """ + + def __init__(self, cleanup_interval: timedelta = 10 * _MINUTE) -> None: + self._data: Dict[_CacheKey, Tuple[_CacheValue, _Expiry]] = {} + self._last_cleanup_time = datetime.now() + self._cleanup_interval = cleanup_interval + + async def get(self, key: _CacheKey) -> Optional[_CacheValue]: + """ + Retrieves the value associated with the given key if it exists and has + not expired, otherwise, returns None. + """ + if (value_and_expiry := self._data.get(key)) is None: + return None + value, expiry = value_and_expiry + if datetime.now() < expiry: + return value + self._data.pop(key, None) + return None + + async def set(self, key: _CacheKey, value: _CacheValue, expires: int) -> None: + """ + Sets the value associated with the given key to the provided value with + the given expiry time in seconds. + """ + self._remove_expired_keys_if_cleanup_interval_exceeded() + expiry = datetime.now() + timedelta(seconds=expires) + self._data[key] = (value, expiry) + + async def delete(self, key: _CacheKey) -> None: + """ + Removes the value associated with the given key if it exists. + """ + self._remove_expired_keys_if_cleanup_interval_exceeded() + self._data.pop(key, None) + + def _remove_expired_keys_if_cleanup_interval_exceeded(self) -> None: + time_since_last_cleanup = datetime.now() - self._last_cleanup_time + if time_since_last_cleanup > self._cleanup_interval: + self._remove_expired_keys() + + def _remove_expired_keys(self) -> None: + current_time = datetime.now() + delete_keys = [key for key, (_, expiry) in self._data.items() if expiry <= current_time] + for key in delete_keys: + self._data.pop(key, None) + self._last_cleanup_time = current_time diff --git a/src/phoenix/server/templates/index.html b/src/phoenix/server/templates/index.html index 4ef66f7942..748288cbe8 100644 --- a/src/phoenix/server/templates/index.html +++ b/src/phoenix/server/templates/index.html @@ -87,6 +87,7 @@ nSamples: parseInt("{{n_samples}}"), }, authenticationEnabled: Boolean("{{authentication_enabled}}" == "True"), + oAuthIdps: {{ oauth_idps | tojson }}, }), writable: false });