Skip to content

Commit

Permalink
feat(auth): user of a given key (#4442)
Browse files Browse the repository at this point in the history
* feat(auth): user of a given key

* remove user role not used

* format

* fix casing
  • Loading branch information
mikeldking authored and RogerHYang committed Sep 21, 2024
1 parent bcf273d commit f8bbf25
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 8 deletions.
1 change: 1 addition & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1454,6 +1454,7 @@ type UserApiKey implements ApiKey & Node {

"""The Globally Unique ID of this object"""
id: GlobalID!
user: User!
}

"""A connection to a list of items."""
Expand Down
4 changes: 4 additions & 0 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
SpanProjectsDataLoader,
TokenCountDataLoader,
TraceRowIdsDataLoader,
UserRolesDataLoader,
UsersDataLoader,
)
from phoenix.server.dml_event import DmlEvent
from phoenix.server.types import CanGetLastUpdatedAt, CanPutItem, DbSessionFactory, TokenStore
Expand Down Expand Up @@ -59,6 +61,8 @@ class DataLoaders:
token_counts: TokenCountDataLoader
trace_row_ids: TraceRowIdsDataLoader
project_by_name: ProjectByNameDataLoader
users: UsersDataLoader
user_roles: UserRolesDataLoader


class _NoOp:
Expand Down
4 changes: 4 additions & 0 deletions src/phoenix/server/api/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from .span_projects import SpanProjectsDataLoader
from .token_counts import TokenCountCache, TokenCountDataLoader
from .trace_row_ids import TraceRowIdsDataLoader
from .user_roles import UserRolesDataLoader
from .users import UsersDataLoader

__all__ = [
"CacheForDataLoaders",
Expand All @@ -50,6 +52,8 @@
"TraceRowIdsDataLoader",
"ProjectByNameDataLoader",
"SpanAnnotationsDataLoader",
"UsersDataLoader",
"UserRolesDataLoader",
]


Expand Down
30 changes: 30 additions & 0 deletions src/phoenix/server/api/dataloaders/user_roles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from collections import defaultdict
from typing import DefaultDict, List, Optional

from sqlalchemy import select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory

UserRoleId: TypeAlias = int
Key: TypeAlias = UserRoleId
Result: TypeAlias = Optional[models.UserRole]


class UserRolesDataLoader(DataLoader[Key, Result]):
"""DataLoader that batches together user roles by their ids."""

def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: List[Key]) -> List[Result]:
user_roles_by_id: DefaultDict[Key, Result] = defaultdict(None)
async with self._db() as session:
data = await session.stream_scalars(select(models.UserRole))
async for user_role in data:
user_roles_by_id[user_role.id] = user_role

return [user_roles_by_id.get(role_id) for role_id in keys]
33 changes: 33 additions & 0 deletions src/phoenix/server/api/dataloaders/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from collections import defaultdict
from typing import DefaultDict, List, Optional

from sqlalchemy import select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory

UserId: TypeAlias = int
Key: TypeAlias = UserId
Result: TypeAlias = Optional[models.User]


class UsersDataLoader(DataLoader[Key, Result]):
"""DataLoader that batches together users by their ids."""

def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: List[Key]) -> List[Result]:
user_ids = list(set(keys))
users_by_id: DefaultDict[Key, Result] = defaultdict(None)
async with self._db() as session:
data = await session.stream_scalars(
select(models.User).where(models.User.id.in_(user_ids))
)
async for user in data:
users_by_id[user.id] = user

return [users_by_id.get(user_id) for user_id in keys]
3 changes: 1 addition & 2 deletions src/phoenix/server/api/mutations/user_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from phoenix.server.api.input_types.UserRoleInput import UserRoleInput
from phoenix.server.api.mutations.auth import HasSecret, IsAdmin, IsAuthenticated, IsNotReadOnly
from phoenix.server.api.types.User import User
from phoenix.server.api.types.UserRole import UserRole


@strawberry.input
Expand Down Expand Up @@ -78,7 +77,7 @@ async def create_user(
email=user.email,
username=user.username,
created_at=user.created_at,
role=UserRole(id_attr=user.user_role_id, name=role_name),
user_role_id=user.user_role_id,
)
)

Expand Down
5 changes: 1 addition & 4 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,7 @@ async def users(
email=user.email,
username=user.username,
created_at=user.created_at,
role=UserRole(
id_attr=user.role.id,
name=user.role.name,
),
user_role_id=user.user_role_id,
)
async for user in users
]
Expand Down
30 changes: 28 additions & 2 deletions src/phoenix/server/api/types/User.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
from typing import Optional

import strawberry
from strawberry import Private
from strawberry.relay import Node, NodeID
from strawberry.types import Info

from .UserRole import UserRole
from phoenix.db import models
from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import NotFound

from .UserRole import UserRole, to_gql_user_role


@strawberry.type
Expand All @@ -13,4 +19,24 @@ class User(Node):
email: str
username: Optional[str]
created_at: datetime
role: UserRole
user_role_id: Private[int]

@strawberry.field
async def role(self, info: Info[Context, None]) -> UserRole:
role = await info.context.data_loaders.user_roles.load(self.user_role_id)
if role is None:
raise NotFound(f"User role with id {self.user_role_id} not found")
return to_gql_user_role(role)


def to_gql_user(user: models.User) -> User:
"""
Converts an ORM user to a GraphQL user.
"""
return User(
id_attr=user.id,
username=user.username,
email=user.email,
created_at=user.created_at,
user_role_id=user.user_role_id,
)
12 changes: 12 additions & 0 deletions src/phoenix/server/api/types/UserApiKey.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import strawberry
from strawberry import Private
from strawberry.relay.types import Node, NodeID
from strawberry.types import Info

from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import NotFound

from .ApiKey import ApiKey
from .User import User, to_gql_user


@strawberry.type
class UserApiKey(ApiKey, Node):
id_attr: NodeID[int]
user_id: Private[int]

@strawberry.field
async def user(self, info: Info[Context, None]) -> User:
user = await info.context.data_loaders.users.load(self.user_id)
if user is None:
raise NotFound(f"User with id {self.user_id} not found")
return to_gql_user(user)
7 changes: 7 additions & 0 deletions src/phoenix/server/api/types/UserRole.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import strawberry
from strawberry.relay import Node, NodeID

from phoenix.db import models


@strawberry.type
class UserRole(Node):
id_attr: NodeID[int]
name: str


def to_gql_user_role(role: models.UserRole) -> UserRole:
"""Convert an ORM user role to a GraphQL user role."""
return UserRole(id_attr=role.id, name=role.name)
4 changes: 4 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
SpanProjectsDataLoader,
TokenCountDataLoader,
TraceRowIdsDataLoader,
UserRolesDataLoader,
UsersDataLoader,
)
from phoenix.server.api.routers.v1 import REST_API_VERSION
from phoenix.server.api.routers.v1 import router as v1_router
Expand Down Expand Up @@ -528,6 +530,8 @@ def get_context() -> Context:
),
trace_row_ids=TraceRowIdsDataLoader(db),
project_by_name=ProjectByNameDataLoader(db),
users=UsersDataLoader(db),
user_roles=UserRolesDataLoader(db),
),
cache_for_dataloaders=cache_for_dataloaders,
read_only=read_only,
Expand Down

0 comments on commit f8bbf25

Please sign in to comment.