Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add annotation summaries to projects #4108

Merged
merged 5 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ interface Annotation {
explanation: String
}

type AnnotationSummary {
count: Int!
labels: [String!]!
labelFractions: [LabelFraction!]!
meanScore: Float
scoreCount: Int!
labelCount: Int!
}

enum AnnotatorKind {
LLM
HUMAN
Expand Down Expand Up @@ -992,6 +1001,8 @@ type Project implements Node {
documentEvaluationNames(spanId: ID): [String!]!
traceEvaluationSummary(evaluationName: String!, timeRange: TimeRange): EvaluationSummary
spanEvaluationSummary(evaluationName: String!, timeRange: TimeRange, filterCondition: String): EvaluationSummary
traceAnnotationSummary(evaluationName: String!, timeRange: TimeRange): AnnotationSummary
spanAnnotationSummary(evaluationName: String!, timeRange: TimeRange, filterCondition: String): AnnotationSummary
documentEvaluationSummary(evaluationName: String!, timeRange: TimeRange, filterCondition: String): DocumentEvaluationSummary
streamingLastUpdatedAt: DateTime
validateSpanFilterCondition(condition: String!): ValidationResult!
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from phoenix.core.model_schema import Model
from phoenix.server.api.dataloaders import (
AnnotationSummaryDataLoader,
AverageExperimentRunLatencyDataLoader,
CacheForDataLoaders,
DatasetExampleRevisionsDataLoader,
Expand Down Expand Up @@ -44,6 +45,7 @@ class DataLoaders:
document_evaluation_summaries: DocumentEvaluationSummaryDataLoader
document_evaluations: DocumentEvaluationsDataLoader
document_retrieval_metrics: DocumentRetrievalMetricsDataLoader
annotation_summaries: AnnotationSummaryDataLoader
evaluation_summaries: EvaluationSummaryDataLoader
experiment_annotation_summaries: ExperimentAnnotationSummaryDataLoader
experiment_error_rates: ExperimentErrorRatesDataLoader
Expand Down
8 changes: 8 additions & 0 deletions src/phoenix/server/api/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from phoenix.db.insertion.span import ClearProjectSpansEvent, SpanInsertionEvent

from .annotation_summaries import AnnotationSummaryCache, AnnotationSummaryDataLoader
from .average_experiment_run_latency import AverageExperimentRunLatencyDataLoader
from .dataset_example_revisions import DatasetExampleRevisionsDataLoader
from .dataset_example_spans import DatasetExampleSpansDataLoader
Expand Down Expand Up @@ -43,6 +44,7 @@
"DocumentEvaluationSummaryDataLoader",
"DocumentEvaluationsDataLoader",
"DocumentRetrievalMetricsDataLoader",
"AnnotationSummaryDataLoader",
"EvaluationSummaryDataLoader",
"ExperimentAnnotationSummaryDataLoader",
"ExperimentErrorRatesDataLoader",
Expand All @@ -68,6 +70,9 @@ class CacheForDataLoaders:
document_evaluation_summary: DocumentEvaluationSummaryCache = field(
default_factory=DocumentEvaluationSummaryCache,
)
annotation_summary: AnnotationSummaryCache = field(
default_factory=AnnotationSummaryCache,
)
evaluation_summary: EvaluationSummaryCache = field(
default_factory=EvaluationSummaryCache,
)
Expand All @@ -92,6 +97,7 @@ def _update_spans(self, project_rowid: int) -> None:

def _clear_spans(self, project_rowid: int) -> None:
self._update_spans(project_rowid)
self.annotation_summary.invalidate_project(project_rowid)
self.evaluation_summary.invalidate_project(project_rowid)
self.document_evaluation_summary.invalidate_project(project_rowid)

Expand All @@ -113,9 +119,11 @@ def _(self, event: DocumentEvaluationInsertionEvent) -> None:
@invalidate.register
def _(self, event: SpanEvaluationInsertionEvent) -> None:
project_rowid, evaluation_name = event
self.annotation_summary.invalidate((project_rowid, evaluation_name, "span"))
self.evaluation_summary.invalidate((project_rowid, evaluation_name, "span"))

@invalidate.register
def _(self, event: TraceEvaluationInsertionEvent) -> None:
project_rowid, evaluation_name = event
self.annotation_summary.invalidate((project_rowid, evaluation_name, "trace"))
self.evaluation_summary.invalidate((project_rowid, evaluation_name, "trace"))
149 changes: 149 additions & 0 deletions src/phoenix/server/api/dataloaders/annotation_summaries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from collections import defaultdict
from datetime import datetime
from typing import (
Any,
DefaultDict,
List,
Literal,
Optional,
Tuple,
)

import pandas as pd
from aioitertools.itertools import groupby
from cachetools import LFUCache, TTLCache
from sqlalchemy import Select, func, or_, select
from strawberry.dataloader import AbstractCache, DataLoader
from typing_extensions import TypeAlias, assert_never

from phoenix.db import models
from phoenix.server.api.dataloaders.cache import TwoTierCache
from phoenix.server.api.input_types.TimeRange import TimeRange
from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
from phoenix.server.types import DbSessionFactory
from phoenix.trace.dsl import SpanFilter

Kind: TypeAlias = Literal["span", "trace"]
ProjectRowId: TypeAlias = int
TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]]
FilterCondition: TypeAlias = Optional[str]
EvalName: TypeAlias = str

Segment: TypeAlias = Tuple[Kind, ProjectRowId, TimeInterval, FilterCondition]
Param: TypeAlias = EvalName

Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition, EvalName]
Result: TypeAlias = Optional[AnnotationSummary]
ResultPosition: TypeAlias = int
DEFAULT_VALUE: Result = None


