Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(sessions): add dataloaders for session queries #5222

Merged
merged 10 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions scripts/fixtures/multi-turn_chat_sessions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install -Uqqq datasets openinference-instrumentation-openai openai-responses openai tiktoken langchain langchain-openai llama-index llama-index-llms-openai faker mdgen"
"%pip install -Uqqq datasets openinference-semantic-conventions openinference-instrumentation-openai faker openai-responses openai tiktoken"
]
},
{
Expand All @@ -25,7 +25,6 @@
"import pandas as pd\n",
"from datasets import load_dataset\n",
"from faker import Faker\n",
"from mdgen import MarkdownPostProvider\n",
"from openai_responses import OpenAIMock\n",
"from openinference.instrumentation import using_session, using_user\n",
"from openinference.instrumentation.openai import OpenAIInstrumentor\n",
Expand All @@ -39,8 +38,7 @@
"import phoenix as px\n",
"from phoenix.trace.span_evaluations import SpanEvaluations\n",
"\n",
"fake = Faker(\"ja_JP\")\n",
"fake.add_provider(MarkdownPostProvider)"
"fake = Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\", \"th_TH\", \"bn_BD\"])"
]
},
{
Expand Down Expand Up @@ -104,7 +102,7 @@
" if p < 0.1:\n",
" return \":\" * randint(1, 5)\n",
" if p < 0.9:\n",
" return Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\"]).address()\n",
" return fake.address()\n",
" return int(abs(random()) * 1_000_000_000)\n",
"\n",
"\n",
Expand All @@ -113,15 +111,17 @@
" if p < 0.1:\n",
" return \":\" * randint(1, 5)\n",
" if p < 0.9:\n",
" return Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\"]).name()\n",
" return fake.name()\n",
" return int(abs(random()) * 1_000_000_000)\n",
"\n",
"\n",
"def export_spans():\n",
"def export_spans(prob_drop_root):\n",
" \"\"\"Export spans in random order for receiver testing\"\"\"\n",
" spans = list(in_memory_span_exporter.get_finished_spans())\n",
" shuffle(spans)\n",
" for span in spans:\n",
" if span.parent is None and random() < prob_drop_root:\n",
" continue\n",
" otlp_span_exporter.export([span])\n",
" in_memory_span_exporter.clear()\n",
" session_count = len({id_ for span in spans if (id_ := span.attributes.get(\"session.id\"))})\n",
Expand All @@ -147,7 +147,7 @@
" return\n",
" has_yielded = False\n",
" with tracer.start_as_current_span(\n",
" Faker(\"ja_JP\").kana_name(),\n",
" fake.city(),\n",
" attributes=dict(rand_span_kind()),\n",
" end_on_exit=False,\n",
" ) as root:\n",
Expand Down Expand Up @@ -185,6 +185,7 @@
"source": [
"session_count = randint(5, 10)\n",
"tree_complexity = 4 # set to 0 for single span under root\n",
"prob_drop_root = 0.0 # probability that a root span gets dropped\n",
"\n",
"\n",
"def simulate_openai():\n",
Expand Down Expand Up @@ -237,7 +238,7 @@
" simulate_openai()\n",
"finally:\n",
" OpenAIInstrumentor().uninstrument()\n",
"spans = export_spans()\n",
"spans = export_spans(prob_drop_root)\n",
"\n",
"# Annotate root spans\n",
"root_span_ids = pd.Series(\n",
Expand Down
9 changes: 9 additions & 0 deletions src/phoenix/server/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@
MinStartOrMaxEndTimeDataLoader,
ProjectByNameDataLoader,
RecordCountDataLoader,
SessionIODataLoader,
SessionNumTracesDataLoader,
SessionTokenUsagesDataLoader,
SpanAnnotationsDataLoader,
SpanDatasetExamplesDataLoader,
SpanDescendantsDataLoader,
SpanProjectsDataLoader,
TokenCountDataLoader,
TraceByTraceIdsDataLoader,
TraceRootSpansDataLoader,
UserRolesDataLoader,
UsersDataLoader,
)
Expand Down Expand Up @@ -68,12 +72,17 @@ class DataLoaders:
latency_ms_quantile: LatencyMsQuantileDataLoader
min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader
record_counts: RecordCountDataLoader
session_first_inputs: SessionIODataLoader
session_last_outputs: SessionIODataLoader
session_num_traces: SessionNumTracesDataLoader
session_token_usages: SessionTokenUsagesDataLoader
span_annotations: SpanAnnotationsDataLoader
span_dataset_examples: SpanDatasetExamplesDataLoader
span_descendants: SpanDescendantsDataLoader
span_projects: SpanProjectsDataLoader
token_counts: TokenCountDataLoader
trace_by_trace_ids: TraceByTraceIdsDataLoader
trace_root_spans: TraceRootSpansDataLoader
project_by_name: ProjectByNameDataLoader
users: UsersDataLoader
user_roles: UserRolesDataLoader
Expand Down
8 changes: 8 additions & 0 deletions src/phoenix/server/api/dataloaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
from .min_start_or_max_end_times import MinStartOrMaxEndTimeCache, MinStartOrMaxEndTimeDataLoader
from .project_by_name import ProjectByNameDataLoader
from .record_counts import RecordCountCache, RecordCountDataLoader
from .session_io import SessionIODataLoader
from .session_num_traces import SessionNumTracesDataLoader
from .session_token_usages import SessionTokenUsagesDataLoader
from .span_annotations import SpanAnnotationsDataLoader
from .span_dataset_examples import SpanDatasetExamplesDataLoader
from .span_descendants import SpanDescendantsDataLoader
from .span_projects import SpanProjectsDataLoader
from .token_counts import TokenCountCache, TokenCountDataLoader
from .trace_by_trace_ids import TraceByTraceIdsDataLoader
from .trace_root_spans import TraceRootSpansDataLoader
from .user_roles import UserRolesDataLoader
from .users import UsersDataLoader

Expand All @@ -45,11 +49,15 @@
"LatencyMsQuantileDataLoader",
"MinStartOrMaxEndTimeDataLoader",
"RecordCountDataLoader",
"SessionIODataLoader",
"SessionNumTracesDataLoader",
"SessionTokenUsagesDataLoader",
"SpanDatasetExamplesDataLoader",
"SpanDescendantsDataLoader",
"SpanProjectsDataLoader",
"TokenCountDataLoader",
"TraceByTraceIdsDataLoader",
"TraceRootSpansDataLoader",
"ProjectByNameDataLoader",
"SpanAnnotationsDataLoader",
"UsersDataLoader",
Expand Down
75 changes: 75 additions & 0 deletions src/phoenix/server/api/dataloaders/session_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from functools import cached_property
from typing import Literal, Optional, cast

from openinference.semconv.trace import SpanAttributes
from sqlalchemy import Select, func, select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias, assert_never

from phoenix.db import models
from phoenix.server.types import DbSessionFactory
from phoenix.trace.schemas import MimeType, SpanIOValue

Key: TypeAlias = int
Result: TypeAlias = Optional[SpanIOValue]

Kind = Literal["first_input", "last_output"]


class SessionIODataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory, kind: Kind) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db
self._kind = kind

