Skip to content

Commit

Permalink
feat(sessions): add db table for sessions (#4961)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Oct 29, 2024
1 parent 661ec17 commit ba9391f
Show file tree
Hide file tree
Showing 12 changed files with 681 additions and 26 deletions.
44 changes: 44 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,7 @@ type Project implements Node {
spanLatencyMsQuantile(probability: Float!, timeRange: TimeRange, filterCondition: String): Float
trace(traceId: ID!): Trace
spans(timeRange: TimeRange, first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection!
sessions(timeRange: TimeRange, first: Int = 50, after: String): ProjectSessionConnection!

"""
Names of all available annotations for traces. (The list contains no duplicates.)
Expand Down Expand Up @@ -1174,6 +1175,31 @@ type ProjectEdge {
node: Project!
}

type ProjectSession implements Node {
"""The Globally Unique ID of this object"""
id: GlobalID!
sessionId: String!
traces(first: Int = 50, last: Int, after: String, before: String): TraceConnection!
}

"""A connection to a list of items."""
type ProjectSessionConnection {
"""Pagination data for this connection"""
pageInfo: PageInfo!

"""Contains the nodes in this connection"""
edges: [ProjectSessionEdge!]!
}

"""An edge in a connection."""
type ProjectSessionEdge {
"""A cursor for use in pagination"""
cursor: String!

"""The item at the end of the edge"""
node: ProjectSession!
}

type PromptResponse {
"""The prompt submitted to the LLM"""
prompt: String
Expand Down Expand Up @@ -1553,6 +1579,24 @@ input TraceAnnotationSort {
dir: SortDir!
}

"""A connection to a list of items."""
type TraceConnection {
"""Pagination data for this connection"""
pageInfo: PageInfo!

"""Contains the nodes in this connection"""
edges: [TraceEdge!]!
}

"""An edge in a connection."""
type TraceEdge {
"""A cursor for use in pagination"""
cursor: String!

"""The item at the end of the edge"""
node: Trace!
}

type UMAPPoint {
id: GlobalID!

Expand Down
107 changes: 82 additions & 25 deletions src/phoenix/db/insertion/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,42 +28,99 @@ async def insert_span(
dialect = SupportedSQLDialect(session.bind.dialect.name)
if (
project_rowid := await session.scalar(
select(models.Project.id).where(models.Project.name == project_name)
select(models.Project.id).filter_by(name=project_name)
)
) is None:
project_rowid = await session.scalar(
insert(models.Project).values(dict(name=project_name)).returning(models.Project.id)
insert(models.Project).values(name=project_name).returning(models.Project.id)
)
assert project_rowid is not None
if trace := await session.scalar(
select(models.Trace).where(models.Trace.trace_id == span.context.trace_id)
):
trace_rowid = trace.id
if span.start_time < trace.start_time or trace.end_time < span.end_time:
trace_start_time = min(trace.start_time, span.start_time)
trace_end_time = max(trace.end_time, span.end_time)

project_session: Optional[models.ProjectSession] = None
session_id = get_attribute_value(span.attributes, SpanAttributes.SESSION_ID)
if session_id is not None and (not isinstance(session_id, str) or session_id.strip()):
session_id = str(session_id).strip()
assert isinstance(session_id, str)
project_session = await session.scalar(
select(models.ProjectSession).filter_by(session_id=session_id)
)
if project_session:
project_session_needs_update = False
project_session_end_time = None
project_session_project_id = None
if project_session.end_time < span.end_time:
project_session_needs_update = True
project_session_end_time = span.end_time
project_session_project_id = project_rowid
project_session_start_time = None
if span.start_time < project_session.start_time:
project_session_needs_update = True
project_session_start_time = span.start_time
if project_session_needs_update:
project_session = await session.scalar(
update(models.ProjectSession)
.filter_by(id=project_session.id)
.values(
start_time=project_session_start_time or project_session.start_time,
end_time=project_session_end_time or project_session.end_time,
project_id=project_session_project_id or project_session.project_id,
)
.returning(models.ProjectSession)
)
else:
project_session = await session.scalar(
insert(models.ProjectSession)
.values(
project_id=project_rowid,
session_id=session_id,
start_time=span.start_time,
end_time=span.end_time,
)
.returning(models.ProjectSession)
)

trace_id = span.context.trace_id
trace: Optional[models.Trace] = await session.scalar(
select(models.Trace).filter_by(trace_id=trace_id)
)
if trace:
trace_needs_update = False
trace_end_time = None
trace_project_rowid = None
trace_project_session_id = None
if trace.end_time < span.end_time:
trace_needs_update = True
trace_end_time = span.end_time
trace_project_rowid = project_rowid
trace_project_session_id = project_session.id if project_session else None
trace_start_time = None
if span.start_time < trace.start_time:
trace_needs_update = True
trace_start_time = span.start_time
if trace_needs_update:
await session.execute(
update(models.Trace)
.where(models.Trace.id == trace_rowid)
.filter_by(id=trace.id)
.values(
start_time=trace_start_time,
end_time=trace_end_time,
start_time=trace_start_time or trace.start_time,
end_time=trace_end_time or trace.end_time,
project_rowid=trace_project_rowid or trace.project_rowid,
project_session_id=trace_project_session_id or trace.project_session_id,
)
)
else:
trace_rowid = cast(
int,
await session.scalar(
insert(models.Trace)
.values(
project_rowid=project_rowid,
trace_id=span.context.trace_id,
start_time=span.start_time,
end_time=span.end_time,
)
.returning(models.Trace.id)
),
trace = await session.scalar(
insert(models.Trace)
.values(
project_rowid=project_rowid,
trace_id=span.context.trace_id,
start_time=span.start_time,
end_time=span.end_time,
project_session_id=project_session.id if project_session else None,
)
.returning(models.Trace)
)
assert trace is not None
cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR)
cumulative_llm_token_count_prompt = cast(
int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0
Expand Down Expand Up @@ -94,7 +151,7 @@ async def insert_span(
insert_on_conflict(
dict(
span_id=span.context.span_id,
trace_rowid=trace_rowid,
trace_rowid=trace.id,
parent_id=span.parent_id,
span_kind=span.span_kind.value,
name=span.name,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""create project_session table
Revision ID: 4ded9e43755f
Revises: cd164e83824f
Create Date: 2024-10-08 22:53:24.539786
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "4ded9e43755f"
down_revision: Union[str, None] = "cd164e83824f"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.create_table(
"project_sessions",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("session_id", sa.String, unique=True, nullable=False),
sa.Column(
"project_id",
sa.Integer,
sa.ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("start_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False),
sa.Column("end_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False),
)
with op.batch_alter_table("traces") as batch_op:
batch_op.add_column(
sa.Column(
"project_session_id",
sa.Integer,
sa.ForeignKey("project_sessions.id", ondelete="CASCADE"),
nullable=True,
),
)
op.create_index(
"ix_traces_project_session_id",
"traces",
["project_session_id"],
)


def downgrade() -> None:
op.drop_index("ix_traces_project_session_id")
with op.batch_alter_table("traces") as batch_op:
batch_op.drop_column("project_session_id")
op.drop_table("project_sessions")
27 changes: 27 additions & 0 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,24 @@ class Project(Base):
)


class ProjectSession(Base):
__tablename__ = "project_sessions"
id: Mapped[int] = mapped_column(primary_key=True)
session_id: Mapped[str] = mapped_column(String, nullable=False, unique=True)
project_id: Mapped[int] = mapped_column(
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
end_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
traces: Mapped[list["Trace"]] = relationship(
"Trace",
back_populates="project_session",
uselist=True,
)


class Trace(Base):
__tablename__ = "traces"
id: Mapped[int] = mapped_column(primary_key=True)
Expand All @@ -164,6 +182,11 @@ class Trace(Base):
index=True,
)
trace_id: Mapped[str]
project_session_id: Mapped[int] = mapped_column(
ForeignKey("project_sessions.id", ondelete="CASCADE"),
nullable=True,
index=True,
)
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)

Expand All @@ -188,6 +211,10 @@ def _latency_ms_expression(cls) -> ColumnElement[float]:
cascade="all, delete-orphan",
uselist=True,
)
project_session: Mapped[ProjectSession] = relationship(
"ProjectSession",
back_populates="traces",
)
experiment_runs: Mapped[list["ExperimentRun"]] = relationship(
primaryjoin="foreign(ExperimentRun.trace_id) == Trace.trace_id",
back_populates="trace",
Expand Down
9 changes: 9 additions & 0 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
connection_from_list,
)
from phoenix.server.api.types.Project import Project
from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
from phoenix.server.api.types.SortDir import SortDir
from phoenix.server.api.types.Span import Span, to_gql_span
from phoenix.server.api.types.SystemApiKey import SystemApiKey
Expand Down Expand Up @@ -476,6 +477,14 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
if span is None:
raise NotFound(f"Unknown span: {id}")
return to_gql_span(span)
elif type_name == ProjectSession.__name__:
async with info.context.db() as session:
project_session = await session.scalar(
select(models.ProjectSession).filter_by(id=node_id)
)
if project_session is None:
raise NotFound(f"Unknown project_session: {id}")
return to_gql_project_session(project_session)
elif type_name == Dataset.__name__:
dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
async with info.context.db() as session:
Expand Down
41 changes: 41 additions & 0 deletions src/phoenix/server/api/types/Project.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CursorString,
connection_from_cursors_and_nodes,
)
from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
from phoenix.server.api.types.SortDir import SortDir
from phoenix.server.api.types.Span import Span, to_gql_span
from phoenix.server.api.types.Trace import Trace
Expand Down Expand Up @@ -241,6 +242,46 @@ async def spans(
has_next_page=has_next_page,
)

@strawberry.field
async def sessions(
self,
info: Info[Context, None],
time_range: Optional[TimeRange] = UNSET,
first: Optional[int] = 50,
after: Optional[CursorString] = UNSET,
) -> Connection[ProjectSession]:
table = models.ProjectSession
stmt = select(table).filter_by(project_id=self.id_attr)
if time_range:
if time_range.start:
stmt = stmt.where(time_range.start <= table.start_time)
if time_range.end:
stmt = stmt.where(table.start_time < time_range.end)
if after:
cursor = Cursor.from_string(after)
stmt = stmt.where(table.id > cursor.rowid)
if first:
stmt = stmt.limit(
first + 1 # over-fetch by one to determine whether there's a next page
)
stmt = stmt.order_by(table.id)
cursors_and_nodes = []
async with info.context.db() as session:
records = await session.scalars(stmt)
async for project_session in islice(records, first):
cursor = Cursor(rowid=project_session.id)
cursors_and_nodes.append((cursor, to_gql_project_session(project_session)))
has_next_page = True
try:
next(records)
except StopIteration:
has_next_page = False
return connection_from_cursors_and_nodes(
cursors_and_nodes,
has_previous_page=False,
has_next_page=has_next_page,
)

@strawberry.field(
description="Names of all available annotations for traces. "
"(The list contains no duplicates.)"
Expand Down
Loading

0 comments on commit ba9391f

Please sign in to comment.