diff --git a/src/phoenix/server/api/mutations/dataset_mutations.py b/src/phoenix/server/api/mutations/dataset_mutations.py index 860ea5fabd..3fb9db02d6 100644 --- a/src/phoenix/server/api/mutations/dataset_mutations.py +++ b/src/phoenix/server/api/mutations/dataset_mutations.py @@ -154,6 +154,36 @@ async def add_spans_to_dataset( raise ValueError( f"Could not find spans with rowids: {', '.join(map(str, missing_span_rowids))}" ) # todo: implement error handling types https://github.com/Arize-ai/phoenix/issues/3221 + + span_annotations = ( + await session.execute( + select( + models.SpanAnnotation.span_rowid, + models.SpanAnnotation.name, + models.SpanAnnotation.label, + models.SpanAnnotation.score, + models.SpanAnnotation.explanation, + models.SpanAnnotation.metadata_, + models.SpanAnnotation.annotator_kind, + ) + .select_from(models.SpanAnnotation) + .where(models.SpanAnnotation.span_rowid.in_(span_rowids)) + ) + ).all() + + span_annotations_by_span: Dict[int, Dict[Any, Any]] = {span.id: {} for span in spans} + for annotation in span_annotations: + span_id = annotation.span_rowid + if span_id not in span_annotations_by_span: + span_annotations_by_span[span_id] = dict() + span_annotations_by_span[span_id][annotation.name] = { + "label": annotation.label, + "score": annotation.score, + "explanation": annotation.explanation, + "metadata": annotation.metadata_, + "annotator_kind": annotation.annotator_kind, + } + DatasetExample = models.DatasetExample dataset_example_rowids = ( await session.scalars( @@ -170,6 +200,7 @@ async def add_spans_to_dataset( assert len(dataset_example_rowids) == len(spans) assert all(map(lambda id: isinstance(id, int), dataset_example_rowids)) DatasetExampleRevision = models.DatasetExampleRevision + await session.execute( insert(DatasetExampleRevision), [ @@ -178,7 +209,10 @@ async def add_spans_to_dataset( DatasetExampleRevision.dataset_version_id.key: dataset_version_rowid, DatasetExampleRevision.input.key: get_dataset_example_input(span), DatasetExampleRevision.output.key: get_dataset_example_output(span), - DatasetExampleRevision.metadata_.key: span.attributes, + DatasetExampleRevision.metadata_.key: { + **span.attributes, + "annotations": span_annotations_by_span[span.id], + }, DatasetExampleRevision.revision_kind.key: "CREATE", } for dataset_example_rowid, span in zip(dataset_example_rowids, spans) diff --git a/tests/server/api/mutations/test_dataset_mutations.py b/tests/server/api/mutations/test_dataset_mutations.py index 850618d4c7..865392fbd6 100644 --- a/tests/server/api/mutations/test_dataset_mutations.py +++ b/tests/server/api/mutations/test_dataset_mutations.py @@ -168,6 +168,7 @@ async def test_add_span_to_dataset( httpx_client: httpx.AsyncClient, empty_dataset, spans, + span_annotation, ) -> None: dataset_id = GlobalID(type_name="Dataset", node_id=str(1)) mutation = """ @@ -239,7 +240,8 @@ async def test_add_span_to_dataset( } } ], - } + }, + "annotations": {}, }, "output": { "messages": [ @@ -281,6 +283,7 @@ async def test_add_span_to_dataset( } ] }, + "annotations": {}, }, } } @@ -299,6 +302,15 @@ async def test_add_span_to_dataset( "value": "chain-span-output-value", "mime_type": "text/plain", }, + "annotations": { + "test annotation": { + "label": "ambiguous", + "score": 0.5, + "explanation": "meaningful words", + "metadata": {}, + "annotator_kind": "HUMAN", + } + }, }, } } @@ -704,6 +716,21 @@ async def spans(db: DbSessionFactory) -> None: ) +@pytest.fixture +async def span_annotation(db): + async with db() as session: + span_annotation = models.SpanAnnotation( + span_rowid=1, + name="test annotation", + annotator_kind="HUMAN", + label="ambiguous", + score=0.5, + explanation="meaningful words", + ) + session.add(span_annotation) + await session.flush() + + @pytest.fixture async def dataset_with_a_single_version( db: DbSessionFactory,