Skip to content

Commit

Permalink
feat(auth): add deleteUsers mutation (#4537)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored and RogerHYang committed Sep 21, 2024
1 parent b80f532 commit 745cba7
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 5 deletions.
8 changes: 8 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,10 @@ input DeleteExperimentsInput {
experimentIds: [GlobalID!]!
}

input DeleteUsersInput {
userIds: [GlobalID!]!
}

type Dimension implements Node {
"""The Globally Unique ID of this object"""
id: GlobalID!
Expand Down Expand Up @@ -933,6 +937,7 @@ type Mutation {
createUser(input: CreateUserInput!): UserMutationPayload!
patchUser(input: PatchUserInput!): UserMutationPayload!
patchViewer(input: PatchViewerInput!): UserMutationPayload!
deleteUsers(input: DeleteUsersInput!): Void
}

"""An object with a Globally Unique ID"""
Expand Down Expand Up @@ -1523,3 +1528,6 @@ type ValidationResult {
enum VectorDriftMetric {
euclideanDistance
}

"""Represents NULL values"""
scalar Void
59 changes: 59 additions & 0 deletions integration_tests/auth/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
List,
Optional,
Protocol,
Sequence,
Tuple,
cast,
)
Expand Down Expand Up @@ -91,6 +92,10 @@ def __call__(
) -> None: ...


class _DeleteUsers(Protocol):
def __call__(self, token: _Token, /, *, user_ids: Sequence[_GqlId]) -> None: ...


class _PatchViewer(Protocol):
def __call__(
self,
Expand Down Expand Up @@ -236,6 +241,18 @@ def admin_token(
yield token


@pytest.fixture
def member_token(
get_new_user: _GetNewUser,
member_email: str,
member_password: str,
log_in: _LogIn,
) -> _Token:
member = get_new_user(UserRoleInput.MEMBER)
assert (token := member.token) is not None
return token


@pytest.fixture(scope="module")
def admin_email() -> _Email:
return "admin@localhost"
Expand All @@ -246,6 +263,16 @@ def admin_password() -> _Email:
return "admin"


@pytest.fixture(scope="module")
def member_email() -> _Email:
return "member@domain.com"


@pytest.fixture(scope="module")
def member_password() -> _Password:
return "Member-password1234"


@pytest.fixture(scope="module")
def create_user(
httpx_client: Callable[[], httpx.Client],
Expand Down Expand Up @@ -316,6 +343,38 @@ def _(
return _


@pytest.fixture(scope="module")
def delete_users(
httpx_client: Callable[[], httpx.Client],
) -> _DeleteUsers:
def _(
token: _Token,
/,
*,
user_ids: Sequence[_GqlId],
) -> None:
mutation = """
mutation ($userIds: [GlobalID!]!) {
deleteUsers(input: {userIds: $userIds})
}
"""
response = httpx_client().post(
urljoin(get_base_url(), "/graphql"),
json={
"query": mutation,
"variables": {"userIds": list(user_ids)},
},
cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token},
)
response.raise_for_status()
if (errors := response.json().get("errors")) is not None:
assert len(errors) == 1
error_message = errors[0]["message"]
raise Exception(error_message)

return _


@pytest.fixture(scope="module")
def patch_viewer(
httpx_client: Callable[[], httpx.Client],
Expand Down
57 changes: 57 additions & 0 deletions integration_tests/auth/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Iterator,
Optional,
Protocol,
Sequence,
Tuple,
)
from urllib.parse import urljoin
Expand All @@ -29,6 +30,7 @@
from phoenix.config import get_base_url
from phoenix.server.api.exceptions import Unauthorized
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
from strawberry.relay import GlobalID
from typing_extensions import TypeAlias

NOW = datetime.now(timezone.utc)
Expand Down Expand Up @@ -88,6 +90,10 @@ def __call__(
) -> None: ...


class _DeleteUsers(Protocol):
def __call__(self, token: _Token, /, *, user_ids: Sequence[_GqlId]) -> None: ...


class _PatchViewer(Protocol):
def __call__(
self,
Expand Down Expand Up @@ -471,6 +477,57 @@ def test_only_admin_can_change_username_for_non_self(
with expectation:
patch_user(token, gid, new_username=new_username)

def test_admin_can_delete_non_self_admin_and_member_users(
self,
admin_token: _Token,
get_new_user: _GetNewUser,
delete_users: _DeleteUsers,
) -> None:
admin = get_new_user(UserRoleInput.ADMIN)
member = get_new_user(UserRoleInput.MEMBER)
delete_users(admin_token, user_ids=[admin.gid, member.gid])

def test_admin_cannot_delete_system_user(
self,
admin_token: _Token,
get_new_user: _GetNewUser,
delete_users: _DeleteUsers,
) -> None:
system_user_gid = str(GlobalID(type_name="User", node_id="1"))
with pytest.raises(Exception, match="Some user IDs could not be found"):
delete_users(admin_token, user_ids=[system_user_gid])

def test_error_is_raised_when_deleting_a_non_existent_user_id(
self,
admin_token: _Token,
get_new_user: _GetNewUser,
delete_users: _DeleteUsers,
) -> None:
system_user_gid = str(GlobalID(type_name="User", node_id="10000"))
with pytest.raises(Exception, match="Some user IDs could not be found"):
delete_users(admin_token, user_ids=[system_user_gid])

def test_member_cannot_delete_users(
self,
member_token: _Token,
get_new_user: _GetNewUser,
delete_users: _DeleteUsers,
) -> None:
admin = get_new_user(UserRoleInput.ADMIN)
member = get_new_user(UserRoleInput.MEMBER)
with pytest.raises(Exception, match="Only admin can perform this action"):
delete_users(member_token, user_ids=[admin.gid, member.gid])

def test_admin_cannot_delete_default_admin_user(
self,
admin_token: _Token,
get_new_user: _GetNewUser,
delete_users: _DeleteUsers,
) -> None:
admin_user_gid = str(GlobalID(type_name="User", node_id="2"))
with pytest.raises(Exception, match="Cannot delete the default admin user"):
delete_users(admin_token, user_ids=[admin_user_gid])


def create_user_key(httpx_client: Callable[[], httpx.Client], token: str) -> str:
create_user_key_mutation = """
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def validate(
raise ValueError(err_text)


DEFAULT_ADMIN_USERNAME = "admin"
DEFAULT_ADMIN_EMAIL = "admin@localhost"
DEFAULT_ADMIN_PASSWORD = "admin"
DEFAULT_SECRET_LENGTH = 32
"""The default length of a secret key in bytes."""
Expand Down
12 changes: 9 additions & 3 deletions src/phoenix/db/facilitator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql.functions import coalesce

from phoenix.auth import DEFAULT_ADMIN_PASSWORD, DEFAULT_SECRET_LENGTH, compute_password_hash
from phoenix.auth import (
DEFAULT_ADMIN_EMAIL,
DEFAULT_ADMIN_PASSWORD,
DEFAULT_ADMIN_USERNAME,
DEFAULT_SECRET_LENGTH,
compute_password_hash,
)
from phoenix.config import ENABLE_AUTH
from phoenix.db import models
from phoenix.db.enums import COLUMN_ENUMS, AuthMethod, UserRole
Expand Down Expand Up @@ -93,8 +99,8 @@ async def _ensure_user_roles(session: AsyncSession) -> None:
) is not None:
admin_user = models.User(
user_role_id=admin_role_id,
username="admin",
email="admin@localhost",
username=DEFAULT_ADMIN_USERNAME,
email=DEFAULT_ADMIN_EMAIL,
auth_method=AuthMethod.LOCAL.value,
reset_password=True,
)
Expand Down
7 changes: 7 additions & 0 deletions src/phoenix/server/api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ class Unauthorized(CustomGraphQLError):
"""


class Conflict(CustomGraphQLError):
"""
An error raised when a mutation cannot be completed due to a conflict with
the current state of one or more resources.
"""


def get_mask_errors_extension() -> MaskErrors:
return MaskErrors(
should_mask_error=_should_mask_error,
Expand Down
90 changes: 88 additions & 2 deletions src/phoenix/server/api/mutations/user_mutations.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import secrets
from contextlib import AsyncExitStack
from datetime import datetime, timezone
from typing import Literal, Optional, Tuple
from typing import List, Literal, Optional, Tuple

import strawberry
from sqlalchemy import Select, select
from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select
from sqlalchemy.orm import joinedload
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
from strawberry import UNSET
from strawberry.relay import GlobalID
from strawberry.types import Info

from phoenix.auth import (
DEFAULT_ADMIN_EMAIL,
DEFAULT_ADMIN_USERNAME,
DEFAULT_SECRET_LENGTH,
PASSWORD_REQUIREMENTS,
validate_email_format,
Expand All @@ -20,6 +22,7 @@
from phoenix.db import enums, models
from phoenix.server.api.auth import HasSecret, IsAdmin, IsAuthenticated, IsNotReadOnly
from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import Conflict, NotFound
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
from phoenix.server.api.types.node import from_global_id_with_expected_type
from phoenix.server.api.types.User import User, to_gql_user
Expand Down Expand Up @@ -63,6 +66,11 @@ def __post_init__(self) -> None:
PASSWORD_REQUIREMENTS.validate(self.new_password)


@strawberry.input
class DeleteUsersInput:
user_ids: List[GlobalID]


@strawberry.type
class UserMutationPayload:
user: User
Expand Down Expand Up @@ -201,6 +209,84 @@ async def patch_viewer(
await info.context.log_out(user.id)
return UserMutationPayload(user=to_gql_user(user))

@strawberry.mutation(
permission_classes=[
IsNotReadOnly,
IsAuthenticated,
IsAdmin,
]
) # type: ignore
async def delete_users(
self,
info: Info[Context, None],
input: DeleteUsersInput,
) -> None:
if not input.user_ids:
return
user_ids = tuple(
map(
lambda gid: from_global_id_with_expected_type(gid, User.__name__),
set(input.user_ids),
)
)
system_user_role_id = (
select(models.UserRole.id)
.where(models.UserRole.name == enums.UserRole.SYSTEM.value)
.scalar_subquery()
)
admin_user_role_id = (
select(models.UserRole.id)
.where(models.UserRole.name == enums.UserRole.ADMIN.value)
.scalar_subquery()
)
default_admin_user_id = (
select(models.User.id)
.where(
(
and_(
models.User.user_role_id == admin_user_role_id,
models.User.username == DEFAULT_ADMIN_USERNAME,
models.User.email == DEFAULT_ADMIN_EMAIL,
)
)
)
.scalar_subquery()
)
async with info.context.db() as session:
[
(
deletes_default_admin,
num_resolved_user_ids,
)
] = (
await session.execute(
select(
cast(
func.coalesce(
func.max(
case((models.User.id == default_admin_user_id, 1), else_=0)
),
0,
),
Boolean,
).label("deletes_default_admin"),
func.count(distinct(models.User.id)).label("num_resolved_user_ids"),
)
.select_from(models.User)
.where(
and_(
models.User.id.in_(user_ids),
models.User.user_role_id != system_user_role_id,
)
)
)
).all()
if deletes_default_admin:
raise Conflict("Cannot delete the default admin user")
if num_resolved_user_ids < len(user_ids):
raise NotFound("Some user IDs could not be found")
await session.execute(delete(models.User).where(models.User.id.in_(user_ids)))


def _select_role_id_by_name(role_name: str) -> Select[Tuple[int]]:
return select(models.UserRole.id).where(models.UserRole.name == role_name)
Expand Down

0 comments on commit 745cba7

Please sign in to comment.