Skip to content

Commit

Permalink
implement authorization code flow
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Sep 15, 2024
1 parent 2891b49 commit 98b3b7e
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 2 deletions.
35 changes: 34 additions & 1 deletion app/src/pages/auth/LoginPage.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import React from "react";
import { css } from "@emotion/react";

import { Flex, View } from "@arizeai/components";
import { Button, Flex, Form, View } from "@arizeai/components";

import { AuthLayout } from "./AuthLayout";
import { LoginForm } from "./LoginForm";
import { PhoenixLogo } from "./PhoenixLogo";

export function LoginPage() {
const oAuthIdps = window.Config.oAuthIdps;
return (
<AuthLayout>
<Flex direction="column" gap="size-200" alignItems="center">
Expand All @@ -15,6 +17,37 @@ export function LoginPage() {
</View>
</Flex>
<LoginForm />
{oAuthIdps.map((idp) => (
<OAuthLoginForm
key={idp.id}
idpId={idp.id}
idpDisplayName={idp.displayName}
/>
))}
</AuthLayout>
);
}

type OAuthLoginFormProps = {
idpId: string;
idpDisplayName: string;
};
export function OAuthLoginForm({ idpId, idpDisplayName }: OAuthLoginFormProps) {
return (
<Form key={idpId} action={`/oauth/${idpId}/login`} method="post">
<div
css={css`
margin-top: var(--ac-global-dimension-size-400);
margin-bottom: var(--ac-global-dimension-size-50);
button {
width: 100%;
}
`}
>
<Button variant="primary" type="submit">
Login with {idpDisplayName}
</Button>
</div>
</Form>
);
}
41 changes: 41 additions & 0 deletions app/src/pages/auth/oAuthCallbackLoader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// import { redirect } from "react-router";
import { LoaderFunctionArgs } from "react-router-dom";

export async function oAuthCallbackLoader(args: LoaderFunctionArgs) {
const queryParameters = new URL(args.request.url).searchParams;
const authorizationCode = queryParameters.get("code");
const state = queryParameters.get("state");
const actualState = sessionStorage.getItem("oAuthState");
sessionStorage.removeItem("oAuthState");
if (
authorizationCode == undefined ||
state == undefined ||
actualState == undefined ||
state !== actualState
) {
// todo: display error message
return null;
}
const origin = new URL(window.location.href).origin;
const redirectUri = `${origin}/oauth-callback`;
try {
const response = await fetch("/auth/oauth-tokens", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
authorization_code: authorizationCode,
redirect_uri: redirectUri,
}),
});
if (!response.ok) {
// todo: parse response body and display error message
return null;
}
} catch (error) {
// todo: display error
}
// redirect("/");
return null;
}
6 changes: 6 additions & 0 deletions app/src/window.d.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
export {};

type OAuthIdp = {
id: string;
displayName: string;
};

declare global {
interface Window {
Config: {
Expand All @@ -15,6 +20,7 @@ declare global {
nSamples: number;
};
authenticationEnabled: boolean;
oAuthIdps: OAuthIdp[];
};
}
}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ dependencies = [
"fastapi",
"pydantic>=1.0,!=2.0.*,<3", # exclude 2.0.* since it does not support the `json_encoders` configuration setting
"pyjwt",
"authlib",
]
dynamic = ["version"]

Expand Down Expand Up @@ -406,6 +407,7 @@ module = [
"grpc.*",
"py_grpc_prometheus.*",
"orjson", # suppress fastapi internal type errors
"authlib.*",
]
ignore_missing_imports = true

Expand Down
76 changes: 76 additions & 0 deletions src/phoenix/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import os
import re
import tempfile
from dataclasses import dataclass
from datetime import timedelta
from logging import getLogger
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import pandas as pd
from typing_extensions import TypeAlias

from phoenix.utilities.re import parse_env_headers

IdpId: TypeAlias = str
EnvVarName: TypeAlias = str
EnvVarValue: TypeAlias = str

logger = getLogger(__name__)

# Phoenix environment variables
Expand Down Expand Up @@ -201,6 +207,72 @@ def get_env_refresh_token_expiry() -> timedelta:
)


@dataclass(frozen=True)
class OAuthClientConfig:
idp_id: str
display_name: str
client_id: str
client_secret: str
server_metadata_url: Optional[str] = None
authorize_url: Optional[str] = None
access_token_url: Optional[str] = None

