diff --git a/app/schema.graphql b/app/schema.graphql index 7c1e1649872..28bc42d657f 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -284,6 +284,16 @@ type EmbeddingMetadata { linkToData: String } +enum EvalAttr { + score + label +} + +input EvalResultKey { + name: String! + attr: EvalAttr! +} + interface Evaluation { """Name of the evaluation, e.g. 'helpfulness' or 'relevance'.""" name: String! @@ -654,8 +664,12 @@ enum SpanKind { unknown } +""" +The sort key and direction for span connections. Must specify one and only one of either `col` or `evalResultKey`. +""" input SpanSort { - col: SpanColumn! + col: SpanColumn = null + evalResultKey: EvalResultKey = null dir: SortDir! } diff --git a/app/src/pages/tracing/__generated__/SpansTableSpansQuery.graphql.ts b/app/src/pages/tracing/__generated__/SpansTableSpansQuery.graphql.ts index 597fbde2e6f..024d55c9553 100644 --- a/app/src/pages/tracing/__generated__/SpansTableSpansQuery.graphql.ts +++ b/app/src/pages/tracing/__generated__/SpansTableSpansQuery.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<> + * @generated SignedSource<<40bdb7f53168019e2ad59f8809889d9c>> * @lightSyntaxTransform * @nogrep */ @@ -10,11 +10,17 @@ import { ConcreteRequest, Query } from 'relay-runtime'; import { FragmentRefs } from "relay-runtime"; +export type EvalAttr = "label" | "score"; export type SortDir = "asc" | "desc"; export type SpanColumn = "cumulativeTokenCountCompletion" | "cumulativeTokenCountPrompt" | "cumulativeTokenCountTotal" | "endTime" | "latencyMs" | "startTime" | "tokenCountCompletion" | "tokenCountPrompt" | "tokenCountTotal"; export type SpanSort = { - col: SpanColumn; + col?: SpanColumn | null; dir: SortDir; + evalResultKey?: EvalResultKey | null; +}; +export type EvalResultKey = { + attr: EvalAttr; + name: string; }; export type SpansTableSpansQuery$variables = { after?: string | null; diff --git a/app/src/pages/tracing/__generated__/TracesTableQuery.graphql.ts b/app/src/pages/tracing/__generated__/TracesTableQuery.graphql.ts index 93c232a8066..7181998b3c9 100644 --- a/app/src/pages/tracing/__generated__/TracesTableQuery.graphql.ts +++ b/app/src/pages/tracing/__generated__/TracesTableQuery.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<<32eaa17bdc10fa7efd4f68063092dad0>> + * @generated SignedSource<> * @lightSyntaxTransform * @nogrep */ @@ -10,11 +10,17 @@ import { ConcreteRequest, Query } from 'relay-runtime'; import { FragmentRefs } from "relay-runtime"; +export type EvalAttr = "label" | "score"; export type SortDir = "asc" | "desc"; export type SpanColumn = "cumulativeTokenCountCompletion" | "cumulativeTokenCountPrompt" | "cumulativeTokenCountTotal" | "endTime" | "latencyMs" | "startTime" | "tokenCountCompletion" | "tokenCountPrompt" | "tokenCountTotal"; export type SpanSort = { - col: SpanColumn; + col?: SpanColumn | null; dir: SortDir; + evalResultKey?: EvalResultKey | null; +}; +export type EvalResultKey = { + attr: EvalAttr; + name: string; }; export type TracesTableQuery$variables = { after?: string | null; diff --git a/src/phoenix/core/evals.py b/src/phoenix/core/evals.py index 8bbadcd75d0..da386cfc6d7 100644 --- a/src/phoenix/core/evals.py +++ b/src/phoenix/core/evals.py @@ -84,6 +84,10 @@ def _process_evaluation(self, evaluation: pb.Evaluation) -> None: else: assert_never(subject_id_kind) + def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluation]: + with self._lock: + return self._evaluations_by_span_id[span_id].get(name) + def get_span_evaluation_names(self) -> List[EvaluationName]: with self._lock: return list(self._span_evaluations_by_name.keys()) diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index ceebcfd0f0a..97036ad53ba 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -1,20 +1,20 @@ from enum import Enum from functools import partial -from typing import Any, Iterable, Iterator +from typing import Any, Iterable, Iterator, Optional, Protocol import pandas as pd import strawberry +from typing_extensions import assert_never +import phoenix.trace.v1 as pb from phoenix.core.traces import ( END_TIME, - LLM_TOKEN_COUNT_COMPLETION, - LLM_TOKEN_COUNT_PROMPT, - LLM_TOKEN_COUNT_TOTAL, START_TIME, ComputedAttributes, ) from phoenix.server.api.types.SortDir import SortDir -from phoenix.trace.schemas import Span +from phoenix.trace import semantic_conventions +from phoenix.trace.schemas import Span, SpanID @strawberry.enum @@ -22,36 +22,92 @@ class SpanColumn(Enum): startTime = START_TIME endTime = END_TIME latencyMs = ComputedAttributes.LATENCY_MS.value - tokenCountTotal = LLM_TOKEN_COUNT_TOTAL - tokenCountPrompt = LLM_TOKEN_COUNT_PROMPT - tokenCountCompletion = LLM_TOKEN_COUNT_COMPLETION + tokenCountTotal = semantic_conventions.LLM_TOKEN_COUNT_TOTAL + tokenCountPrompt = semantic_conventions.LLM_TOKEN_COUNT_PROMPT + tokenCountCompletion = semantic_conventions.LLM_TOKEN_COUNT_COMPLETION cumulativeTokenCountTotal = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL.value cumulativeTokenCountPrompt = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT.value cumulativeTokenCountCompletion = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION.value +@strawberry.enum +class EvalAttr(Enum): + score = "score" + label = "label" + + @strawberry.input -class SpanSort: - """ - The sort column and direction for span connections - """ +class EvalResultKey: + name: str + attr: EvalAttr + + +class SupportsGetSpanEvaluation(Protocol): + def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluation]: + ... - col: SpanColumn + +@strawberry.input( + description="The sort key and direction for span connections. Must " + "specify one and only one of either `col` or `evalResultKey`." +) +class SpanSort: + col: Optional[SpanColumn] = None + eval_result_key: Optional[EvalResultKey] = None dir: SortDir - def __call__(self, spans: Iterable[Span]) -> Iterator[Span]: + def __call__( + self, + spans: Iterable[Span], + evals: Optional[SupportsGetSpanEvaluation] = None, + ) -> Iterator[Span]: """ - Sorts the spans by the given column and direction + Sorts the spans by the given key (column or eval) and direction """ + if self.eval_result_key is not None: + get_sort_key_value = partial( + _get_eval_result_value, + eval_name=self.eval_result_key.name, + eval_attr=self.eval_result_key.attr, + evals=evals, + ) + else: + get_sort_key_value = partial( + _get_column_value, + span_column=self.col or SpanColumn.startTime, + ) yield from pd.Series(spans, dtype=object).sort_values( - key=lambda series: series.apply(partial(_get_column, span_column=self.col)), + key=lambda series: series.apply(get_sort_key_value), ascending=self.dir.value == SortDir.asc.value, ) -def _get_column(span: Span, span_column: SpanColumn) -> Any: +def _get_column_value(span: Span, span_column: SpanColumn) -> Any: if span_column is SpanColumn.startTime: return span.start_time if span_column is SpanColumn.endTime: return span.end_time return span.attributes.get(span_column.value) + + +def _get_eval_result_value( + span: Span, + eval_name: str, + eval_attr: EvalAttr, + evals: Optional[SupportsGetSpanEvaluation] = None, +) -> Any: + """ + Returns the evaluation result for the given span + """ + if evals is None: + return None + span_id = span.context.span_id + evaluation = evals.get_span_evaluation(span_id, eval_name) + if evaluation is None: + return None + result = evaluation.result + if eval_attr is EvalAttr.score: + return result.score.value if result.HasField("score") else None + if eval_attr is EvalAttr.label: + return result.label.value if result.HasField("label") else None + assert_never(eval_attr) diff --git a/src/phoenix/server/api/schema.py b/src/phoenix/server/api/schema.py index 844cdb32a8d..8ac43bbc2ab 100644 --- a/src/phoenix/server/api/schema.py +++ b/src/phoenix/server/api/schema.py @@ -235,7 +235,7 @@ def spans( if predicate: spans = filter(predicate, spans) if sort: - spans = sort(spans) + spans = sort(spans, evals=info.context.evals) data = list(map(to_gql_span, spans)) return connection_from_list(data=data, args=args) diff --git a/tests/server/api/input_types/test_SpanSort.py b/tests/server/api/input_types/test_SpanSort.py new file mode 100644 index 00000000000..9082c1a79cc --- /dev/null +++ b/tests/server/api/input_types/test_SpanSort.py @@ -0,0 +1,60 @@ +from collections import namedtuple +from itertools import count, islice +from random import random + +import phoenix.trace.v1 as pb +import pytest +from google.protobuf.wrappers_pb2 import DoubleValue, StringValue +from phoenix.server.api.input_types.SpanSort import EvalAttr, EvalResultKey, SpanColumn, SpanSort +from phoenix.server.api.types.SortDir import SortDir + + +@pytest.mark.parametrize("col", [SpanColumn.endTime, SpanColumn.latencyMs]) +def test_sort_by_col(spans, col): + span0, span1, span2 = islice(spans, 3) + sort = SpanSort(col=col, dir=SortDir.desc) + assert list(sort([span0, span1, span2])) == [span2, span0, span1] + + +@pytest.mark.parametrize("eval_attr", list(EvalAttr)) +def test_sort_by_eval(spans, evals, eval_name, eval_attr): + span0, span1, span2 = islice(spans, 3) + + eval_result_key = EvalResultKey(name=eval_name, attr=eval_attr) + sort = SpanSort(eval_result_key=eval_result_key, dir=SortDir.desc) + assert list(sort([span0, span2, span1], evals)) == [span1, span0, span2] + + # non-existent evaluation name + no_op_key = EvalResultKey(name=random(), attr=eval_attr) + no_op_sort = SpanSort(eval_result_key=no_op_key, dir=SortDir.desc) + assert list(no_op_sort([span2, span0, span1], evals)) == [span2, span0, span1] + + +Span = namedtuple("Span", "context end_time attributes") +Context = namedtuple("Context", "span_id") +Evals = namedtuple("Evals", "get_span_evaluation") + + +@pytest.fixture +def evals(eval_name): + result0 = pb.Evaluation.Result(score=DoubleValue(value=0)) + result1 = pb.Evaluation.Result(score=DoubleValue(value=1), label=StringValue(value="1")) + evaluations = {eval_name: {0: pb.Evaluation(result=result0), 1: pb.Evaluation(result=result1)}} + return Evals(lambda span_id, name: evaluations.get(name, {}).get(span_id)) + + +@pytest.fixture +def eval_name(): + return "correctness" + + +@pytest.fixture +def spans(): + return ( + Span( + context=Context(i), + end_time=None if i % 2 else i, + attributes={} if i % 2 else {SpanColumn.latencyMs.value: i}, + ) + for i in count() + )