Skip to content

Commit

Permalink
chore(auth): implement viewer on query (#4429)
Browse files Browse the repository at this point in the history
* fix user

* fix

* simplify

* fix lint error

---------

Co-authored-by: Mikyo King <mikyo@arize.com>
  • Loading branch information
2 people authored and RogerHYang committed Sep 4, 2024
1 parent 7e475cd commit 7101772
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 13 deletions.
1 change: 1 addition & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,7 @@ type Query {
functionality: Functionality!
model: Model!
node(id: GlobalID!): Node!
viewer: User
clusters(clusters: [ClusterInput!]!): [Cluster!]!
hdbscanClustering(
"""Event ID of the coordinates"""
Expand Down
9 changes: 9 additions & 0 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import Any, Optional

from starlette.requests import Request as StarletteRequest
from starlette.responses import Response as StarletteResponse
from strawberry.fastapi import BaseContext

Expand Down Expand Up @@ -97,6 +98,14 @@ def get_secret(self) -> str:
)
return self.secret

def get_request(self) -> StarletteRequest:
"""
A type-safe way to get the request object. Throws an error if the request is not set.
"""
if not isinstance(request := self.request, StarletteRequest):
raise ValueError("no request is set")
return request

def get_response(self) -> StarletteResponse:
"""
A type-safe way to get the response object. Throws an error if the response is not set.
Expand Down
37 changes: 24 additions & 13 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import strawberry
from sqlalchemy import and_, distinct, func, select
from sqlalchemy.orm import joinedload
from starlette.authentication import UnauthenticatedUser
from strawberry import ID, UNSET
from strawberry.relay import Connection, GlobalID, Node
from strawberry.types import Info
Expand Down Expand Up @@ -69,8 +70,8 @@
from phoenix.server.api.types.Span import Span, to_gql_span
from phoenix.server.api.types.SystemApiKey import SystemApiKey
from phoenix.server.api.types.Trace import Trace
from phoenix.server.api.types.User import User
from phoenix.server.api.types.UserApiKey import UserApiKey
from phoenix.server.api.types.User import User, to_gql_user
from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
from phoenix.server.api.types.UserRole import UserRole


Expand Down Expand Up @@ -140,17 +141,7 @@ async def user_api_keys(self, info: Info[Context, None]) -> List[UserApiKey]:
)
async with info.context.db() as session:
api_keys = await session.scalars(stmt)
return [
UserApiKey(
id_attr=api_key.id,
user_id=api_key.user_id,
name=api_key.name,
description=api_key.description,
created_at=api_key.created_at,
expires_at=api_key.expires_at,
)
for api_key in api_keys
]
return [to_gql_api_key(api_key) for api_key in api_keys]

@strawberry.field
async def system_api_keys(self, info: Info[Context, None]) -> List[SystemApiKey]:
Expand Down Expand Up @@ -487,6 +478,26 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
return to_gql_experiment_run(run)
raise NotFound(f"Unknown node type: {type_name}")

@strawberry.field
async def viewer(self, info: Info[Context, None]) -> Optional[User]:
request = info.context.get_request()
try:
user = request.user
except AssertionError:
return None
if isinstance(user, UnauthenticatedUser):
return None
async with info.context.db() as session:
if (
user := await session.scalar(
select(models.User)
.where(models.User.id == int(user.identity))
.options(joinedload(models.User.role))
)
) is None:
return None
return to_gql_user(user)

@strawberry.field
def clusters(
self,
Expand Down
15 changes: 15 additions & 0 deletions src/phoenix/server/api/types/UserApiKey.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from strawberry.relay.types import Node, NodeID
from strawberry.types import Info

from phoenix.db.models import ApiKey as OrmApiKey
from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import NotFound

Expand All @@ -21,3 +22,17 @@ async def user(self, info: Info[Context, None]) -> User:
if user is None:
raise NotFound(f"User with id {self.user_id} not found")
return to_gql_user(user)


def to_gql_api_key(api_key: OrmApiKey) -> UserApiKey:
"""
Converts an ORM API key to a GraphQL UserApiKey type.
"""
return UserApiKey(
id_attr=api_key.id,
user_id=api_key.user_id,
name=api_key.name,
description=api_key.description,
created_at=api_key.created_at,
expires_at=api_key.expires_at,
)

0 comments on commit 7101772

Please sign in to comment.