Skip to content

Commit

Permalink
feat: add db table for trace sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Oct 14, 2024
1 parent 23e9d82 commit 7113eb7
Show file tree
Hide file tree
Showing 16 changed files with 579 additions and 62 deletions.
44 changes: 44 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,7 @@ type Project implements Node {
spanLatencyMsQuantile(probability: Float!, timeRange: TimeRange, filterCondition: String): Float
trace(traceId: ID!): Trace
spans(timeRange: TimeRange, first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection!
sessions(timeRange: TimeRange, first: Int = 50, last: Int, after: String, before: String): TraceSessionConnection!

"""
Names of all available annotations for traces. (The list contains no duplicates.)
Expand Down Expand Up @@ -1473,6 +1474,49 @@ input TraceAnnotationSort {
dir: SortDir!
}

"""A connection to a list of items."""
type TraceConnection {
"""Pagination data for this connection"""
pageInfo: PageInfo!

"""Contains the nodes in this connection"""
edges: [TraceEdge!]!
}

"""An edge in a connection."""
type TraceEdge {
"""A cursor for use in pagination"""
cursor: String!

"""The item at the end of the edge"""
node: Trace!
}

type TraceSession implements Node {
"""The Globally Unique ID of this object"""
id: GlobalID!
sessionId: String!
traces(first: Int = 50, last: Int, after: String, before: String): TraceConnection!
}

"""A connection to a list of items."""
type TraceSessionConnection {
"""Pagination data for this connection"""
pageInfo: PageInfo!

"""Contains the nodes in this connection"""
edges: [TraceSessionEdge!]!
}

"""An edge in a connection."""
type TraceSessionEdge {
"""A cursor for use in pagination"""
cursor: String!

"""The item at the end of the edge"""
node: TraceSession!
}

type UMAPPoint {
id: GlobalID!

Expand Down
73 changes: 54 additions & 19 deletions src/phoenix/db/insertion/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,42 +28,77 @@ async def insert_span(
dialect = SupportedSQLDialect(session.bind.dialect.name)
if (
project_rowid := await session.scalar(
select(models.Project.id).where(models.Project.name == project_name)
select(models.Project.id).filter_by(name=project_name)
)
) is None:
project_rowid = await session.scalar(
insert(models.Project).values(dict(name=project_name)).returning(models.Project.id)
insert(models.Project).values(name=project_name).returning(models.Project.id)
)
assert project_rowid is not None
if trace := await session.scalar(
select(models.Trace).where(models.Trace.trace_id == span.context.trace_id)
):

trace_session: Optional[models.TraceSession] = None
session_id = get_attribute_value(span.attributes, SpanAttributes.SESSION_ID)
if session_id:
session_id = str(session_id)
assert isinstance(session_id, str)
trace_session = await session.scalar(
select(models.TraceSession).filter_by(session_id=session_id)
)
if not trace_session:
trace_session = await session.scalar(
insert(models.TraceSession)
.values(
project_id=project_rowid,
session_id=session_id,
start_time=span.start_time,
end_time=span.end_time,
)
.returning(models.TraceSession)
)
elif span.start_time < trace_session.start_time or trace_session.end_time < span.end_time:
trace_session_start_time = min(trace_session.start_time, span.start_time)
trace_session_end_time = max(trace_session.end_time, span.end_time)
trace_session = await session.scalar(
update(models.TraceSession)
.where(models.TraceSession.id == trace_session.id)
.values(
start_time=trace_session_start_time,
end_time=trace_session_end_time,
)
.returning(models.TraceSession)
)

trace_id = span.context.trace_id
trace: Optional[models.Trace] = await session.scalar(
select(models.Trace).filter_by(trace_id=trace_id)
)
if trace:
trace_rowid = trace.id
if span.start_time < trace.start_time or trace.end_time < span.end_time:
trace_start_time = min(trace.start_time, span.start_time)
trace_end_time = max(trace.end_time, span.end_time)
await session.execute(
update(models.Trace)
.where(models.Trace.id == trace_rowid)
.filter_by(id=trace_rowid)
.values(
start_time=trace_start_time,
end_time=trace_end_time,
trace_session_id=trace_session.id if trace_session else trace.trace_session_id,
)
)
else:
trace_rowid = cast(
int,
await session.scalar(
insert(models.Trace)
.values(
project_rowid=project_rowid,
trace_id=span.context.trace_id,
start_time=span.start_time,
end_time=span.end_time,
)
.returning(models.Trace.id)
),
trace = await session.scalar(
insert(models.Trace)
.values(
project_rowid=project_rowid,
trace_id=span.context.trace_id,
start_time=span.start_time,
end_time=span.end_time,
trace_session_id=trace_session.id if trace_session else None,
)
.returning(models.Trace)
)
assert trace is not None
cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR)
cumulative_llm_token_count_prompt = cast(
int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0
Expand Down Expand Up @@ -94,7 +129,7 @@ async def insert_span(
insert_on_conflict(
dict(
span_id=span.context.span_id,
trace_rowid=trace_rowid,
trace_rowid=trace.id,
parent_id=span.parent_id,
span_kind=span.span_kind.value,
name=span.name,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""create trace_session table
Revision ID: 4ded9e43755f
Revises: cd164e83824f
Create Date: 2024-10-08 22:53:24.539786
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "4ded9e43755f"
down_revision: Union[str, None] = "cd164e83824f"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.create_table(
"trace_sessions",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("session_id", sa.String, unique=True),
sa.Column("project_id", sa.Integer, sa.ForeignKey("projects.id", ondelete="CASCADE")),
sa.Column("start_time", sa.TIMESTAMP(timezone=True), index=True),
sa.Column("end_time", sa.TIMESTAMP(timezone=True), index=True),
)
with op.batch_alter_table("traces") as batch_op:
batch_op.add_column(
sa.Column(
"trace_session_id",
sa.Integer,
sa.ForeignKey("trace_sessions.id", ondelete="CASCADE"),
nullable=True,
),
)
op.create_index(
"ix_traces_trace_session_id",
"traces",
["trace_session_id"],
)


def downgrade() -> None:
op.drop_index("ix_traces_trace_session_id")
with op.batch_alter_table("traces") as batch_op:
batch_op.drop_column("trace_session_id")
op.drop_table("trace_sessions")
27 changes: 27 additions & 0 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,24 @@ class Project(Base):
)


class TraceSession(Base):
__tablename__ = "trace_sessions"
id: Mapped[int] = mapped_column(primary_key=True)
session_id: Mapped[str] = mapped_column(String, nullable=False, unique=True)
project_id: Mapped[int] = mapped_column(
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
end_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
traces: Mapped[List["Trace"]] = relationship(
"Trace",
back_populates="trace_session",
uselist=True,
)


class Trace(Base):
__tablename__ = "traces"
id: Mapped[int] = mapped_column(primary_key=True)
Expand All @@ -164,6 +182,11 @@ class Trace(Base):
index=True,
)
trace_id: Mapped[str]
trace_session_id: Mapped[int] = mapped_column(
ForeignKey("trace_sessions.id", ondelete="CASCADE"),
nullable=True,
index=True,
)
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)

Expand All @@ -188,6 +211,10 @@ def _latency_ms_expression(cls) -> ColumnElement[float]:
cascade="all, delete-orphan",
uselist=True,
)
trace_session: Mapped[TraceSession] = relationship(
"TraceSession",
back_populates="traces",
)
experiment_runs: Mapped[List["ExperimentRun"]] = relationship(
primaryjoin="foreign(ExperimentRun.trace_id) == Trace.trace_id",
back_populates="trace",
Expand Down
9 changes: 9 additions & 0 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from phoenix.server.api.types.Span import Span, to_gql_span
from phoenix.server.api.types.SystemApiKey import SystemApiKey
from phoenix.server.api.types.Trace import Trace
from phoenix.server.api.types.TraceSession import TraceSession, to_gql_trace_session
from phoenix.server.api.types.User import User, to_gql_user
from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
from phoenix.server.api.types.UserRole import UserRole
Expand Down Expand Up @@ -446,6 +447,14 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
if span is None:
raise NotFound(f"Unknown span: {id}")
return to_gql_span(span)
elif type_name == TraceSession.__name__:
async with info.context.db() as session:
trace_session = await session.scalar(
select(models.TraceSession).filter_by(id=node_id)
)
if trace_session is None:
raise NotFound(f"Unknown trace_session: {id}")
return to_gql_trace_session(trace_session)
elif type_name == Dataset.__name__:
dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
async with info.context.db() as session:
Expand Down
38 changes: 38 additions & 0 deletions src/phoenix/server/api/types/Project.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from phoenix.server.api.types.SortDir import SortDir
from phoenix.server.api.types.Span import Span, to_gql_span
from phoenix.server.api.types.Trace import Trace
from phoenix.server.api.types.TraceSession import TraceSession, to_gql_trace_session
from phoenix.server.api.types.ValidationResult import ValidationResult
from phoenix.trace.dsl import SpanFilter

Expand Down Expand Up @@ -248,6 +249,43 @@ async def spans(
has_next_page=has_next_page,
)

@strawberry.field
async def sessions(
self,
info: Info[Context, None],
time_range: Optional[TimeRange] = UNSET,
first: Optional[int] = 50,
last: Optional[int] = UNSET,
after: Optional[CursorString] = UNSET,
before: Optional[CursorString] = UNSET,
) -> Connection[TraceSession]:
stmt = select(models.TraceSession).filter_by(project_id=self.id_attr)
cursor_rowid_column: Any = models.TraceSession.id
if after:
cursor = Cursor.from_string(after)
stmt = stmt.where(cursor_rowid_column > cursor.rowid)
if first:
stmt = stmt.limit(
first + 1 # over-fetch by one to determine whether there's a next page
)
stmt = stmt.order_by(cursor_rowid_column)
cursors_and_nodes = []
async with info.context.db() as session:
records = await session.scalars(stmt)
async for trace_session in islice(records, first):
cursor = Cursor(rowid=trace_session.id)
cursors_and_nodes.append((cursor, to_gql_trace_session(trace_session)))
has_next_page = True
try:
next(records)
except StopIteration:
has_next_page = False
return connection_from_cursors_and_nodes(
cursors_and_nodes,
has_previous_page=False,
has_next_page=has_next_page,
)

@strawberry.field(
description="Names of all available annotations for traces. "
"(The list contains no duplicates.)"
Expand Down
8 changes: 8 additions & 0 deletions src/phoenix/server/api/types/Trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,11 @@ async def span_annotations(
stmt = stmt.order_by(models.TraceAnnotation.created_at.desc())
annotations = await session.scalars(stmt)
return [to_gql_trace_annotation(annotation) for annotation in annotations]


def to_gql_trace(trace: models.Trace) -> Trace:
return Trace(
id_attr=trace.id,
project_rowid=trace.project_rowid,
trace_id=trace.trace_id,
)
Loading

0 comments on commit 7113eb7

Please sign in to comment.