From 745cba7530336192cc8ef010682ed1f25d7f0589 Mon Sep 17 00:00:00 2001 From: Xander Song Date: Sat, 7 Sep 2024 16:09:27 -0700 Subject: [PATCH] feat(auth): add deleteUsers mutation (#4537) --- app/schema.graphql | 8 ++ integration_tests/auth/conftest.py | 59 ++++++++++++ integration_tests/auth/test_auth.py | 57 ++++++++++++ src/phoenix/auth.py | 2 + src/phoenix/db/facilitator.py | 12 ++- src/phoenix/server/api/exceptions.py | 7 ++ .../server/api/mutations/user_mutations.py | 90 ++++++++++++++++++- 7 files changed, 230 insertions(+), 5 deletions(-) diff --git a/app/schema.graphql b/app/schema.graphql index 28890cba7b..af830defb2 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -355,6 +355,10 @@ input DeleteExperimentsInput { experimentIds: [GlobalID!]! } +input DeleteUsersInput { + userIds: [GlobalID!]! +} + type Dimension implements Node { """The Globally Unique ID of this object""" id: GlobalID! @@ -933,6 +937,7 @@ type Mutation { createUser(input: CreateUserInput!): UserMutationPayload! patchUser(input: PatchUserInput!): UserMutationPayload! patchViewer(input: PatchViewerInput!): UserMutationPayload! + deleteUsers(input: DeleteUsersInput!): Void } """An object with a Globally Unique ID""" @@ -1523,3 +1528,6 @@ type ValidationResult { enum VectorDriftMetric { euclideanDistance } + +"""Represents NULL values""" +scalar Void diff --git a/integration_tests/auth/conftest.py b/integration_tests/auth/conftest.py index 50bd6dad6b..1ce0069ce8 100644 --- a/integration_tests/auth/conftest.py +++ b/integration_tests/auth/conftest.py @@ -14,6 +14,7 @@ List, Optional, Protocol, + Sequence, Tuple, cast, ) @@ -91,6 +92,10 @@ def __call__( ) -> None: ... +class _DeleteUsers(Protocol): + def __call__(self, token: _Token, /, *, user_ids: Sequence[_GqlId]) -> None: ... + + class _PatchViewer(Protocol): def __call__( self, @@ -236,6 +241,18 @@ def admin_token( yield token +@pytest.fixture +def member_token( + get_new_user: _GetNewUser, + member_email: str, + member_password: str, + log_in: _LogIn, +) -> _Token: + member = get_new_user(UserRoleInput.MEMBER) + assert (token := member.token) is not None + return token + + @pytest.fixture(scope="module") def admin_email() -> _Email: return "admin@localhost" @@ -246,6 +263,16 @@ def admin_password() -> _Email: return "admin" +@pytest.fixture(scope="module") +def member_email() -> _Email: + return "member@domain.com" + + +@pytest.fixture(scope="module") +def member_password() -> _Password: + return "Member-password1234" + + @pytest.fixture(scope="module") def create_user( httpx_client: Callable[[], httpx.Client], @@ -316,6 +343,38 @@ def _( return _ +@pytest.fixture(scope="module") +def delete_users( + httpx_client: Callable[[], httpx.Client], +) -> _DeleteUsers: + def _( + token: _Token, + /, + *, + user_ids: Sequence[_GqlId], + ) -> None: + mutation = """ + mutation ($userIds: [GlobalID!]!) { + deleteUsers(input: {userIds: $userIds}) + } + """ + response = httpx_client().post( + urljoin(get_base_url(), "/graphql"), + json={ + "query": mutation, + "variables": {"userIds": list(user_ids)}, + }, + cookies={PHOENIX_ACCESS_TOKEN_COOKIE_NAME: token}, + ) + response.raise_for_status() + if (errors := response.json().get("errors")) is not None: + assert len(errors) == 1 + error_message = errors[0]["message"] + raise Exception(error_message) + + return _ + + @pytest.fixture(scope="module") def patch_viewer( httpx_client: Callable[[], httpx.Client], diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index fa8fac4569..dd0707f698 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -10,6 +10,7 @@ Iterator, Optional, Protocol, + Sequence, Tuple, ) from urllib.parse import urljoin @@ -29,6 +30,7 @@ from phoenix.config import get_base_url from phoenix.server.api.exceptions import Unauthorized from phoenix.server.api.input_types.UserRoleInput import UserRoleInput +from strawberry.relay import GlobalID from typing_extensions import TypeAlias NOW = datetime.now(timezone.utc) @@ -88,6 +90,10 @@ def __call__( ) -> None: ... +class _DeleteUsers(Protocol): + def __call__(self, token: _Token, /, *, user_ids: Sequence[_GqlId]) -> None: ... + + class _PatchViewer(Protocol): def __call__( self, @@ -471,6 +477,57 @@ def test_only_admin_can_change_username_for_non_self( with expectation: patch_user(token, gid, new_username=new_username) + def test_admin_can_delete_non_self_admin_and_member_users( + self, + admin_token: _Token, + get_new_user: _GetNewUser, + delete_users: _DeleteUsers, + ) -> None: + admin = get_new_user(UserRoleInput.ADMIN) + member = get_new_user(UserRoleInput.MEMBER) + delete_users(admin_token, user_ids=[admin.gid, member.gid]) + + def test_admin_cannot_delete_system_user( + self, + admin_token: _Token, + get_new_user: _GetNewUser, + delete_users: _DeleteUsers, + ) -> None: + system_user_gid = str(GlobalID(type_name="User", node_id="1")) + with pytest.raises(Exception, match="Some user IDs could not be found"): + delete_users(admin_token, user_ids=[system_user_gid]) + + def test_error_is_raised_when_deleting_a_non_existent_user_id( + self, + admin_token: _Token, + get_new_user: _GetNewUser, + delete_users: _DeleteUsers, + ) -> None: + system_user_gid = str(GlobalID(type_name="User", node_id="10000")) + with pytest.raises(Exception, match="Some user IDs could not be found"): + delete_users(admin_token, user_ids=[system_user_gid]) + + def test_member_cannot_delete_users( + self, + member_token: _Token, + get_new_user: _GetNewUser, + delete_users: _DeleteUsers, + ) -> None: + admin = get_new_user(UserRoleInput.ADMIN) + member = get_new_user(UserRoleInput.MEMBER) + with pytest.raises(Exception, match="Only admin can perform this action"): + delete_users(member_token, user_ids=[admin.gid, member.gid]) + + def test_admin_cannot_delete_default_admin_user( + self, + admin_token: _Token, + get_new_user: _GetNewUser, + delete_users: _DeleteUsers, + ) -> None: + admin_user_gid = str(GlobalID(type_name="User", node_id="2")) + with pytest.raises(Exception, match="Cannot delete the default admin user"): + delete_users(admin_token, user_ids=[admin_user_gid]) + def create_user_key(httpx_client: Callable[[], httpx.Client], token: str) -> str: create_user_key_mutation = """ diff --git a/src/phoenix/auth.py b/src/phoenix/auth.py index 369656c485..b9ab3b15ab 100644 --- a/src/phoenix/auth.py +++ b/src/phoenix/auth.py @@ -176,6 +176,8 @@ def validate( raise ValueError(err_text) +DEFAULT_ADMIN_USERNAME = "admin" +DEFAULT_ADMIN_EMAIL = "admin@localhost" DEFAULT_ADMIN_PASSWORD = "admin" DEFAULT_SECRET_LENGTH = 32 """The default length of a secret key in bytes.""" diff --git a/src/phoenix/db/facilitator.py b/src/phoenix/db/facilitator.py index a21fbcccad..0e61861442 100644 --- a/src/phoenix/db/facilitator.py +++ b/src/phoenix/db/facilitator.py @@ -14,7 +14,13 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql.functions import coalesce -from phoenix.auth import DEFAULT_ADMIN_PASSWORD, DEFAULT_SECRET_LENGTH, compute_password_hash +from phoenix.auth import ( + DEFAULT_ADMIN_EMAIL, + DEFAULT_ADMIN_PASSWORD, + DEFAULT_ADMIN_USERNAME, + DEFAULT_SECRET_LENGTH, + compute_password_hash, +) from phoenix.config import ENABLE_AUTH from phoenix.db import models from phoenix.db.enums import COLUMN_ENUMS, AuthMethod, UserRole @@ -93,8 +99,8 @@ async def _ensure_user_roles(session: AsyncSession) -> None: ) is not None: admin_user = models.User( user_role_id=admin_role_id, - username="admin", - email="admin@localhost", + username=DEFAULT_ADMIN_USERNAME, + email=DEFAULT_ADMIN_EMAIL, auth_method=AuthMethod.LOCAL.value, reset_password=True, ) diff --git a/src/phoenix/server/api/exceptions.py b/src/phoenix/server/api/exceptions.py index 2424bd29dc..745952ec29 100644 --- a/src/phoenix/server/api/exceptions.py +++ b/src/phoenix/server/api/exceptions.py @@ -27,6 +27,13 @@ class Unauthorized(CustomGraphQLError): """ +class Conflict(CustomGraphQLError): + """ + An error raised when a mutation cannot be completed due to a conflict with + the current state of one or more resources. + """ + + def get_mask_errors_extension() -> MaskErrors: return MaskErrors( should_mask_error=_should_mask_error, diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py index 71e71d1d61..8cd6f668a0 100644 --- a/src/phoenix/server/api/mutations/user_mutations.py +++ b/src/phoenix/server/api/mutations/user_mutations.py @@ -1,10 +1,10 @@ import secrets from contextlib import AsyncExitStack from datetime import datetime, timezone -from typing import Literal, Optional, Tuple +from typing import List, Literal, Optional, Tuple import strawberry -from sqlalchemy import Select, select +from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select from sqlalchemy.orm import joinedload from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped] from strawberry import UNSET @@ -12,6 +12,8 @@ from strawberry.types import Info from phoenix.auth import ( + DEFAULT_ADMIN_EMAIL, + DEFAULT_ADMIN_USERNAME, DEFAULT_SECRET_LENGTH, PASSWORD_REQUIREMENTS, validate_email_format, @@ -20,6 +22,7 @@ from phoenix.db import enums, models from phoenix.server.api.auth import HasSecret, IsAdmin, IsAuthenticated, IsNotReadOnly from phoenix.server.api.context import Context +from phoenix.server.api.exceptions import Conflict, NotFound from phoenix.server.api.input_types.UserRoleInput import UserRoleInput from phoenix.server.api.types.node import from_global_id_with_expected_type from phoenix.server.api.types.User import User, to_gql_user @@ -63,6 +66,11 @@ def __post_init__(self) -> None: PASSWORD_REQUIREMENTS.validate(self.new_password) +@strawberry.input +class DeleteUsersInput: + user_ids: List[GlobalID] + + @strawberry.type class UserMutationPayload: user: User @@ -201,6 +209,84 @@ async def patch_viewer( await info.context.log_out(user.id) return UserMutationPayload(user=to_gql_user(user)) + @strawberry.mutation( + permission_classes=[ + IsNotReadOnly, + IsAuthenticated, + IsAdmin, + ] + ) # type: ignore + async def delete_users( + self, + info: Info[Context, None], + input: DeleteUsersInput, + ) -> None: + if not input.user_ids: + return + user_ids = tuple( + map( + lambda gid: from_global_id_with_expected_type(gid, User.__name__), + set(input.user_ids), + ) + ) + system_user_role_id = ( + select(models.UserRole.id) + .where(models.UserRole.name == enums.UserRole.SYSTEM.value) + .scalar_subquery() + ) + admin_user_role_id = ( + select(models.UserRole.id) + .where(models.UserRole.name == enums.UserRole.ADMIN.value) + .scalar_subquery() + ) + default_admin_user_id = ( + select(models.User.id) + .where( + ( + and_( + models.User.user_role_id == admin_user_role_id, + models.User.username == DEFAULT_ADMIN_USERNAME, + models.User.email == DEFAULT_ADMIN_EMAIL, + ) + ) + ) + .scalar_subquery() + ) + async with info.context.db() as session: + [ + ( + deletes_default_admin, + num_resolved_user_ids, + ) + ] = ( + await session.execute( + select( + cast( + func.coalesce( + func.max( + case((models.User.id == default_admin_user_id, 1), else_=0) + ), + 0, + ), + Boolean, + ).label("deletes_default_admin"), + func.count(distinct(models.User.id)).label("num_resolved_user_ids"), + ) + .select_from(models.User) + .where( + and_( + models.User.id.in_(user_ids), + models.User.user_role_id != system_user_role_id, + ) + ) + ) + ).all() + if deletes_default_admin: + raise Conflict("Cannot delete the default admin user") + if num_resolved_user_ids < len(user_ids): + raise NotFound("Some user IDs could not be found") + await session.execute(delete(models.User).where(models.User.id.in_(user_ids))) + def _select_role_id_by_name(role_name: str) -> Select[Tuple[int]]: return select(models.UserRole.id).where(models.UserRole.name == role_name)