diff --git a/src/phoenix/core/project.py b/src/phoenix/core/project.py index de99c66f86..4d554e594b 100644 --- a/src/phoenix/core/project.py +++ b/src/phoenix/core/project.py @@ -202,9 +202,6 @@ def get_span_evaluation_span_ids(self, name: EvaluationName) -> Tuple[SpanID, .. def get_evaluations_by_span_id(self, span_id: SpanID) -> List[pb.Evaluation]: return self._evals.get_evaluations_by_span_id(span_id) - def get_document_evaluation_span_ids(self, name: EvaluationName) -> Tuple[SpanID, ...]: - return self._evals.get_document_evaluation_span_ids(name) - def get_document_evaluations_by_span_id(self, span_id: SpanID) -> List[pb.Evaluation]: return self._evals.get_document_evaluations_by_span_id(span_id) @@ -671,11 +668,6 @@ def get_evaluations_by_span_id(self, span_id: SpanID) -> List[pb.Evaluation]: evaluations = self._evaluations_by_span_id.get(span_id) return list(evaluations.values()) if evaluations else [] - def get_document_evaluation_span_ids(self, name: EvaluationName) -> Tuple[SpanID, ...]: - with self._lock: - document_evaluations = self._document_evaluations_by_name.get(name) - return tuple(document_evaluations.keys()) if document_evaluations else () - def get_document_evaluations_by_span_id(self, span_id: SpanID) -> List[pb.Evaluation]: all_evaluations: List[pb.Evaluation] = [] with self._lock: diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index ba06240934..8044fe53e8 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -168,6 +168,7 @@ class Span(Base): cumulative_llm_token_count_completion: Mapped[int] trace: Mapped["Trace"] = relationship("Trace", back_populates="spans") + document_annotations: Mapped[List["DocumentAnnotation"]] = relationship(back_populates="span") __table_args__ = ( UniqueConstraint( @@ -267,6 +268,8 @@ class DocumentAnnotation(Base): updated_at: Mapped[datetime] = mapped_column( UtcTimeStamp, server_default=func.now(), onupdate=func.now() ) + span: Mapped["Span"] = relationship(back_populates="document_annotations") + __table_args__ = ( UniqueConstraint( "span_rowid", diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index c2be1090dd..bd2b76f558 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -1,10 +1,11 @@ from datetime import datetime from typing import List, Optional +import numpy as np import strawberry from openinference.semconv.trace import SpanAttributes from sqlalchemy import and_, func, select -from sqlalchemy.orm import contains_eager +from sqlalchemy.orm import contains_eager, selectinload from sqlalchemy.sql.functions import coalesce from strawberry import ID, UNSET from strawberry.types import Info @@ -293,37 +294,43 @@ def span_evaluation_summary( return EvaluationSummary(evaluations, labels) @strawberry.field - def document_evaluation_summary( + async def document_evaluation_summary( self, + info: Info[Context, None], evaluation_name: str, time_range: Optional[TimeRange] = UNSET, filter_condition: Optional[str] = UNSET, ) -> Optional[DocumentEvaluationSummary]: - project = self.project - predicate = ( - SpanFilter(condition=filter_condition, evals=project) if filter_condition else None - ) - span_ids = project.get_document_evaluation_span_ids(evaluation_name) - if not span_ids: - return None - spans = project.get_spans( - start_time=time_range.start if time_range else None, - stop_time=time_range.end if time_range else None, - span_ids=span_ids, + stmt = ( + select(models.Span) + .join(models.Trace) + .where( + models.Trace.project_rowid == self.id_attr, + ) + .options(selectinload(models.Span.document_annotations)) + .options(contains_eager(models.Span.trace)) ) - if predicate: - spans = filter(predicate, spans) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Span.start_time, + models.Span.start_time < time_range.end, + ) + ) + # todo: add filter_condition + async with info.context.db() as session: + sql_spans = await session.scalars(stmt) metrics_collection = [] - for span in spans: - span_id = span.context.span_id - num_documents = project.get_num_documents(span_id) - if not num_documents: + for sql_span in sql_spans: + span = to_gql_span(sql_span, self.project) + if not (num_documents := span.num_documents): continue - evaluation_scores = project.get_document_evaluation_scores( - span_id=span_id, - evaluation_name=evaluation_name, - num_documents=num_documents, - ) + evaluation_scores: List[float] = [np.nan] * num_documents + for annotation in sql_span.document_annotations: + if (score := annotation.score) is not None and ( + document_position := annotation.document_index + ) < num_documents: + evaluation_scores[document_position] = score metrics_collection.append(RetrievalMetrics(evaluation_scores)) if not metrics_collection: return None