diff --git a/app/schema.graphql b/app/schema.graphql index d09e7f9ea4..5a0a35e08b 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -30,6 +30,15 @@ interface Annotation { explanation: String } +type AnnotationSummary { + count: Int! + labels: [String!]! + labelFractions: [LabelFraction!]! + meanScore: Float + scoreCount: Int! + labelCount: Int! +} + enum AnnotatorKind { LLM HUMAN @@ -992,6 +1001,8 @@ type Project implements Node { documentEvaluationNames(spanId: ID): [String!]! traceEvaluationSummary(evaluationName: String!, timeRange: TimeRange): EvaluationSummary spanEvaluationSummary(evaluationName: String!, timeRange: TimeRange, filterCondition: String): EvaluationSummary + traceAnnotationSummary(evaluationName: String!, timeRange: TimeRange): AnnotationSummary + spanAnnotationSummary(evaluationName: String!, timeRange: TimeRange, filterCondition: String): AnnotationSummary documentEvaluationSummary(evaluationName: String!, timeRange: TimeRange, filterCondition: String): DocumentEvaluationSummary streamingLastUpdatedAt: DateTime validateSpanFilterCondition(condition: String!): ValidationResult! diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index aa70f6c27b..7ca40db2fb 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -8,6 +8,7 @@ from phoenix.core.model_schema import Model from phoenix.server.api.dataloaders import ( + AnnotationSummaryDataLoader, AverageExperimentRunLatencyDataLoader, CacheForDataLoaders, DatasetExampleRevisionsDataLoader, @@ -44,6 +45,7 @@ class DataLoaders: document_evaluation_summaries: DocumentEvaluationSummaryDataLoader document_evaluations: DocumentEvaluationsDataLoader document_retrieval_metrics: DocumentRetrievalMetricsDataLoader + annotation_summaries: AnnotationSummaryDataLoader evaluation_summaries: EvaluationSummaryDataLoader experiment_annotation_summaries: ExperimentAnnotationSummaryDataLoader experiment_error_rates: ExperimentErrorRatesDataLoader diff --git a/src/phoenix/server/api/dataloaders/__init__.py b/src/phoenix/server/api/dataloaders/__init__.py index 26b90d38c2..edd35d9cb5 100644 --- a/src/phoenix/server/api/dataloaders/__init__.py +++ b/src/phoenix/server/api/dataloaders/__init__.py @@ -8,6 +8,7 @@ ) from phoenix.db.insertion.span import ClearProjectSpansEvent, SpanInsertionEvent +from .annotation_summaries import AnnotationSummaryCache, AnnotationSummaryDataLoader from .average_experiment_run_latency import AverageExperimentRunLatencyDataLoader from .dataset_example_revisions import DatasetExampleRevisionsDataLoader from .dataset_example_spans import DatasetExampleSpansDataLoader @@ -43,6 +44,7 @@ "DocumentEvaluationSummaryDataLoader", "DocumentEvaluationsDataLoader", "DocumentRetrievalMetricsDataLoader", + "AnnotationSummaryDataLoader", "EvaluationSummaryDataLoader", "ExperimentAnnotationSummaryDataLoader", "ExperimentErrorRatesDataLoader", @@ -68,6 +70,9 @@ class CacheForDataLoaders: document_evaluation_summary: DocumentEvaluationSummaryCache = field( default_factory=DocumentEvaluationSummaryCache, ) + annotation_summary: AnnotationSummaryCache = field( + default_factory=AnnotationSummaryCache, + ) evaluation_summary: EvaluationSummaryCache = field( default_factory=EvaluationSummaryCache, ) @@ -92,6 +97,7 @@ def _update_spans(self, project_rowid: int) -> None: def _clear_spans(self, project_rowid: int) -> None: self._update_spans(project_rowid) + self.annotation_summary.invalidate_project(project_rowid) self.evaluation_summary.invalidate_project(project_rowid) self.document_evaluation_summary.invalidate_project(project_rowid) @@ -113,9 +119,11 @@ def _(self, event: DocumentEvaluationInsertionEvent) -> None: @invalidate.register def _(self, event: SpanEvaluationInsertionEvent) -> None: project_rowid, evaluation_name = event + self.annotation_summary.invalidate((project_rowid, evaluation_name, "span")) self.evaluation_summary.invalidate((project_rowid, evaluation_name, "span")) @invalidate.register def _(self, event: TraceEvaluationInsertionEvent) -> None: project_rowid, evaluation_name = event + self.annotation_summary.invalidate((project_rowid, evaluation_name, "trace")) self.evaluation_summary.invalidate((project_rowid, evaluation_name, "trace")) diff --git a/src/phoenix/server/api/dataloaders/annotation_summaries.py b/src/phoenix/server/api/dataloaders/annotation_summaries.py new file mode 100644 index 0000000000..0f17d437d8 --- /dev/null +++ b/src/phoenix/server/api/dataloaders/annotation_summaries.py @@ -0,0 +1,146 @@ +from collections import defaultdict +from datetime import datetime +from typing import ( + Any, + DefaultDict, + List, + Literal, + Optional, + Tuple, +) + +import pandas as pd +from aioitertools.itertools import groupby +from cachetools import LFUCache, TTLCache +from sqlalchemy import Select, func, or_, select +from strawberry.dataloader import AbstractCache, DataLoader +from typing_extensions import TypeAlias, assert_never + +from phoenix.db import models +from phoenix.server.api.dataloaders.cache import TwoTierCache +from phoenix.server.api.input_types.TimeRange import TimeRange +from phoenix.server.api.types.AnnotationSummary import AnnotationSummary +from phoenix.server.types import DbSessionFactory +from phoenix.trace.dsl import SpanFilter + +Kind: TypeAlias = Literal["span", "trace"] +ProjectRowId: TypeAlias = int +TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]] +FilterCondition: TypeAlias = Optional[str] +EvalName: TypeAlias = str + +Segment: TypeAlias = Tuple[Kind, ProjectRowId, TimeInterval, FilterCondition] +Param: TypeAlias = EvalName + +Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition, EvalName] +Result: TypeAlias = Optional[AnnotationSummary] +ResultPosition: TypeAlias = int +DEFAULT_VALUE: Result = None + + +def _cache_key_fn(key: Key) -> Tuple[Segment, Param]: + kind, project_rowid, time_range, filter_condition, eval_name = key + interval = ( + (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None) + ) + return (kind, project_rowid, interval, filter_condition), eval_name + + +_Section: TypeAlias = Tuple[ProjectRowId, EvalName, Kind] +_SubKey: TypeAlias = Tuple[TimeInterval, FilterCondition] + + +class AnnotationSummaryCache( + TwoTierCache[Key, Result, _Section, _SubKey], +): + def __init__(self) -> None: + super().__init__( + # TTL=3600 (1-hour) because time intervals are always moving forward, but + # interval endpoints are rounded down to the hour by the UI, so anything + # older than an hour most likely won't be a cache-hit anyway. + main_cache=TTLCache(maxsize=64 * 32 * 2, ttl=3600), + sub_cache_factory=lambda: LFUCache(maxsize=2 * 2), + ) + + def invalidate_project(self, project_rowid: ProjectRowId) -> None: + for section in self._cache.keys(): + if section[0] == project_rowid: + del self._cache[section] + + def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]: + (kind, project_rowid, interval, filter_condition), eval_name = _cache_key_fn(key) + return (project_rowid, eval_name, kind), (interval, filter_condition) + + +class AnnotationSummaryDataLoader(DataLoader[Key, Result]): + def __init__( + self, + db: DbSessionFactory, + cache_map: Optional[AbstractCache[Key, Result]] = None, + ) -> None: + super().__init__( + load_fn=self._load_fn, + cache_key_fn=_cache_key_fn, + cache_map=cache_map, + ) + self._db = db + + async def _load_fn(self, keys: List[Key]) -> List[Result]: + results: List[Result] = [DEFAULT_VALUE] * len(keys) + arguments: DefaultDict[ + Segment, + DefaultDict[Param, List[ResultPosition]], + ] = defaultdict(lambda: defaultdict(list)) + for position, key in enumerate(keys): + segment, param = _cache_key_fn(key) + arguments[segment][param].append(position) + for segment, params in arguments.items(): + stmt = _get_stmt(segment, *params.keys()) + async with self._db() as session: + data = await session.stream(stmt) + async for eval_name, group in groupby(data, lambda row: row.name): + summary = AnnotationSummary(pd.DataFrame(group)) + for position in params[eval_name]: + results[position] = summary + return results + + +def _get_stmt( + segment: Segment, + *eval_names: Param, +) -> Select[Any]: + kind, project_rowid, (start_time, end_time), filter_condition = segment + stmt = select() + if kind == "span": + msa = models.SpanAnnotation + name_column, label_column, score_column = msa.name, msa.label, msa.score + time_column = models.Span.start_time + stmt = stmt.join(models.Span).join_from(models.Span, models.Trace) + if filter_condition: + sf = SpanFilter(filter_condition) + stmt = sf(stmt) + elif kind == "trace": + mta = models.TraceAnnotation + name_column, label_column, score_column = mta.name, mta.label, mta.score + time_column = models.Trace.start_time + stmt = stmt.join(models.Trace) + else: + assert_never(kind) + stmt = stmt.add_columns( + name_column, + label_column, + func.count().label("record_count"), + func.count(label_column).label("label_count"), + func.count(score_column).label("score_count"), + func.sum(score_column).label("score_sum"), + ) + stmt = stmt.group_by(name_column, label_column) + stmt = stmt.order_by(name_column, label_column) + stmt = stmt.where(models.Trace.project_rowid == project_rowid) + stmt = stmt.where(or_(score_column.is_not(None), label_column.is_not(None))) + stmt = stmt.where(name_column.in_(eval_names)) + if start_time: + stmt = stmt.where(start_time <= time_column) + if end_time: + stmt = stmt.where(time_column < end_time) + return stmt diff --git a/src/phoenix/server/api/types/AnnotationSummary.py b/src/phoenix/server/api/types/AnnotationSummary.py new file mode 100644 index 0000000000..d52c6556bd --- /dev/null +++ b/src/phoenix/server/api/types/AnnotationSummary.py @@ -0,0 +1,55 @@ +from typing import List, Optional, Union, cast + +import pandas as pd +import strawberry +from strawberry import Private + +from phoenix.db import models +from phoenix.server.api.types.LabelFraction import LabelFraction + +AnnotationType = Union[models.SpanAnnotation, models.TraceAnnotation] + + +@strawberry.type +class AnnotationSummary: + df: Private[pd.DataFrame] + + def __init__(self, dataframe: pd.DataFrame) -> None: + self.df = dataframe + + @strawberry.field + def count(self) -> int: + return cast(int, self.df.record_count.sum()) + + @strawberry.field + def labels(self) -> List[str]: + return self.df.label.dropna().tolist() + + @strawberry.field + def label_fractions(self) -> List[LabelFraction]: + if not (n := self.df.label_count.sum()): + return [] + return [ + LabelFraction( + label=cast(str, row.label), + fraction=row.label_count / n, + ) + for row in self.df.loc[ + self.df.label.notna(), + ["label", "label_count"], + ].itertuples() + ] + + @strawberry.field + def mean_score(self) -> Optional[float]: + if not (n := self.df.score_count.sum()): + return None + return cast(float, self.df.score_sum.sum() / n) + + @strawberry.field + def score_count(self) -> int: + return cast(int, self.df.score_count.sum()) + + @strawberry.field + def label_count(self) -> int: + return cast(int, self.df.label_count.sum()) diff --git a/src/phoenix/server/api/types/EvaluationSummary.py b/src/phoenix/server/api/types/EvaluationSummary.py index 664d0f2094..3addb3b13b 100644 --- a/src/phoenix/server/api/types/EvaluationSummary.py +++ b/src/phoenix/server/api/types/EvaluationSummary.py @@ -5,16 +5,11 @@ from strawberry import Private from phoenix.db import models +from phoenix.server.api.types.LabelFraction import LabelFraction AnnotationType = Union[models.SpanAnnotation, models.TraceAnnotation] -@strawberry.type -class LabelFraction: - label: str - fraction: float - - @strawberry.type class EvaluationSummary: df: Private[pd.DataFrame] diff --git a/src/phoenix/server/api/types/LabelFraction.py b/src/phoenix/server/api/types/LabelFraction.py new file mode 100644 index 0000000000..5d67b099e7 --- /dev/null +++ b/src/phoenix/server/api/types/LabelFraction.py @@ -0,0 +1,7 @@ +import strawberry + + +@strawberry.type +class LabelFraction: + label: str + fraction: float diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 5e6e25be11..6b239b8933 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -20,6 +20,7 @@ from phoenix.server.api.context import Context from phoenix.server.api.input_types.SpanSort import SpanSort, SpanSortConfig from phoenix.server.api.input_types.TimeRange import TimeRange +from phoenix.server.api.types.AnnotationSummary import AnnotationSummary from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary from phoenix.server.api.types.EvaluationSummary import EvaluationSummary from phoenix.server.api.types.pagination import ( @@ -356,6 +357,29 @@ async def span_evaluation_summary( ("span", self.id_attr, time_range, filter_condition, evaluation_name), ) + @strawberry.field + async def trace_annotation_summary( + self, + info: Info[Context, None], + evaluation_name: str, + time_range: Optional[TimeRange] = UNSET, + ) -> Optional[AnnotationSummary]: + return await info.context.data_loaders.annotation_summaries.load( + ("trace", self.id_attr, time_range, None, evaluation_name), + ) + + @strawberry.field + async def span_annotation_summary( + self, + info: Info[Context, None], + evaluation_name: str, + time_range: Optional[TimeRange] = UNSET, + filter_condition: Optional[str] = UNSET, + ) -> Optional[AnnotationSummary]: + return await info.context.data_loaders.annotation_summaries.load( + ("span", self.id_attr, time_range, filter_condition, evaluation_name), + ) + @strawberry.field async def document_evaluation_summary( self, diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 2a16e9ede0..d4428fced9 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -58,6 +58,7 @@ from phoenix.pointcloud.umap_parameters import UMAPParameters from phoenix.server.api.context import Context, DataLoaders from phoenix.server.api.dataloaders import ( + AnnotationSummaryDataLoader, AverageExperimentRunLatencyDataLoader, CacheForDataLoaders, DatasetExampleRevisionsDataLoader, @@ -286,6 +287,12 @@ def get_context() -> Context: ), document_evaluations=DocumentEvaluationsDataLoader(db), document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(db), + annotation_summaries=AnnotationSummaryDataLoader( + db, + cache_map=cache_for_dataloaders.annotation_summary + if cache_for_dataloaders + else None, + ), evaluation_summaries=EvaluationSummaryDataLoader( db, cache_map=cache_for_dataloaders.evaluation_summary diff --git a/tests/server/api/dataloaders/test_annotation_summaries.py b/tests/server/api/dataloaders/test_annotation_summaries.py new file mode 100644 index 0000000000..8e2ef32f5b --- /dev/null +++ b/tests/server/api/dataloaders/test_annotation_summaries.py @@ -0,0 +1,74 @@ +from datetime import datetime + +import pandas as pd +import pytest +from phoenix.db import models +from phoenix.server.api.dataloaders import AnnotationSummaryDataLoader +from phoenix.server.api.input_types.TimeRange import TimeRange +from phoenix.server.types import DbSessionFactory +from sqlalchemy import func, select + + +async def test_evaluation_summaries( + db: DbSessionFactory, + data_for_testing_dataloaders: None, +) -> None: + start_time = datetime.fromisoformat("2021-01-01T00:00:10.000+00:00") + end_time = datetime.fromisoformat("2021-01-01T00:10:00.000+00:00") + pid = models.Trace.project_rowid + async with db() as session: + span_df = await session.run_sync( + lambda s: pd.read_sql_query( + select( + pid, + models.SpanAnnotation.name, + func.avg(models.SpanAnnotation.score).label("mean_score"), + ) + .group_by(pid, models.SpanAnnotation.name) + .order_by(pid, models.SpanAnnotation.name) + .join_from(models.Trace, models.Span) + .join_from(models.Span, models.SpanAnnotation) + .where(models.Span.name.contains("_5_")) + .where(models.SpanAnnotation.name.in_(("A", "C"))) + .where(start_time <= models.Span.start_time) + .where(models.Span.start_time < end_time), + s.connection(), + ) + ) + trace_df = await session.run_sync( + lambda s: pd.read_sql_query( + select( + pid, + models.TraceAnnotation.name, + func.avg(models.TraceAnnotation.score).label("mean_score"), + ) + .group_by(pid, models.TraceAnnotation.name) + .order_by(pid, models.TraceAnnotation.name) + .join_from(models.Trace, models.TraceAnnotation) + .where(models.TraceAnnotation.name.in_(("B", "D"))) + .where(start_time <= models.Trace.start_time) + .where(models.Trace.start_time < end_time), + s.connection(), + ) + ) + expected = trace_df.loc[:, "mean_score"].to_list() + span_df.loc[:, "mean_score"].to_list() + actual = [ + smry.mean_score() + for smry in ( + await AnnotationSummaryDataLoader(db)._load_fn( + [ + ( + kind, + id_ + 1, + TimeRange(start=start_time, end=end_time), + "'_5_' in name" if kind == "span" else None, + eval_name, + ) + for kind in ("trace", "span") + for id_ in range(10) + for eval_name in (("B", "D") if kind == "trace" else ("A", "C")) + ] + ) + ) + ] + assert actual == pytest.approx(expected, 1e-7)