From 946b9c5218c030ba06e344d156ae39dd21a71757 Mon Sep 17 00:00:00 2001 From: Xander Song Date: Sat, 21 Sep 2024 15:54:24 -0700 Subject: [PATCH] update username (#4714) --- src/phoenix/server/api/routers/oauth2.py | 82 ++++++++++-------------- 1 file changed, 33 insertions(+), 49 deletions(-) diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index b720efc4f7..c15d244c80 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -1,6 +1,7 @@ import re from dataclasses import dataclass from datetime import timedelta +from random import randrange from typing import Any, Dict, Optional, Tuple from urllib.parse import unquote @@ -12,6 +13,7 @@ from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload +from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped] from starlette.datastructures import URL from starlette.responses import RedirectResponse from starlette.routing import Router @@ -156,7 +158,7 @@ async def create_tokens( oauth2_client_id=str(oauth2_client.client_id), user_info=user_info, ) - except (EmailAlreadyInUse, UsernameAlreadyInUse) as error: + except EmailAlreadyInUse as error: return _redirect_to_login(error=str(error)) access_token, refresh_token = await create_access_and_refresh_tokens( user=user, @@ -223,8 +225,8 @@ async def _ensure_user_exists_and_is_up_to_date( ) if user is None: user = await _create_user(session, oauth2_client_id=oauth2_client_id, user_info=user_info) - elif not _user_is_up_to_date(user=user, user_info=user_info): - user = await _update_user(session, user_id=user.id, user_info=user_info) + elif user.email != user_info.email: + user = await _update_user_email(session, user_id=user.id, email=user_info.email) return user @@ -258,11 +260,13 @@ async def _create_user( """ Creates a new user with the user info from the IDP. """ - await _ensure_email_and_username_are_not_in_use( + email_exists, username_exists = await _email_and_username_exist( session, - email=user_info.email, - username=user_info.username, + email=(email := user_info.email), + username=(username := user_info.username), ) + if email_exists: + raise EmailAlreadyInUse(f"An account for {email} is already in use.") member_role_id = ( select(models.UserRole.id) .where(models.UserRole.name == UserRole.MEMBER.value) @@ -275,8 +279,8 @@ async def _create_user( user_role_id=member_role_id, oauth2_client_id=oauth2_client_id, oauth2_user_id=user_info.idp_user_id, - username=user_info.username, - email=user_info.email, + username=_with_random_suffix(username) if username and username_exists else username, + email=email, profile_picture_url=user_info.profile_picture_url, reset_password=False, ) @@ -289,27 +293,19 @@ async def _create_user( return user -async def _update_user( - session: AsyncSession, /, *, user_id: int, user_info: UserInfo -) -> models.User: +async def _update_user_email(session: AsyncSession, /, *, user_id: int, email: str) -> models.User: """ - Updates an existing user with user info from the IDP. + Updates an existing user's email. """ - await _ensure_email_and_username_are_not_in_use( - session, - email=user_info.email, - username=user_info.username, - ) - await session.execute( - update(models.User) - .where(models.User.id == user_id) - .values( - username=user_info.username, - email=user_info.email, - profile_picture_url=user_info.profile_picture_url, + try: + await session.execute( + update(models.User) + .where(models.User.id == user_id) + .values(email=email) + .options(joinedload(models.User.role)) ) - .options(joinedload(models.User.role)) - ) + except IntegrityError: + raise EmailAlreadyInUse(f"An account for {email} is already in use.") user = await session.scalar( select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role)) ) # query user again for joined load @@ -317,11 +313,11 @@ async def _update_user( return user -async def _ensure_email_and_username_are_not_in_use( +async def _email_and_username_exist( session: AsyncSession, /, *, email: str, username: Optional[str] -) -> None: +) -> Tuple[bool, bool]: """ - Raises an error if the email or username are already in use. + Checks whether the email and username are already in use. """ [(email_exists, username_exists)] = ( await session.execute( @@ -343,32 +339,13 @@ async def _ensure_email_and_username_are_not_in_use( ).where(or_(models.User.email == email, models.User.username == username)) ) ).all() - if email_exists: - raise EmailAlreadyInUse(f"An account for {email} is already in use.") - if username_exists: - raise UsernameAlreadyInUse(f'An account already exists with username "{username}".') - - -def _user_is_up_to_date(*, user: models.User, user_info: UserInfo) -> bool: - """ - Determines whether the user's tuple in the database is up-to-date with the - IDP's user info. - """ - return ( - user.email == user_info.email - and user.username == user_info.username - and user.profile_picture_url == user_info.profile_picture_url - ) + return email_exists, username_exists class EmailAlreadyInUse(Exception): pass -class UsernameAlreadyInUse(Exception): - pass - - def _redirect_to_login(*, error: str) -> RedirectResponse: """ Creates a RedirectResponse to the login page to display an error message. @@ -433,6 +410,13 @@ def _is_relative_url(url: str) -> bool: return bool(_RELATIVE_URL_PATTERN.match(url)) +def _with_random_suffix(string: str) -> str: + """ + Appends a random suffix. + """ + return f"{string}-{randrange(10_000, 100_000)}" + + _RETURN_URL = "return_url" _JWT_ALGORITHM = "HS256" _INVALID_OAUTH2_STATE_MESSAGE = (