Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(auth): hard-delete users #4715

Merged
merged 1 commit into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ def upgrade() -> None:
server_default=sa.func.now(),
onupdate=sa.func.now(),
),
sa.Column(
"deleted_at",
sa.TIMESTAMP(timezone=True),
nullable=True,
),
sa.CheckConstraint(
"(password_hash IS NULL) = (password_salt IS NULL)",
name="password_hash_and_salt",
Expand Down
1 change: 0 additions & 1 deletion src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,6 @@ class User(Base):
updated_at: Mapped[datetime] = mapped_column(
UtcTimeStamp, server_default=func.now(), onupdate=func.now()
)
deleted_at: Mapped[Optional[datetime]] = mapped_column(UtcTimeStamp)
password_reset_token: Mapped[Optional["PasswordResetToken"]] = relationship(
"PasswordResetToken",
back_populates="user",
Expand Down
6 changes: 2 additions & 4 deletions src/phoenix/server/api/dataloaders/users.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from typing import DefaultDict, List, Optional

from sqlalchemy import and_, select
from sqlalchemy import select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

Expand All @@ -25,9 +25,7 @@ async def _load_fn(self, keys: List[Key]) -> List[Result]:
users_by_id: DefaultDict[Key, Result] = defaultdict(None)
async with self._db() as session:
data = await session.stream_scalars(
select(models.User).where(
and_(models.User.id.in_(user_ids), models.User.deleted_at.is_(None))
)
select(models.User).where(models.User.id.in_(user_ids))
)
async for user in data:
users_by_id[user.id] = user
Expand Down
6 changes: 2 additions & 4 deletions src/phoenix/server/api/mutations/api_key_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import strawberry
from sqlalchemy import and_, select
from sqlalchemy import select
from strawberry import UNSET
from strawberry.relay import GlobalID
from strawberry.types import Info
Expand Down Expand Up @@ -71,9 +71,7 @@ async def create_system_api_key(
system_user = await session.scalar(
select(models.User)
.join(models.UserRole) # Join User with UserRole
.where(
and_(models.UserRole.name == user_role.value, models.User.deleted_at.is_(None))
) # Filter where role is SYSTEM
.where(models.UserRole.name == user_role.value) # Filter where role is SYSTEM
.order_by(models.User.id)
.limit(1)
)
Expand Down
31 changes: 9 additions & 22 deletions src/phoenix/server/api/mutations/user_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List, Literal, Optional, Tuple

import strawberry
from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select, update
from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select
from sqlalchemy.orm import joinedload
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
from strawberry import UNSET
Expand Down Expand Up @@ -258,7 +258,6 @@ async def delete_users(
.where(
and_(
models.User.id.in_(user_ids),
models.User.deleted_at.is_(None),
models.User.user_role_id != system_user_role_id,
)
)
Expand All @@ -271,40 +270,30 @@ async def delete_users(
password_reset_token_ids = [
PasswordResetTokenId(id_)
async for id_ in await session.stream_scalars(
delete(models.PasswordResetToken)
.where(models.PasswordResetToken.user_id.in_(user_ids))
.returning(models.PasswordResetToken.id)
select(models.PasswordResetToken.id).where(
models.PasswordResetToken.user_id.in_(user_ids)
)
)
]
access_token_ids = [
AccessTokenId(id_)
async for id_ in await session.stream_scalars(
delete(models.AccessToken)
.where(models.AccessToken.user_id.in_(user_ids))
.returning(models.AccessToken.id)
select(models.AccessToken.id).where(models.AccessToken.user_id.in_(user_ids))
)
]
refresh_token_ids = [
RefreshTokenId(id_)
async for id_ in await session.stream_scalars(
delete(models.RefreshToken)
.where(models.RefreshToken.user_id.in_(user_ids))
.returning(models.RefreshToken.id)
select(models.RefreshToken.id).where(models.RefreshToken.user_id.in_(user_ids))
)
]
api_key_ids = [
ApiKeyId(id_)
async for id_ in await session.stream_scalars(
delete(models.ApiKey)
.where(models.ApiKey.user_id.in_(user_ids))
.returning(models.ApiKey.id)
select(models.ApiKey.id).where(models.ApiKey.user_id.in_(user_ids))
)
]
await session.execute(
update(models.User)
.where(models.User.id.in_(user_ids))
.values(deleted_at=func.now())
)
await session.execute(delete(models.User).where(models.User.id.in_(user_ids)))
await token_store.revoke(
*password_reset_token_ids,
*access_token_ids,
Expand All @@ -319,9 +308,7 @@ def _select_role_id_by_name(role_name: str) -> Select[Tuple[int]]:

def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]:
return (
select(models.User)
.where(and_(models.User.id == user_id, models.User.deleted_at.is_(None)))
.options(joinedload(models.User.role))
select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
)


Expand Down
15 changes: 3 additions & 12 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,7 @@ async def users(
stmt = (
select(models.User)
.join(models.UserRole)
.where(
and_(
models.UserRole.name != enums.UserRole.SYSTEM.value,
models.User.deleted_at.is_(None),
)
)
.where(models.UserRole.name != enums.UserRole.SYSTEM.value)
.order_by(models.User.email)
.options(joinedload(models.User.role))
)
Expand Down Expand Up @@ -477,9 +472,7 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
async with info.context.db() as session:
if not (
user := await session.scalar(
select(models.User).where(
and_(models.User.id == node_id, models.User.deleted_at.is_(None))
)
select(models.User).where(models.User.id == node_id)
)
):
raise NotFound(f"Unknown user: {id}")
Expand All @@ -499,9 +492,7 @@ async def viewer(self, info: Info[Context, None]) -> Optional[User]:
if (
user := await session.scalar(
select(models.User)
.where(
and_(models.User.id == int(user.identity), models.User.deleted_at.is_(None))
)
.where(models.User.id == int(user.identity))
.options(joinedload(models.User.role))
)
) is None:
Expand Down
15 changes: 5 additions & 10 deletions src/phoenix/server/api/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import secrets
from datetime import datetime, timedelta, timezone
from functools import partial
from typing import Tuple

from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sqlalchemy import Select, select
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from starlette.status import (
HTTP_204_NO_CONTENT,
Expand Down Expand Up @@ -80,7 +79,7 @@ async def login(request: Request) -> Response:

async with request.app.state.db() as session:
user = await session.scalar(
_select_active_user().filter_by(email=email).options(joinedload(models.User.role))
select(models.User).filter_by(email=email).options(joinedload(models.User.role))
)
if (
user is None
Expand Down Expand Up @@ -160,7 +159,7 @@ async def refresh_tokens(request: Request) -> Response:
async with request.app.state.db() as session:
if (
user := await session.scalar(
_select_active_user().filter_by(id=user_id).options(joinedload(models.User.role))
select(models.User).filter_by(id=user_id).options(joinedload(models.User.role))
)
) is None:
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="User not found")
Expand Down Expand Up @@ -191,7 +190,7 @@ async def initiate_password_reset(request: Request) -> Response:
assert isinstance(token_expiry := request.app.state.password_reset_token_expiry, timedelta)
async with request.app.state.db() as session:
user = await session.scalar(
_select_active_user()
select(models.User)
.filter_by(email=email)
.options(
joinedload(models.User.password_reset_token).load_only(models.PasswordResetToken.id)
Expand Down Expand Up @@ -228,7 +227,7 @@ async def reset_password(request: Request) -> Response:
raise INVALID_TOKEN
assert (user_id := claims.subject)
async with request.app.state.db() as session:
user = await session.scalar(_select_active_user().filter_by(id=int(user_id)))
user = await session.scalar(select(models.User).filter_by(id=int(user_id)))
if user is None or user.auth_method != enums.AuthMethod.LOCAL.value:
# Withold privileged information
return Response(status_code=HTTP_204_NO_CONTENT)
Expand All @@ -249,10 +248,6 @@ async def reset_password(request: Request) -> Response:
return response


def _select_active_user() -> Select[Tuple[models.User]]:
return select(models.User).where(models.User.deleted_at.is_(None))


LOGIN_FAILED_MESSAGE = "Invalid email and/or password"

MISSING_EMAIL = HTTPException(
Expand Down
Loading