Skip to content

Commit

Permalink
feat: Add span annotations to dataset example metadata (#4123)
Browse files Browse the repository at this point in the history
  • Loading branch information
anticorrelator authored Aug 3, 2024
1 parent d49212e commit a16dd57
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
36 changes: 35 additions & 1 deletion src/phoenix/server/api/mutations/dataset_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,36 @@ async def add_spans_to_dataset(
raise ValueError(
f"Could not find spans with rowids: {', '.join(map(str, missing_span_rowids))}"
) # todo: implement error handling types https://github.com/Arize-ai/phoenix/issues/3221

span_annotations = (
await session.execute(
select(
models.SpanAnnotation.span_rowid,
models.SpanAnnotation.name,
models.SpanAnnotation.label,
models.SpanAnnotation.score,
models.SpanAnnotation.explanation,
models.SpanAnnotation.metadata_,
models.SpanAnnotation.annotator_kind,
)
.select_from(models.SpanAnnotation)
.where(models.SpanAnnotation.span_rowid.in_(span_rowids))
)
).all()

span_annotations_by_span: Dict[int, Dict[Any, Any]] = {span.id: {} for span in spans}
for annotation in span_annotations:
span_id = annotation.span_rowid
if span_id not in span_annotations_by_span:
span_annotations_by_span[span_id] = dict()
span_annotations_by_span[span_id][annotation.name] = {
"label": annotation.label,
"score": annotation.score,
"explanation": annotation.explanation,
"metadata": annotation.metadata_,
"annotator_kind": annotation.annotator_kind,
}

DatasetExample = models.DatasetExample
dataset_example_rowids = (
await session.scalars(
Expand All @@ -170,6 +200,7 @@ async def add_spans_to_dataset(
assert len(dataset_example_rowids) == len(spans)
assert all(map(lambda id: isinstance(id, int), dataset_example_rowids))
DatasetExampleRevision = models.DatasetExampleRevision

await session.execute(
insert(DatasetExampleRevision),
[
Expand All @@ -178,7 +209,10 @@ async def add_spans_to_dataset(
DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid,
DatasetExampleRevision.input.key: get_dataset_example_input(span),
DatasetExampleRevision.output.key: get_dataset_example_output(span),
DatasetExampleRevision.metadata_.key: span.attributes,
DatasetExampleRevision.metadata_.key: {
**span.attributes,
"annotations": span_annotations_by_span[span.id],
},
DatasetExampleRevision.revision_kind.key: "CREATE",
}
for dataset_example_rowid, span in zip(dataset_example_rowids, spans)
Expand Down
29 changes: 28 additions & 1 deletion tests/server/api/mutations/test_dataset_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ async def test_add_span_to_dataset(
httpx_client: httpx.AsyncClient,
empty_dataset,
spans,
span_annotation,
) -> None:
dataset_id = GlobalID(type_name="Dataset", node_id=str(1))
mutation = """
Expand Down Expand Up @@ -239,7 +240,8 @@ async def test_add_span_to_dataset(
}
}
],
}
},
"annotations": {},
},
"output": {
"messages": [
Expand Down Expand Up @@ -281,6 +283,7 @@ async def test_add_span_to_dataset(
}
]
},
"annotations": {},
},
}
}
Expand All @@ -299,6 +302,15 @@ async def test_add_span_to_dataset(
"value": "chain-span-output-value",
"mime_type": "text/plain",
},
"annotations": {
"test annotation": {
"label": "ambiguous",
"score": 0.5,
"explanation": "meaningful words",
"metadata": {},
"annotator_kind": "HUMAN",
}
},
},
}
}
Expand Down Expand Up @@ -704,6 +716,21 @@ async def spans(db: DbSessionFactory) -> None:
)


@pytest.fixture
async def span_annotation(db):
async with db() as session:
span_annotation = models.SpanAnnotation(
span_rowid=1,
name="test annotation",
annotator_kind="HUMAN",
label="ambiguous",
score=0.5,
explanation="meaningful words",
)
session.add(span_annotation)
await session.flush()


@pytest.fixture
async def dataset_with_a_single_version(
db: DbSessionFactory,
Expand Down

0 comments on commit a16dd57

Please sign in to comment.