Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Sep 18, 2024
1 parent d16e796 commit b8729b5
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 121 deletions.
2 changes: 1 addition & 1 deletion app/src/pages/auth/LoginPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { Flex, View } from "@arizeai/components";

import { AuthLayout } from "./AuthLayout";
import { LoginForm } from "./LoginForm";
import { OAuth2Login } from "./Oauth2Login";
import { OAuth2Login } from "./OAuth2Login";
import { PhoenixLogo } from "./PhoenixLogo";

const separatorCSS = css`
Expand Down
File renamed without changes.
41 changes: 0 additions & 41 deletions app/src/pages/auth/oAuthCallbackLoader.ts

This file was deleted.

14 changes: 8 additions & 6 deletions src/phoenix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@
from urllib.parse import urlparse

import pandas as pd
from typing_extensions import TypeAlias

from phoenix.utilities.re import parse_env_headers

EnvVarName: TypeAlias = str
EnvVarValue: TypeAlias = str

logger = getLogger(__name__)

# Phoenix environment variables
Expand Down Expand Up @@ -211,7 +207,7 @@ def get_env_refresh_token_expiry() -> timedelta:
@dataclass(frozen=True)
class OAuth2ClientConfig:
idp_name: str
display_name: str
idp_display_name: str
client_id: str
client_secret: str
server_metadata_url: str
Expand Down Expand Up @@ -257,7 +253,7 @@ def from_env(cls, idp_name: str) -> "OAuth2ClientConfig":
)
return cls(
idp_name=idp_name,
display_name=os.getenv(
idp_display_name=os.getenv(
f"PHOENIX_OAUTH2_{idp_name_upper}_DISPLAY_NAME",
_get_default_idp_display_name(idp_name),
),
Expand Down Expand Up @@ -473,6 +469,9 @@ class OAuth2Idp(Enum):


def _get_default_idp_display_name(idp_name: str) -> str:
"""
Get the default display name for an OAuth2 IDP.
"""
if idp_name == OAuth2Idp.AWS_COGNITO.value:
return "AWS Cognito"
if idp_name == OAuth2Idp.MICROSOFT_ENTRA_ID.value:
Expand All @@ -481,6 +480,9 @@ def _get_default_idp_display_name(idp_name: str) -> str:


def _get_default_server_metadata_url(idp_name: str) -> Optional[str]:
"""
Gets the default server metadata URL for an OAuth2 IDP.
"""
if idp_name == OAuth2Idp.GOOGLE.value:
return "https://accounts.google.com/.well-known/openid-configuration"
return None
Expand Down
4 changes: 0 additions & 4 deletions src/phoenix/db/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ class UserRole(Enum):
MEMBER = "MEMBER"


class IdentityProviderName(Enum):
LOCAL = "local"


COLUMN_ENUMS: Mapping[InstrumentedAttribute[str], Type[Enum]] = {
models.UserRole.name: UserRole,
}
135 changes: 73 additions & 62 deletions src/phoenix/server/api/routers/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from authlib.integrations.starlette_client import OAuthError
from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client
from fastapi import APIRouter, Depends, Path, Request
from fastapi import APIRouter, Path, Request
from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
Expand All @@ -20,26 +20,17 @@
from phoenix.db.enums import UserRole
from phoenix.server.bearer_auth import create_access_and_refresh_tokens
from phoenix.server.jwt_store import JwtStore
from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter

ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"
_LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"

rate_limiter = ServerRateLimiter(
per_second_rate_limit=0.2,
enforcement_window_seconds=30,
partition_seconds=60,
active_partitions=2,
)
login_rate_limiter = fastapi_rate_limiter(rate_limiter, paths=["/login"])
router = APIRouter(
prefix="/oauth2", include_in_schema=False, dependencies=[Depends(login_rate_limiter)]
)

router = APIRouter(prefix="/oauth2", include_in_schema=False)


@router.post("/{idp_name}/login")
async def login(
request: Request,
idp_name: Annotated[str, Path(min_length=1, pattern=ALPHANUMS_AND_UNDERSCORES)],
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
) -> RedirectResponse:
if not isinstance(
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
Expand All @@ -53,7 +44,7 @@ async def login(
@router.get("/{idp_name}/tokens")
async def create_tokens(
request: Request,
idp_name: Annotated[str, Path(min_length=1, pattern=ALPHANUMS_AND_UNDERSCORES)],
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
) -> RedirectResponse:
assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
Expand All @@ -63,10 +54,10 @@ async def create_tokens(
):
return _redirect_to_login(error=f"Unknown IDP: {idp_name}.")
try:
token = await oauth2_client.authorize_access_token(request)
token_data = await oauth2_client.authorize_access_token(request)
except OAuthError as error:
return _redirect_to_login(error=str(error))
if (user_info := _get_user_info(token)) is None:
if (user_info := _get_user_info(token_data)) is None:
return _redirect_to_login(
error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect."
)
Expand Down Expand Up @@ -103,11 +94,14 @@ class UserInfo:
profile_picture_url: Optional[str]


def _get_user_info(token: Dict[str, Any]) -> Optional[UserInfo]:
assert isinstance(token.get("access_token"), str)
assert isinstance(token_type := token.get("token_type"), str)
def _get_user_info(token_data: Dict[str, Any]) -> Optional[UserInfo]:
"""
Parses token data and extracts user info if available.
"""
assert isinstance(token_data.get("access_token"), str)
assert isinstance(token_type := token_data.get("token_type"), str)
assert token_type.lower() == "bearer"
if (user_info := token.get("userinfo")) is None:
if (user_info := token_data.get("userinfo")) is None:
return None
assert isinstance(subject := user_info.get("sub"), (str, int))
idp_user_id = str(subject)
Expand Down Expand Up @@ -135,14 +129,18 @@ async def _ensure_user_exists_and_is_up_to_date(
)
if user is None:
user = await _create_user(session, oauth2_client_id=oauth2_client_id, user_info=user_info)
elif _db_user_is_outdated(user=user, user_info=user_info):
elif not _user_is_up_to_date(user=user, user_info=user_info):
user = await _update_user(session, user_id=user.id, user_info=user_info)
return user


async def _get_user(
session: AsyncSession, /, *, oauth2_client_id: str, idp_user_id: str
) -> Optional[models.User]:
"""
Retrieves the user uniquely identified by the given OAuth2 client ID and IDP
user ID.
"""
user = await session.scalar(
select(models.User)
.where(
Expand All @@ -156,43 +154,16 @@ async def _get_user(
return user


async def _ensure_email_and_username_are_not_in_use(
session: AsyncSession, /, *, email: str, username: Optional[str]
) -> None:
[(email_exists, username_exists)] = (
await session.execute(
select(
cast(
func.coalesce(
func.max(case((models.User.email == email, 1), else_=0)),
0,
),
Boolean,
).label("email_exists"),
cast(
func.coalesce(
func.max(case((models.User.username == username, 1), else_=0)),
0,
),
Boolean,
).label("username_exists"),
).where(or_(models.User.email == email, models.User.username == username))
)
).all()
if email_exists:
raise EmailAlreadyInUse(f"An account for {email} is already in use.")
if username_exists:
raise UsernameAlreadyInUse(f'An account already exists with username "{username}".')
return None


async def _create_user(
session: AsyncSession,
/,
*,
oauth2_client_id: str,
user_info: UserInfo,
) -> models.User:
"""
Creates a new user with the user info from the IDP.
"""
await _ensure_email_and_username_are_not_in_use(
session,
email=user_info.email,
Expand All @@ -213,22 +184,27 @@ async def _create_user(
username=user_info.username,
email=user_info.email,
profile_picture_url=user_info.profile_picture_url,
password_hash=None,
password_salt=None,
reset_password=False,
)
)
assert isinstance(user_id, int)
user = await session.scalar(
select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
) # query user for joined load
) # query user again for joined load
assert isinstance(user, models.User)
return user


async def _update_user(
session: AsyncSession, /, *, user_id: int, user_info: UserInfo
) -> models.User:
"""
Updates an existing user with user info from the IDP.
"""
await _ensure_email_and_username_are_not_in_use(
session,
email=user_info.email,
username=user_info.username,
)
await session.execute(
update(models.User)
.where(models.User.id == user_id)
Expand All @@ -239,19 +215,54 @@ async def _update_user(
)
.options(joinedload(models.User.role))
)
assert isinstance(user_id, int)
user = await session.scalar(
select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
) # query user for joined load
) # query user again for joined load
assert isinstance(user, models.User)
return user


def _db_user_is_outdated(*, user: models.User, user_info: UserInfo) -> bool:
async def _ensure_email_and_username_are_not_in_use(
session: AsyncSession, /, *, email: str, username: Optional[str]
) -> None:
"""
Raises an error if the email or username are already in use.
"""
[(email_exists, username_exists)] = (
await session.execute(
select(
cast(
func.coalesce(
func.max(case((models.User.email == email, 1), else_=0)),
0,
),
Boolean,
).label("email_exists"),
cast(
func.coalesce(
func.max(case((models.User.username == username, 1), else_=0)),
0,
),
Boolean,
).label("username_exists"),
).where(or_(models.User.email == email, models.User.username == username))
)
).all()
if email_exists:
raise EmailAlreadyInUse(f"An account for {email} is already in use.")
if username_exists:
raise UsernameAlreadyInUse(f'An account already exists with username "{username}".')


def _user_is_up_to_date(*, user: models.User, user_info: UserInfo) -> bool:
"""
Determines whether the user's tuple in the database is up-to-date with the
IDP's user info.
"""
return (
user.email != user_info.email
or user.username != user_info.username
or user.profile_picture_url != user_info.profile_picture_url
user.email == user_info.email
and user.username == user_info.username
and user.profile_picture_url == user_info.profile_picture_url
)


Expand Down
2 changes: 1 addition & 1 deletion src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json"
if serve_ui and web_manifest_path.is_file():
oauth2_idps = [
OAuth2Idp(name=config.idp_name, displayName=config.display_name)
OAuth2Idp(name=config.idp_name, displayName=config.idp_display_name)
for config in oauth2_client_configs or []
]
app.mount(
Expand Down
5 changes: 3 additions & 2 deletions src/phoenix/server/bearer_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,10 @@ async def create_access_and_refresh_tokens(
refresh_token_expiry: timedelta,
) -> Tuple[AccessToken, RefreshToken]:
issued_at = datetime.now(timezone.utc)
user_id = UserId(user.id)
user_role = UserRole(user.role.name)
refresh_token_claims = RefreshTokenClaims(
subject=UserId(user.id),
subject=user_id,
issued_at=issued_at,
expiration_time=issued_at + refresh_token_expiry,
attributes=RefreshTokenAttributes(
Expand All @@ -148,7 +149,7 @@ async def create_access_and_refresh_tokens(
)
refresh_token, refresh_token_id = await token_store.create_refresh_token(refresh_token_claims)
access_token_claims = AccessTokenClaims(
subject=UserId(user.id),
subject=user_id,
issued_at=issued_at,
expiration_time=issued_at + access_token_expiry,
attributes=AccessTokenAttributes(
Expand Down
Loading

0 comments on commit b8729b5

Please sign in to comment.