Skip to content

Commit

Permalink
feat: convert graphql api to pull trace evaluations from db (#2867)
Browse files Browse the repository at this point in the history
makes graphql api pull trace evaluations from db
  • Loading branch information
axiomofjoy authored Apr 14, 2024
1 parent 73ca2d7 commit 11aa455
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 28 deletions.
1 change: 0 additions & 1 deletion app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,6 @@ type TimeSeriesDataPoint {
}

type Trace {
traceId: ID!
spans(first: Int = 50, last: Int, after: String, before: String): SpanConnection!

"""Evaluations associated with the trace"""
Expand Down
3 changes: 2 additions & 1 deletion src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
from phoenix.core.model_schema import Model
from phoenix.core.traces import Traces
from phoenix.server.api.input_types.TimeRange import TimeRange
from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation
from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation, TraceEvaluation


@dataclass
class DataLoaders:
latency_ms_quantile: DataLoader[Tuple[int, Optional[TimeRange], float], Optional[float]]
span_evaluations: DataLoader[int, List[SpanEvaluation]]
document_evaluations: DataLoader[int, List[DocumentEvaluation]]
trace_evaluations: DataLoader[int, List[TraceEvaluation]]


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/api/dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .document_evaluations import DocumentEvaluationsDataLoader
from .latency_ms_quantile import LatencyMsQuantileDataLoader
from .span_evaluations import SpanEvaluationsDataLoader
from .trace_evaluations import TraceEvaluationsDataLoader

__all__ = [
"DocumentEvaluationsDataLoader",
"LatencyMsQuantileDataLoader",
"SpanEvaluationsDataLoader",
"TraceEvaluationsDataLoader",
]
39 changes: 39 additions & 0 deletions src/phoenix/server/api/dataloaders/trace_evaluations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from collections import defaultdict
from typing import (
AsyncContextManager,
Callable,
DefaultDict,
List,
)

from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.api.types.Evaluation import TraceEvaluation

Key: TypeAlias = int


class TraceEvaluationsDataLoader(DataLoader[Key, List[TraceEvaluation]]):
def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: List[Key]) -> List[List[TraceEvaluation]]:
trace_evaluations_by_id: DefaultDict[Key, List[TraceEvaluation]] = defaultdict(list)
async with self._db() as session:
for trace_evaluation in await session.scalars(
select(models.TraceAnnotation).where(
and_(
models.TraceAnnotation.trace_rowid.in_(keys),
models.TraceAnnotation.annotator_kind == "LLM",
)
)
):
trace_evaluations_by_id[trace_evaluation.trace_rowid].append(
TraceEvaluation.from_sql_trace_annotation(trace_evaluation)
)
return [trace_evaluations_by_id[key] for key in keys]
24 changes: 13 additions & 11 deletions src/phoenix/server/api/types/Evaluation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import TYPE_CHECKING, Optional
from typing import Optional

import strawberry

import phoenix.trace.v1 as pb
from phoenix.trace.schemas import TraceID

if TYPE_CHECKING:
from phoenix.db.models import DocumentAnnotation, SpanAnnotation
from phoenix.db.models import DocumentAnnotation, SpanAnnotation, TraceAnnotation


@strawberry.interface
Expand All @@ -29,21 +26,26 @@ class Evaluation:

@strawberry.type
class TraceEvaluation(Evaluation):
trace_id: strawberry.Private[TraceID]

@staticmethod
def from_pb_evaluation(evaluation: pb.Evaluation) -> "TraceEvaluation":
result = evaluation.result
score = result.score.value if result.HasField("score") else None
label = result.label.value if result.HasField("label") else None
explanation = result.explanation.value if result.HasField("explanation") else None
trace_id = TraceID(evaluation.subject_id.trace_id)
return TraceEvaluation(
name=evaluation.name,
score=score,
label=label,
explanation=explanation,
trace_id=trace_id,
)

@staticmethod
def from_sql_trace_annotation(annotation: TraceAnnotation) -> "TraceEvaluation":
return TraceEvaluation(
name=annotation.name,
score=annotation.score,
label=annotation.label,
explanation=annotation.explanation,
)


Expand All @@ -63,7 +65,7 @@ def from_pb_evaluation(evaluation: pb.Evaluation) -> "SpanEvaluation":
)

