diff --git a/app/src/pages/auth/LoginPage.tsx b/app/src/pages/auth/LoginPage.tsx
index a66701ca67..1660d9d757 100644
--- a/app/src/pages/auth/LoginPage.tsx
+++ b/app/src/pages/auth/LoginPage.tsx
@@ -1,12 +1,14 @@
import React from "react";
+import { css } from "@emotion/react";
-import { Flex, View } from "@arizeai/components";
+import { Button, Flex, Form, View } from "@arizeai/components";
import { AuthLayout } from "./AuthLayout";
import { LoginForm } from "./LoginForm";
import { PhoenixLogo } from "./PhoenixLogo";
export function LoginPage() {
+ const oAuthIdps = window.Config.oAuthIdps;
return (
@@ -15,6 +17,37 @@ export function LoginPage() {
+ {oAuthIdps.map((idp) => (
+
+ ))}
);
}
+
+type OAuthLoginFormProps = {
+ idpId: string;
+ idpDisplayName: string;
+};
+export function OAuthLoginForm({ idpId, idpDisplayName }: OAuthLoginFormProps) {
+ return (
+
+ );
+}
diff --git a/app/src/pages/auth/oAuthCallbackLoader.ts b/app/src/pages/auth/oAuthCallbackLoader.ts
new file mode 100644
index 0000000000..251ed76915
--- /dev/null
+++ b/app/src/pages/auth/oAuthCallbackLoader.ts
@@ -0,0 +1,41 @@
+// import { redirect } from "react-router";
+import { LoaderFunctionArgs } from "react-router-dom";
+
+export async function oAuthCallbackLoader(args: LoaderFunctionArgs) {
+ const queryParameters = new URL(args.request.url).searchParams;
+ const authorizationCode = queryParameters.get("code");
+ const state = queryParameters.get("state");
+ const actualState = sessionStorage.getItem("oAuthState");
+ sessionStorage.removeItem("oAuthState");
+ if (
+ authorizationCode == undefined ||
+ state == undefined ||
+ actualState == undefined ||
+ state !== actualState
+ ) {
+ // todo: display error message
+ return null;
+ }
+ const origin = new URL(window.location.href).origin;
+ const redirectUri = `${origin}/oauth-callback`;
+ try {
+ const response = await fetch("/auth/oauth-tokens", {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ },
+ body: JSON.stringify({
+ authorization_code: authorizationCode,
+ redirect_uri: redirectUri,
+ }),
+ });
+ if (!response.ok) {
+ // todo: parse response body and display error message
+ return null;
+ }
+ } catch (error) {
+ // todo: display error
+ }
+ // redirect("/");
+ return null;
+}
diff --git a/app/src/window.d.ts b/app/src/window.d.ts
index f50eca98f9..08ab1cb825 100644
--- a/app/src/window.d.ts
+++ b/app/src/window.d.ts
@@ -1,5 +1,10 @@
export {};
+type OAuthIdp = {
+ id: string;
+ displayName: string;
+};
+
declare global {
interface Window {
Config: {
@@ -15,6 +20,7 @@ declare global {
nSamples: number;
};
authenticationEnabled: boolean;
+ oAuthIdps: OAuthIdp[];
};
}
}
diff --git a/pyproject.toml b/pyproject.toml
index 1eaea37c10..a296055c8a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -61,6 +61,7 @@ dependencies = [
"fastapi",
"pydantic>=1.0,!=2.0.*,<3", # exclude 2.0.* since it does not support the `json_encoders` configuration setting
"pyjwt",
+ "authlib",
]
dynamic = ["version"]
@@ -406,6 +407,7 @@ module = [
"grpc.*",
"py_grpc_prometheus.*",
"orjson", # suppress fastapi internal type errors
+ "authlib.*",
]
ignore_missing_imports = true
diff --git a/src/phoenix/config.py b/src/phoenix/config.py
index 78f8cc626f..c51393868a 100644
--- a/src/phoenix/config.py
+++ b/src/phoenix/config.py
@@ -1,15 +1,21 @@
import os
import re
import tempfile
+from dataclasses import dataclass
from datetime import timedelta
from logging import getLogger
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import pandas as pd
+from typing_extensions import TypeAlias
from phoenix.utilities.re import parse_env_headers
+IdpId: TypeAlias = str
+EnvVarName: TypeAlias = str
+EnvVarValue: TypeAlias = str
+
logger = getLogger(__name__)
# Phoenix environment variables
@@ -201,6 +207,72 @@ def get_env_refresh_token_expiry() -> timedelta:
)
+@dataclass(frozen=True)
+class OAuthClientConfig:
+ idp_id: str
+ display_name: str
+ client_id: str
+ client_secret: str
+ server_metadata_url: Optional[str] = None
+ authorize_url: Optional[str] = None
+ access_token_url: Optional[str] = None
+
+ @classmethod
+ def from_env(cls, idp_id: str) -> "OAuthClientConfig":
+ idp_id_upper = idp_id.upper()
+ if (
+ client_id := os.getenv(client_id_env_var := f"PHOENIX_OAUTH_{idp_id_upper}_CLIENT_ID")
+ ) is None:
+ raise ValueError(
+ f"A client id must be set for the {idp_id} OAuth IDP "
+ f"via the {client_id_env_var} environment variable"
+ )
+ if (
+ client_secret := os.getenv(
+ client_secret_env_var := f"PHOENIX_OAUTH_{idp_id_upper}_CLIENT_SECRET"
+ )
+ ) is None:
+ raise ValueError(
+ f"A client secret must be set for the {idp_id} OAuth IDP "
+ f"via the {client_secret_env_var} environment variable"
+ )
+ return cls(
+ idp_id=idp_id,
+ display_name=os.getenv(
+ f"PHOENIX_OAUTH_{idp_id_upper}_DISPLAY_NAME", get_default_idp_display_name(idp_id)
+ ),
+ client_id=client_id,
+ client_secret=client_secret,
+ server_metadata_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_SERVER_METADATA_URL"),
+ access_token_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_ACCESS_TOKEN_URL"),
+ authorize_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_AUTHORIZE_URL"),
+ )
+
+ def __post_init__(self) -> None:
+ assert self.idp_id
+ if not self.display_name:
+ raise ValueError(f"OAuth display name for {self.idp_id} cannot be empty")
+ if not self.client_id:
+ raise ValueError(f"OAuth client id for {self.idp_id} cannot be empty")
+ if not self.client_secret:
+ raise ValueError(f"OAuth client secret for {self.idp_id} cannot be empty")
+
+
+def get_env_oauth_settings() -> List[OAuthClientConfig]:
+ """
+ Get OAuth settings from environment variables.
+ """
+
+ idp_ids = set()
+ pattern = re.compile(
+ r"^PHOENIX_OAUTH_(\w+)_(DISPLAY_NAME|CLIENT_ID|CLIENT_SECRET|SERVER_METADATA_URL|ACCESS_TOKEN_URL|AUTHORIZE_URL)$"
+ )
+ for env_var in os.environ:
+ if (match := pattern.match(env_var)) is not None and (idp_id := match.group(1).lower()):
+ idp_ids.add(idp_id)
+ return [OAuthClientConfig.from_env(idp_id) for idp_id in sorted(idp_ids)]
+
+
def _parse_duration(duration_str: str) -> timedelta:
"""
Parses a duration string into a timedelta object, assuming the duration is
@@ -385,5 +457,9 @@ def get_web_base_url() -> str:
return get_base_url()
+def get_default_idp_display_name(ipd_id: IdpId) -> str:
+ return ipd_id.replace("_", " ").title()
+
+
DEFAULT_PROJECT_NAME = "default"
_KUBERNETES_PHOENIX_PORT_PATTERN = re.compile(r"^tcp://\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3}:\d+$")
diff --git a/src/phoenix/server/api/routers/__init__.py b/src/phoenix/server/api/routers/__init__.py
index 354d5d106d..8c65c0c768 100644
--- a/src/phoenix/server/api/routers/__init__.py
+++ b/src/phoenix/server/api/routers/__init__.py
@@ -1,9 +1,11 @@
from .auth import router as auth_router
from .embeddings import create_embeddings_router
+from .oauth import router as oauth_router
from .v1 import create_v1_router
__all__ = [
"auth_router",
"create_embeddings_router",
"create_v1_router",
+ "oauth_router",
]
diff --git a/src/phoenix/server/api/routers/oauth.py b/src/phoenix/server/api/routers/oauth.py
new file mode 100644
index 0000000000..db4e82897e
--- /dev/null
+++ b/src/phoenix/server/api/routers/oauth.py
@@ -0,0 +1,39 @@
+from authlib.integrations.starlette_client import OAuthError
+from authlib.integrations.starlette_client import StarletteOAuth2App as OAuthClient
+from fastapi import APIRouter, Depends, HTTPException, Request
+from starlette.responses import RedirectResponse
+from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND
+
+from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter
+
+rate_limiter = ServerRateLimiter(
+ per_second_rate_limit=0.2,
+ enforcement_window_seconds=30,
+ partition_seconds=60,
+ active_partitions=2,
+)
+login_rate_limiter = fastapi_rate_limiter(rate_limiter, paths=["/login"])
+router = APIRouter(
+ prefix="/oauth", include_in_schema=False, dependencies=[Depends(login_rate_limiter)]
+)
+
+
+@router.post("/{idp}/login")
+async def login(request: Request, idp: str) -> RedirectResponse:
+ if not isinstance(oauth_client := request.app.state.oauth_clients.get_client(idp), OAuthClient):
+ raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp}")
+ redirect_uri = request.url_for("create_tokens", idp=idp)
+ response: RedirectResponse = await oauth_client.authorize_redirect(request, redirect_uri)
+ return response
+
+
+@router.get("/{idp}/tokens")
+async def create_tokens(request: Request, idp: str) -> RedirectResponse:
+ if not isinstance(oauth_client := request.app.state.oauth_clients.get_client(idp), OAuthClient):
+ raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp}")
+ try:
+ token = await oauth_client.authorize_access_token(request)
+ except OAuthError as error:
+ raise HTTPException(HTTP_401_UNAUTHORIZED, detail=str(error))
+ print(f"{token=}")
+ return RedirectResponse(url="/")
diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py
index 2070e09f73..2338b578e8 100644
--- a/src/phoenix/server/app.py
+++ b/src/phoenix/server/app.py
@@ -20,7 +20,9 @@
List,
NamedTuple,
Optional,
+ Sequence,
Tuple,
+ TypedDict,
Union,
cast,
)
@@ -50,6 +52,7 @@
from phoenix.config import (
DEFAULT_PROJECT_NAME,
SERVER_DIR,
+ OAuthClientConfig,
get_env_host,
get_env_port,
server_instrumentation_is_enabled,
@@ -90,7 +93,12 @@
UserRolesDataLoader,
UsersDataLoader,
)
-from phoenix.server.api.routers import auth_router, create_embeddings_router, create_v1_router
+from phoenix.server.api.routers import (
+ auth_router,
+ create_embeddings_router,
+ create_v1_router,
+ oauth_router,
+)
from phoenix.server.api.routers.v1 import REST_API_VERSION
from phoenix.server.api.schema import schema
from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
@@ -98,6 +106,7 @@
from phoenix.server.dml_event_handler import DmlEventHandler
from phoenix.server.grpc_server import GrpcServer
from phoenix.server.jwt_store import JwtStore
+from phoenix.server.oauth import OAuthClients
from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider
from phoenix.server.types import (
CanGetLastUpdatedAt,
@@ -143,6 +152,11 @@
_Callback: TypeAlias = Callable[[], Union[None, Awaitable[None]]]
+class OAuthIdp(TypedDict):
+ id: str
+ displayName: str
+
+
class AppConfig(NamedTuple):
has_inferences: bool
""" Whether the model has inferences (e.g. a primary dataset) """
@@ -154,6 +168,7 @@ class AppConfig(NamedTuple):
web_manifest_path: Path
authentication_enabled: bool
""" Whether authentication is enabled """
+ oauth_idps: Sequence[OAuthIdp]
class Static(StaticFiles):
@@ -202,6 +217,7 @@ async def get_response(self, path: str, scope: Scope) -> Response:
"is_development": self._app_config.is_development,
"manifest": self._web_manifest,
"authentication_enabled": self._app_config.authentication_enabled,
+ "oauth_idps": self._app_config.oauth_idps,
},
)
except Exception as e:
@@ -608,6 +624,7 @@ def create_app(
access_token_expiry: Optional[timedelta] = None,
refresh_token_expiry: Optional[timedelta] = None,
scaffolder_config: Optional[ScaffolderConfig] = None,
+ oauth_client_configs: Optional[List[OAuthClientConfig]] = None,
) -> FastAPI:
startup_callbacks_list: List[_Callback] = list(startup_callbacks)
shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks)
@@ -720,9 +737,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
app.include_router(graphql_router)
if authentication_enabled:
app.include_router(auth_router)
+ app.include_router(oauth_router)
app.add_middleware(GZipMiddleware)
web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json"
if serve_ui and web_manifest_path.is_file():
+ oauth_idps = [
+ OAuthIdp(id=config.idp_id, displayName=config.display_name)
+ for config in oauth_client_configs or []
+ ]
app.mount(
"/",
app=Static(
@@ -736,6 +758,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
is_development=dev,
authentication_enabled=authentication_enabled,
web_manifest_path=web_manifest_path,
+ oauth_idps=oauth_idps,
),
),
name="static",
@@ -744,6 +767,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
app.state.export_path = export_path
app.state.access_token_expiry = access_token_expiry
app.state.refresh_token_expiry = refresh_token_expiry
+ app.state.oauth_clients = OAuthClients.from_configs(oauth_client_configs or [])
app.state.db = db
app = _add_get_secret_method(app=app, secret=secret)
app = _add_get_token_store_method(app=app, token_store=token_store)
diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py
index 399e635b4b..6289ffbb06 100644
--- a/src/phoenix/server/main.py
+++ b/src/phoenix/server/main.py
@@ -25,6 +25,7 @@
get_env_grpc_port,
get_env_host,
get_env_host_root_path,
+ get_env_oauth_settings,
get_env_port,
get_env_refresh_token_expiry,
get_pids_path,
@@ -392,6 +393,7 @@ def _get_pid_file() -> Path:
access_token_expiry=get_env_access_token_expiry(),
refresh_token_expiry=get_env_refresh_token_expiry(),
scaffolder_config=scaffolder_config,
+ oauth_client_configs=get_env_oauth_settings(),
)
server = Server(config=Config(app, host=host, port=port, root_path=host_root_path)) # type: ignore
Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start()
diff --git a/src/phoenix/server/oauth.py b/src/phoenix/server/oauth.py
new file mode 100644
index 0000000000..5ce659d7da
--- /dev/null
+++ b/src/phoenix/server/oauth.py
@@ -0,0 +1,145 @@
+from dataclasses import asdict, dataclass
+from datetime import datetime, timedelta
+from types import MappingProxyType
+from typing import Any, Dict, Generic, List, Optional, Tuple
+
+from authlib.integrations.starlette_client import OAuth
+from authlib.integrations.starlette_client import StarletteOAuth2App as OAuthClient
+from typing_extensions import TypeAlias, TypeVar
+
+from phoenix.config import OAuthClientConfig
+
+IdpId: TypeAlias = str
+
+
+class OAuthClients:
+ def __init__(self) -> None:
+ self._clients: Dict[IdpId, OAuthClient] = {}
+ self._oauth = OAuth(cache=_OAuthClientTTLCache[str, Any]())
+
+ def add_client(self, config: OAuthClientConfig) -> None:
+ if (idp_id := config.idp_id) in self._clients:
+ raise ValueError(f"oauth client already registered: {idp_id}")
+ config = _apply_oauth_config_defaults(config)
+ server_metadata_url = config.server_metadata_url
+ authorize_url = config.authorize_url
+ access_token_url = config.access_token_url
+ if not (server_metadata_url or (authorize_url and access_token_url)):
+ raise ValueError(
+ f"{idp_id} OAuth client must have either a server metadata URL,"
+ " or authorize and access token URLs"
+ )
+ client = self._oauth.register(
+ idp_id,
+ client_id=config.client_id,
+ client_secret=config.client_secret,
+ server_metadata_url=server_metadata_url,
+ authorize_url=authorize_url,
+ access_token_url=access_token_url,
+ client_kwargs={"scope": "openid email profile"},
+ )
+ assert isinstance(client, OAuthClient)
+ self._clients[config.idp_id] = client
+
+ def get_client(self, idp_id: IdpId) -> OAuthClient:
+ if (client := self._clients.get(idp_id)) is None:
+ raise ValueError(f"unknown or unregistered oauth client: {idp_id}")
+ return client
+
+ @classmethod
+ def from_configs(cls, configs: List[OAuthClientConfig]) -> "OAuthClients":
+ oauth_clients = cls()
+ for config in configs:
+ oauth_clients.add_client(config)
+ return oauth_clients
+
+
+@dataclass
+class OAuthClientDefaultConfig:
+ idp_id: IdpId
+ display_name: Optional[str] = None
+ server_metadata_url: Optional[str] = None
+ authorize_url: Optional[str] = None
+ access_token_url: Optional[str] = None
+
+
+def _apply_oauth_config_defaults(config: OAuthClientConfig) -> OAuthClientConfig:
+ if (default_config := _OAUTH_CLIENT_DEFAULT_CONFIGS.get(config.idp_id)) is None:
+ return config
+ return OAuthClientConfig(
+ **{
+ **{k: v for k, v in asdict(default_config).items() if v is not None},
+ **{k: v for k, v in asdict(config).items() if v is not None},
+ }
+ )
+
+
+_OAUTH_CLIENT_DEFAULT_CONFIGS = MappingProxyType(
+ {
+ config.idp_id: config
+ for config in (
+ OAuthClientDefaultConfig(
+ idp_id="google",
+ server_metadata_url="https://accounts.google.com/.well-known/openid-configuration",
+ ),
+ )
+ }
+)
+
+_CacheKey = TypeVar("_CacheKey")
+_CacheValue = TypeVar("_CacheValue")
+_Expiry: TypeAlias = datetime
+_MINUTE = timedelta(minutes=1)
+
+
+class _OAuthClientTTLCache(Generic[_CacheKey, _CacheValue]):
+ """
+ A TTL cache satisfying the interface required by the Authlib Starlette
+ integration. Provides an alternative to starlette session middleware.
+ """
+
+ def __init__(self, cleanup_interval: timedelta = 10 * _MINUTE) -> None:
+ self._data: Dict[_CacheKey, Tuple[_CacheValue, _Expiry]] = {}
+ self._last_cleanup_time = datetime.now()
+ self._cleanup_interval = cleanup_interval
+
+ async def get(self, key: _CacheKey) -> Optional[_CacheValue]:
+ """
+ Retrieves the value associated with the given key if it exists and has
+ not expired, otherwise, returns None.
+ """
+ if (value_and_expiry := self._data.get(key)) is None:
+ return None
+ value, expiry = value_and_expiry
+ if datetime.now() < expiry:
+ return value
+ self._data.pop(key, None)
+ return None
+
+ async def set(self, key: _CacheKey, value: _CacheValue, expires: int) -> None:
+ """
+ Sets the value associated with the given key to the provided value with
+ the given expiry time in seconds.
+ """
+ self._remove_expired_keys_if_cleanup_interval_exceeded()
+ expiry = datetime.now() + timedelta(seconds=expires)
+ self._data[key] = (value, expiry)
+
+ async def delete(self, key: _CacheKey) -> None:
+ """
+ Removes the value associated with the given key if it exists.
+ """
+ self._remove_expired_keys_if_cleanup_interval_exceeded()
+ self._data.pop(key, None)
+
+ def _remove_expired_keys_if_cleanup_interval_exceeded(self) -> None:
+ time_since_last_cleanup = datetime.now() - self._last_cleanup_time
+ if time_since_last_cleanup > self._cleanup_interval:
+ self._remove_expired_keys()
+
+ def _remove_expired_keys(self) -> None:
+ current_time = datetime.now()
+ delete_keys = [key for key, (_, expiry) in self._data.items() if expiry <= current_time]
+ for key in delete_keys:
+ self._data.pop(key, None)
+ self._last_cleanup_time = current_time
diff --git a/src/phoenix/server/templates/index.html b/src/phoenix/server/templates/index.html
index 4ef66f7942..748288cbe8 100644
--- a/src/phoenix/server/templates/index.html
+++ b/src/phoenix/server/templates/index.html
@@ -87,6 +87,7 @@
nSamples: parseInt("{{n_samples}}"),
},
authenticationEnabled: Boolean("{{authentication_enabled}}" == "True"),
+ oAuthIdps: {{ oauth_idps | tojson }},
}),
writable: false
});