-
Notifications
You must be signed in to change notification settings - Fork 285
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(traces): server-side sort of spans by evaluation result (score o…
…r label) (#1812) * feat: server-side sort by eval
- Loading branch information
1 parent
291332c
commit 00b80d2
Showing
7 changed files
with
169 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
10 changes: 8 additions & 2 deletions
10
app/src/pages/tracing/__generated__/SpansTableSpansQuery.graphql.ts
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
10 changes: 8 additions & 2 deletions
10
app/src/pages/tracing/__generated__/TracesTableQuery.graphql.ts
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,113 @@ | ||
from enum import Enum | ||
from functools import partial | ||
from typing import Any, Iterable, Iterator | ||
from typing import Any, Iterable, Iterator, Optional, Protocol | ||
|
||
import pandas as pd | ||
import strawberry | ||
from typing_extensions import assert_never | ||
|
||
import phoenix.trace.v1 as pb | ||
from phoenix.core.traces import ( | ||
END_TIME, | ||
LLM_TOKEN_COUNT_COMPLETION, | ||
LLM_TOKEN_COUNT_PROMPT, | ||
LLM_TOKEN_COUNT_TOTAL, | ||
START_TIME, | ||
ComputedAttributes, | ||
) | ||
from phoenix.server.api.types.SortDir import SortDir | ||
from phoenix.trace.schemas import Span | ||
from phoenix.trace import semantic_conventions | ||
from phoenix.trace.schemas import Span, SpanID | ||
|
||
|
||
@strawberry.enum | ||
class SpanColumn(Enum): | ||
startTime = START_TIME | ||
endTime = END_TIME | ||
latencyMs = ComputedAttributes.LATENCY_MS.value | ||
tokenCountTotal = LLM_TOKEN_COUNT_TOTAL | ||
tokenCountPrompt = LLM_TOKEN_COUNT_PROMPT | ||
tokenCountCompletion = LLM_TOKEN_COUNT_COMPLETION | ||
tokenCountTotal = semantic_conventions.LLM_TOKEN_COUNT_TOTAL | ||
tokenCountPrompt = semantic_conventions.LLM_TOKEN_COUNT_PROMPT | ||
tokenCountCompletion = semantic_conventions.LLM_TOKEN_COUNT_COMPLETION | ||
cumulativeTokenCountTotal = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL.value | ||
cumulativeTokenCountPrompt = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT.value | ||
cumulativeTokenCountCompletion = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION.value | ||
|
||
|
||
@strawberry.enum | ||
class EvalAttr(Enum): | ||
score = "score" | ||
label = "label" | ||
|
||
|
||
@strawberry.input | ||
class SpanSort: | ||
""" | ||
The sort column and direction for span connections | ||
""" | ||
class EvalResultKey: | ||
name: str | ||
attr: EvalAttr | ||
|
||
|
||
class SupportsGetSpanEvaluation(Protocol): | ||
def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluation]: | ||
... | ||
|
||
col: SpanColumn | ||
|
||
@strawberry.input( | ||
description="The sort key and direction for span connections. Must " | ||
"specify one and only one of either `col` or `evalResultKey`." | ||
) | ||
class SpanSort: | ||
col: Optional[SpanColumn] = None | ||
eval_result_key: Optional[EvalResultKey] = None | ||
dir: SortDir | ||
|
||
def __call__(self, spans: Iterable[Span]) -> Iterator[Span]: | ||
def __call__( | ||
self, | ||
spans: Iterable[Span], | ||
evals: Optional[SupportsGetSpanEvaluation] = None, | ||
) -> Iterator[Span]: | ||
""" | ||
Sorts the spans by the given column and direction | ||
Sorts the spans by the given key (column or eval) and direction | ||
""" | ||
if self.eval_result_key is not None: | ||
get_sort_key_value = partial( | ||
_get_eval_result_value, | ||
eval_name=self.eval_result_key.name, | ||
eval_attr=self.eval_result_key.attr, | ||
evals=evals, | ||
) | ||
else: | ||
get_sort_key_value = partial( | ||
_get_column_value, | ||
span_column=self.col or SpanColumn.startTime, | ||
) | ||
yield from pd.Series(spans, dtype=object).sort_values( | ||
key=lambda series: series.apply(partial(_get_column, span_column=self.col)), | ||
key=lambda series: series.apply(get_sort_key_value), | ||
ascending=self.dir.value == SortDir.asc.value, | ||
) | ||
|
||
|
||
def _get_column(span: Span, span_column: SpanColumn) -> Any: | ||
def _get_column_value(span: Span, span_column: SpanColumn) -> Any: | ||
if span_column is SpanColumn.startTime: | ||
return span.start_time | ||
if span_column is SpanColumn.endTime: | ||
return span.end_time | ||
return span.attributes.get(span_column.value) | ||
|
||
|
||
def _get_eval_result_value( | ||
span: Span, | ||
eval_name: str, | ||
eval_attr: EvalAttr, | ||
evals: Optional[SupportsGetSpanEvaluation] = None, | ||
) -> Any: | ||
""" | ||
Returns the evaluation result for the given span | ||
""" | ||
if evals is None: | ||
return None | ||
span_id = span.context.span_id | ||
evaluation = evals.get_span_evaluation(span_id, eval_name) | ||
if evaluation is None: | ||
return None | ||
result = evaluation.result | ||
if eval_attr is EvalAttr.score: | ||
return result.score.value if result.HasField("score") else None | ||
if eval_attr is EvalAttr.label: | ||
return result.label.value if result.HasField("label") else None | ||
assert_never(eval_attr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from collections import namedtuple | ||
from itertools import count, islice | ||
from random import random | ||
|
||
import phoenix.trace.v1 as pb | ||
import pytest | ||
from google.protobuf.wrappers_pb2 import DoubleValue, StringValue | ||
from phoenix.server.api.input_types.SpanSort import EvalAttr, EvalResultKey, SpanColumn, SpanSort | ||
from phoenix.server.api.types.SortDir import SortDir | ||
|
||
|
||
@pytest.mark.parametrize("col", [SpanColumn.endTime, SpanColumn.latencyMs]) | ||
def test_sort_by_col(spans, col): | ||
span0, span1, span2 = islice(spans, 3) | ||
sort = SpanSort(col=col, dir=SortDir.desc) | ||
assert list(sort([span0, span1, span2])) == [span2, span0, span1] | ||
|
||
|
||
@pytest.mark.parametrize("eval_attr", list(EvalAttr)) | ||
def test_sort_by_eval(spans, evals, eval_name, eval_attr): | ||
span0, span1, span2 = islice(spans, 3) | ||
|
||
eval_result_key = EvalResultKey(name=eval_name, attr=eval_attr) | ||
sort = SpanSort(eval_result_key=eval_result_key, dir=SortDir.desc) | ||
assert list(sort([span0, span2, span1], evals)) == [span1, span0, span2] | ||
|
||
# non-existent evaluation name | ||
no_op_key = EvalResultKey(name=random(), attr=eval_attr) | ||
no_op_sort = SpanSort(eval_result_key=no_op_key, dir=SortDir.desc) | ||
assert list(no_op_sort([span2, span0, span1], evals)) == [span2, span0, span1] | ||
|
||
|
||
Span = namedtuple("Span", "context end_time attributes") | ||
Context = namedtuple("Context", "span_id") | ||
Evals = namedtuple("Evals", "get_span_evaluation") | ||
|
||
|
||
@pytest.fixture | ||
def evals(eval_name): | ||
result0 = pb.Evaluation.Result(score=DoubleValue(value=0)) | ||
result1 = pb.Evaluation.Result(score=DoubleValue(value=1), label=StringValue(value="1")) | ||
evaluations = {eval_name: {0: pb.Evaluation(result=result0), 1: pb.Evaluation(result=result1)}} | ||
return Evals(lambda span_id, name: evaluations.get(name, {}).get(span_id)) | ||
|
||
|
||
@pytest.fixture | ||
def eval_name(): | ||
return "correctness" | ||
|
||
|
||
@pytest.fixture | ||
def spans(): | ||
return ( | ||
Span( | ||
context=Context(i), | ||
end_time=None if i % 2 else i, | ||
attributes={} if i % 2 else {SpanColumn.latencyMs.value: i}, | ||
) | ||
for i in count() | ||
) |