Skip to content

Commit

Permalink
feat(sessions): add trace latency p50 to session details (#5236)
Browse files Browse the repository at this point in the history
* feat(sessions): add trace latency p50 to session details

* fix type for latency in ui

* update test, fix where clause

* add is not none check for db filter

* revert latency dataloader changes

* revert dataloader changes for project span and trace latency

* refactor to use its own data loader

* remove extra param from test

* add unit tests

* ruff

* fix import

* fix name filter in other dataloader tests

* update unit test fixture naming, add unit test for project session graphql trace_latency_ms_quantile field

* clean up imports

* pin aiohttp

* fix unit test deps
  • Loading branch information
Parker-Stafford authored and RogerHYang committed Dec 4, 2024
1 parent c400b52 commit cb6d27e
Show file tree
Hide file tree
Showing 16 changed files with 253 additions and 66 deletions.
2 changes: 2 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,7 @@ type ProjectSession implements Node {
sessionUser: String
startTime: DateTime!
endTime: DateTime!
projectId: GlobalID!

"""Duration of the session in seconds"""
durationS: Float!
Expand All @@ -1336,6 +1337,7 @@ type ProjectSession implements Node {
lastOutput: SpanIOValue
tokenUsage: TokenUsage!
traces(first: Int = 50, last: Int, after: String, before: String): TraceConnection!
traceLatencyMsQuantile(probability: Float!): Float
}

"""A connection to a list of items."""
Expand Down
13 changes: 13 additions & 0 deletions app/src/pages/trace/SessionDetails.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { css } from "@emotion/react";

import { Flex, Text, View } from "@arizeai/components";

import { LatencyText } from "@phoenix/components/trace/LatencyText";
import { TokenCount } from "@phoenix/components/trace/TokenCount";

import {
Expand All @@ -15,9 +16,11 @@ import { SessionDetailsTraceList } from "./SessionDetailsTraceList";
function SessionDetailsHeader({
traceCount,
tokenUsage,
latencyP50,
}: {
traceCount: number;
tokenUsage?: NonNullable<SessionDetailsQuery$data["session"]>["tokenUsage"];
latencyP50?: number | null;
}) {
return (
<View
Expand Down Expand Up @@ -45,6 +48,14 @@ function SessionDetailsHeader({
/>
</Flex>
) : null}
{latencyP50 != null ? (
<Flex direction={"column"}>
<Text elementType={"h3"} textSize={"medium"} color={"text-700"}>
Latency P50
</Text>
<LatencyText latencyMs={latencyP50} textSize={"xlarge"} />
</Flex>
) : null}
</Flex>
</View>
);
Expand Down Expand Up @@ -72,6 +83,7 @@ export function SessionDetails(props: SessionDetailsProps) {
prompt
}
sessionId
latencyP50: traceLatencyMsQuantile(probability: 0.50)
traces {
edges {
trace: node {
Expand Down Expand Up @@ -135,6 +147,7 @@ export function SessionDetails(props: SessionDetailsProps) {
<SessionDetailsHeader
traceCount={data.session.numTraces ?? 0}
tokenUsage={data.session.tokenUsage}
latencyP50={data.session.latencyP50}
/>
<SessionDetailsTraceList traces={data.session.traces} />
</main>
Expand Down
22 changes: 18 additions & 4 deletions app/src/pages/trace/__generated__/SessionDetailsQuery.graphql.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
SessionIODataLoader,
SessionNumTracesDataLoader,
SessionTokenUsagesDataLoader,
SessionTraceLatencyMsQuantileDataLoader,
SpanAnnotationsDataLoader,
SpanDatasetExamplesDataLoader,
SpanDescendantsDataLoader,
Expand Down Expand Up @@ -76,6 +77,7 @@ class DataLoaders:
session_last_outputs: SessionIODataLoader
session_num_traces: SessionNumTracesDataLoader
session_token_usages: SessionTokenUsagesDataLoader
session_trace_latency_ms_quantile: SessionTraceLatencyMsQuantileDataLoader
span_annotations: SpanAnnotationsDataLoader
span_dataset_examples: SpanDatasetExamplesDataLoader
span_descendants: SpanDescendantsDataLoader
Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/api/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .session_io import SessionIODataLoader
from .session_num_traces import SessionNumTracesDataLoader
from .session_token_usages import SessionTokenUsagesDataLoader
from .session_trace_latency_ms_quantile import SessionTraceLatencyMsQuantileDataLoader
from .span_annotations import SpanAnnotationsDataLoader
from .span_dataset_examples import SpanDatasetExamplesDataLoader
from .span_descendants import SpanDescendantsDataLoader
Expand Down Expand Up @@ -52,6 +53,7 @@
"SessionIODataLoader",
"SessionNumTracesDataLoader",
"SessionTokenUsagesDataLoader",
"SessionTraceLatencyMsQuantileDataLoader",
"SpanDatasetExamplesDataLoader",
"SpanDescendantsDataLoader",
"SpanProjectsDataLoader",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from collections import defaultdict
from typing import Optional

import numpy as np
from aioitertools.itertools import groupby
from sqlalchemy import select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory

SessionId: TypeAlias = int
Probability: TypeAlias = float
QuantileValue: TypeAlias = float

Key: TypeAlias = tuple[SessionId, Probability]
Result: TypeAlias = Optional[QuantileValue]
ResultPosition: TypeAlias = int

DEFAULT_VALUE: Result = None


class SessionTraceLatencyMsQuantileDataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: list[Key]) -> list[Result]:
results: list[Result] = [DEFAULT_VALUE] * len(keys)
argument_position_map: defaultdict[
SessionId, defaultdict[Probability, list[ResultPosition]]
] = defaultdict(lambda: defaultdict(list))
session_rowids = {session_id for session_id, _ in keys}
for position, (session_id, probability) in enumerate(keys):
argument_position_map[session_id][probability].append(position)
stmt = (
select(
models.Trace.project_session_rowid,
models.Trace.latency_ms,
)
.where(models.Trace.project_session_rowid.in_(session_rowids))
.order_by(models.Trace.project_session_rowid)
)
async with self._db() as session:
data = await session.stream(stmt)
async for project_session_rowid, group in groupby(
data, lambda row: row.project_session_rowid
):
session_latencies = [row.latency_ms for row in group]
for probability, positions in argument_position_map[project_session_rowid].items():
quantile_value = np.quantile(session_latencies, probability)
for position in positions:
results[position] = quantile_value
return results
16 changes: 14 additions & 2 deletions src/phoenix/server/api/types/Project.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,13 @@ async def latency_ms_quantile(
time_range: Optional[TimeRange] = UNSET,
) -> Optional[float]:
return await info.context.data_loaders.latency_ms_quantile.load(
("trace", self.id_attr, time_range, None, probability),
(
"trace",
self.id_attr,
time_range,
None,
probability,
),
)

@strawberry.field
Expand All @@ -140,7 +146,13 @@ async def span_latency_ms_quantile(
filter_condition: Optional[str] = UNSET,
) -> Optional[float]:
return await info.context.data_loaders.latency_ms_quantile.load(
("span", self.id_attr, time_range, filter_condition, probability),
(
"span",
self.id_attr,
time_range,
filter_condition,
probability,
),
)

@strawberry.field
Expand Down
22 changes: 20 additions & 2 deletions src/phoenix/server/api/types/ProjectSession.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import strawberry
from openinference.semconv.trace import SpanAttributes
from sqlalchemy import select
from strawberry import UNSET, Info, lazy
from strawberry.relay import Connection, Node, NodeID
from strawberry import UNSET, Info, Private, lazy
from strawberry.relay import Connection, GlobalID, Node, NodeID

from phoenix.db import models
from phoenix.server.api.context import Context
Expand All @@ -22,11 +22,18 @@
class ProjectSession(Node):
_table: ClassVar[Type[models.ProjectSession]] = models.ProjectSession
id_attr: NodeID[int]
project_rowid: Private[int]
session_id: str
session_user: Optional[str]
start_time: datetime
end_time: datetime

@strawberry.field
async def project_id(self) -> GlobalID:
from phoenix.server.api.types.Project import Project

return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid))

@strawberry.field(description="Duration of the session in seconds") # type: ignore
async def duration_s(self) -> float:
return (self.end_time - self.start_time).total_seconds()
Expand Down Expand Up @@ -103,6 +110,16 @@ async def traces(
data = [to_gql_trace(trace) async for trace in traces]
return connection_from_list(data=data, args=args)

@strawberry.field
async def trace_latency_ms_quantile(
self,
info: Info[Context, None],
probability: float,
) -> Optional[float]:
return await info.context.data_loaders.session_trace_latency_ms_quantile.load(
(self.id_attr, probability)
)


def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSession:
return ProjectSession(
Expand All @@ -111,6 +128,7 @@ def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSes
session_user=project_session.session_user,
start_time=project_session.start_time,
end_time=project_session.end_time,
project_rowid=project_session.project_id,
)


Expand Down
2 changes: 2 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
SessionIODataLoader,
SessionNumTracesDataLoader,
SessionTokenUsagesDataLoader,
SessionTraceLatencyMsQuantileDataLoader,
SpanAnnotationsDataLoader,
SpanDatasetExamplesDataLoader,
SpanDescendantsDataLoader,
Expand Down Expand Up @@ -617,6 +618,7 @@ def get_context() -> Context:
session_last_outputs=SessionIODataLoader(db, "last_output"),
session_num_traces=SessionNumTracesDataLoader(db),
session_token_usages=SessionTokenUsagesDataLoader(db),
session_trace_latency_ms_quantile=SessionTraceLatencyMsQuantileDataLoader(db),
span_annotations=SpanAnnotationsDataLoader(db),
span_dataset_examples=SpanDatasetExamplesDataLoader(db),
span_descendants=SpanDescendantsDataLoader(db),
Expand Down
Loading

0 comments on commit cb6d27e

Please sign in to comment.