def _cache_key_fn(key: Key) -> Tuple[Segment, Param]:
kind, project_rowid, time_range, filter_condition, eval_name = key
interval = (
(time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None)
)
return (kind, project_rowid, interval, filter_condition), eval_name


_Section: TypeAlias = Tuple[ProjectRowId, EvalName, Kind]
_SubKey: TypeAlias = Tuple[TimeInterval, FilterCondition]


class AnnotationSummaryCache(
TwoTierCache[Key, Result, _Section, _SubKey],
):
def __init__(self) -> None:
super().__init__(
# TTL=3600 (1-hour) because time intervals are always moving forward, but
# interval endpoints are rounded down to the hour by the UI, so anything
# older than an hour most likely won't be a cache-hit anyway.
main_cache=TTLCache(maxsize=64 * 32 * 2, ttl=3600),
sub_cache_factory=lambda: LFUCache(maxsize=2 * 2),
)

def invalidate_project(self, project_rowid: ProjectRowId) -> None:
for section in self._cache.keys():
if section[0] == project_rowid:
del self._cache[section]

def _cache_key(self, key: Key) -> Tuple[_Section, _SubKey]:
(kind, project_rowid, interval, filter_condition), eval_name = _cache_key_fn(key)
return (project_rowid, eval_name, kind), (interval, filter_condition)


class AnnotationSummaryDataLoader(DataLoader[Key, Result]):
def __init__(
self,
db: DbSessionFactory,
cache_map: Optional[AbstractCache[Key, Result]] = None,
) -> None:
super().__init__(
load_fn=self._load_fn,
cache_key_fn=_cache_key_fn,
cache_map=cache_map,
)
self._db = db

async def _load_fn(self, keys: List[Key]) -> List[Result]:
results: List[Result] = [DEFAULT_VALUE] * len(keys)
arguments: DefaultDict[
Segment,
DefaultDict[Param, List[ResultPosition]],
] = defaultdict(lambda: defaultdict(list))
for position, key in enumerate(keys):
segment, param = _cache_key_fn(key)
arguments[segment][param].append(position)
for segment, params in arguments.items():
stmt = _get_stmt(segment, *params.keys())
async with self._db() as session:
data = await session.stream(stmt)
async for eval_name, group in groupby(data, lambda row: row.name):
summary = AnnotationSummary(pd.DataFrame(group))
for position in params[eval_name]:
results[position] = summary
return results


def _get_stmt(
segment: Segment,
*eval_names: Param,
) -> Select[Any]:
kind, project_rowid, (start_time, end_time), filter_condition = segment
stmt = select()
if kind == "span":
msa = models.SpanAnnotation
name_column, label_column, score_column = msa.name, msa.label, msa.score
annotator_kind_column = msa.annotator_kind
time_column = models.Span.start_time
stmt = stmt.join(models.Span).join_from(models.Span, models.Trace)
if filter_condition:
sf = SpanFilter(filter_condition)
stmt = sf(stmt)
elif kind == "trace":
mta = models.TraceAnnotation
name_column, label_column, score_column = mta.name, mta.label, mta.score
annotator_kind_column = mta.annotator_kind
time_column = models.Trace.start_time
stmt = stmt.join(models.Trace)
else:
assert_never(kind)
stmt = stmt.add_columns(
name_column,
label_column,
func.count().label("record_count"),
func.count(label_column).label("label_count"),
func.count(score_column).label("score_count"),
func.sum(score_column).label("score_sum"),
)
stmt = stmt.group_by(name_column, label_column)
stmt = stmt.order_by(name_column, label_column)
stmt = stmt.where(models.Trace.project_rowid == project_rowid)
stmt = stmt.where(annotator_kind_column == "LLM")
anticorrelator marked this conversation as resolved.
Show resolved Hide resolved
stmt = stmt.where(or_(score_column.is_not(None), label_column.is_not(None)))
stmt = stmt.where(name_column.in_(eval_names))
if start_time:
stmt = stmt.where(start_time <= time_column)
if end_time:
stmt = stmt.where(time_column < end_time)
return stmt
55 changes: 55 additions & 0 deletions src/phoenix/server/api/types/AnnotationSummary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import List, Optional, Union, cast