@staticmethod
def from_sql_span_annotation(annotation: "SpanAnnotation") -> "SpanEvaluation":
def from_sql_span_annotation(annotation: SpanAnnotation) -> "SpanEvaluation":
return SpanEvaluation(
name=annotation.name,
score=annotation.score,
Expand Down Expand Up @@ -96,7 +98,7 @@ def from_pb_evaluation(evaluation: pb.Evaluation) -> "DocumentEvaluation":
)

@staticmethod
def from_sql_document_annotation(annotation: "DocumentAnnotation") -> "DocumentEvaluation":
def from_sql_document_annotation(annotation: DocumentAnnotation) -> "DocumentEvaluation":
return DocumentEvaluation(
name=annotation.name,
score=annotation.score,
Expand Down
19 changes: 12 additions & 7 deletions src/phoenix/server/api/types/Project.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,20 @@ async def latency_ms_quantile(
)

@strawberry.field
async def trace(self, info: Info[Context, None], trace_id: ID) -> Optional[Trace]:
async def trace(self, trace_id: ID, info: Info[Context, None]) -> Optional[Trace]:
async with info.context.db() as session:
if not await session.scalar(
select(models.Trace.id)
.where(models.Trace.trace_id == str(trace_id))
.where(models.Trace.project_rowid == self.id_attr),
):
if (
trace_rowid := await session.scalar(
select(models.Trace.id).where(
and_(
models.Trace.trace_id == str(trace_id),
models.Trace.project_rowid == self.id_attr,
)
)
)
) is None:
return None
return Trace(trace_id=trace_id, project=self.project)
return Trace(trace_rowid=trace_rowid, project=self.project)

@strawberry.field
async def spans(
Expand Down
14 changes: 6 additions & 8 deletions src/phoenix/server/api/types/Trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import strawberry
from sqlalchemy import select
from sqlalchemy.orm import contains_eager
from strawberry import ID, UNSET, Private
from strawberry import UNSET
from strawberry.types import Info

from phoenix.core.project import Project
Expand All @@ -17,13 +17,12 @@
connection_from_list,
)
from phoenix.server.api.types.Span import Span, to_gql_span
from phoenix.trace.schemas import TraceID


@strawberry.type
class Trace:
trace_id: ID
project: Private[Project]
trace_rowid: strawberry.Private[int]
project: strawberry.Private[Project]

@strawberry.field
async def spans(
Expand All @@ -44,13 +43,12 @@ async def spans(
spans = await session.scalars(
select(models.Span)
.join(models.Trace)
.filter(models.Trace.trace_id == self.trace_id)
.where(models.Trace.id == self.trace_rowid)
.options(contains_eager(models.Span.trace))
)
data = [to_gql_span(span, self.project) for span in spans]
return connection_from_list(data=data, args=args)

@strawberry.field(description="Evaluations associated with the trace") # type: ignore
def trace_evaluations(self) -> List[TraceEvaluation]:
evaluations = self.project.get_evaluations_by_trace_id(TraceID(self.trace_id))
return [TraceEvaluation.from_pb_evaluation(evaluation) for evaluation in evaluations]
async def trace_evaluations(self, info: Info[Context, None]) -> List[TraceEvaluation]:
return await info.context.data_loaders.trace_evaluations.load(self.trace_rowid)
2 changes: 2 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
DocumentEvaluationsDataLoader,
LatencyMsQuantileDataLoader,
SpanEvaluationsDataLoader,
TraceEvaluationsDataLoader,
)
from phoenix.server.api.routers.v1 import V1_ROUTES
from phoenix.server.api.schema import schema
Expand Down Expand Up @@ -156,6 +157,7 @@ async def get_context(
latency_ms_quantile=LatencyMsQuantileDataLoader(self.db),
span_evaluations=SpanEvaluationsDataLoader(self.db),
document_evaluations=DocumentEvaluationsDataLoader(self.db),
trace_evaluations=TraceEvaluationsDataLoader(self.db),
),
)

Expand Down

0 comments on commit 11aa455

Please sign in to comment.