diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index 5220ea564e..896d415d4e 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -55,7 +55,7 @@ get_env_grpc_port, get_env_host, ) -from phoenix.server.api.auth import IsAdmin, IsAuthenticated +from phoenix.server.api.auth import IsAdmin from phoenix.server.api.exceptions import Unauthorized from phoenix.server.api.input_types.UserRoleInput import UserRoleInput from psutil import STATUS_ZOMBIE, Popen @@ -130,6 +130,13 @@ def email(self) -> _Email: def username(self) -> Optional[_Username]: return self.profile.username + def gql( + self, + query: str, + variables: Optional[Mapping[str, Any]] = None, + ) -> Dict[str, Any]: + return _gql(self, query=query, variables=variables) + def create_user( self, role: UserRoleInput = _MEMBER, @@ -739,7 +746,7 @@ def _json( assert (resp_dict := cast(Dict[str, Any], resp.json())) if errers := resp_dict.get("errors"): msg = errers[0]["message"] - if "not auth" in msg or IsAuthenticated.message in msg or IsAdmin.message in msg: + if "not auth" in msg or IsAdmin.message in msg: raise Unauthorized(msg) raise RuntimeError(msg) return resp_dict diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index cbcea45ffc..359e78c516 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -107,11 +107,11 @@ def test_can_log_out( _get_user: _GetUser, ) -> None: u = _get_user(role_or_user) - doers = [u.log_in() for _ in range(2)] - for logged_in_user in doers: + logged_in_users = [u.log_in() for _ in range(2)] + for logged_in_user in logged_in_users: logged_in_user.create_api_key() - doers[0].log_out() - for logged_in_user in doers: + logged_in_users[0].log_out() + for logged_in_user in logged_in_users: with _EXPECTATION_401: logged_in_user.create_api_key() @@ -167,35 +167,35 @@ def test_end_to_end_credentials_flow( _get_user: _GetUser, ) -> None: u = _get_user(role_or_user) - doers: DefaultDict[int, Dict[int, _LoggedInUser]] = defaultdict(dict) + logged_in_users: DefaultDict[int, Dict[int, _LoggedInUser]] = defaultdict(dict) # user logs into first browser - doers[0][0] = u.log_in() + logged_in_users[0][0] = u.log_in() # user creates api key in the first browser - doers[0][0].create_api_key() + logged_in_users[0][0].create_api_key() # tokens are refreshed in the first browser - doers[0][1] = doers[0][0].refresh() + logged_in_users[0][1] = logged_in_users[0][0].refresh() # user creates api key in the first browser - doers[0][1].create_api_key() + logged_in_users[0][1].create_api_key() # refresh token is good for one use only with pytest.raises(HTTPStatusError): - doers[0][0].refresh() + logged_in_users[0][0].refresh() # original access token is invalid after refresh with _EXPECTATION_401: - doers[0][0].create_api_key() + logged_in_users[0][0].create_api_key() # user logs into second browser - doers[1][0] = u.log_in() + logged_in_users[1][0] = u.log_in() # user creates api key in the second browser - doers[1][0].create_api_key() + logged_in_users[1][0].create_api_key() # user logs out in first browser - doers[0][1].log_out() + logged_in_users[0][1].log_out() # user is logged out of both browsers with _EXPECTATION_401: - doers[0][1].create_api_key() + logged_in_users[0][1].create_api_key() with _EXPECTATION_401: - doers[1][0].create_api_key() + logged_in_users[1][0].create_api_key() class TestCreateUser: @@ -580,6 +580,70 @@ def test_only_admin_can_delete_system_api_key( logged_in_user.delete_api_key(api_key) +class TestGraphQLQuery: + @pytest.mark.parametrize( + "role_or_user,expectation", + [ + (_MEMBER, _DENIED), + (_ADMIN, _OK), + (_DEFAULT_ADMIN, _OK), + ], + ) + @pytest.mark.parametrize( + "query", + [ + "query{users{edges{node{id}}}}", + "query{userApiKeys{id}}", + "query{systemApiKeys{id}}", + ], + ) + def test_only_admin_can_list_users_and_api_keys( + self, + role_or_user: _RoleOrUser, + query: str, + expectation: _OK_OR_DENIED, + _get_user: _GetUser, + ) -> None: + u = _get_user(role_or_user) + logged_in_user = u.log_in() + with expectation: + logged_in_user.gql(query) + + @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _DEFAULT_ADMIN]) + def test_can_query_user_node_for_self( + self, + role_or_user: _RoleOrUser, + _get_user: _GetUser, + ) -> None: + u = _get_user(role_or_user) + logged_in_user = u.log_in() + query = 'query{node(id:"' + u.gid + '"){__typename}}' + logged_in_user.gql(query) + + @pytest.mark.parametrize( + "role_or_user,expectation", + [ + (_MEMBER, _DENIED), + (_ADMIN, _OK), + (_DEFAULT_ADMIN, _OK), + ], + ) + @pytest.mark.parametrize("role", list(UserRoleInput)) + def test_only_admin_can_query_user_node_for_non_self( + self, + role_or_user: _RoleOrUser, + role: UserRoleInput, + expectation: _OK_OR_DENIED, + _get_user: _GetUser, + ) -> None: + u = _get_user(role_or_user) + logged_in_user = u.log_in() + non_self = _get_user(role) + query = 'query{node(id:"' + non_self.gid + '"){__typename}}' + with expectation: + logged_in_user.gql(query) + + class TestSpanExporters: @pytest.mark.parametrize( "with_headers,expires_at,expected", @@ -612,5 +676,5 @@ def test_headers( for _ in range(2): assert export(spans) is expected if api_key and expected is SpanExportResult.SUCCESS: - _DEFAULT_ADMIN.log_in().delete_api_key(api_key) + _DEFAULT_ADMIN.delete_api_key(api_key) assert export(spans) is SpanExportResult.FAILURE diff --git a/src/phoenix/server/api/README.md b/src/phoenix/server/api/README.md new file mode 100644 index 0000000000..a646a42c71 --- /dev/null +++ b/src/phoenix/server/api/README.md @@ -0,0 +1,28 @@ +# Permission Matrix for GraphQL API + +## Mutations + +| Action | Admin | Member | +|:-----------------------------|:-----:|:------:| +| Create User | Yes | No | +| Delete User | Yes | No | +| Change Own Password | Yes | Yes | +| Change Other's Password | Yes | No | +| Change Own Username | Yes | Yes | +| Change Other's Username | Yes | No | +| Change Own Email | No | No | +| Change Other's Email | No | No | +| Create System API Keys | Yes | No | +| Delete System API Keys | Yes | No | +| Create Own User API Keys | Yes | Yes | +| Delete Own User API Keys | Yes | Yes | +| Delete Other's User API Keys | Yes | No | + +## Queries + +| Action | Admin | Member | +|:-------------------------------------|:-----:|:------:| +| List All System API Keys | Yes | No | +| List All User API Keys | Yes | No | +| List All Users | Yes | No | +| Fetch Other User's Info, e.g. emails | Yes | No | diff --git a/src/phoenix/server/api/auth.py b/src/phoenix/server/api/auth.py index f24424e1bb..2e937dd0c2 100644 --- a/src/phoenix/server/api/auth.py +++ b/src/phoenix/server/api/auth.py @@ -4,7 +4,6 @@ from strawberry import Info from strawberry.permission import BasePermission -from phoenix.db import enums from phoenix.server.api.exceptions import Unauthorized from phoenix.server.bearer_auth import PhoenixUser @@ -21,40 +20,13 @@ def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: return not info.context.read_only -class IsAuthenticated(Authorization): - message = "User is not authenticated" - - def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: - if info.context.token_store is None: - return True - try: - user = info.context.request.user - except AttributeError: - return False - return isinstance(user, PhoenixUser) and user.is_authenticated +MSG_ADMIN_ONLY = "Only admin can perform this action" class IsAdmin(Authorization): - message = "Only admin can perform this action" + message = MSG_ADMIN_ONLY def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: - if info.context.token_store is None: - return False - try: - user = info.context.request.user - except AttributeError: + if not info.context.auth_enabled: return False - return ( - isinstance(user, PhoenixUser) - and user.is_authenticated - and user.claims is not None - and user.claims.attributes is not None - and user.claims.attributes.user_role == enums.UserRole.ADMIN - ) - - -class HasSecret(BasePermission): - message = "Application secret is not set" - - def has_permission(self, source: Any, info: Info, **kwargs: Any) -> bool: - return info.context.secret is not None + return isinstance((user := info.context.user), PhoenixUser) and user.is_admin diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index a98e637a55..8ed113e57d 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -1,8 +1,8 @@ from asyncio import get_running_loop from dataclasses import dataclass -from functools import partial +from functools import cached_property, partial from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, cast from starlette.requests import Request as StarletteRequest from starlette.responses import Response as StarletteResponse @@ -42,6 +42,7 @@ UserRolesDataLoader, UsersDataLoader, ) +from phoenix.server.bearer_auth import PhoenixUser from phoenix.server.dml_event import DmlEvent from phoenix.server.types import ( CanGetLastUpdatedAt, @@ -96,6 +97,7 @@ class Context(BaseContext): event_queue: CanPutItem[DmlEvent] = _NoOp() corpus: Optional[Model] = None read_only: bool = False + auth_enabled: bool = False secret: Optional[str] = None token_store: Optional[TokenStore] = None @@ -146,3 +148,7 @@ async def log_out(self, user_id: int) -> None: response = self.get_response() response.delete_cookie(PHOENIX_REFRESH_TOKEN_COOKIE_NAME) response.delete_cookie(PHOENIX_ACCESS_TOKEN_COOKIE_NAME) + + @cached_property + def user(self) -> PhoenixUser: + return cast(PhoenixUser, self.get_request().user) diff --git a/src/phoenix/server/api/mutations/api_key_mutations.py b/src/phoenix/server/api/mutations/api_key_mutations.py index 97738a5b90..d7877c9418 100644 --- a/src/phoenix/server/api/mutations/api_key_mutations.py +++ b/src/phoenix/server/api/mutations/api_key_mutations.py @@ -8,8 +8,7 @@ from strawberry.types import Info from phoenix.db import enums, models -from phoenix.db.models import ApiKey as OrmApiKey -from phoenix.server.api.auth import HasSecret, IsAdmin, IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsAdmin, IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.exceptions import Unauthorized from phoenix.server.api.queries import Query @@ -59,32 +58,9 @@ class DeleteApiKeyMutationPayload: query: Query -def can_delete_user_key(info: Info[Context, None], key: OrmApiKey) -> bool: - try: - user = info.context.request.user # type: ignore - except AttributeError: - return False - return ( - isinstance(user, PhoenixUser) - and user.claims is not None - and user.claims.attributes is not None - and ( - user.claims.attributes.user_role == enums.UserRole.ADMIN - or int(user.identity) == key.user_id - ) - ) - - @strawberry.type class ApiKeyMutationMixin: - @strawberry.mutation( - permission_classes=[ - IsNotReadOnly, - HasSecret, - IsAuthenticated, - IsAdmin, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore async def create_system_api_key( self, info: Info[Context, None], input: CreateApiKeyInput ) -> CreateSystemApiKeyMutationPayload: @@ -125,13 +101,7 @@ async def create_system_api_key( query=Query(), ) - @strawberry.mutation( - permission_classes=[ - IsNotReadOnly, - HasSecret, - IsAuthenticated, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def create_user_api_key( self, info: Info[Context, None], input: CreateUserApiKeyInput ) -> CreateUserApiKeyMutationPayload: @@ -166,7 +136,7 @@ async def create_user_api_key( query=Query(), ) - @strawberry.mutation(permission_classes=[HasSecret, IsAuthenticated, IsAdmin, IsNotReadOnly]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore async def delete_system_api_key( self, info: Info[Context, None], input: DeleteApiKeyInput ) -> DeleteApiKeyMutationPayload: @@ -177,12 +147,7 @@ async def delete_system_api_key( await token_store.revoke(ApiKeyId(api_key_id)) return DeleteApiKeyMutationPayload(apiKeyId=input.id, query=Query()) - @strawberry.mutation( - permission_classes=[ - HasSecret, - IsAuthenticated, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_user_api_key( self, info: Info[Context, None], input: DeleteApiKeyInput ) -> DeleteApiKeyMutationPayload: @@ -196,7 +161,7 @@ async def delete_user_api_key( ) if api_key is None: raise ValueError(f"API key with id {input.id} not found") - if not can_delete_user_key(info, api_key): + if int((user := info.context.user).identity) != api_key.user_id and not user.is_admin: raise Unauthorized("User not authorized to delete") await token_store.revoke(ApiKeyId(api_key_id)) return DeleteApiKeyMutationPayload(apiKeyId=input.id, query=Query()) diff --git a/src/phoenix/server/api/mutations/dataset_mutations.py b/src/phoenix/server/api/mutations/dataset_mutations.py index baea741b4c..ddd4ab4fde 100644 --- a/src/phoenix/server/api/mutations/dataset_mutations.py +++ b/src/phoenix/server/api/mutations/dataset_mutations.py @@ -12,7 +12,7 @@ from phoenix.db import models from phoenix.db.helpers import get_eval_trace_ids_for_datasets, get_project_names_for_datasets -from phoenix.server.api.auth import IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.exceptions import BadRequest, NotFound from phoenix.server.api.helpers.dataset_helpers import ( @@ -44,7 +44,7 @@ class DatasetMutationPayload: @strawberry.type class DatasetMutationMixin: - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def create_dataset( self, info: Info[Context, None], @@ -67,7 +67,7 @@ async def create_dataset( info.context.event_queue.put(DatasetInsertEvent((dataset.id,))) return DatasetMutationPayload(dataset=to_gql_dataset(dataset)) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def patch_dataset( self, info: Info[Context, None], @@ -96,7 +96,7 @@ async def patch_dataset( info.context.event_queue.put(DatasetInsertEvent((dataset.id,))) return DatasetMutationPayload(dataset=to_gql_dataset(dataset)) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def add_spans_to_dataset( self, info: Info[Context, None], @@ -225,7 +225,7 @@ async def add_spans_to_dataset( info.context.event_queue.put(DatasetInsertEvent((dataset.id,))) return DatasetMutationPayload(dataset=to_gql_dataset(dataset)) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def add_examples_to_dataset( self, info: Info[Context, None], input: AddExamplesToDatasetInput ) -> DatasetMutationPayload: @@ -351,7 +351,7 @@ async def add_examples_to_dataset( info.context.event_queue.put(DatasetInsertEvent((dataset.id,))) return DatasetMutationPayload(dataset=to_gql_dataset(dataset)) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_dataset( self, info: Info[Context, None], @@ -382,7 +382,7 @@ async def delete_dataset( info.context.event_queue.put(DatasetDeleteEvent((dataset.id,))) return DatasetMutationPayload(dataset=to_gql_dataset(dataset)) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def patch_dataset_examples( self, info: Info[Context, None], @@ -474,7 +474,7 @@ async def patch_dataset_examples( info.context.event_queue.put(DatasetInsertEvent((dataset.id,))) return DatasetMutationPayload(dataset=to_gql_dataset(dataset)) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_dataset_examples( self, info: Info[Context, None], input: DeleteDatasetExamplesInput ) -> DatasetMutationPayload: diff --git a/src/phoenix/server/api/mutations/experiment_mutations.py b/src/phoenix/server/api/mutations/experiment_mutations.py index 1372cdad3a..8ebfce4c6c 100644 --- a/src/phoenix/server/api/mutations/experiment_mutations.py +++ b/src/phoenix/server/api/mutations/experiment_mutations.py @@ -8,7 +8,7 @@ from phoenix.db import models from phoenix.db.helpers import get_eval_trace_ids_for_experiments, get_project_names_for_experiments -from phoenix.server.api.auth import IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.exceptions import CustomGraphQLError from phoenix.server.api.input_types.DeleteExperimentsInput import DeleteExperimentsInput @@ -25,7 +25,7 @@ class ExperimentMutationPayload: @strawberry.type class ExperimentMutationMixin: - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_experiments( self, info: Info[Context, None], diff --git a/src/phoenix/server/api/mutations/export_events_mutations.py b/src/phoenix/server/api/mutations/export_events_mutations.py index c051af65a8..57ad4cb04d 100644 --- a/src/phoenix/server/api/mutations/export_events_mutations.py +++ b/src/phoenix/server/api/mutations/export_events_mutations.py @@ -8,7 +8,7 @@ from strawberry.types import Info import phoenix.core.model_schema as ms -from phoenix.server.api.auth import IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.input_types.ClusterInput import ClusterInput from phoenix.server.api.types.Event import parse_event_ids_by_inferences_role, unpack_event_id @@ -19,7 +19,7 @@ @strawberry.type class ExportEventsMutationMixin: @strawberry.mutation( - permission_classes=[IsNotReadOnly, IsAuthenticated], + permission_classes=[IsNotReadOnly], description=( "Given a list of event ids, export the corresponding data subset in Parquet format." " File name is optional, but if specified, should be without file extension. By default" @@ -51,7 +51,7 @@ async def export_events( return ExportedFile(file_name=file_name) @strawberry.mutation( - permission_classes=[IsNotReadOnly, IsAuthenticated], + permission_classes=[IsNotReadOnly], description=( "Given a list of clusters, export the corresponding data subset in Parquet format." " File name is optional, but if specified, should be without file extension. By default" diff --git a/src/phoenix/server/api/mutations/project_mutations.py b/src/phoenix/server/api/mutations/project_mutations.py index aa51b49b86..30d38620f1 100644 --- a/src/phoenix/server/api/mutations/project_mutations.py +++ b/src/phoenix/server/api/mutations/project_mutations.py @@ -6,7 +6,7 @@ from phoenix.config import DEFAULT_PROJECT_NAME from phoenix.db import models -from phoenix.server.api.auth import IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.input_types.ClearProjectInput import ClearProjectInput from phoenix.server.api.queries import Query @@ -16,7 +16,7 @@ @strawberry.type class ProjectMutationMixin: - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query: project_id = from_global_id_with_expected_type(global_id=id, expected_type_name="Project") async with info.context.db() as session: @@ -33,7 +33,7 @@ async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query info.context.event_queue.put(ProjectDeleteEvent((project_id,))) return Query() - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def clear_project(self, info: Info[Context, None], input: ClearProjectInput) -> Query: project_id = from_global_id_with_expected_type( global_id=input.id, expected_type_name="Project" diff --git a/src/phoenix/server/api/mutations/span_annotations_mutations.py b/src/phoenix/server/api/mutations/span_annotations_mutations.py index 95c38c3ba1..f007f73d41 100644 --- a/src/phoenix/server/api/mutations/span_annotations_mutations.py +++ b/src/phoenix/server/api/mutations/span_annotations_mutations.py @@ -6,7 +6,7 @@ from strawberry.types import Info from phoenix.db import models -from phoenix.server.api.auth import IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.input_types.CreateSpanAnnotationInput import CreateSpanAnnotationInput from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput @@ -25,7 +25,7 @@ class SpanAnnotationMutationPayload: @strawberry.type class SpanAnnotationMutationMixin: - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def create_span_annotations( self, info: Info[Context, None], input: List[CreateSpanAnnotationInput] ) -> SpanAnnotationMutationPayload: @@ -59,7 +59,7 @@ async def create_span_annotations( query=Query(), ) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def patch_span_annotations( self, info: Info[Context, None], input: List[PatchAnnotationInput] ) -> SpanAnnotationMutationPayload: @@ -99,7 +99,7 @@ async def patch_span_annotations( info.context.event_queue.put(SpanAnnotationInsertEvent((span_annotation.id,))) return SpanAnnotationMutationPayload(span_annotations=patched_annotations, query=Query()) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_span_annotations( self, info: Info[Context, None], input: DeleteAnnotationsInput ) -> SpanAnnotationMutationPayload: diff --git a/src/phoenix/server/api/mutations/trace_annotations_mutations.py b/src/phoenix/server/api/mutations/trace_annotations_mutations.py index 3fccc94b29..2aeaca77e1 100644 --- a/src/phoenix/server/api/mutations/trace_annotations_mutations.py +++ b/src/phoenix/server/api/mutations/trace_annotations_mutations.py @@ -6,7 +6,7 @@ from strawberry.types import Info from phoenix.db import models -from phoenix.server.api.auth import IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.input_types.CreateTraceAnnotationInput import CreateTraceAnnotationInput from phoenix.server.api.input_types.DeleteAnnotationsInput import DeleteAnnotationsInput @@ -25,7 +25,7 @@ class TraceAnnotationMutationPayload: @strawberry.type class TraceAnnotationMutationMixin: - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def create_trace_annotations( self, info: Info[Context, None], input: List[CreateTraceAnnotationInput] ) -> TraceAnnotationMutationPayload: @@ -59,7 +59,7 @@ async def create_trace_annotations( query=Query(), ) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def patch_trace_annotations( self, info: Info[Context, None], input: List[PatchAnnotationInput] ) -> TraceAnnotationMutationPayload: @@ -98,7 +98,7 @@ async def patch_trace_annotations( info.context.event_queue.put(TraceAnnotationInsertEvent((trace_annotation.id,))) return TraceAnnotationMutationPayload(trace_annotations=patched_annotations, query=Query()) - @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAuthenticated]) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def delete_trace_annotations( self, info: Info[Context, None], input: DeleteAnnotationsInput ) -> TraceAnnotationMutationPayload: diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py index 8cd6f668a0..7b0c822c7f 100644 --- a/src/phoenix/server/api/mutations/user_mutations.py +++ b/src/phoenix/server/api/mutations/user_mutations.py @@ -20,7 +20,7 @@ validate_password_format, ) from phoenix.db import enums, models -from phoenix.server.api.auth import HasSecret, IsAdmin, IsAuthenticated, IsNotReadOnly +from phoenix.server.api.auth import IsAdmin, IsNotReadOnly from phoenix.server.api.context import Context from phoenix.server.api.exceptions import Conflict, NotFound from phoenix.server.api.input_types.UserRoleInput import UserRoleInput @@ -78,14 +78,7 @@ class UserMutationPayload: @strawberry.type class UserMutationMixin: - @strawberry.mutation( - permission_classes=[ - IsNotReadOnly, - HasSecret, - IsAuthenticated, - IsAdmin, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore async def create_user( self, info: Info[Context, None], @@ -117,14 +110,7 @@ async def create_user( raise ValueError(_user_operation_error_message(error)) return UserMutationPayload(user=to_gql_user(user)) - @strawberry.mutation( - permission_classes=[ - IsNotReadOnly, - HasSecret, - IsAuthenticated, - IsAdmin, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore async def patch_user( self, info: Info[Context, None], @@ -165,13 +151,7 @@ async def patch_user( await info.context.log_out(user.id) return UserMutationPayload(user=to_gql_user(user)) - @strawberry.mutation( - permission_classes=[ - IsNotReadOnly, - HasSecret, - IsAuthenticated, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly]) # type: ignore async def patch_viewer( self, info: Info[Context, None], @@ -209,13 +189,7 @@ async def patch_viewer( await info.context.log_out(user.id) return UserMutationPayload(user=to_gql_user(user)) - @strawberry.mutation( - permission_classes=[ - IsNotReadOnly, - IsAuthenticated, - IsAdmin, - ] - ) # type: ignore + @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore async def delete_users( self, info: Info[Context, None], diff --git a/src/phoenix/server/api/queries.py b/src/phoenix/server/api/queries.py index ef5b0f21ae..0c52856db8 100644 --- a/src/phoenix/server/api/queries.py +++ b/src/phoenix/server/api/queries.py @@ -33,8 +33,9 @@ Trace as OrmTrace, ) from phoenix.pointcloud.clustering import Hdbscan +from phoenix.server.api.auth import MSG_ADMIN_ONLY, IsAdmin from phoenix.server.api.context import Context -from phoenix.server.api.exceptions import NotFound +from phoenix.server.api.exceptions import NotFound, Unauthorized from phoenix.server.api.helpers import ensure_list from phoenix.server.api.input_types.ClusterInput import ClusterInput from phoenix.server.api.input_types.Coordinates import ( @@ -77,7 +78,7 @@ @strawberry.type class Query: - @strawberry.field + @strawberry.field(permission_classes=[IsAdmin]) # type: ignore async def users( self, info: Info[Context, None], @@ -121,9 +122,8 @@ async def user_roles( for role in roles ] - @strawberry.field + @strawberry.field(permission_classes=[IsAdmin]) # type: ignore async def user_api_keys(self, info: Info[Context, None]) -> List[UserApiKey]: - # TODO(auth): add access control stmt = ( select(models.ApiKey) .join(models.User) @@ -134,9 +134,8 @@ async def user_api_keys(self, info: Info[Context, None]) -> List[UserApiKey]: api_keys = await session.scalars(stmt) return [to_gql_api_key(api_key) for api_key in api_keys] - @strawberry.field + @strawberry.field(permission_classes=[IsAdmin]) # type: ignore async def system_api_keys(self, info: Info[Context, None]) -> List[SystemApiKey]: - # TODO(auth): add access control stmt = ( select(models.ApiKey) .join(models.User) @@ -468,6 +467,8 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node: raise NotFound(f"Unknown experiment run: {id}") return to_gql_experiment_run(run) elif type_name == User.__name__: + if int((user := info.context.user).identity) != node_id and not user.is_admin: + raise Unauthorized(MSG_ADMIN_ONLY) async with info.context.db() as session: if not ( user := await session.scalar( diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 73443cdce4..9f77be6d7c 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -537,6 +537,7 @@ def get_context() -> Context: ), cache_for_dataloaders=cache_for_dataloaders, read_only=read_only, + auth_enabled=authentication_enabled, secret=secret, token_store=token_store, ) diff --git a/src/phoenix/server/bearer_auth.py b/src/phoenix/server/bearer_auth.py index 613a44fb9c..4a986f12da 100644 --- a/src/phoenix/server/bearer_auth.py +++ b/src/phoenix/server/bearer_auth.py @@ -1,4 +1,5 @@ from abc import ABC +from functools import cached_property from typing import Any, Awaitable, Callable, Optional, Tuple import grpc @@ -16,6 +17,7 @@ ClaimSetStatus, Token, ) +from phoenix.db import enums from phoenix.server.types import AccessTokenClaims, ApiKeyClaims, UserClaimSet, UserId @@ -50,12 +52,21 @@ class PhoenixUser(BaseUser): def __init__(self, user_id: UserId, claims: UserClaimSet) -> None: self._user_id = user_id self.claims = claims + assert claims.attributes + self._is_admin = ( + claims.status is ClaimSetStatus.VALID + and claims.attributes.user_role == enums.UserRole.ADMIN + ) - @property + @cached_property + def is_admin(self) -> bool: + return self._is_admin + + @cached_property def identity(self) -> UserId: return self._user_id - @property + @cached_property def is_authenticated(self) -> bool: return True