Skip to content

Commit

Permalink
feat: Add containedInDataset boolean field to gql Spans (#4015)
Browse files Browse the repository at this point in the history
* Add containedInDataset boolean field to gql Spans

* Improve names and resolve missing type annotations

* Build gql
  • Loading branch information
anticorrelator authored Jul 25, 2024
1 parent 0587462 commit 3c096ca
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 0 deletions.
3 changes: 3 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,9 @@ type Span implements Node {

"""The project that this span belongs to."""
project: Project!

"""Indicates if the span is contained in any dataset"""
containedInDataset: Boolean!
}

type SpanAnnotation implements Node & Annotation {
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ProjectByNameDataLoader,
RecordCountDataLoader,
SpanAnnotationsDataLoader,
SpanDatasetExamplesDataLoader,
SpanDescendantsDataLoader,
SpanEvaluationsDataLoader,
SpanProjectsDataLoader,
Expand All @@ -51,6 +52,7 @@ class DataLoaders:
latency_ms_quantile: LatencyMsQuantileDataLoader
min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
record_counts: RecordCountDataLoader
span_dataset_examples: SpanDatasetExamplesDataLoader
span_descendants: SpanDescendantsDataLoader
span_evaluations: SpanEvaluationsDataLoader
span_projects: SpanProjectsDataLoader
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/api/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .project_by_name import ProjectByNameDataLoader
from .record_counts import RecordCountCache, RecordCountDataLoader
from .span_annotations import SpanAnnotationsDataLoader
from .span_dataset_examples import SpanDatasetExamplesDataLoader
from .span_descendants import SpanDescendantsDataLoader
from .span_evaluations import SpanEvaluationsDataLoader
from .span_projects import SpanProjectsDataLoader
Expand All @@ -50,6 +51,7 @@
"LatencyMsQuantileDataLoader",
"MinStartOrMaxEndTimeDataLoader",
"RecordCountDataLoader",
"SpanDatasetExamplesDataLoader",
"SpanDescendantsDataLoader",
"SpanEvaluationsDataLoader",
"SpanProjectsDataLoader",
Expand Down
38 changes: 38 additions & 0 deletions src/phoenix/server/api/dataloaders/span_dataset_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import (
AsyncContextManager,
Callable,
Dict,
List,
)

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models

SpanID: TypeAlias = int
Key: TypeAlias = SpanID
Result: TypeAlias = List[models.DatasetExample]


class SpanDatasetExamplesDataLoader(DataLoader[Key, Result]):
def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: List[Key]) -> List[Result]:
span_rowids = keys
async with self._db() as session:
dataset_examples: Dict[Key, List[models.DatasetExample]] = {
span_rowid: [] for span_rowid in span_rowids
}
async for span_rowid, dataset_example in await session.stream(
select(models.Span.id, models.DatasetExample)
.select_from(models.Span)
.join(models.DatasetExample, models.DatasetExample.span_rowid == models.Span.id)
.where(models.Span.id.in_(span_rowids))
):
dataset_examples[span_rowid].append(dataset_example)
return [dataset_examples.get(span_rowid, []) for span_rowid in span_rowids]
5 changes: 5 additions & 0 deletions src/phoenix/server/api/types/Span.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ async def project(
project = await info.context.data_loaders.span_projects.load(span_id)
return to_gql_project(project)

@strawberry.field(description="Indicates if the span is contained in any dataset") # type: ignore
async def contained_in_dataset(self, info: Info[Context, None]) -> bool:
examples = await info.context.data_loaders.span_dataset_examples.load(self.id_attr)
return bool(examples)


def to_gql_span(span: models.Span) -> Span:
events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events))
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
ProjectByNameDataLoader,
RecordCountDataLoader,
SpanAnnotationsDataLoader,
SpanDatasetExamplesDataLoader,
SpanDescendantsDataLoader,
SpanEvaluationsDataLoader,
SpanProjectsDataLoader,
Expand Down Expand Up @@ -309,6 +310,7 @@ def get_context() -> Context:
cache_map=cache_for_dataloaders.record_count if cache_for_dataloaders else None,
),
span_annotations=SpanAnnotationsDataLoader(db),
span_dataset_examples=SpanDatasetExamplesDataLoader(db),
span_descendants=SpanDescendantsDataLoader(db),
span_evaluations=SpanEvaluationsDataLoader(db),
span_projects=SpanProjectsDataLoader(db),
Expand Down
105 changes: 105 additions & 0 deletions tests/server/api/types/test_Span.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,64 @@ async def test_project_resolver_returns_correct_project(
}


async def test_querying_spans_contained_in_datasets(
httpx_client: httpx.AsyncClient, project_with_a_single_trace_and_span: Any, simple_dataset: Any
):
query = """
query ($spanId: GlobalID!) {
span: node(id: $spanId) {
... on Span {
containedInDataset
}
}
}
"""
span_id = str(GlobalID(Span.__name__, str(1)))
response = await httpx_client.post(
"/graphql",
json={
"query": query,
"variables": {
"spanId": span_id,
},
},
)
assert response.status_code == 200
response_json = response.json()
assert response_json.get("errors") is None
actual_contained_in_dataset = response_json["data"]["span"]["containedInDataset"]
assert actual_contained_in_dataset is True


async def test_querying_spans_not_contained_in_datasets(
httpx_client: httpx.AsyncClient, project_with_a_single_trace_and_span: Any
):
query = """
query ($spanId: GlobalID!) {
span: node(id: $spanId) {
... on Span {
containedInDataset
}
}
}
"""
span_id = str(GlobalID(Span.__name__, str(1)))
response = await httpx_client.post(
"/graphql",
json={
"query": query,
"variables": {
"spanId": span_id,
},
},
)
assert response.status_code == 200
response_json = response.json()
assert response_json.get("errors") is None
actual_contained_in_dataset = response_json["data"]["span"]["containedInDataset"]
assert actual_contained_in_dataset is False


@pytest.fixture
async def project_with_a_single_trace_and_span(
db: Callable[[], AsyncContextManager[AsyncSession]],
Expand Down Expand Up @@ -91,3 +149,50 @@ async def project_with_a_single_trace_and_span(
)
.returning(models.Span.id)
)


@pytest.fixture
async def simple_dataset(
db: Callable[[], AsyncContextManager[AsyncSession]],
) -> None:
"""
A dataset with one example added in one version
"""
async with db() as session:
dataset = models.Dataset(
id=0,
name="simple dataset",
description=None,
metadata_={"info": "a test dataset"},
)
session.add(dataset)
await session.flush()

dataset_version_0 = models.DatasetVersion(
id=0,
dataset_id=0,
description="the first version",
metadata_={"info": "gotta get some test data somewhere"},
)
session.add(dataset_version_0)
await session.flush()

example_0 = models.DatasetExample(
id=0,
dataset_id=0,
span_rowid=1,
)
session.add(example_0)
await session.flush()

example_0_revision_0 = models.DatasetExampleRevision(
id=0,
dataset_example_id=0,
dataset_version_id=0,
input={"in": "foo"},
output={"out": "bar"},
metadata_={"info": "the first reivision"},
revision_kind="CREATE",
)
session.add(example_0_revision_0)
await session.flush()

0 comments on commit 3c096ca

Please sign in to comment.