@classmethod
def from_env(cls, idp_id: str) -> "OAuthClientConfig":
idp_id_upper = idp_id.upper()
if (
client_id := os.getenv(client_id_env_var := f"PHOENIX_OAUTH_{idp_id_upper}_CLIENT_ID")
) is None:
raise ValueError(
f"A client id must be set for the {idp_id} OAuth IDP "
f"via the {client_id_env_var} environment variable"
)
if (
client_secret := os.getenv(
client_secret_env_var := f"PHOENIX_OAUTH_{idp_id_upper}_CLIENT_SECRET"
)
) is None:
raise ValueError(
f"A client secret must be set for the {idp_id} OAuth IDP "
f"via the {client_secret_env_var} environment variable"
)
return cls(
idp_id=idp_id,
display_name=os.getenv(
f"PHOENIX_OAUTH_{idp_id_upper}_DISPLAY_NAME", get_default_idp_display_name(idp_id)
),
client_id=client_id,
client_secret=client_secret,
server_metadata_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_SERVER_METADATA_URL"),
access_token_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_ACCESS_TOKEN_URL"),
authorize_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_AUTHORIZE_URL"),
)

def __post_init__(self) -> None:
assert self.idp_id
if not self.display_name:
raise ValueError(f"OAuth display name for {self.idp_id} cannot be empty")
if not self.client_id:
raise ValueError(f"OAuth client id for {self.idp_id} cannot be empty")
if not self.client_secret:
raise ValueError(f"OAuth client secret for {self.idp_id} cannot be empty")


def get_env_oauth_settings() -> List[OAuthClientConfig]:
"""
Get OAuth settings from environment variables.
"""

idp_ids = set()
pattern = re.compile(
r"^PHOENIX_OAUTH_(\w+)_(DISPLAY_NAME|CLIENT_ID|CLIENT_SECRET|SERVER_METADATA_URL|ACCESS_TOKEN_URL|AUTHORIZE_URL)$"
)
for env_var in os.environ:
if (match := pattern.match(env_var)) is not None and (idp_id := match.group(1).lower()):
idp_ids.add(idp_id)
return [OAuthClientConfig.from_env(idp_id) for idp_id in sorted(idp_ids)]


def _parse_duration(duration_str: str) -> timedelta:
"""
Parses a duration string into a timedelta object, assuming the duration is
Expand Down Expand Up @@ -385,5 +457,9 @@ def get_web_base_url() -> str:
return get_base_url()


def get_default_idp_display_name(ipd_id: IdpId) -> str:
return ipd_id.replace("_", " ").title()


DEFAULT_PROJECT_NAME = "default"
_KUBERNETES_PHOENIX_PORT_PATTERN = re.compile(r"^tcp://\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3}:\d+$")
2 changes: 2 additions & 0 deletions src/phoenix/server/api/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .auth import router as auth_router
from .embeddings import create_embeddings_router
from .oauth import router as oauth_router
from .v1 import create_v1_router

__all__ = [
"auth_router",
"create_embeddings_router",
"create_v1_router",
"oauth_router",
]
39 changes: 39 additions & 0 deletions src/phoenix/server/api/routers/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from authlib.integrations.starlette_client import OAuthError
from authlib.integrations.starlette_client import StarletteOAuth2App as OAuthClient
from fastapi import APIRouter, Depends, HTTPException, Request
from starlette.responses import RedirectResponse
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND

from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter

rate_limiter = ServerRateLimiter(
per_second_rate_limit=0.2,
enforcement_window_seconds=30,
partition_seconds=60,
active_partitions=2,
)
login_rate_limiter = fastapi_rate_limiter(rate_limiter, paths=["/login"])
router = APIRouter(
prefix="/oauth", include_in_schema=False, dependencies=[Depends(login_rate_limiter)]
)


@router.post("/{idp}/login")
async def login(request: Request, idp: str) -> RedirectResponse:
if not isinstance(oauth_client := request.app.state.oauth_clients.get_client(idp), OAuthClient):
raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp}")
redirect_uri = request.url_for("create_tokens", idp=idp)
response: RedirectResponse = await oauth_client.authorize_redirect(request, redirect_uri)
return response


@router.get("/{idp}/tokens")
async def create_tokens(request: Request, idp: str) -> RedirectResponse:
if not isinstance(oauth_client := request.app.state.oauth_clients.get_client(idp), OAuthClient):
raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp}")
try:
token = await oauth_client.authorize_access_token(request)
except OAuthError as error:
raise HTTPException(HTTP_401_UNAUTHORIZED, detail=str(error))
print(f"{token=}")
return RedirectResponse(url="/")
Loading

0 comments on commit 98b3b7e

Please sign in to comment.