@cached_property
def _subq(self) -> Select[tuple[Optional[int], str, str, int]]:
stmt = (
select(models.Trace.project_session_rowid.label("id_"))
.join_from(models.Span, models.Trace)
.where(models.Span.parent_id.is_(None))
)
if self._kind == "first_input":
stmt = stmt.add_columns(
models.Span.attributes[INPUT_VALUE].label("value"),
models.Span.attributes[INPUT_MIME_TYPE].label("mime_type"),
func.row_number()
.over(
partition_by=models.Trace.project_session_rowid,
order_by=[models.Trace.start_time.asc(), models.Trace.id.asc()],
)
.label("rank"),
)
elif self._kind == "last_output":
stmt = stmt.add_columns(
models.Span.attributes[OUTPUT_VALUE].label("value"),
models.Span.attributes[OUTPUT_MIME_TYPE].label("mime_type"),
func.row_number()
.over(
partition_by=models.Trace.project_session_rowid,
order_by=[models.Trace.start_time.desc(), models.Trace.id.desc()],
)
.label("rank"),
)
else:
assert_never(self._kind)
return cast(Select[tuple[Optional[int], str, str, int]], stmt)

def _stmt(self, *keys: Key) -> Select[tuple[int, str, str]]:
subq = self._subq.where(models.Trace.project_session_rowid.in_(keys)).subquery()
return select(subq.c.id_, subq.c.value, subq.c.mime_type).filter_by(rank=1)

