diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index bd5627e154..e78d59eec5 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -11,13 +11,14 @@ from phoenix.core.model_schema import Model from phoenix.core.traces import Traces from phoenix.server.api.input_types.TimeRange import TimeRange -from phoenix.server.api.types.Evaluation import SpanEvaluation +from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation @dataclass class DataLoaders: latency_ms_quantile: DataLoader[Tuple[str, Optional[TimeRange], float], Optional[float]] span_evaluations: DataLoader[int, List[SpanEvaluation]] + document_evaluations: DataLoader[int, List[DocumentEvaluation]] @dataclass diff --git a/src/phoenix/server/api/dataloaders/__init__.py b/src/phoenix/server/api/dataloaders/__init__.py index 819d2aaf0c..58cf63d3e3 100644 --- a/src/phoenix/server/api/dataloaders/__init__.py +++ b/src/phoenix/server/api/dataloaders/__init__.py @@ -1,7 +1,9 @@ +from .document_evaluations import DocumentEvaluationsDataLoader from .latency_ms_quantile import LatencyMsQuantileDataLoader from .span_evaluations import SpanEvaluationsDataLoader __all__ = [ + "DocumentEvaluationsDataLoader", "LatencyMsQuantileDataLoader", "SpanEvaluationsDataLoader", ] diff --git a/src/phoenix/server/api/dataloaders/document_evaluations.py b/src/phoenix/server/api/dataloaders/document_evaluations.py new file mode 100644 index 0000000000..f45ea60a98 --- /dev/null +++ b/src/phoenix/server/api/dataloaders/document_evaluations.py @@ -0,0 +1,39 @@ +from collections import defaultdict +from typing import ( + AsyncContextManager, + Callable, + DefaultDict, + List, +) + +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry.dataloader import DataLoader +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.server.api.types.Evaluation import DocumentEvaluation + +Key: TypeAlias = int + + +class DocumentEvaluationsDataLoader(DataLoader[Key, List[DocumentEvaluation]]): + def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None: + super().__init__(load_fn=self._load_fn) + self._db = db + + async def _load_fn(self, keys: List[Key]) -> List[List[DocumentEvaluation]]: + document_evaluations_by_id: DefaultDict[Key, List[DocumentEvaluation]] = defaultdict(list) + async with self._db() as session: + for document_evaluation in await session.scalars( + select(models.DocumentAnnotation).where( + and_( + models.DocumentAnnotation.span_rowid.in_(keys), + models.DocumentAnnotation.annotator_kind == "LLM", + ) + ) + ): + document_evaluations_by_id[document_evaluation.span_rowid].append( + DocumentEvaluation.from_sql_document_annotation(document_evaluation) + ) + return [document_evaluations_by_id[key] for key in keys] diff --git a/src/phoenix/server/api/types/Evaluation.py b/src/phoenix/server/api/types/Evaluation.py index 841bf58cc2..15cf50f28f 100644 --- a/src/phoenix/server/api/types/Evaluation.py +++ b/src/phoenix/server/api/types/Evaluation.py @@ -1,10 +1,12 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional import strawberry import phoenix.trace.v1 as pb -from phoenix.db.models import SpanAnnotation -from phoenix.trace.schemas import SpanID, TraceID +from phoenix.trace.schemas import TraceID + +if TYPE_CHECKING: + from phoenix.db.models import DocumentAnnotation, SpanAnnotation @strawberry.interface @@ -72,7 +74,6 @@ def from_sql_span_annotation(annotation: "SpanAnnotation") -> "SpanEvaluation": @strawberry.type class DocumentEvaluation(Evaluation): - span_id: strawberry.Private[SpanID] document_position: int = strawberry.field( description="The zero-based index among retrieved documents, which " "is collected as a list (even when ordering is not inherently meaningful)." @@ -86,12 +87,20 @@ def from_pb_evaluation(evaluation: pb.Evaluation) -> "DocumentEvaluation": explanation = result.explanation.value if result.HasField("explanation") else None document_retrieval_id = evaluation.subject_id.document_retrieval_id document_position = document_retrieval_id.document_position - span_id = SpanID(document_retrieval_id.span_id) return DocumentEvaluation( name=evaluation.name, score=score, label=label, explanation=explanation, document_position=document_position, - span_id=span_id, + ) + + @staticmethod + def from_sql_document_annotation(annotation: "DocumentAnnotation") -> "DocumentEvaluation": + return DocumentEvaluation( + name=annotation.name, + score=annotation.score, + label=annotation.label, + explanation=annotation.explanation, + document_position=annotation.document_index, ) diff --git a/src/phoenix/server/api/types/Span.py b/src/phoenix/server/api/types/Span.py index 1c78dc2c9c..5071737275 100644 --- a/src/phoenix/server/api/types/Span.py +++ b/src/phoenix/server/api/types/Span.py @@ -158,12 +158,8 @@ async def span_evaluations(self, info: Info[Context, None]) -> List[SpanEvaluati "a list, and each evaluation is identified by its document's (zero-based) " "index in that list." ) # type: ignore - def document_evaluations(self) -> List[DocumentEvaluation]: - span_id = SpanID(str(self.context.span_id)) - return [ - DocumentEvaluation.from_pb_evaluation(evaluation) - for evaluation in self.project.get_document_evaluations_by_span_id(span_id) - ] + async def document_evaluations(self, info: Info[Context, None]) -> List[DocumentEvaluation]: + return await info.context.data_loaders.document_evaluations.load(self.span_rowid) @strawberry.field( description="Retrieval metrics: NDCG@K, Precision@K, Reciprocal Rank, etc.", diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 66ff92d527..2a125b57e1 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -45,6 +45,7 @@ from phoenix.pointcloud.umap_parameters import UMAPParameters from phoenix.server.api.context import Context, DataLoaders from phoenix.server.api.dataloaders import ( + DocumentEvaluationsDataLoader, LatencyMsQuantileDataLoader, SpanEvaluationsDataLoader, ) @@ -151,6 +152,7 @@ async def get_context( data_loaders=DataLoaders( latency_ms_quantile=LatencyMsQuantileDataLoader(self.db), span_evaluations=SpanEvaluationsDataLoader(self.db), + document_evaluations=DocumentEvaluationsDataLoader(self.db), ), )