Skip to content

Commit

Permalink
feat(persistence): sql sorting for spans (#2823)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang authored Apr 10, 2024
1 parent 2821bb4 commit eeafb64
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 132 deletions.
20 changes: 11 additions & 9 deletions app/src/pages/project/TracesTable.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ export function TracesTable(props: TracesTableProps) {
statusCode: propagatedStatusCode
startTime
latencyMs
tokenCountTotal: cumulativeTokenCountTotal
tokenCountPrompt: cumulativeTokenCountPrompt
tokenCountCompletion: cumulativeTokenCountCompletion
cumulativeTokenCountTotal
cumulativeTokenCountPrompt
cumulativeTokenCountCompletion
parentId
input {
value
Expand Down Expand Up @@ -154,9 +154,9 @@ export function TracesTable(props: TracesTableProps) {
startTime
latencyMs
parentId
tokenCountTotal
tokenCountPrompt
tokenCountCompletion
cumulativeTokenCountTotal: tokenCountTotal
cumulativeTokenCountPrompt: tokenCountPrompt
cumulativeTokenCountCompletion: tokenCountCompletion
input {
value
}
Expand Down Expand Up @@ -389,7 +389,7 @@ export function TracesTable(props: TracesTableProps) {
{
header: "total tokens",
minSize: 80,
accessorKey: "tokenCountTotal",
accessorKey: "cumulativeTokenCountTotal",
cell: ({ row, getValue }) => {
const value = getValue();
if (value === null) {
Expand All @@ -398,8 +398,10 @@ export function TracesTable(props: TracesTableProps) {
return (
<TokenCount
tokenCountTotal={value as number}
tokenCountPrompt={row.original.tokenCountPrompt || 0}
tokenCountCompletion={row.original.tokenCountCompletion || 0}
tokenCountPrompt={row.original.cumulativeTokenCountPrompt || 0}
tokenCountCompletion={
row.original.cumulativeTokenCountCompletion || 0
}
/>
);
},
Expand Down
146 changes: 81 additions & 65 deletions app/src/pages/project/__generated__/ProjectPageQuery.graphql.ts

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions app/src/pages/project/__generated__/TracesTableQuery.graphql.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 14 additions & 14 deletions app/src/pages/project/__generated__/TracesTable_spans.graphql.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

69 changes: 46 additions & 23 deletions src/phoenix/server/api/input_types/SpanSort.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,56 @@
from enum import Enum
from enum import Enum, auto
from functools import partial
from typing import Any, Iterable, Iterator, Optional, Protocol

import pandas as pd
import strawberry
from openinference.semconv.trace import SpanAttributes
from sqlalchemy import Integer, cast, desc, nulls_last
from strawberry import UNSET
from typing_extensions import assert_never

import phoenix.trace.v1 as pb
from phoenix.core.project import WrappedSpan
from phoenix.db import models
from phoenix.server.api.types.SortDir import SortDir
from phoenix.trace.schemas import ComputedAttributes, SpanID
from phoenix.trace.schemas import SpanID

LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL


@strawberry.enum
class SpanColumn(Enum):
startTime = "start_time"
endTime = "end_time"
latencyMs = ComputedAttributes.LATENCY_MS
tokenCountTotal = SpanAttributes.LLM_TOKEN_COUNT_TOTAL
tokenCountPrompt = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
tokenCountCompletion = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
cumulativeTokenCountTotal = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL
cumulativeTokenCountPrompt = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT
cumulativeTokenCountCompletion = ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION
startTime = auto()
endTime = auto()
latencyMs = auto()
tokenCountTotal = auto()
tokenCountPrompt = auto()
tokenCountCompletion = auto()
cumulativeTokenCountTotal = auto()
cumulativeTokenCountPrompt = auto()
cumulativeTokenCountCompletion = auto()


_SPAN_COLUMN_TO_ORM_EXPR_MAP = {
SpanColumn.startTime: models.Span.start_time,
SpanColumn.endTime: models.Span.end_time,
SpanColumn.latencyMs: models.Span.latency_ms,
SpanColumn.tokenCountTotal: cast(
models.Span.attributes[LLM_TOKEN_COUNT_TOTAL].as_string(), Integer
),
SpanColumn.tokenCountPrompt: cast(
models.Span.attributes[LLM_TOKEN_COUNT_PROMPT].as_string(), Integer
),
SpanColumn.tokenCountCompletion: cast(
models.Span.attributes[LLM_TOKEN_COUNT_COMPLETION].as_string(), Integer
),
SpanColumn.cumulativeTokenCountTotal: models.Span.cumulative_llm_token_count_prompt
+ models.Span.cumulative_llm_token_count_completion,
SpanColumn.cumulativeTokenCountPrompt: models.Span.cumulative_llm_token_count_prompt,
SpanColumn.cumulativeTokenCountCompletion: models.Span.cumulative_llm_token_count_completion,
}


@strawberry.enum
Expand Down Expand Up @@ -52,6 +78,14 @@ class SpanSort:
eval_result_key: Optional[EvalResultKey] = UNSET
dir: SortDir

def to_orm_expr(self) -> Any:
if self.col:
expr = _SPAN_COLUMN_TO_ORM_EXPR_MAP[self.col]
if self.dir == SortDir.desc:
expr = desc(expr)
return nulls_last(expr)
NotImplementedError("not implemented")

def __call__(
self,
spans: Iterable[WrappedSpan],
Expand All @@ -68,24 +102,13 @@ def __call__(
evals=evals,
)
else:
get_sort_key_value = partial(
_get_column_value,
span_column=self.col or SpanColumn.startTime,
)
NotImplementedError("This should be unreachable. Use SQL instead.")
yield from pd.Series(spans, dtype=object).sort_values(
key=lambda series: series.apply(get_sort_key_value),
ascending=self.dir.value == SortDir.asc.value,
)


def _get_column_value(span: WrappedSpan, span_column: SpanColumn) -> Any:
if span_column is SpanColumn.startTime:
return span.start_time
if span_column is SpanColumn.endTime:
return span.end_time
return span[span_column.value]


def _get_eval_result_value(
span: WrappedSpan,
eval_name: str,
Expand Down
4 changes: 3 additions & 1 deletion src/phoenix/server/api/types/Project.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ async def spans(
parent,
models.Span.parent_span_id == parent.c.span_id,
).where(parent.c.span_id.is_(None))
# TODO(persistence): enable sort and filter
# TODO(persistence): enable filter
if sort:
stmt = stmt.order_by(sort.to_orm_expr())
async with info.context.db() as session:
spans = await session.scalars(stmt)
data = [to_gql_span(span, self.project) for span in spans]
Expand Down
19 changes: 9 additions & 10 deletions tests/server/api/input_types/test_SpanSort.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@
from google.protobuf.wrappers_pb2 import DoubleValue, StringValue
from openinference.semconv.trace import SpanAttributes
from phoenix.core.project import WrappedSpan
from phoenix.server.api.input_types.SpanSort import EvalAttr, EvalResultKey, SpanColumn, SpanSort
from phoenix.server.api.input_types.SpanSort import (
_SPAN_COLUMN_TO_ORM_EXPR_MAP,
EvalAttr,
EvalResultKey,
SpanColumn,
SpanSort,
)
from phoenix.server.api.types.SortDir import SortDir


@pytest.mark.parametrize(
"col", [SpanColumn.endTime, SpanColumn.latencyMs, SpanColumn.tokenCountTotal]
)
def test_sort_by_col(spans, col):
span0, span1, span2, span3, span4 = spans
sort = SpanSort(col=col, dir=SortDir.desc)
assert list(sort(spans)) == [span4, span2, span0, span1, span3]
def test_span_column_has_orm_expr():
assert set(SpanColumn) == set(_SPAN_COLUMN_TO_ORM_EXPR_MAP)


@pytest.mark.parametrize("eval_attr", list(EvalAttr))
Expand Down Expand Up @@ -62,7 +63,5 @@ def spans():
attributes={} if i % 2 else {SpanAttributes.LLM_TOKEN_COUNT_TOTAL: i},
)
)
if i % 2 == 0:
span[SpanColumn.latencyMs.value] = i
_spans.append(span)
return _spans

0 comments on commit eeafb64

Please sign in to comment.