From 32584b9b4e7ffc11af974c964b624aa5f214560b Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 11 Apr 2024 16:44:25 -0700 Subject: [PATCH] convert graphql api to pull trace evals from a database --- src/phoenix/server/api/context.py | 3 +- .../server/api/dataloaders/__init__.py | 2 + .../api/dataloaders/trace_evaluations.py | 39 +++++++++++++++++++ src/phoenix/server/api/types/Evaluation.py | 24 ++++++------ src/phoenix/server/api/types/Project.py | 20 ++++++++-- src/phoenix/server/api/types/Trace.py | 14 +++---- src/phoenix/server/app.py | 2 + 7 files changed, 80 insertions(+), 24 deletions(-) create mode 100644 src/phoenix/server/api/dataloaders/trace_evaluations.py diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index e78d59eec55..0d078a157f4 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -11,7 +11,7 @@ from phoenix.core.model_schema import Model from phoenix.core.traces import Traces from phoenix.server.api.input_types.TimeRange import TimeRange -from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation +from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation, TraceEvaluation @dataclass @@ -19,6 +19,7 @@ class DataLoaders: latency_ms_quantile: DataLoader[Tuple[str, Optional[TimeRange], float], Optional[float]] span_evaluations: DataLoader[int, List[SpanEvaluation]] document_evaluations: DataLoader[int, List[DocumentEvaluation]] + trace_evaluations: DataLoader[int, List[TraceEvaluation]] @dataclass diff --git a/src/phoenix/server/api/dataloaders/__init__.py b/src/phoenix/server/api/dataloaders/__init__.py index 58cf63d3e37..07748f5e291 100644 --- a/src/phoenix/server/api/dataloaders/__init__.py +++ b/src/phoenix/server/api/dataloaders/__init__.py @@ -1,9 +1,11 @@ from .document_evaluations import DocumentEvaluationsDataLoader from .latency_ms_quantile import LatencyMsQuantileDataLoader from .span_evaluations import SpanEvaluationsDataLoader +from .trace_evaluations import TraceEvaluationsDataLoader __all__ = [ "DocumentEvaluationsDataLoader", "LatencyMsQuantileDataLoader", "SpanEvaluationsDataLoader", + "TraceEvaluationsDataLoader", ] diff --git a/src/phoenix/server/api/dataloaders/trace_evaluations.py b/src/phoenix/server/api/dataloaders/trace_evaluations.py new file mode 100644 index 00000000000..85db431041f --- /dev/null +++ b/src/phoenix/server/api/dataloaders/trace_evaluations.py @@ -0,0 +1,39 @@ +from collections import defaultdict +from typing import ( + AsyncContextManager, + Callable, + DefaultDict, + List, +) + +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry.dataloader import DataLoader +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.server.api.types.Evaluation import TraceEvaluation + +Key: TypeAlias = int + + +class TraceEvaluationsDataLoader(DataLoader[Key, List[TraceEvaluation]]): + def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None: + super().__init__(load_fn=self._load_fn) + self._db = db + + async def _load_fn(self, keys: List[Key]) -> List[List[TraceEvaluation]]: + trace_evaluations_by_id: DefaultDict[Key, List[TraceEvaluation]] = defaultdict(list) + async with self._db() as session: + for trace_evaluation in await session.scalars( + select(models.TraceAnnotation).where( + and_( + models.TraceAnnotation.trace_rowid.in_(keys), + models.TraceAnnotation.annotator_kind == "LLM", + ) + ) + ): + trace_evaluations_by_id[trace_evaluation.trace_rowid].append( + TraceEvaluation.from_sql_trace_annotation(trace_evaluation) + ) + return [trace_evaluations_by_id[key] for key in keys] diff --git a/src/phoenix/server/api/types/Evaluation.py b/src/phoenix/server/api/types/Evaluation.py index 15cf50f28f5..09934d4f1d9 100644 --- a/src/phoenix/server/api/types/Evaluation.py +++ b/src/phoenix/server/api/types/Evaluation.py @@ -1,12 +1,9 @@ -from typing import TYPE_CHECKING, Optional +from typing import Optional import strawberry import phoenix.trace.v1 as pb -from phoenix.trace.schemas import TraceID - -if TYPE_CHECKING: - from phoenix.db.models import DocumentAnnotation, SpanAnnotation +from phoenix.db.models import DocumentAnnotation, SpanAnnotation, TraceAnnotation @strawberry.interface @@ -29,21 +26,26 @@ class Evaluation: @strawberry.type class TraceEvaluation(Evaluation): - trace_id: strawberry.Private[TraceID] - @staticmethod def from_pb_evaluation(evaluation: pb.Evaluation) -> "TraceEvaluation": result = evaluation.result score = result.score.value if result.HasField("score") else None label = result.label.value if result.HasField("label") else None explanation = result.explanation.value if result.HasField("explanation") else None - trace_id = TraceID(evaluation.subject_id.trace_id) return TraceEvaluation( name=evaluation.name, score=score, label=label, explanation=explanation, - trace_id=trace_id, + ) + + @staticmethod + def from_sql_trace_annotation(annotation: TraceAnnotation) -> "TraceEvaluation": + return TraceEvaluation( + name=annotation.name, + score=annotation.score, + label=annotation.label, + explanation=annotation.explanation, ) @@ -63,7 +65,7 @@ def from_pb_evaluation(evaluation: pb.Evaluation) -> "SpanEvaluation": ) @staticmethod - def from_sql_span_annotation(annotation: "SpanAnnotation") -> "SpanEvaluation": + def from_sql_span_annotation(annotation: SpanAnnotation) -> "SpanEvaluation": return SpanEvaluation( name=annotation.name, score=annotation.score, @@ -96,7 +98,7 @@ def from_pb_evaluation(evaluation: pb.Evaluation) -> "DocumentEvaluation": ) @staticmethod - def from_sql_document_annotation(annotation: "DocumentAnnotation") -> "DocumentEvaluation": + def from_sql_document_annotation(annotation: DocumentAnnotation) -> "DocumentEvaluation": return DocumentEvaluation( name=annotation.name, score=annotation.score, diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index c59a3875692..e46011536fd 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -146,10 +146,22 @@ async def latency_ms_quantile( ) @strawberry.field - def trace(self, trace_id: ID) -> Optional[Trace]: - if self.project.has_trace(TraceID(trace_id)): - return Trace(trace_id=trace_id, project=self.project) - return None + async def trace(self, trace_id: ID, info: Info[Context, None]) -> Optional[Trace]: + async with info.context.db() as session: + if ( + trace_rowid := await session.scalar( + select(models.Trace.id) + .join(models.Project) + .where( + and_( + models.Trace.trace_id == TraceID(trace_id), + models.Project.name == self.name, + ) + ) + ) + ) is None: + return None + return Trace(trace_rowid=trace_rowid, project=self.project) @strawberry.field async def spans( diff --git a/src/phoenix/server/api/types/Trace.py b/src/phoenix/server/api/types/Trace.py index 075928ba8a3..29a16edbed5 100644 --- a/src/phoenix/server/api/types/Trace.py +++ b/src/phoenix/server/api/types/Trace.py @@ -3,7 +3,7 @@ import strawberry from sqlalchemy import select from sqlalchemy.orm import contains_eager -from strawberry import ID, UNSET, Private +from strawberry import UNSET from strawberry.types import Info from phoenix.core.project import Project @@ -17,13 +17,12 @@ connection_from_list, ) from phoenix.server.api.types.Span import Span, to_gql_span -from phoenix.trace.schemas import TraceID @strawberry.type class Trace: - trace_id: ID - project: Private[Project] + trace_rowid: strawberry.Private[int] + project: strawberry.Private[Project] @strawberry.field async def spans( @@ -44,13 +43,12 @@ async def spans( spans = await session.scalars( select(models.Span) .join(models.Trace) - .filter(models.Trace.trace_id == self.trace_id) + .filter(models.Trace.id == self.trace_rowid) .options(contains_eager(models.Span.trace)) ) data = [to_gql_span(span, self.project) for span in spans] return connection_from_list(data=data, args=args) @strawberry.field(description="Evaluations associated with the trace") # type: ignore - def trace_evaluations(self) -> List[TraceEvaluation]: - evaluations = self.project.get_evaluations_by_trace_id(TraceID(self.trace_id)) - return [TraceEvaluation.from_pb_evaluation(evaluation) for evaluation in evaluations] + async def trace_evaluations(self, info: Info[Context, None]) -> List[TraceEvaluation]: + return await info.context.data_loaders.trace_evaluations.load(self.trace_rowid) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 2a125b57e1e..31a13e6a924 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -48,6 +48,7 @@ DocumentEvaluationsDataLoader, LatencyMsQuantileDataLoader, SpanEvaluationsDataLoader, + TraceEvaluationsDataLoader, ) from phoenix.server.api.routers.evaluation_handler import EvaluationHandler from phoenix.server.api.routers.span_handler import SpanHandler @@ -153,6 +154,7 @@ async def get_context( latency_ms_quantile=LatencyMsQuantileDataLoader(self.db), span_evaluations=SpanEvaluationsDataLoader(self.db), document_evaluations=DocumentEvaluationsDataLoader(self.db), + trace_evaluations=TraceEvaluationsDataLoader(self.db), ), )