Skip to content

Commit

Permalink
feat(traces): server-side sort of spans by evaluation result (score o…
Browse files Browse the repository at this point in the history
…r label) (#1812)

* feat: server-side sort by eval
  • Loading branch information
RogerHYang authored and mikeldking committed Dec 1, 2023
1 parent 291332c commit 00b80d2
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 23 deletions.
16 changes: 15 additions & 1 deletion app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,16 @@ type EmbeddingMetadata {
linkToData: String
}

enum EvalAttr {
score
label
}

input EvalResultKey {
name: String!
attr: EvalAttr!
}

interface Evaluation {
"""Name of the evaluation, e.g. 'helpfulness' or 'relevance'."""
name: String!
Expand Down Expand Up @@ -654,8 +664,12 @@ enum SpanKind {
unknown
}

"""
The sort key and direction for span connections. Must specify one and only one of either `col` or `evalResultKey`.
"""
input SpanSort {
col: SpanColumn!
col: SpanColumn = null
evalResultKey: EvalResultKey = null
dir: SortDir!
}

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions app/src/pages/tracing/__generated__/TracesTableQuery.graphql.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions src/phoenix/core/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def _process_evaluation(self, evaluation: pb.Evaluation) -> None:
else:
assert_never(subject_id_kind)

def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluation]:
with self._lock:
return self._evaluations_by_span_id[span_id].get(name)

def get_span_evaluation_names(self) -> List[EvaluationName]:
with self._lock:
return list(self._span_evaluations_by_name.keys())
Expand Down
90 changes: 73 additions & 17 deletions src/phoenix/server/api/input_types/SpanSort.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,113 @@
from enum import Enum
from functools import partial
from typing import Any, Iterable, Iterator
from typing import Any, Iterable, Iterator, Optional, Protocol

import pandas as pd
import strawberry
from typing_extensions import assert_never

import phoenix.trace.v1 as pb
from phoenix.core.traces import (
END_TIME,
LLM_TOKEN_COUNT_COMPLETION,
LLM_TOKEN_COUNT_PROMPT,
LLM_TOKEN_COUNT_TOTAL,
START_TIME,
ComputedAttributes,
)
from phoenix.server.api.types.SortDir import SortDir
from phoenix.trace.schemas import Span
from phoenix.trace import semantic_conventions
from phoenix.trace.schemas import Span, SpanID


@strawberry.enum
class SpanColumn(Enum):
startTime = START_TIME
endTime = END_TIME
latencyMs = ComputedAttributes.LATENCY_MS.value
tokenCountTotal = LLM_TOKEN_COUNT_TOTAL
tokenCountPrompt = LLM_TOKEN_COUNT_PROMPT
tokenCountCompletion = LLM_TOKEN_COUNT_COMPLETION
tokenCountTotal = semantic_conventions.LLM_TOKEN_COUNT_TOTAL
tokenCountPrompt = semantic_conventions.LLM_TOKEN_COUNT_PROMPT
tokenCountCompletion = semantic_conventions.LLM_TOKEN_COUNT_COMPLETION
cumulativeTokenCountTotal = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL.value
cumulativeTokenCountPrompt = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT.value
cumulativeTokenCountCompletion = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION.value


@strawberry.enum
class EvalAttr(Enum):
score = "score"
label = "label"


@strawberry.input
class SpanSort:
"""
The sort column and direction for span connections
"""
class EvalResultKey:
name: str
attr: EvalAttr


class SupportsGetSpanEvaluation(Protocol):
def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluation]:
...

col: SpanColumn

@strawberry.input(
description="The sort key and direction for span connections. Must "
"specify one and only one of either `col` or `evalResultKey`."
)
class SpanSort:
col: Optional[SpanColumn] = None
eval_result_key: Optional[EvalResultKey] = None
dir: SortDir

