diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index 3ca73d5142..76241d1b16 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -11,6 +11,7 @@ from phoenix.core.model_schema import Model from phoenix.server.api.input_types.TimeRange import TimeRange +from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation, TraceEvaluation @@ -20,6 +21,9 @@ class DataLoaders: span_evaluations: DataLoader[int, List[SpanEvaluation]] document_evaluations: DataLoader[int, List[DocumentEvaluation]] trace_evaluations: DataLoader[int, List[TraceEvaluation]] + document_retrieval_metrics: DataLoader[ + Tuple[int, Optional[str], int], List[DocumentRetrievalMetrics] + ] @dataclass diff --git a/src/phoenix/server/api/dataloaders/__init__.py b/src/phoenix/server/api/dataloaders/__init__.py index 07748f5e29..efa1acf78b 100644 --- a/src/phoenix/server/api/dataloaders/__init__.py +++ b/src/phoenix/server/api/dataloaders/__init__.py @@ -1,4 +1,5 @@ from .document_evaluations import DocumentEvaluationsDataLoader +from .document_retrieval_metrics import DocumentRetrievalMetricsDataLoader from .latency_ms_quantile import LatencyMsQuantileDataLoader from .span_evaluations import SpanEvaluationsDataLoader from .trace_evaluations import TraceEvaluationsDataLoader @@ -8,4 +9,5 @@ "LatencyMsQuantileDataLoader", "SpanEvaluationsDataLoader", "TraceEvaluationsDataLoader", + "DocumentRetrievalMetricsDataLoader", ] diff --git a/src/phoenix/server/api/dataloaders/document_retrieval_metrics.py b/src/phoenix/server/api/dataloaders/document_retrieval_metrics.py new file mode 100644 index 0000000000..5af058a12e --- /dev/null +++ b/src/phoenix/server/api/dataloaders/document_retrieval_metrics.py @@ -0,0 +1,96 @@ +from collections import defaultdict +from itertools import groupby +from typing import ( + AsyncContextManager, + Callable, + DefaultDict, + Dict, + List, + Optional, + Set, + Tuple, +) + +import numpy as np +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry.dataloader import DataLoader +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.metrics.retrieval_metrics import RetrievalMetrics +from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics + +RowId: TypeAlias = int +NumDocs: TypeAlias = int +EvalName: TypeAlias = Optional[str] +Key: TypeAlias = Tuple[RowId, EvalName, NumDocs] + + +class DocumentRetrievalMetricsDataLoader(DataLoader[Key, List[DocumentRetrievalMetrics]]): + def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None: + super().__init__(load_fn=self._load_fn) + self._db = db + + async def _load_fn(self, keys: List[Key]) -> List[List[DocumentRetrievalMetrics]]: + mda = models.DocumentAnnotation + stmt = ( + select( + mda.span_rowid, + mda.name, + mda.score, + mda.document_position, + ) + .where(mda.score != None) # noqa: E711 + .where(mda.annotator_kind == "LLM") + .where(mda.document_position >= 0) + .order_by(mda.span_rowid, mda.name) + ) + # Using CTE with VALUES clause is possible in SQLite, but not in + # SQLAlchemy v2.0.29, hence the workaround below with over-fetching. + # We could use CTE with VALUES for postgresql, but for now we'll keep + # it simple and just use one approach for all backends. + all_row_ids = {row_id for row_id, _, _ in keys} + stmt = stmt.where(mda.span_rowid.in_(all_row_ids)) + all_eval_names = {eval_name for _, eval_name, _ in keys} + if None not in all_eval_names: + stmt = stmt.where(mda.name.in_(all_eval_names)) + max_position = max(num_docs for _, _, num_docs in keys) + stmt = stmt.where(mda.document_position < max_position) + async with self._db() as session: + data = await session.execute(stmt) + if not data: + return [[] for _ in keys] + results: Dict[Key, List[DocumentRetrievalMetrics]] = {key: [] for key in keys} + requested_num_docs: DefaultDict[Tuple[RowId, EvalName], Set[NumDocs]] = defaultdict(set) + for row_id, eval_name, num_docs in results.keys(): + requested_num_docs[(row_id, eval_name)].add(num_docs) + for (span_rowid, name), group in groupby(data, lambda r: (r.span_rowid, r.name)): + # We need to fulfill two types of potential requests: 1. when it + # specifies an evaluation name, and 2. when it doesn't care about + # the evaluation name by specifying None. + max_requested_num_docs = max( + ( + num_docs + for eval_name in (name, None) + for num_docs in (requested_num_docs.get((span_rowid, eval_name)) or ()) + ), + default=0, + ) + if max_requested_num_docs <= 0: + # We have over-fetched. Skip this group. + continue + scores = [np.nan] * max_requested_num_docs + for row in group: + # Length check is necessary due to over-fetching. + if row.document_position < len(scores): + scores[row.document_position] = row.score + for eval_name in (name, None): + for num_docs in requested_num_docs.get((span_rowid, eval_name)) or (): + metrics = RetrievalMetrics(scores[:num_docs]) + doc_metrics = DocumentRetrievalMetrics(evaluation_name=name, metrics=metrics) + key = (span_rowid, eval_name, num_docs) + results[key].append(doc_metrics) + # Make sure to copy the result, so we don't return the same list + # object to two different requesters. + return [results[key].copy() for key in keys] diff --git a/src/phoenix/server/api/types/Span.py b/src/phoenix/server/api/types/Span.py index 29fd4dce88..e99c079d8a 100644 --- a/src/phoenix/server/api/types/Span.py +++ b/src/phoenix/server/api/types/Span.py @@ -1,7 +1,6 @@ import json from datetime import datetime from enum import Enum -from itertools import groupby from typing import Any, List, Mapping, Optional, Sized, cast import numpy as np @@ -14,7 +13,6 @@ import phoenix.trace.schemas as trace_schema from phoenix.db import models -from phoenix.metrics.retrieval_metrics import RetrievalMetrics from phoenix.server.api.context import Context from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation @@ -171,34 +169,9 @@ async def document_retrieval_metrics( ) -> List[DocumentRetrievalMetrics]: if not self.num_documents: return [] - mda = models.DocumentAnnotation - stmt = ( - select(mda.name, mda.score, mda.document_position) - .where(mda.score != None) # noqa: E711 - .where(mda.span_rowid == self.span_rowid) - .where(mda.document_position >= 0) - .where(mda.document_position < self.num_documents) - .where(mda.annotator_kind == "LLM") - .order_by(mda.name) + return await info.context.data_loaders.document_retrieval_metrics.load( + (self.span_rowid, evaluation_name or None, self.num_documents), ) - if evaluation_name: - stmt = stmt.where(mda.name == evaluation_name) - async with info.context.db() as session: - rows = await session.execute(stmt) - if not rows: - return [] - retrieval_metrics = [] - for name, group in groupby(rows, lambda r: r.name): - scores: List[float] = [np.nan] * self.num_documents - for row in group: - scores[row.document_position] = row.score - retrieval_metrics.append( - DocumentRetrievalMetrics( - evaluation_name=name, - metrics=RetrievalMetrics(scores), - ) - ) - return retrieval_metrics @strawberry.field( description="All descendant spans (children, grandchildren, etc.)", diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 288a77fd38..42e6aff0a5 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -47,6 +47,7 @@ from phoenix.server.api.context import Context, DataLoaders from phoenix.server.api.dataloaders import ( DocumentEvaluationsDataLoader, + DocumentRetrievalMetricsDataLoader, LatencyMsQuantileDataLoader, SpanEvaluationsDataLoader, TraceEvaluationsDataLoader, @@ -157,6 +158,7 @@ async def get_context( span_evaluations=SpanEvaluationsDataLoader(self.db), document_evaluations=DocumentEvaluationsDataLoader(self.db), trace_evaluations=TraceEvaluationsDataLoader(self.db), + document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(self.db), ), )