Skip to content

Commit

Permalink
fix(auth): soft-delete users (#4562)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Sep 10, 2024
1 parent 83536af commit 98cd236
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 23 deletions.
13 changes: 13 additions & 0 deletions integration_tests/auth/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,19 @@ def test_cannot_log_in_with_wrong_password(
with _EXPECTATION_401:
_log_in(wrong_password, email=u.email)

@pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN])
def test_cannot_log_in_with_deleted_user(
self,
role_or_user: _RoleOrUser,
_get_user: _GetUser,
_passwords: Iterator[_Password],
) -> None:
admin_user = _get_user(UserRoleInput.ADMIN)
user = _get_user(role_or_user)
admin_user.delete_users(user)
with _EXPECTATION_401:
user.log_in()


class TestLogOut:
def test_default_admin_cannot_log_out_during_testing(self) -> None:
Expand Down
6 changes: 4 additions & 2 deletions src/phoenix/server/api/dataloaders/users.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from typing import DefaultDict, List, Optional

from sqlalchemy import select
from sqlalchemy import and_, select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

Expand All @@ -25,7 +25,9 @@ async def _load_fn(self, keys: List[Key]) -> List[Result]:
users_by_id: DefaultDict[Key, Result] = defaultdict(None)
async with self._db() as session:
data = await session.stream_scalars(
select(models.User).where(models.User.id.in_(user_ids))
select(models.User).where(
and_(models.User.id.in_(user_ids), models.User.deleted_at.is_(None))
)
)
async for user in data:
users_by_id[user.id] = user
Expand Down
6 changes: 4 additions & 2 deletions src/phoenix/server/api/mutations/api_key_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import strawberry
from sqlalchemy import select
from sqlalchemy import and_, select
from strawberry import UNSET
from strawberry.relay import GlobalID
from strawberry.types import Info
Expand Down Expand Up @@ -71,7 +71,9 @@ async def create_system_api_key(
system_user = await session.scalar(
select(models.User)
.join(models.UserRole) # Join User with UserRole
.where(models.UserRole.name == user_role.value) # Filter where role is SYSTEM
.where(
and_(models.UserRole.name == user_role.value, models.User.deleted_at.is_(None))
) # Filter where role is SYSTEM
.order_by(models.User.id)
.limit(1)
)
Expand Down
54 changes: 41 additions & 13 deletions src/phoenix/server/api/mutations/user_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List, Literal, Optional, Tuple

import strawberry
from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select
from sqlalchemy import Boolean, Select, and_, case, cast, distinct, func, select, update
from sqlalchemy.orm import joinedload
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
from strawberry import UNSET
Expand All @@ -27,6 +27,7 @@
from phoenix.server.api.types.node import from_global_id_with_expected_type
from phoenix.server.api.types.User import User, to_gql_user
from phoenix.server.bearer_auth import PhoenixUser
from phoenix.server.types import AccessTokenId, ApiKeyId, RefreshTokenId


@strawberry.input
Expand Down Expand Up @@ -100,14 +101,14 @@ async def create_user(
session = await stack.enter_async_context(info.context.db())
user_role_id = await session.scalar(_select_role_id_by_name(input.role.value))
if user_role_id is None:
raise ValueError(f"Role {input.role.value} not found")
raise NotFound(f"Role {input.role.value} not found")
stack.enter_context(session.no_autoflush)
user.user_role_id = user_role_id
session.add(user)
try:
await session.flush()
except IntegrityError as error:
raise ValueError(_user_operation_error_message(error))
raise Conflict(_user_operation_error_message(error))
return UserMutationPayload(user=to_gql_user(user))

@strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore
Expand All @@ -125,16 +126,16 @@ async def patch_user(
requester = await session.scalar(_select_user_by_id(requester_id))
assert requester
if not (user := await session.scalar(_select_user_by_id(user_id))):
raise ValueError("User not found")
raise NotFound("User not found")
stack.enter_context(session.no_autoflush)
if input.new_role:
user_role_id = await session.scalar(_select_role_id_by_name(input.new_role.value))
if user_role_id is None:
raise ValueError(f"Role {input.new_role.value} not found")
raise NotFound(f"Role {input.new_role.value} not found")
user.user_role_id = user_role_id
if password := input.new_password:
if user.auth_method != enums.AuthMethod.LOCAL.value:
raise ValueError("Cannot modify password for non-local user")
raise Conflict("Cannot modify password for non-local user")
validate_password_format(password)
user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
user.password_hash = await info.context.hash_password(password, user.password_salt)
Expand All @@ -145,7 +146,7 @@ async def patch_user(
try:
await session.flush()
except IntegrityError as error:
raise ValueError(_user_operation_error_message(error, "modify"))
raise Conflict(_user_operation_error_message(error, "modify"))
assert user
if input.new_password:
await info.context.log_out(user.id)
Expand All @@ -163,15 +164,15 @@ async def patch_viewer(
async with AsyncExitStack() as stack:
session = await stack.enter_async_context(info.context.db())
if not (user := await session.scalar(_select_user_by_id(user_id))):
raise ValueError("User not found")
raise NotFound("User not found")
stack.enter_context(session.no_autoflush)
if password := input.new_password:
if user.auth_method != enums.AuthMethod.LOCAL.value:
raise ValueError("Cannot modify password for non-local user")
raise Conflict("Cannot modify password for non-local user")
if not (
current_password := input.current_password
) or not await info.context.is_valid_password(current_password, user):
raise ValueError("Valid current password is required to modify password")
raise Conflict("Valid current password is required to modify password")
validate_password_format(password)
user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH)
user.password_hash = await info.context.hash_password(password, user.password_salt)
Expand All @@ -183,7 +184,7 @@ async def patch_viewer(
try:
await session.flush()
except IntegrityError as error:
raise ValueError(_user_operation_error_message(error, "modify"))
raise Conflict(_user_operation_error_message(error, "modify"))
assert user
if input.new_password:
await info.context.log_out(user.id)
Expand All @@ -195,6 +196,7 @@ async def delete_users(
info: Info[Context, None],
input: DeleteUsersInput,
) -> None:
assert (token_store := info.context.token_store) is not None
if not input.user_ids:
return
user_ids = tuple(
Expand Down Expand Up @@ -250,6 +252,7 @@ async def delete_users(
.where(
and_(
models.User.id.in_(user_ids),
models.User.deleted_at.is_(None),
models.User.user_role_id != system_user_role_id,
)
)
Expand All @@ -259,7 +262,30 @@ async def delete_users(
raise Conflict("Cannot delete the default admin user")
if num_resolved_user_ids < len(user_ids):
raise NotFound("Some user IDs could not be found")
await session.execute(delete(models.User).where(models.User.id.in_(user_ids)))
access_token_ids = (
AccessTokenId(id)
for id in await session.scalars(
select(models.AccessToken.id).where(models.AccessToken.user_id.in_(user_ids))
)
)
refresh_token_ids = (
RefreshTokenId(id)
for id in await session.scalars(
select(models.AccessToken.id).where(models.AccessToken.user_id.in_(user_ids))
)
)
api_key_ids = (
ApiKeyId(id)
for id in await session.scalars(
select(models.ApiKey.id).where(models.ApiKey.user_id.in_(user_ids))
)
)
await token_store.revoke(*access_token_ids, *refresh_token_ids, *api_key_ids)
await session.execute(
update(models.User)
.where(models.User.id.in_(user_ids))
.values(deleted_at=func.now())
)


def _select_role_id_by_name(role_name: str) -> Select[Tuple[int]]:
Expand All @@ -268,7 +294,9 @@ def _select_role_id_by_name(role_name: str) -> Select[Tuple[int]]:

def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]:
return (
select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
select(models.User)
.where(and_(models.User.id == user_id, models.User.deleted_at.is_(None)))
.options(joinedload(models.User.role))
)


Expand Down
15 changes: 12 additions & 3 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ async def users(
stmt = (
select(models.User)
.join(models.UserRole)
.where(models.UserRole.name != enums.UserRole.SYSTEM.value)
.where(
and_(
models.UserRole.name != enums.UserRole.SYSTEM.value,
models.User.deleted_at.is_(None),
)
)
.order_by(models.User.email)
.options(joinedload(models.User.role))
)
Expand Down Expand Up @@ -472,7 +477,9 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
async with info.context.db() as session:
if not (
user := await session.scalar(
select(models.User).where(models.User.id == node_id)
select(models.User).where(
and_(models.User.id == node_id, models.User.deleted_at.is_(None))
)
)
):
raise NotFound(f"Unknown user: {id}")
Expand All @@ -492,7 +499,9 @@ async def viewer(self, info: Info[Context, None]) -> Optional[User]:
if (
user := await session.scalar(
select(models.User)
.where(models.User.id == int(user.identity))
.where(
and_(models.User.id == int(user.identity), models.User.deleted_at.is_(None))
)
.options(joinedload(models.User.role))
)
) is None:
Expand Down
10 changes: 7 additions & 3 deletions src/phoenix/server/api/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import partial

from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sqlalchemy import select
from sqlalchemy import and_, select
from sqlalchemy.orm import joinedload
from starlette.status import HTTP_204_NO_CONTENT, HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND

Expand Down Expand Up @@ -56,7 +56,9 @@ async def login(request: Request) -> Response:

async with request.app.state.db() as session:
user = await session.scalar(
select(OrmUser).where(OrmUser.email == email).options(joinedload(OrmUser.role))
select(OrmUser)
.where(and_(OrmUser.email == email, OrmUser.deleted_at.is_(None)))
.options(joinedload(OrmUser.role))
)
if (
user is None
Expand Down Expand Up @@ -142,7 +144,9 @@ async def refresh_tokens(request: Request) -> Response:
async with request.app.state.db() as session:
if (
user := await session.scalar(
select(OrmUser).where(OrmUser.id == user_id).options(joinedload(OrmUser.role))
select(OrmUser)
.where(and_(OrmUser.id == user_id, OrmUser.deleted_at.is_(None)))
.options(joinedload(OrmUser.role))
)
) is None:
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="User not found")
Expand Down

0 comments on commit 98cd236

Please sign in to comment.