Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add db table for trace sessions #4961

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,4 @@ jobs:
- name: Set up `tox` with `tox-uv`
run: uv tool install tox --with tox-uv
- name: Run integration tests
run: tox run -e integration_tests -- -ra -x -n auto
run: tox run -e integration_tests -- -ra -x -n 10 --reruns 5
44 changes: 44 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,7 @@ type Project implements Node {
spanLatencyMsQuantile(probability: Float!, timeRange: TimeRange, filterCondition: String): Float
trace(traceId: ID!): Trace
spans(timeRange: TimeRange, first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection!
sessions(timeRange: TimeRange, first: Int = 50, after: String): ProjectSessionConnection!

"""
Names of all available annotations for traces. (The list contains no duplicates.)
Expand Down Expand Up @@ -1158,6 +1159,31 @@ type ProjectEdge {
node: Project!
}

type ProjectSession implements Node {
"""The Globally Unique ID of this object"""
id: GlobalID!
sessionId: String!
traces(first: Int = 50, last: Int, after: String, before: String): TraceConnection!
}

"""A connection to a list of items."""
type ProjectSessionConnection {
"""Pagination data for this connection"""
pageInfo: PageInfo!

"""Contains the nodes in this connection"""
edges: [ProjectSessionEdge!]!
}

"""An edge in a connection."""
type ProjectSessionEdge {
"""A cursor for use in pagination"""
cursor: String!

"""The item at the end of the edge"""
node: ProjectSession!
}

type PromptResponse {
"""The prompt submitted to the LLM"""
prompt: String
Expand Down Expand Up @@ -1537,6 +1563,24 @@ input TraceAnnotationSort {
dir: SortDir!
}

"""A connection to a list of items."""
type TraceConnection {
"""Pagination data for this connection"""
pageInfo: PageInfo!

"""Contains the nodes in this connection"""
edges: [TraceEdge!]!
}

"""An edge in a connection."""
type TraceEdge {
"""A cursor for use in pagination"""
cursor: String!

"""The item at the end of the edge"""
node: Trace!
}

type UMAPPoint {
id: GlobalID!

Expand Down
1 change: 1 addition & 0 deletions requirements/integration-tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ portpicker
psutil
pyjwt
pytest-randomly
pytest-rerunfailures
pytest-smtpd
types-beautifulsoup4
types-psutil
107 changes: 82 additions & 25 deletions src/phoenix/db/insertion/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,42 +28,99 @@ async def insert_span(
dialect = SupportedSQLDialect(session.bind.dialect.name)
if (
project_rowid := await session.scalar(
select(models.Project.id).where(models.Project.name == project_name)
select(models.Project.id).filter_by(name=project_name)
)
) is None:
project_rowid = await session.scalar(
insert(models.Project).values(dict(name=project_name)).returning(models.Project.id)
insert(models.Project).values(name=project_name).returning(models.Project.id)
)
assert project_rowid is not None
if trace := await session.scalar(
select(models.Trace).where(models.Trace.trace_id == span.context.trace_id)
):
trace_rowid = trace.id
if span.start_time < trace.start_time or trace.end_time < span.end_time:
trace_start_time = min(trace.start_time, span.start_time)
trace_end_time = max(trace.end_time, span.end_time)

project_session: Optional[models.ProjectSession] = None
session_id = get_attribute_value(span.attributes, SpanAttributes.SESSION_ID)
if session_id is not None and (not isinstance(session_id, str) or session_id.strip()):
session_id = str(session_id).strip()
assert isinstance(session_id, str)
project_session = await session.scalar(
select(models.ProjectSession).filter_by(session_id=session_id)
)
if project_session:
project_session_needs_update = False
project_session_end_time = None
project_session_project_id = None
if project_session.end_time < span.end_time:
project_session_needs_update = True
project_session_end_time = span.end_time
project_session_project_id = project_rowid
project_session_start_time = None
if span.start_time < project_session.start_time:
project_session_needs_update = True
project_session_start_time = span.start_time
if project_session_needs_update:
project_session = await session.scalar(
update(models.ProjectSession)
.filter_by(id=project_session.id)
.values(
start_time=project_session_start_time or project_session.start_time,
end_time=project_session_end_time or project_session.end_time,
project_id=project_session_project_id or project_session.project_id,
)
.returning(models.ProjectSession)
)
else:
project_session = await session.scalar(
insert(models.ProjectSession)
.values(
project_id=project_rowid,
session_id=session_id,
start_time=span.start_time,
end_time=span.end_time,
)
.returning(models.ProjectSession)
)

trace_id = span.context.trace_id
trace: Optional[models.Trace] = await session.scalar(
select(models.Trace).filter_by(trace_id=trace_id)
)
if trace:
trace_needs_update = False
trace_end_time = None
trace_project_rowid = None
trace_project_session_id = None
if trace.end_time < span.end_time:
trace_needs_update = True
trace_end_time = span.end_time
trace_project_rowid = project_rowid
trace_project_session_id = project_session.id if project_session else None
trace_start_time = None
if span.start_time < trace.start_time:
trace_needs_update = True
trace_start_time = span.start_time
if trace_needs_update:
await session.execute(
update(models.Trace)
.where(models.Trace.id == trace_rowid)
.filter_by(id=trace.id)
.values(
start_time=trace_start_time,
end_time=trace_end_time,
start_time=trace_start_time or trace.start_time,
end_time=trace_end_time or trace.end_time,
project_rowid=trace_project_rowid or trace.project_rowid,
project_session_id=trace_project_session_id or trace.project_session_id,
)
)
else:
trace_rowid = cast(
int,
await session.scalar(
insert(models.Trace)
.values(
project_rowid=project_rowid,
trace_id=span.context.trace_id,
start_time=span.start_time,
end_time=span.end_time,
)
.returning(models.Trace.id)
),
trace = await session.scalar(
insert(models.Trace)
.values(
project_rowid=project_rowid,
trace_id=span.context.trace_id,
start_time=span.start_time,
end_time=span.end_time,
project_session_id=project_session.id if project_session else None,
)
.returning(models.Trace)
)
assert trace is not None
cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR)
cumulative_llm_token_count_prompt = cast(
int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0
Expand Down Expand Up @@ -94,7 +151,7 @@ async def insert_span(
insert_on_conflict(
dict(
span_id=span.context.span_id,
trace_rowid=trace_rowid,
trace_rowid=trace.id,
parent_id=span.parent_id,
span_kind=span.span_kind.value,
name=span.name,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""create project_session table

Revision ID: 4ded9e43755f
Revises: cd164e83824f
Create Date: 2024-10-08 22:53:24.539786

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "4ded9e43755f"
down_revision: Union[str, None] = "cd164e83824f"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.create_table(
"project_sessions",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("session_id", sa.String, unique=True, nullable=False),
sa.Column(
"project_id",
sa.Integer,
sa.ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("start_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False),
sa.Column("end_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False),
)
with op.batch_alter_table("traces") as batch_op:
batch_op.add_column(
sa.Column(
"project_session_id",
sa.Integer,
sa.ForeignKey("project_sessions.id", ondelete="CASCADE"),
nullable=True,
),
)
op.create_index(
"ix_traces_project_session_id",
"traces",
["project_session_id"],
)


def downgrade() -> None:
op.drop_index("ix_traces_project_session_id")
with op.batch_alter_table("traces") as batch_op:
batch_op.drop_column("project_session_id")
op.drop_table("project_sessions")
27 changes: 27 additions & 0 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,24 @@ class Project(Base):
)


class ProjectSession(Base):
__tablename__ = "project_sessions"
id: Mapped[int] = mapped_column(primary_key=True)
session_id: Mapped[str] = mapped_column(String, nullable=False, unique=True)
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
project_id: Mapped[int] = mapped_column(
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
end_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
traces: Mapped[List["Trace"]] = relationship(
"Trace",
back_populates="project_session",
uselist=True,
)


class Trace(Base):
__tablename__ = "traces"
id: Mapped[int] = mapped_column(primary_key=True)
Expand All @@ -164,6 +182,11 @@ class Trace(Base):
index=True,
)
trace_id: Mapped[str]
project_session_id: Mapped[int] = mapped_column(
ForeignKey("project_sessions.id", ondelete="CASCADE"),
nullable=True,
index=True,
)
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)

Expand All @@ -188,6 +211,10 @@ def _latency_ms_expression(cls) -> ColumnElement[float]:
cascade="all, delete-orphan",
uselist=True,
)
project_session: Mapped[ProjectSession] = relationship(
"ProjectSession",
back_populates="traces",
)
experiment_runs: Mapped[List["ExperimentRun"]] = relationship(
primaryjoin="foreign(ExperimentRun.trace_id) == Trace.trace_id",
back_populates="trace",
Expand Down
9 changes: 9 additions & 0 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
connection_from_list,
)
from phoenix.server.api.types.Project import Project
from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
from phoenix.server.api.types.SortDir import SortDir
from phoenix.server.api.types.Span import Span, to_gql_span
from phoenix.server.api.types.SystemApiKey import SystemApiKey
Expand Down Expand Up @@ -476,6 +477,14 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
if span is None:
raise NotFound(f"Unknown span: {id}")
return to_gql_span(span)
elif type_name == ProjectSession.__name__:
async with info.context.db() as session:
project_session = await session.scalar(
select(models.ProjectSession).filter_by(id=node_id)
)
if project_session is None:
raise NotFound(f"Unknown project_session: {id}")
return to_gql_project_session(project_session)
elif type_name == Dataset.__name__:
dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id)
async with info.context.db() as session:
Expand Down
41 changes: 41 additions & 0 deletions src/phoenix/server/api/types/Project.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
CursorString,
connection_from_cursors_and_nodes,
)
from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session
from phoenix.server.api.types.SortDir import SortDir
from phoenix.server.api.types.Span import Span, to_gql_span
from phoenix.server.api.types.Trace import Trace
Expand Down Expand Up @@ -248,6 +249,46 @@ async def spans(
has_next_page=has_next_page,
)

@strawberry.field
async def sessions(
self,
info: Info[Context, None],
time_range: Optional[TimeRange] = UNSET,
first: Optional[int] = 50,
after: Optional[CursorString] = UNSET,
) -> Connection[ProjectSession]:
table = models.ProjectSession
stmt = select(table).filter_by(project_id=self.id_attr)
if time_range:
if time_range.start:
stmt = stmt.where(time_range.start <= table.start_time)
if time_range.end:
stmt = stmt.where(table.start_time < time_range.end)
if after:
cursor = Cursor.from_string(after)
stmt = stmt.where(table.id > cursor.rowid)
if first:
stmt = stmt.limit(
first + 1 # over-fetch by one to determine whether there's a next page
)
stmt = stmt.order_by(table.id)
cursors_and_nodes = []
async with info.context.db() as session:
records = await session.scalars(stmt)
async for project_session in islice(records, first):
cursor = Cursor(rowid=project_session.id)
cursors_and_nodes.append((cursor, to_gql_project_session(project_session)))
has_next_page = True
try:
next(records)
except StopIteration:
has_next_page = False
return connection_from_cursors_and_nodes(
cursors_and_nodes,
has_previous_page=False,
has_next_page=has_next_page,
)

@strawberry.field(
description="Names of all available annotations for traces. "
"(The list contains no duplicates.)"
Expand Down
Loading
Loading