def __call__(self, spans: Iterable[Span]) -> Iterator[Span]:
def __call__(
self,
spans: Iterable[Span],
evals: Optional[SupportsGetSpanEvaluation] = None,
) -> Iterator[Span]:
"""
Sorts the spans by the given column and direction
Sorts the spans by the given key (column or eval) and direction
"""
if self.eval_result_key is not None:
get_sort_key_value = partial(
_get_eval_result_value,
eval_name=self.eval_result_key.name,
eval_attr=self.eval_result_key.attr,
evals=evals,
)
else:
get_sort_key_value = partial(
_get_column_value,
span_column=self.col or SpanColumn.startTime,
)
yield from pd.Series(spans, dtype=object).sort_values(
key=lambda series: series.apply(partial(_get_column, span_column=self.col)),
key=lambda series: series.apply(get_sort_key_value),
ascending=self.dir.value == SortDir.asc.value,
)


def _get_column(span: Span, span_column: SpanColumn) -> Any:
def _get_column_value(span: Span, span_column: SpanColumn) -> Any:
if span_column is SpanColumn.startTime:
return span.start_time
if span_column is SpanColumn.endTime:
return span.end_time
return span.attributes.get(span_column.value)


def _get_eval_result_value(
span: Span,
eval_name: str,
eval_attr: EvalAttr,
evals: Optional[SupportsGetSpanEvaluation] = None,
) -> Any:
"""
Returns the evaluation result for the given span
"""
if evals is None:
return None
span_id = span.context.span_id
evaluation = evals.get_span_evaluation(span_id, eval_name)
if evaluation is None:
return None
result = evaluation.result
if eval_attr is EvalAttr.score:
return result.score.value if result.HasField("score") else None
if eval_attr is EvalAttr.label:
return result.label.value if result.HasField("label") else None
assert_never(eval_attr)
2 changes: 1 addition & 1 deletion src/phoenix/server/api/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def spans(
if predicate:
spans = filter(predicate, spans)
if sort:
spans = sort(spans)
spans = sort(spans, evals=info.context.evals)
data = list(map(to_gql_span, spans))
return connection_from_list(data=data, args=args)

Expand Down
60 changes: 60 additions & 0 deletions tests/server/api/input_types/test_SpanSort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from collections import namedtuple
from itertools import count, islice
from random import random

import phoenix.trace.v1 as pb
import pytest
from google.protobuf.wrappers_pb2 import DoubleValue, StringValue
from phoenix.server.api.input_types.SpanSort import EvalAttr, EvalResultKey, SpanColumn, SpanSort
from phoenix.server.api.types.SortDir import SortDir


@pytest.mark.parametrize("col", [SpanColumn.endTime, SpanColumn.latencyMs])
def test_sort_by_col(spans, col):
span0, span1, span2 = islice(spans, 3)
sort = SpanSort(col=col, dir=SortDir.desc)
assert list(sort([span0, span1, span2])) == [span2, span0, span1]


@pytest.mark.parametrize("eval_attr", list(EvalAttr))
def test_sort_by_eval(spans, evals, eval_name, eval_attr):
span0, span1, span2 = islice(spans, 3)

eval_result_key = EvalResultKey(name=eval_name, attr=eval_attr)
sort = SpanSort(eval_result_key=eval_result_key, dir=SortDir.desc)
assert list(sort([span0, span2, span1], evals)) == [span1, span0, span2]

# non-existent evaluation name
no_op_key = EvalResultKey(name=random(), attr=eval_attr)
no_op_sort = SpanSort(eval_result_key=no_op_key, dir=SortDir.desc)
assert list(no_op_sort([span2, span0, span1], evals)) == [span2, span0, span1]


Span = namedtuple("Span", "context end_time attributes")
Context = namedtuple("Context", "span_id")
Evals = namedtuple("Evals", "get_span_evaluation")


@pytest.fixture
def evals(eval_name):
result0 = pb.Evaluation.Result(score=DoubleValue(value=0))
result1 = pb.Evaluation.Result(score=DoubleValue(value=1), label=StringValue(value="1"))
evaluations = {eval_name: {0: pb.Evaluation(result=result0), 1: pb.Evaluation(result=result1)}}
return Evals(lambda span_id, name: evaluations.get(name, {}).get(span_id))


@pytest.fixture
def eval_name():
return "correctness"


@pytest.fixture
def spans():
return (
Span(
context=Context(i),
end_time=None if i % 2 else i,
attributes={} if i % 2 else {SpanColumn.latencyMs.value: i},
)
for i in count()
)

0 comments on commit 00b80d2

Please sign in to comment.