From 5bcdaaea02e05c621ed33b093e2a75efe07376e0 Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Thu, 31 Oct 2024 12:27:44 -0700 Subject: [PATCH] feat(sessions): session trace error count (#5244) --- app/schema.graphql | 1 + src/phoenix/server/api/context.py | 2 ++ .../server/api/dataloaders/__init__.py | 2 ++ .../session_num_traces_with_error.py | 32 +++++++++++++++++++ .../server/api/types/ProjectSession.py | 7 ++++ src/phoenix/server/app.py | 2 ++ tests/unit/_helpers.py | 3 +- .../server/api/types/test_ProjectSession.py | 10 ++++++ 8 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 src/phoenix/server/api/dataloaders/session_num_traces_with_error.py diff --git a/app/schema.graphql b/app/schema.graphql index 09713319031..59405fd2058 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -1333,6 +1333,7 @@ type ProjectSession implements Node { """Duration of the session in seconds""" durationS: Float! numTraces: Int! + numTracesWithError: Int! firstInput: SpanIOValue lastOutput: SpanIOValue tokenUsage: TokenUsage! diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index a6fb4e38754..75e7544b136 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -33,6 +33,7 @@ RecordCountDataLoader, SessionIODataLoader, SessionNumTracesDataLoader, + SessionNumTracesWithErrorDataLoader, SessionTokenUsagesDataLoader, SessionTraceLatencyMsQuantileDataLoader, SpanAnnotationsDataLoader, @@ -76,6 +77,7 @@ class DataLoaders: session_first_inputs: SessionIODataLoader session_last_outputs: SessionIODataLoader session_num_traces: SessionNumTracesDataLoader + session_num_traces_with_error: SessionNumTracesWithErrorDataLoader session_token_usages: SessionTokenUsagesDataLoader session_trace_latency_ms_quantile: SessionTraceLatencyMsQuantileDataLoader span_annotations: SpanAnnotationsDataLoader diff --git a/src/phoenix/server/api/dataloaders/__init__.py b/src/phoenix/server/api/dataloaders/__init__.py index dfcc67288cf..9cea67eaca3 100644 --- a/src/phoenix/server/api/dataloaders/__init__.py +++ b/src/phoenix/server/api/dataloaders/__init__.py @@ -21,6 +21,7 @@ from .record_counts import RecordCountCache, RecordCountDataLoader from .session_io import SessionIODataLoader from .session_num_traces import SessionNumTracesDataLoader +from .session_num_traces_with_error import SessionNumTracesWithErrorDataLoader from .session_token_usages import SessionTokenUsagesDataLoader from .session_trace_latency_ms_quantile import SessionTraceLatencyMsQuantileDataLoader from .span_annotations import SpanAnnotationsDataLoader @@ -52,6 +53,7 @@ "RecordCountDataLoader", "SessionIODataLoader", "SessionNumTracesDataLoader", + "SessionNumTracesWithErrorDataLoader", "SessionTokenUsagesDataLoader", "SessionTraceLatencyMsQuantileDataLoader", "SpanDatasetExamplesDataLoader", diff --git a/src/phoenix/server/api/dataloaders/session_num_traces_with_error.py b/src/phoenix/server/api/dataloaders/session_num_traces_with_error.py new file mode 100644 index 00000000000..595a928bb9a --- /dev/null +++ b/src/phoenix/server/api/dataloaders/session_num_traces_with_error.py @@ -0,0 +1,32 @@ +from sqlalchemy import func, select +from strawberry.dataloader import DataLoader +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.server.types import DbSessionFactory + +Key: TypeAlias = int +Result: TypeAlias = int + + +class SessionNumTracesWithErrorDataLoader(DataLoader[Key, Result]): + def __init__(self, db: DbSessionFactory) -> None: + super().__init__(load_fn=self._load_fn) + self._db = db + + async def _load_fn(self, keys: list[Key]) -> list[Result]: + stmt = ( + select( + models.Trace.project_session_rowid.label("id_"), + func.count(models.Trace.id).label("value"), + ) + .join(models.Span) + .group_by(models.Trace.project_session_rowid) + .where(models.Span.cumulative_error_count > 0) + .where(models.Trace.project_session_rowid.in_(keys)) + ) + async with self._db() as session: + result: dict[Key, int] = { + id_: value async for id_, value in await session.stream(stmt) if id_ is not None + } + return [result.get(key, 0) for key in keys] diff --git a/src/phoenix/server/api/types/ProjectSession.py b/src/phoenix/server/api/types/ProjectSession.py index c6af6d13fa6..4c782d8f2a8 100644 --- a/src/phoenix/server/api/types/ProjectSession.py +++ b/src/phoenix/server/api/types/ProjectSession.py @@ -45,6 +45,13 @@ async def num_traces( ) -> int: return await info.context.data_loaders.session_num_traces.load(self.id_attr) + @strawberry.field + async def num_traces_with_error( + self, + info: Info[Context, None], + ) -> int: + return await info.context.data_loaders.session_num_traces_with_error.load(self.id_attr) + @strawberry.field async def first_input( self, diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 010f130f29e..cffb1f3b4f6 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -89,6 +89,7 @@ RecordCountDataLoader, SessionIODataLoader, SessionNumTracesDataLoader, + SessionNumTracesWithErrorDataLoader, SessionTokenUsagesDataLoader, SessionTraceLatencyMsQuantileDataLoader, SpanAnnotationsDataLoader, @@ -617,6 +618,7 @@ def get_context() -> Context: session_first_inputs=SessionIODataLoader(db, "first_input"), session_last_outputs=SessionIODataLoader(db, "last_output"), session_num_traces=SessionNumTracesDataLoader(db), + session_num_traces_with_error=SessionNumTracesWithErrorDataLoader(db), session_token_usages=SessionTokenUsagesDataLoader(db), session_trace_latency_ms_quantile=SessionTraceLatencyMsQuantileDataLoader(db), span_annotations=SpanAnnotationsDataLoader(db), diff --git a/tests/unit/_helpers.py b/tests/unit/_helpers.py index 85f8107622c..a0172263c30 100644 --- a/tests/unit/_helpers.py +++ b/tests/unit/_helpers.py @@ -81,6 +81,7 @@ async def _add_span( end_time: Optional[datetime] = None, parent_span: Optional[models.Span] = None, span_kind: str = "LLM", + cumulative_error_count: int = 0, cumulative_llm_token_count_prompt: int = 0, cumulative_llm_token_count_completion: int = 0, ) -> models.Span: @@ -95,7 +96,7 @@ async def _add_span( end_time=end_time, status_code="OK", status_message="test_status_message", - cumulative_error_count=0, + cumulative_error_count=cumulative_error_count, cumulative_llm_token_count_prompt=cumulative_llm_token_count_prompt, cumulative_llm_token_count_completion=cumulative_llm_token_count_completion, attributes=attributes or {}, diff --git a/tests/unit/server/api/types/test_ProjectSession.py b/tests/unit/server/api/types/test_ProjectSession.py index 54cdeebb427..1a0d69fbeb6 100644 --- a/tests/unit/server/api/types/test_ProjectSession.py +++ b/tests/unit/server/api/types/test_ProjectSession.py @@ -70,6 +70,7 @@ async def _data( attributes={"input": {"value": "123"}, "output": {"value": "321"}}, cumulative_llm_token_count_prompt=1, cumulative_llm_token_count_completion=2, + cumulative_error_count=2, ) ) traces.append( @@ -116,6 +117,15 @@ async def test_num_traces( field = "numTraces" assert await self._node(field, project_session, httpx_client) == 2 + async def test_num_traces_with_error( + self, + _data: _Data, + httpx_client: httpx.AsyncClient, + ) -> None: + project_session = _data.project_sessions[0] + field = "numTracesWithError" + assert await self._node(field, project_session, httpx_client) == 1 + async def test_first_input( self, _data: _Data,