Skip to content

Commit

Permalink
update username (#4714)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Sep 21, 2024
1 parent f6adb34 commit 946b9c5
Showing 1 changed file with 33 additions and 49 deletions.
82 changes: 33 additions & 49 deletions src/phoenix/server/api/routers/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from dataclasses import dataclass
from datetime import timedelta
from random import randrange
from typing import Any, Dict, Optional, Tuple
from urllib.parse import unquote

Expand All @@ -12,6 +13,7 @@
from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
from starlette.datastructures import URL
from starlette.responses import RedirectResponse
from starlette.routing import Router
Expand Down Expand Up @@ -156,7 +158,7 @@ async def create_tokens(
oauth2_client_id=str(oauth2_client.client_id),
user_info=user_info,
)
except (EmailAlreadyInUse, UsernameAlreadyInUse) as error:
except EmailAlreadyInUse as error:
return _redirect_to_login(error=str(error))
access_token, refresh_token = await create_access_and_refresh_tokens(
user=user,
Expand Down Expand Up @@ -223,8 +225,8 @@ async def _ensure_user_exists_and_is_up_to_date(
)
if user is None:
user = await _create_user(session, oauth2_client_id=oauth2_client_id, user_info=user_info)
elif not _user_is_up_to_date(user=user, user_info=user_info):
user = await _update_user(session, user_id=user.id, user_info=user_info)
elif user.email != user_info.email:
user = await _update_user_email(session, user_id=user.id, email=user_info.email)
return user


Expand Down Expand Up @@ -258,11 +260,13 @@ async def _create_user(
"""
Creates a new user with the user info from the IDP.
"""
await _ensure_email_and_username_are_not_in_use(
email_exists, username_exists = await _email_and_username_exist(
session,
email=user_info.email,
username=user_info.username,
email=(email := user_info.email),
username=(username := user_info.username),
)
if email_exists:
raise EmailAlreadyInUse(f"An account for {email} is already in use.")
member_role_id = (
select(models.UserRole.id)
.where(models.UserRole.name == UserRole.MEMBER.value)
Expand All @@ -275,8 +279,8 @@ async def _create_user(
user_role_id=member_role_id,
oauth2_client_id=oauth2_client_id,
oauth2_user_id=user_info.idp_user_id,
username=user_info.username,
email=user_info.email,
username=_with_random_suffix(username) if username and username_exists else username,
email=email,
profile_picture_url=user_info.profile_picture_url,
reset_password=False,
)
Expand All @@ -289,39 +293,31 @@ async def _create_user(
return user


async def _update_user(
session: AsyncSession, /, *, user_id: int, user_info: UserInfo
) -> models.User:
async def _update_user_email(session: AsyncSession, /, *, user_id: int, email: str) -> models.User:
"""
Updates an existing user with user info from the IDP.
Updates an existing user's email.
"""
await _ensure_email_and_username_are_not_in_use(
session,
email=user_info.email,
username=user_info.username,
)
await session.execute(
update(models.User)
.where(models.User.id == user_id)
.values(
username=user_info.username,
email=user_info.email,
profile_picture_url=user_info.profile_picture_url,
try:
await session.execute(
update(models.User)
.where(models.User.id == user_id)
.values(email=email)
.options(joinedload(models.User.role))
)
.options(joinedload(models.User.role))
)
except IntegrityError:
raise EmailAlreadyInUse(f"An account for {email} is already in use.")
user = await session.scalar(
select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
) # query user again for joined load
assert isinstance(user, models.User)
return user


async def _ensure_email_and_username_are_not_in_use(
async def _email_and_username_exist(
session: AsyncSession, /, *, email: str, username: Optional[str]
) -> None:
) -> Tuple[bool, bool]:
"""
Raises an error if the email or username are already in use.
Checks whether the email and username are already in use.
"""
[(email_exists, username_exists)] = (
await session.execute(
Expand All @@ -343,32 +339,13 @@ async def _ensure_email_and_username_are_not_in_use(
).where(or_(models.User.email == email, models.User.username == username))
)
).all()
if email_exists:
raise EmailAlreadyInUse(f"An account for {email} is already in use.")
if username_exists:
raise UsernameAlreadyInUse(f'An account already exists with username "{username}".')


def _user_is_up_to_date(*, user: models.User, user_info: UserInfo) -> bool:
"""
Determines whether the user's tuple in the database is up-to-date with the
IDP's user info.
"""
return (
user.email == user_info.email
and user.username == user_info.username
and user.profile_picture_url == user_info.profile_picture_url
)
return email_exists, username_exists


class EmailAlreadyInUse(Exception):
pass


class UsernameAlreadyInUse(Exception):
pass


def _redirect_to_login(*, error: str) -> RedirectResponse:
"""
Creates a RedirectResponse to the login page to display an error message.
Expand Down Expand Up @@ -433,6 +410,13 @@ def _is_relative_url(url: str) -> bool:
return bool(_RELATIVE_URL_PATTERN.match(url))


def _with_random_suffix(string: str) -> str:
"""
Appends a random suffix.
"""
return f"{string}-{randrange(10_000, 100_000)}"


_RETURN_URL = "return_url"
_JWT_ALGORITHM = "HS256"
_INVALID_OAUTH2_STATE_MESSAGE = (
Expand Down

0 comments on commit 946b9c5

Please sign in to comment.