Skip to content

Commit

Permalink
feat(persistence): dataloader for document retrieval metrics (#2978)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang authored Apr 25, 2024
1 parent 3e7cbad commit f55c458
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 29 deletions.
4 changes: 4 additions & 0 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from phoenix.core.model_schema import Model
from phoenix.server.api.input_types.TimeRange import TimeRange
from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation, TraceEvaluation


Expand All @@ -20,6 +21,9 @@ class DataLoaders:
span_evaluations: DataLoader[int, List[SpanEvaluation]]
document_evaluations: DataLoader[int, List[DocumentEvaluation]]
trace_evaluations: DataLoader[int, List[TraceEvaluation]]
document_retrieval_metrics: DataLoader[
Tuple[int, Optional[str], int], List[DocumentRetrievalMetrics]
]


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/api/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .document_evaluations import DocumentEvaluationsDataLoader
from .document_retrieval_metrics import DocumentRetrievalMetricsDataLoader
from .latency_ms_quantile import LatencyMsQuantileDataLoader
from .span_evaluations import SpanEvaluationsDataLoader
from .trace_evaluations import TraceEvaluationsDataLoader
Expand All @@ -8,4 +9,5 @@
"LatencyMsQuantileDataLoader",
"SpanEvaluationsDataLoader",
"TraceEvaluationsDataLoader",
"DocumentRetrievalMetricsDataLoader",
]
96 changes: 96 additions & 0 deletions src/phoenix/server/api/dataloaders/document_retrieval_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from collections import defaultdict
from itertools import groupby
from typing import (
AsyncContextManager,
Callable,
DefaultDict,
Dict,
List,
Optional,
Set,
Tuple,
)

import numpy as np
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.metrics.retrieval_metrics import RetrievalMetrics
from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics

RowId: TypeAlias = int
NumDocs: TypeAlias = int
EvalName: TypeAlias = Optional[str]
Key: TypeAlias = Tuple[RowId, EvalName, NumDocs]


class DocumentRetrievalMetricsDataLoader(DataLoader[Key, List[DocumentRetrievalMetrics]]):
def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: List[Key]) -> List[List[DocumentRetrievalMetrics]]:
mda = models.DocumentAnnotation
stmt = (
select(
mda.span_rowid,
mda.name,
mda.score,
mda.document_position,
)
.where(mda.score != None) # noqa: E711
.where(mda.annotator_kind == "LLM")
.where(mda.document_position >= 0)
.order_by(mda.span_rowid, mda.name)
)
# Using CTE with VALUES clause is possible in SQLite, but not in
# SQLAlchemy v2.0.29, hence the workaround below with over-fetching.
# We could use CTE with VALUES for postgresql, but for now we'll keep
# it simple and just use one approach for all backends.
all_row_ids = {row_id for row_id, _, _ in keys}
stmt = stmt.where(mda.span_rowid.in_(all_row_ids))
all_eval_names = {eval_name for _, eval_name, _ in keys}
if None not in all_eval_names:
stmt = stmt.where(mda.name.in_(all_eval_names))
max_position = max(num_docs for _, _, num_docs in keys)
stmt = stmt.where(mda.document_position < max_position)
async with self._db() as session:
data = await session.execute(stmt)
if not data:
return [[] for _ in keys]
results: Dict[Key, List[DocumentRetrievalMetrics]] = {key: [] for key in keys}
requested_num_docs: DefaultDict[Tuple[RowId, EvalName], Set[NumDocs]] = defaultdict(set)
for row_id, eval_name, num_docs in results.keys():
requested_num_docs[(row_id, eval_name)].add(num_docs)
for (span_rowid, name), group in groupby(data, lambda r: (r.span_rowid, r.name)):
# We need to fulfill two types of potential requests: 1. when it
# specifies an evaluation name, and 2. when it doesn't care about
# the evaluation name by specifying None.
max_requested_num_docs = max(
(
num_docs
for eval_name in (name, None)
for num_docs in (requested_num_docs.get((span_rowid, eval_name)) or ())
),
default=0,
)
if max_requested_num_docs <= 0:
# We have over-fetched. Skip this group.
continue
scores = [np.nan] * max_requested_num_docs
for row in group:
# Length check is necessary due to over-fetching.
if row.document_position < len(scores):
scores[row.document_position] = row.score
for eval_name in (name, None):
for num_docs in requested_num_docs.get((span_rowid, eval_name)) or ():
metrics = RetrievalMetrics(scores[:num_docs])
doc_metrics = DocumentRetrievalMetrics(evaluation_name=name, metrics=metrics)
key = (span_rowid, eval_name, num_docs)
results[key].append(doc_metrics)
# Make sure to copy the result, so we don't return the same list
# object to two different requesters.
return [results[key].copy() for key in keys]
31 changes: 2 additions & 29 deletions src/phoenix/server/api/types/Span.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
from datetime import datetime
from enum import Enum
from itertools import groupby
from typing import Any, List, Mapping, Optional, Sized, cast

import numpy as np
Expand All @@ -14,7 +13,6 @@

import phoenix.trace.schemas as trace_schema
from phoenix.db import models
from phoenix.metrics.retrieval_metrics import RetrievalMetrics
from phoenix.server.api.context import Context
from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation
Expand Down Expand Up @@ -171,34 +169,9 @@ async def document_retrieval_metrics(
) -> List[DocumentRetrievalMetrics]:
if not self.num_documents:
return []
mda = models.DocumentAnnotation
stmt = (
select(mda.name, mda.score, mda.document_position)
.where(mda.score != None) # noqa: E711
.where(mda.span_rowid == self.span_rowid)
.where(mda.document_position >= 0)
.where(mda.document_position < self.num_documents)
.where(mda.annotator_kind == "LLM")
.order_by(mda.name)
return await info.context.data_loaders.document_retrieval_metrics.load(
(self.span_rowid, evaluation_name or None, self.num_documents),
)
if evaluation_name:
stmt = stmt.where(mda.name == evaluation_name)
async with info.context.db() as session:
rows = await session.execute(stmt)
if not rows:
return []
retrieval_metrics = []
for name, group in groupby(rows, lambda r: r.name):
scores: List[float] = [np.nan] * self.num_documents
for row in group:
scores[row.document_position] = row.score
retrieval_metrics.append(
DocumentRetrievalMetrics(
evaluation_name=name,
metrics=RetrievalMetrics(scores),
)
)
return retrieval_metrics

@strawberry.field(
description="All descendant spans (children, grandchildren, etc.)",
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from phoenix.server.api.context import Context, DataLoaders
from phoenix.server.api.dataloaders import (
DocumentEvaluationsDataLoader,
DocumentRetrievalMetricsDataLoader,
LatencyMsQuantileDataLoader,
SpanEvaluationsDataLoader,
TraceEvaluationsDataLoader,
Expand Down Expand Up @@ -157,6 +158,7 @@ async def get_context(
span_evaluations=SpanEvaluationsDataLoader(self.db),
document_evaluations=DocumentEvaluationsDataLoader(self.db),
trace_evaluations=TraceEvaluationsDataLoader(self.db),
document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(self.db),
),
)

Expand Down

0 comments on commit f55c458

Please sign in to comment.