async def _load_fn(self, keys: list[Key]) -> list[Result]:
async with self._db() as session:
result: dict[Key, SpanIOValue] = {
id_: SpanIOValue(value=value, mime_type=MimeType(mime_type))
async for id_, value, mime_type in await session.stream(self._stmt(*keys))
if id_ is not None
}
return [result.get(key) for key in keys]


INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".")
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".")
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".")
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".")
30 changes: 30 additions & 0 deletions src/phoenix/server/api/dataloaders/session_num_traces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from sqlalchemy import func, select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory

Key: TypeAlias = int
Result: TypeAlias = int


class SessionNumTracesDataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: list[Key]) -> list[Result]:
stmt = (
select(
models.Trace.project_session_rowid.label("id_"),
func.count(models.Trace.id).label("value"),
)
.group_by(models.Trace.project_session_rowid)
.where(models.Trace.project_session_rowid.in_(keys))
)
async with self._db() as session:
result: dict[Key, int] = {
id_: value async for id_, value in await session.stream(stmt) if id_ is not None
}
return [result.get(key, 0) for key in keys]
41 changes: 41 additions & 0 deletions src/phoenix/server/api/dataloaders/session_token_usages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from sqlalchemy import func, select
from sqlalchemy.sql.functions import coalesce
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory
from phoenix.trace.schemas import TokenUsage

Key: TypeAlias = int
Result: TypeAlias = TokenUsage


class SessionTokenUsagesDataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: list[Key]) -> list[Result]:
stmt = (
select(
models.Trace.project_session_rowid.label("id_"),
func.sum(coalesce(models.Span.cumulative_llm_token_count_prompt, 0)).label(
"prompt"
),
func.sum(coalesce(models.Span.cumulative_llm_token_count_completion, 0)).label(
"completion"
),
)
.join_from(models.Span, models.Trace)
.where(models.Span.parent_id.is_(None))
.where(models.Trace.project_session_rowid.in_(keys))
.group_by(models.Trace.project_session_rowid)
)
async with self._db() as session:
result: dict[Key, TokenUsage] = {
id_: TokenUsage(prompt=prompt, completion=completion)
async for id_, prompt, completion in await session.stream(stmt)
if id_ is not None
}
return [result.get(key, TokenUsage()) for key in keys]
9 changes: 4 additions & 5 deletions src/phoenix/server/api/dataloaders/trace_by_trace_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from phoenix.db import models
from phoenix.server.types import DbSessionFactory

TraceId: TypeAlias = str
Key: TypeAlias = TraceId
TraceRowId: TypeAlias = int
ProjectRowId: TypeAlias = int
Key: TypeAlias = str
Result: TypeAlias = Optional[models.Trace]


Expand All @@ -22,5 +19,7 @@ def __init__(self, db: DbSessionFactory) -> None:
async def _load_fn(self, keys: List[Key]) -> List[Result]:
stmt = select(models.Trace).where(models.Trace.trace_id.in_(keys))
async with self._db() as session:
result = {trace.trace_id: trace for trace in await session.scalars(stmt)}
result: dict[Key, models.Trace] = {
trace.trace_id: trace async for trace in await session.stream_scalars(stmt)
}
return [result.get(trace_id) for trace_id in keys]
32 changes: 32 additions & 0 deletions src/phoenix/server/api/dataloaders/trace_root_spans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List, Optional

from sqlalchemy import select
from sqlalchemy.orm import contains_eager
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

from phoenix.db import models
from phoenix.server.types import DbSessionFactory

Key: TypeAlias = int
Result: TypeAlias = Optional[models.Span]


class TraceRootSpansDataLoader(DataLoader[Key, Result]):
def __init__(self, db: DbSessionFactory) -> None:
super().__init__(load_fn=self._load_fn)
self._db = db

async def _load_fn(self, keys: List[Key]) -> List[Result]:
stmt = (
select(models.Span)
.join(models.Trace)
.where(models.Span.parent_id.is_(None))
.where(models.Trace.id.in_(keys))
.options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id))
)
async with self._db() as session:
result: dict[Key, models.Span] = {
span.trace_rowid: span async for span in await session.stream_scalars(stmt)
}
return [result.get(key) for key in keys]
Loading
Loading