Skip to content

Commit

Permalink
use dataloader for span annotations (#4136)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Aug 6, 2024
1 parent 9550a7e commit ab53325
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class DataLoaders:
latency_ms_quantile: LatencyMsQuantileDataLoader
min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
record_counts: RecordCountDataLoader
span_annotations: SpanAnnotationsDataLoader
span_dataset_examples: SpanDatasetExamplesDataLoader
span_descendants: SpanDescendantsDataLoader
span_evaluations: SpanEvaluationsDataLoader
Expand All @@ -62,7 +63,6 @@ class DataLoaders:
trace_evaluations: TraceEvaluationsDataLoader
trace_row_ids: TraceRowIdsDataLoader
project_by_name: ProjectByNameDataLoader
span_annotations: SpanAnnotationsDataLoader


ProjectRowId: TypeAlias = int
Expand Down
15 changes: 6 additions & 9 deletions src/phoenix/server/api/dataloaders/span_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation
from phoenix.db.models import SpanAnnotation as ORMSpanAnnotation
from phoenix.server.types import DbSessionFactory

Key: TypeAlias = int
Result: TypeAlias = List[SpanAnnotation]
Result: TypeAlias = List[ORMSpanAnnotation]


class SpanAnnotationsDataLoader(DataLoader[Key, Result]):
Expand All @@ -23,11 +22,9 @@ def __init__(self, db: DbSessionFactory) -> None:

async def _load_fn(self, keys: List[Key]) -> List[Result]:
span_annotations_by_id: DefaultDict[Key, Result] = defaultdict(list)
msa = models.SpanAnnotation
async with self._db() as session:
data = await session.stream_scalars(select(msa).where(msa.span_rowid.in_(keys)))
async for span_annotation in data:
span_annotations_by_id[span_annotation.span_rowid].append(
to_gql_span_annotation(span_annotation)
)
async for span_annotation in await session.stream_scalars(
select(ORMSpanAnnotation).where(ORMSpanAnnotation.span_rowid.in_(keys))
):
span_annotations_by_id[span_annotation.span_rowid].append(span_annotation)
return [span_annotations_by_id[key] for key in keys]
27 changes: 14 additions & 13 deletions src/phoenix/server/api/types/Span.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import strawberry
from openinference.semconv.trace import EmbeddingAttributes, SpanAttributes
from sqlalchemy import select
from strawberry import ID, UNSET
from strawberry.relay import Node, NodeID
from strawberry.types import Info
Expand All @@ -20,7 +19,10 @@
get_dataset_example_input,
get_dataset_example_output,
)
from phoenix.server.api.input_types.SpanAnnotationSort import SpanAnnotationSort
from phoenix.server.api.input_types.SpanAnnotationSort import (
SpanAnnotationColumn,
SpanAnnotationSort,
)
from phoenix.server.api.types.SortDir import SortDir
from phoenix.server.api.types.SpanAnnotation import to_gql_span_annotation
from phoenix.trace.attributes import get_attribute_value
Expand Down Expand Up @@ -190,17 +192,16 @@ async def span_annotations(
info: Info[Context, None],
sort: Optional[SpanAnnotationSort] = UNSET,
) -> List[SpanAnnotation]:
async with info.context.db() as session:
stmt = select(models.SpanAnnotation).filter_by(span_rowid=self.id_attr)
if sort:
sort_col = getattr(models.SpanAnnotation, sort.col.value)
if sort.dir is SortDir.desc:
stmt = stmt.order_by(sort_col.desc(), models.SpanAnnotation.id.desc())
else:
stmt = stmt.order_by(sort_col.asc(), models.SpanAnnotation.id.asc())
else:
stmt = stmt.order_by(models.SpanAnnotation.created_at.desc())
annotations = await session.scalars(stmt)
span_id = self.id_attr
annotations = await info.context.data_loaders.span_annotations.load(span_id)
sort_key = SpanAnnotationColumn.createdAt.value
sort_descending = True
if sort:
sort_key = sort.col.value
sort_descending = sort.dir is SortDir.desc
annotations.sort(
key=lambda annotation: getattr(annotation, sort_key), reverse=sort_descending
)
return [to_gql_span_annotation(annotation) for annotation in annotations]

@strawberry.field(
Expand Down

0 comments on commit ab53325

Please sign in to comment.