import pandas as pd
import strawberry
from strawberry import Private

from phoenix.db import models
from phoenix.server.api.types.LabelFraction import LabelFraction

AnnotationType = Union[models.SpanAnnotation, models.TraceAnnotation]


@strawberry.type
class AnnotationSummary:
df: Private[pd.DataFrame]

def __init__(self, dataframe: pd.DataFrame) -> None:
self.df = dataframe

@strawberry.field
def count(self) -> int:
return cast(int, self.df.record_count.sum())

@strawberry.field
def labels(self) -> List[str]:
return self.df.label.dropna().tolist()

@strawberry.field
def label_fractions(self) -> List[LabelFraction]:
if not (n := self.df.label_count.sum()):
return []
return [
LabelFraction(
label=cast(str, row.label),
fraction=row.label_count / n,
)
for row in self.df.loc[
self.df.label.notna(),
["label", "label_count"],
].itertuples()
]

@strawberry.field
def mean_score(self) -> Optional[float]:
if not (n := self.df.score_count.sum()):
return None
return cast(float, self.df.score_sum.sum() / n)

@strawberry.field
def score_count(self) -> int:
return cast(int, self.df.score_count.sum())

@strawberry.field
def label_count(self) -> int:
return cast(int, self.df.label_count.sum())
7 changes: 1 addition & 6 deletions src/phoenix/server/api/types/EvaluationSummary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,11 @@
from strawberry import Private

from phoenix.db import models
from phoenix.server.api.types.LabelFraction import LabelFraction

AnnotationType = Union[models.SpanAnnotation, models.TraceAnnotation]


@strawberry.type
class LabelFraction:
label: str
fraction: float


@strawberry.type
class EvaluationSummary:
df: Private[pd.DataFrame]
Expand Down
7 changes: 7 additions & 0 deletions src/phoenix/server/api/types/LabelFraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import strawberry


@strawberry.type
class LabelFraction:
label: str
fraction: float
24 changes: 24 additions & 0 deletions src/phoenix/server/api/types/Project.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from phoenix.server.api.context import Context
from phoenix.server.api.input_types.SpanSort import SpanSort, SpanSortConfig
from phoenix.server.api.input_types.TimeRange import TimeRange
from phoenix.server.api.types.AnnotationSummary import AnnotationSummary
from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary
from phoenix.server.api.types.EvaluationSummary import EvaluationSummary
from phoenix.server.api.types.pagination import (
Expand Down Expand Up @@ -356,6 +357,29 @@ async def span_evaluation_summary(
("span", self.id_attr, time_range, filter_condition, evaluation_name),
)

@strawberry.field
async def trace_annotation_summary(
self,
info: Info[Context, None],
evaluation_name: str,
time_range: Optional[TimeRange] = UNSET,
) -> Optional[AnnotationSummary]:
return await info.context.data_loaders.annotation_summaries.load(
("trace", self.id_attr, time_range, None, evaluation_name),
)

@strawberry.field
async def span_annotation_summary(
self,
info: Info[Context, None],
evaluation_name: str,
time_range: Optional[TimeRange] = UNSET,
filter_condition: Optional[str] = UNSET,
) -> Optional[AnnotationSummary]:
return await info.context.data_loaders.annotation_summaries.load(
("span", self.id_attr, time_range, filter_condition, evaluation_name),
)

@strawberry.field
async def document_evaluation_summary(
self,
Expand Down
7 changes: 7 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from phoenix.pointcloud.umap_parameters import UMAPParameters
from phoenix.server.api.context import Context, DataLoaders
from phoenix.server.api.dataloaders import (
AnnotationSummaryDataLoader,
AverageExperimentRunLatencyDataLoader,
CacheForDataLoaders,
DatasetExampleRevisionsDataLoader,
Expand Down Expand Up @@ -286,6 +287,12 @@ def get_context() -> Context:
),
document_evaluations=DocumentEvaluationsDataLoader(db),
document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(db),
annotation_summaries=AnnotationSummaryDataLoader(
db,
cache_map=cache_for_dataloaders.annotation_summary
if cache_for_dataloaders
else None,
),
evaluation_summaries=EvaluationSummaryDataLoader(
db,
cache_map=cache_for_dataloaders.evaluation_summary
Expand Down
Loading
Loading