Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: graphql node interface for trace and spans #3095

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,8 @@ enum SortDir {
desc
}

type Span {
type Span implements Node {
id: GlobalID!
name: String!
statusCode: SpanStatusCode!
statusMessage: String!
Expand Down Expand Up @@ -806,7 +807,8 @@ type TimeSeriesDataPoint {
value: Float
}

type Trace {
type Trace implements Node {
id: GlobalID!
spans(first: Int = 50, last: Int, after: String, before: String): SpanConnection!

"""Evaluations associated with the trace"""
Expand Down
33 changes: 29 additions & 4 deletions src/phoenix/server/api/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy.typing as npt
import strawberry
from sqlalchemy import delete, select
from sqlalchemy.orm import load_only
from sqlalchemy.orm import contains_eager, load_only
from strawberry import ID, UNSET
from strawberry.types import Info
from typing_extensions import Annotated
Expand Down Expand Up @@ -46,6 +46,8 @@
connection_from_list,
)
from phoenix.server.api.types.Project import Project
from phoenix.server.api.types.Span import to_gql_span
from phoenix.server.api.types.Trace import Trace


@strawberry.type
Expand Down Expand Up @@ -102,10 +104,14 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
embedding_dimension = info.context.model.embedding_dimensions[node_id]
return to_gql_embedding_dimension(node_id, embedding_dimension)
elif type_name == "Project":
project_stmt = select(
models.Project.id,
models.Project.name,
models.Project.gradient_start_color,
models.Project.gradient_end_color,
Comment on lines +108 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we manually specify attributes here to reduce the amount of data fetched?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, there are a few more columns than needed.

).where(models.Project.id == node_id)
async with info.context.db() as session:
project = await session.scalar(
select(models.Project).where(models.Project.id == node_id)
)
project = (await session.execute(project_stmt)).first()
if project is None:
raise ValueError(f"Unknown project: {id}")
return Project(
Expand All @@ -114,6 +120,25 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
gradient_start_color=project.gradient_start_color,
gradient_end_color=project.gradient_end_color,
)
elif type_name == "Trace":
trace_stmt = select(models.Trace.id).where(models.Trace.id == node_id)
async with info.context.db() as session:
id_attr = await session.scalar(trace_stmt)
if id_attr is None:
raise ValueError(f"Unknown trace: {id}")
return Trace(id_attr=id_attr)
elif type_name == "Span":
span_stmt = (
select(models.Span)
.join(models.Trace)
.options(contains_eager(models.Span.trace))
.where(models.Span.id == node_id)
)
async with info.context.db() as session:
span = await session.scalar(span_stmt)
if span is None:
raise ValueError(f"Unknown span: {id}")
return to_gql_span(span)
raise Exception(f"Unknown node type: {type_name}")

@strawberry.field
Expand Down
18 changes: 7 additions & 11 deletions src/phoenix/server/api/types/Project.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,15 @@ async def span_latency_ms_quantile(

@strawberry.field
async def trace(self, trace_id: ID, info: Info[Context, None]) -> Optional[Trace]:
stmt = (
select(models.Trace.id)
.where(models.Trace.trace_id == str(trace_id))
.where(models.Trace.project_rowid == self.id_attr)
)
async with info.context.db() as session:
if (
trace_rowid := await session.scalar(
select(models.Trace.id).where(
and_(
models.Trace.trace_id == str(trace_id),
models.Trace.project_rowid == self.id_attr,
)
)
)
) is None:
if (id_attr := await session.scalar(stmt)) is None:
return None
return Trace(trace_rowid=trace_rowid)
return Trace(id_attr=id_attr)

@strawberry.field
async def spans(
Expand Down
12 changes: 6 additions & 6 deletions src/phoenix/server/api/types/Span.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics
from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation
from phoenix.server.api.types.MimeType import MimeType
from phoenix.server.api.types.node import Node
from phoenix.trace.attributes import get_attribute_value

EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS
Expand Down Expand Up @@ -95,8 +96,7 @@ def from_dict(


@strawberry.type
class Span:
span_rowid: strawberry.Private[int]
class Span(Node):
name: str
status_code: SpanStatusCode
status_message: str
Expand Down Expand Up @@ -144,7 +144,7 @@ class Span:
"respect to its input."
) # type: ignore
async def span_evaluations(self, info: Info[Context, None]) -> List[SpanEvaluation]:
return await info.context.data_loaders.span_evaluations.load(self.span_rowid)
return await info.context.data_loaders.span_evaluations.load(self.id_attr)

@strawberry.field(
description="Evaluations of the documents associated with the span, e.g. "
Expand All @@ -155,7 +155,7 @@ async def span_evaluations(self, info: Info[Context, None]) -> List[SpanEvaluati
"index in that list."
) # type: ignore
async def document_evaluations(self, info: Info[Context, None]) -> List[DocumentEvaluation]:
return await info.context.data_loaders.document_evaluations.load(self.span_rowid)
return await info.context.data_loaders.document_evaluations.load(self.id_attr)

@strawberry.field(
description="Retrieval metrics: NDCG@K, Precision@K, Reciprocal Rank, etc.",
Expand All @@ -168,7 +168,7 @@ async def document_retrieval_metrics(
if not self.num_documents:
return []
return await info.context.data_loaders.document_retrieval_metrics.load(
(self.span_rowid, evaluation_name or None, self.num_documents),
(self.id_attr, evaluation_name or None, self.num_documents),
)

@strawberry.field(
Expand All @@ -190,7 +190,7 @@ def to_gql_span(span: models.Span) -> Span:
retrieval_documents = get_attribute_value(span.attributes, RETRIEVAL_DOCUMENTS)
num_documents = len(retrieval_documents) if isinstance(retrieval_documents, Sized) else None
return Span(
span_rowid=span.id,
id_attr=span.id,
name=span.name,
status_code=SpanStatusCode(span.status_code),
status_message=span.status_message,
Expand Down
9 changes: 4 additions & 5 deletions src/phoenix/server/api/types/Trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from phoenix.db import models
from phoenix.server.api.context import Context
from phoenix.server.api.types.Evaluation import TraceEvaluation
from phoenix.server.api.types.node import Node
from phoenix.server.api.types.pagination import (
Connection,
ConnectionArgs,
Expand All @@ -19,9 +20,7 @@


@strawberry.type
class Trace:
trace_rowid: strawberry.Private[int]

class Trace(Node):
@strawberry.field
async def spans(
self,
Expand All @@ -40,7 +39,7 @@ async def spans(
stmt = (
select(models.Span)
.join(models.Trace)
.where(models.Trace.id == self.trace_rowid)
.where(models.Trace.id == self.id_attr)
.options(contains_eager(models.Span.trace))
# Sort descending because the root span tends to show up later
# in the ingestion process.
Expand All @@ -54,4 +53,4 @@ async def spans(

@strawberry.field(description="Evaluations associated with the trace") # type: ignore
async def trace_evaluations(self, info: Info[Context, None]) -> List[TraceEvaluation]:
return await info.context.data_loaders.trace_evaluations.load(self.trace_rowid)
return await info.context.data_loaders.trace_evaluations.load(self.id_attr)
Loading