From f3fde5093bf5c0f7de41c29b30c1b61d57c6ce48 Mon Sep 17 00:00:00 2001 From: Xander Song Date: Wed, 10 Apr 2024 14:15:51 -0700 Subject: [PATCH] feat: ingest document evals (#2847) --- src/phoenix/db/bulk_inserter.py | 76 +++++++++++++++++++++++++-------- 1 file changed, 58 insertions(+), 18 deletions(-) diff --git a/src/phoenix/db/bulk_inserter.py b/src/phoenix/db/bulk_inserter.py index b27b2ce992..1c161db180 100644 --- a/src/phoenix/db/bulk_inserter.py +++ b/src/phoenix/db/bulk_inserter.py @@ -2,7 +2,17 @@ import logging from itertools import islice from time import time -from typing import Any, AsyncContextManager, Callable, Iterable, List, Optional, Tuple, cast +from typing import ( + Any, + AsyncContextManager, + Callable, + Iterable, + List, + Optional, + Tuple, + assert_never, + cast, +) from openinference.semconv.trace import SpanAttributes from sqlalchemy import func, insert, select, update @@ -104,25 +114,55 @@ async def _insert_evaluations(self) -> None: async def _insert_evaluation(session: AsyncSession, evaluation: pb.Evaluation) -> None: - if not ( - span_rowid := await session.scalar( - select(models.Span.id).where(models.Span.span_id == evaluation.subject_id.span_id) - ) - ): + if (evaluation_kind := evaluation.subject_id.WhichOneof("kind")) is None: return - await session.scalar( - insert(models.SpanAnnotation) - .values( - span_rowid=span_rowid, - name=evaluation.name, - label=evaluation.result.label.value, - score=evaluation.result.score.value, - explanation=evaluation.result.explanation.value, - metadata_={}, - annotator_kind="LLM", + elif evaluation_kind == "trace_id": + raise NotImplementedError() + elif evaluation_kind == "span_id": + if not ( + span_rowid := await session.scalar( + select(models.Span.id).where(models.Span.span_id == evaluation.subject_id.span_id) + ) + ): + return + await session.scalar( + insert(models.SpanAnnotation) + .values( + span_rowid=span_rowid, + name=evaluation.name, + label=evaluation.result.label.value, + score=evaluation.result.score.value, + explanation=evaluation.result.explanation.value, + metadata_={}, + annotator_kind="LLM", + ) + .returning(models.SpanAnnotation.id) ) - .returning(models.SpanAnnotation.id) - ) + elif evaluation_kind == "document_retrieval_id": + if not ( + span_rowid := await session.scalar( + select(models.Span.id).where( + models.Span.span_id == evaluation.subject_id.document_retrieval_id.span_id + ) + ) + ): + return + await session.scalar( + insert(models.DocumentAnnotation) + .values( + span_rowid=span_rowid, + document_index=evaluation.subject_id.document_retrieval_id.document_position, + name=evaluation.name, + label=evaluation.result.label.value, + score=evaluation.result.score.value, + explanation=evaluation.result.explanation.value, + metadata_={}, + annotator_kind="LLM", + ) + .returning(models.DocumentAnnotation.id) + ) + else: + assert_never(evaluation_kind) async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> None: