From 83125e2a040c7e8bfcea72d4f3b1b39eff7bc6be Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Tue, 22 Oct 2024 13:53:43 -0700 Subject: [PATCH] feat: add db table for sessions (#4961) --- app/schema.graphql | 44 +++ src/phoenix/db/insertion/span.py | 107 ++++++-- ...ed9e43755f_create_project_session_table.py | 55 ++++ src/phoenix/db/models.py | 27 ++ src/phoenix/server/api/queries.py | 9 + src/phoenix/server/api/types/Project.py | 41 +++ .../server/api/types/ProjectSession.py | 52 ++++ src/phoenix/server/api/types/Trace.py | 8 + .../test_up_and_down_migrations.py | 256 +++++++++++++++++- .../integration/project_sessions/__init__.py | 0 .../integration/project_sessions/conftest.py | 13 + .../project_sessions/test_project_sessions.py | 95 +++++++ 12 files changed, 681 insertions(+), 26 deletions(-) create mode 100644 src/phoenix/db/migrations/versions/4ded9e43755f_create_project_session_table.py create mode 100644 src/phoenix/server/api/types/ProjectSession.py create mode 100644 tests/integration/project_sessions/__init__.py create mode 100644 tests/integration/project_sessions/conftest.py create mode 100644 tests/integration/project_sessions/test_project_sessions.py diff --git a/app/schema.graphql b/app/schema.graphql index cc712d2a50..370b8be31f 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -1132,6 +1132,7 @@ type Project implements Node { spanLatencyMsQuantile(probability: Float!, timeRange: TimeRange, filterCondition: String): Float trace(traceId: ID!): Trace spans(timeRange: TimeRange, first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection! + sessions(timeRange: TimeRange, first: Int = 50, after: String): ProjectSessionConnection! """ Names of all available annotations for traces. (The list contains no duplicates.) @@ -1170,6 +1171,31 @@ type ProjectEdge { node: Project! } +type ProjectSession implements Node { + """The Globally Unique ID of this object""" + id: GlobalID! + sessionId: String! + traces(first: Int = 50, last: Int, after: String, before: String): TraceConnection! +} + +"""A connection to a list of items.""" +type ProjectSessionConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + + """Contains the nodes in this connection""" + edges: [ProjectSessionEdge!]! +} + +"""An edge in a connection.""" +type ProjectSessionEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: ProjectSession! +} + type PromptResponse { """The prompt submitted to the LLM""" prompt: String @@ -1549,6 +1575,24 @@ input TraceAnnotationSort { dir: SortDir! } +"""A connection to a list of items.""" +type TraceConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + + """Contains the nodes in this connection""" + edges: [TraceEdge!]! +} + +"""An edge in a connection.""" +type TraceEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Trace! +} + type UMAPPoint { id: GlobalID! diff --git a/src/phoenix/db/insertion/span.py b/src/phoenix/db/insertion/span.py index 21f46de1f0..1ca80b2b08 100644 --- a/src/phoenix/db/insertion/span.py +++ b/src/phoenix/db/insertion/span.py @@ -28,42 +28,99 @@ async def insert_span( dialect = SupportedSQLDialect(session.bind.dialect.name) if ( project_rowid := await session.scalar( - select(models.Project.id).where(models.Project.name == project_name) + select(models.Project.id).filter_by(name=project_name) ) ) is None: project_rowid = await session.scalar( - insert(models.Project).values(dict(name=project_name)).returning(models.Project.id) + insert(models.Project).values(name=project_name).returning(models.Project.id) ) assert project_rowid is not None - if trace := await session.scalar( - select(models.Trace).where(models.Trace.trace_id == span.context.trace_id) - ): - trace_rowid = trace.id - if span.start_time < trace.start_time or trace.end_time < span.end_time: - trace_start_time = min(trace.start_time, span.start_time) - trace_end_time = max(trace.end_time, span.end_time) + + project_session: Optional[models.ProjectSession] = None + session_id = get_attribute_value(span.attributes, SpanAttributes.SESSION_ID) + if session_id is not None and (not isinstance(session_id, str) or session_id.strip()): + session_id = str(session_id).strip() + assert isinstance(session_id, str) + project_session = await session.scalar( + select(models.ProjectSession).filter_by(session_id=session_id) + ) + if project_session: + project_session_needs_update = False + project_session_end_time = None + project_session_project_id = None + if project_session.end_time < span.end_time: + project_session_needs_update = True + project_session_end_time = span.end_time + project_session_project_id = project_rowid + project_session_start_time = None + if span.start_time < project_session.start_time: + project_session_needs_update = True + project_session_start_time = span.start_time + if project_session_needs_update: + project_session = await session.scalar( + update(models.ProjectSession) + .filter_by(id=project_session.id) + .values( + start_time=project_session_start_time or project_session.start_time, + end_time=project_session_end_time or project_session.end_time, + project_id=project_session_project_id or project_session.project_id, + ) + .returning(models.ProjectSession) + ) + else: + project_session = await session.scalar( + insert(models.ProjectSession) + .values( + project_id=project_rowid, + session_id=session_id, + start_time=span.start_time, + end_time=span.end_time, + ) + .returning(models.ProjectSession) + ) + + trace_id = span.context.trace_id + trace: Optional[models.Trace] = await session.scalar( + select(models.Trace).filter_by(trace_id=trace_id) + ) + if trace: + trace_needs_update = False + trace_end_time = None + trace_project_rowid = None + trace_project_session_id = None + if trace.end_time < span.end_time: + trace_needs_update = True + trace_end_time = span.end_time + trace_project_rowid = project_rowid + trace_project_session_id = project_session.id if project_session else None + trace_start_time = None + if span.start_time < trace.start_time: + trace_needs_update = True + trace_start_time = span.start_time + if trace_needs_update: await session.execute( update(models.Trace) - .where(models.Trace.id == trace_rowid) + .filter_by(id=trace.id) .values( - start_time=trace_start_time, - end_time=trace_end_time, + start_time=trace_start_time or trace.start_time, + end_time=trace_end_time or trace.end_time, + project_rowid=trace_project_rowid or trace.project_rowid, + project_session_id=trace_project_session_id or trace.project_session_id, ) ) else: - trace_rowid = cast( - int, - await session.scalar( - insert(models.Trace) - .values( - project_rowid=project_rowid, - trace_id=span.context.trace_id, - start_time=span.start_time, - end_time=span.end_time, - ) - .returning(models.Trace.id) - ), + trace = await session.scalar( + insert(models.Trace) + .values( + project_rowid=project_rowid, + trace_id=span.context.trace_id, + start_time=span.start_time, + end_time=span.end_time, + project_session_id=project_session.id if project_session else None, + ) + .returning(models.Trace) ) + assert trace is not None cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) cumulative_llm_token_count_prompt = cast( int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0 @@ -94,7 +151,7 @@ async def insert_span( insert_on_conflict( dict( span_id=span.context.span_id, - trace_rowid=trace_rowid, + trace_rowid=trace.id, parent_id=span.parent_id, span_kind=span.span_kind.value, name=span.name, diff --git a/src/phoenix/db/migrations/versions/4ded9e43755f_create_project_session_table.py b/src/phoenix/db/migrations/versions/4ded9e43755f_create_project_session_table.py new file mode 100644 index 0000000000..c49486d42c --- /dev/null +++ b/src/phoenix/db/migrations/versions/4ded9e43755f_create_project_session_table.py @@ -0,0 +1,55 @@ +"""create project_session table + +Revision ID: 4ded9e43755f +Revises: cd164e83824f +Create Date: 2024-10-08 22:53:24.539786 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "4ded9e43755f" +down_revision: Union[str, None] = "cd164e83824f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "project_sessions", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("session_id", sa.String, unique=True, nullable=False), + sa.Column( + "project_id", + sa.Integer, + sa.ForeignKey("projects.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("start_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False), + sa.Column("end_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False), + ) + with op.batch_alter_table("traces") as batch_op: + batch_op.add_column( + sa.Column( + "project_session_id", + sa.Integer, + sa.ForeignKey("project_sessions.id", ondelete="CASCADE"), + nullable=True, + ), + ) + op.create_index( + "ix_traces_project_session_id", + "traces", + ["project_session_id"], + ) + + +def downgrade() -> None: + op.drop_index("ix_traces_project_session_id") + with op.batch_alter_table("traces") as batch_op: + batch_op.drop_column("project_session_id") + op.drop_table("project_sessions") diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 2adef3b19f..f3b29161d7 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -156,6 +156,24 @@ class Project(Base): ) +class ProjectSession(Base): + __tablename__ = "project_sessions" + id: Mapped[int] = mapped_column(primary_key=True) + session_id: Mapped[str] = mapped_column(String, nullable=False, unique=True) + project_id: Mapped[int] = mapped_column( + ForeignKey("projects.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True) + end_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True) + traces: Mapped[List["Trace"]] = relationship( + "Trace", + back_populates="project_session", + uselist=True, + ) + + class Trace(Base): __tablename__ = "traces" id: Mapped[int] = mapped_column(primary_key=True) @@ -164,6 +182,11 @@ class Trace(Base): index=True, ) trace_id: Mapped[str] + project_session_id: Mapped[int] = mapped_column( + ForeignKey("project_sessions.id", ondelete="CASCADE"), + nullable=True, + index=True, + ) start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True) end_time: Mapped[datetime] = mapped_column(UtcTimeStamp) @@ -188,6 +211,10 @@ def _latency_ms_expression(cls) -> ColumnElement[float]: cascade="all, delete-orphan", uselist=True, ) + project_session: Mapped[ProjectSession] = relationship( + "ProjectSession", + back_populates="traces", + ) experiment_runs: Mapped[List["ExperimentRun"]] = relationship( primaryjoin="foreign(ExperimentRun.trace_id) == Trace.trace_id", back_populates="trace", diff --git a/src/phoenix/server/api/queries.py b/src/phoenix/server/api/queries.py index fe09ca6f25..a29911d381 100644 --- a/src/phoenix/server/api/queries.py +++ b/src/phoenix/server/api/queries.py @@ -72,6 +72,7 @@ connection_from_list, ) from phoenix.server.api.types.Project import Project +from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session from phoenix.server.api.types.SortDir import SortDir from phoenix.server.api.types.Span import Span, to_gql_span from phoenix.server.api.types.SystemApiKey import SystemApiKey @@ -476,6 +477,14 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node: if span is None: raise NotFound(f"Unknown span: {id}") return to_gql_span(span) + elif type_name == ProjectSession.__name__: + async with info.context.db() as session: + project_session = await session.scalar( + select(models.ProjectSession).filter_by(id=node_id) + ) + if project_session is None: + raise NotFound(f"Unknown project_session: {id}") + return to_gql_project_session(project_session) elif type_name == Dataset.__name__: dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id) async with info.context.db() as session: diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 9fc1859bda..409d35b4cf 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -30,6 +30,7 @@ CursorString, connection_from_cursors_and_nodes, ) +from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session from phoenix.server.api.types.SortDir import SortDir from phoenix.server.api.types.Span import Span, to_gql_span from phoenix.server.api.types.Trace import Trace @@ -243,6 +244,46 @@ async def spans( has_next_page=has_next_page, ) + @strawberry.field + async def sessions( + self, + info: Info[Context, None], + time_range: Optional[TimeRange] = UNSET, + first: Optional[int] = 50, + after: Optional[CursorString] = UNSET, + ) -> Connection[ProjectSession]: + table = models.ProjectSession + stmt = select(table).filter_by(project_id=self.id_attr) + if time_range: + if time_range.start: + stmt = stmt.where(time_range.start <= table.start_time) + if time_range.end: + stmt = stmt.where(table.start_time < time_range.end) + if after: + cursor = Cursor.from_string(after) + stmt = stmt.where(table.id > cursor.rowid) + if first: + stmt = stmt.limit( + first + 1 # over-fetch by one to determine whether there's a next page + ) + stmt = stmt.order_by(table.id) + cursors_and_nodes = [] + async with info.context.db() as session: + records = await session.scalars(stmt) + async for project_session in islice(records, first): + cursor = Cursor(rowid=project_session.id) + cursors_and_nodes.append((cursor, to_gql_project_session(project_session))) + has_next_page = True + try: + next(records) + except StopIteration: + has_next_page = False + return connection_from_cursors_and_nodes( + cursors_and_nodes, + has_previous_page=False, + has_next_page=has_next_page, + ) + @strawberry.field( description="Names of all available annotations for traces. " "(The list contains no duplicates.)" diff --git a/src/phoenix/server/api/types/ProjectSession.py b/src/phoenix/server/api/types/ProjectSession.py new file mode 100644 index 0000000000..436b0f494a --- /dev/null +++ b/src/phoenix/server/api/types/ProjectSession.py @@ -0,0 +1,52 @@ +from typing import Optional + +import strawberry +from sqlalchemy import desc, select +from strawberry import UNSET, Info +from strawberry.relay import Connection, Node, NodeID + +from phoenix.db import models +from phoenix.server.api.context import Context +from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list +from phoenix.server.api.types.Trace import Trace, to_gql_trace + + +@strawberry.type +class ProjectSession(Node): + id_attr: NodeID[int] + session_id: str + + @strawberry.field + async def traces( + self, + info: Info[Context, None], + first: Optional[int] = 50, + last: Optional[int] = UNSET, + after: Optional[CursorString] = UNSET, + before: Optional[CursorString] = UNSET, + ) -> Connection[Trace]: + args = ConnectionArgs( + first=first, + after=after if isinstance(after, CursorString) else None, + last=last, + before=before if isinstance(before, CursorString) else None, + ) + stmt = ( + select(models.Trace) + .filter_by(project_session_id=self.id_attr) + .order_by(desc(models.Trace.id)) + .limit(first) + ) + async with info.context.db() as session: + traces = await session.stream_scalars(stmt) + data = [to_gql_trace(trace) async for trace in traces] + return connection_from_list(data=data, args=args) + + +def to_gql_project_session( + project_session: models.ProjectSession, +) -> ProjectSession: + return ProjectSession( + id_attr=project_session.id, + session_id=project_session.session_id, + ) diff --git a/src/phoenix/server/api/types/Trace.py b/src/phoenix/server/api/types/Trace.py index a1a648cd8e..c22c9078d1 100644 --- a/src/phoenix/server/api/types/Trace.py +++ b/src/phoenix/server/api/types/Trace.py @@ -82,3 +82,11 @@ async def span_annotations( stmt = stmt.order_by(models.TraceAnnotation.created_at.desc()) annotations = await session.scalars(stmt) return [to_gql_trace_annotation(annotation) for annotation in annotations] + + +def to_gql_trace(trace: models.Trace) -> Trace: + return Trace( + id_attr=trace.id, + project_rowid=trace.project_rowid, + trace_id=trace.trace_id, + ) diff --git a/tests/integration/db_migrations/test_up_and_down_migrations.py b/tests/integration/db_migrations/test_up_and_down_migrations.py index 3335d46761..e2178addf7 100644 --- a/tests/integration/db_migrations/test_up_and_down_migrations.py +++ b/tests/integration/db_migrations/test_up_and_down_migrations.py @@ -5,7 +5,18 @@ from alembic import command from alembic.config import Config from phoenix.config import ENV_PHOENIX_SQL_DATABASE_SCHEMA -from sqlalchemy import Engine, Row, text +from sqlalchemy import ( + INTEGER, + TIMESTAMP, + VARCHAR, + Engine, + ForeignKeyConstraint, + MetaData, + PrimaryKeyConstraint, + Row, + UniqueConstraint, + text, +) def test_up_and_down_migrations( @@ -35,6 +46,249 @@ def test_up_and_down_migrations( _down(_engine, _alembic_config, "3be8647b87d8") _up(_engine, _alembic_config, "cd164e83824f") + for _ in range(2): + _up(_engine, _alembic_config, "4ded9e43755f") + + metadata = MetaData() + metadata.reflect(bind=_engine) + + assert (project_sessions := metadata.tables.get("project_sessions")) is not None + + columns = {str(col.name): col for col in project_sessions.columns} + + column = columns.pop("id", None) + assert column is not None + assert column.primary_key + assert isinstance(column.type, INTEGER) + del column + + column = columns.pop("session_id", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, VARCHAR) + del column + + column = columns.pop("project_id", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, INTEGER) + del column + + column = columns.pop("start_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + + column = columns.pop("end_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + + assert not columns + del columns + + indexes = {str(idx.name): idx for idx in project_sessions.indexes} + + index = indexes.pop("ix_project_sessions_start_time", None) + assert index is not None + assert not index.unique + del index + + index = indexes.pop("ix_project_sessions_end_time", None) + assert index is not None + assert not index.unique + del index + + assert not indexes + del indexes + + constraints = {str(con.name): con for con in project_sessions.constraints} + + constraint = constraints.pop("pk_project_sessions", None) + assert constraint is not None + assert isinstance(constraint, PrimaryKeyConstraint) + del constraint + + constraint = constraints.pop("uq_project_sessions_session_id", None) + assert constraint is not None + assert isinstance(constraint, UniqueConstraint) + del constraint + + constraint = constraints.pop("fk_project_sessions_project_id_projects", None) + assert constraint is not None + assert isinstance(constraint, ForeignKeyConstraint) + assert constraint.ondelete == "CASCADE" + del constraint + + assert not constraints + del constraints + + assert (traces := metadata.tables.get("traces")) is not None + + columns = {str(col.name): col for col in traces.columns} + + column = columns.pop("id", None) + assert column is not None + assert column.primary_key + assert isinstance(column.type, INTEGER) + del column + + column = columns.pop("trace_id", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, VARCHAR) + del column + + column = columns.pop("project_rowid", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, INTEGER) + del column + + column = columns.pop("start_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + + column = columns.pop("end_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + + column = columns.pop("project_session_id", None) + assert column is not None + assert column.nullable + assert isinstance(column.type, INTEGER) + del column + + assert not columns + del columns + + indexes = {str(idx.name): idx for idx in traces.indexes} + + index = indexes.pop("ix_traces_project_rowid", None) + assert index is not None + assert not index.unique + del index + + index = indexes.pop("ix_traces_start_time", None) + assert index is not None + assert not index.unique + del index + + index = indexes.pop("ix_traces_project_session_id", None) + assert index is not None + assert not index.unique + del index + + assert not indexes + del indexes + + constraints = {str(con.name): con for con in traces.constraints} + + constraint = constraints.pop("pk_traces", None) + assert isinstance(constraint, PrimaryKeyConstraint) + del constraint + + constraint = constraints.pop("uq_traces_trace_id", None) + assert isinstance(constraint, UniqueConstraint) + del constraint + + constraint = constraints.pop("fk_traces_project_rowid_projects", None) + assert isinstance(constraint, ForeignKeyConstraint) + assert constraint.ondelete == "CASCADE" + del constraint + + constraint = constraints.pop("fk_traces_project_session_id_project_sessions", None) + assert isinstance(constraint, ForeignKeyConstraint) + assert constraint.ondelete == "CASCADE" + del constraint + + assert not constraints + del constraints + + _down(_engine, _alembic_config, "cd164e83824f") + + metadata = MetaData() + metadata.reflect(bind=_engine) + + assert metadata.tables.get("project_sessions") is None + + assert (traces := metadata.tables.get("traces")) is not None + + columns = {str(col.name): col for col in traces.columns} + + column = columns.pop("id", None) + assert column is not None + assert column.primary_key + assert isinstance(column.type, INTEGER) + del column + + column = columns.pop("trace_id", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, VARCHAR) + del column + + column = columns.pop("project_rowid", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, INTEGER) + del column + + column = columns.pop("start_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + + column = columns.pop("end_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + + assert not columns + del columns + + indexes = {str(idx.name): idx for idx in traces.indexes} + + index = indexes.pop("ix_traces_project_rowid", None) + assert index is not None + assert not index.unique + del index + + index = indexes.pop("ix_traces_start_time", None) + assert index is not None + assert not index.unique + del index + + assert not indexes + del indexes + + constraints = {str(con.name): con for con in traces.constraints} + + constraint = constraints.pop("pk_traces", None) + assert isinstance(constraint, PrimaryKeyConstraint) + del constraint + + constraint = constraints.pop("uq_traces_trace_id", None) + assert isinstance(constraint, UniqueConstraint) + del constraint + + constraint = constraints.pop("fk_traces_project_rowid_projects", None) + assert isinstance(constraint, ForeignKeyConstraint) + assert constraint.ondelete == "CASCADE" + del constraint + + assert not constraints + del constraints + _up(_engine, _alembic_config, "4ded9e43755f") + def _up(_engine: Engine, _alembic_config: Config, revision: str) -> None: with _engine.connect() as conn: diff --git a/tests/integration/project_sessions/__init__.py b/tests/integration/project_sessions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/project_sessions/conftest.py b/tests/integration/project_sessions/conftest.py new file mode 100644 index 0000000000..00c9e6bf14 --- /dev/null +++ b/tests/integration/project_sessions/conftest.py @@ -0,0 +1,13 @@ +from typing import Any, Iterator + +import pytest + +from .._helpers import _server + + +@pytest.fixture(autouse=True, scope="module") +def _app( + _env_phoenix_sql_database_url: Any, +) -> Iterator[None]: + with _server(): + yield diff --git a/tests/integration/project_sessions/test_project_sessions.py b/tests/integration/project_sessions/test_project_sessions.py new file mode 100644 index 0000000000..185063e4d8 --- /dev/null +++ b/tests/integration/project_sessions/test_project_sessions.py @@ -0,0 +1,95 @@ +import sys +from contextlib import ExitStack +from secrets import token_hex +from time import sleep +from typing import List + +import pytest +from openinference.semconv.trace import SpanAttributes +from opentelemetry.trace import Span, use_span +from opentelemetry.util.types import AttributeValue +from pytest import param + +from .._helpers import _gql, _grpc_span_exporter, _start_span + + +@pytest.mark.skipif(sys.platform == "win32", reason="FIXME: unclear why it fails") +class TestProjectSessions: + @pytest.mark.parametrize( + "session_id", + [ + param(0, id="integer"), + param(3.14, id="float"), + param(True, id="bool"), + param("abc", id="string"), + param(" a b c ", id="string with extra spaces"), + param(" ", id="empty string"), + param([1, 2], id="list of integers"), + param([1.1, 2.2], id="list of floats"), + param([True, False], id="list of bools"), + param(["a", "b"], id="list of strings"), + param([], id="empty list"), + ], + ) + def test_span_ingestion_with_session_id( + self, + session_id: AttributeValue, + ) -> None: + # remove extra whitespaces + str_session_id = str(session_id).strip() + num_traces, num_spans_per_trace = 2, 3 + assert num_traces > 1 and num_spans_per_trace > 2 + project_names = [token_hex(8)] + spans: List[Span] = [] + for _ in range(num_traces): + project_names.append(token_hex(8)) + with ExitStack() as stack: + for i in range(num_spans_per_trace): + if i == 0: + # Not all spans are required to have `session_id`. + attributes = None + elif i == 1: + # In case of conflict, the `Span` with later `end_time` wins. + attributes = {SpanAttributes.SESSION_ID: session_id} + elif str_session_id: + attributes = {SpanAttributes.SESSION_ID: token_hex(8)} + span = _start_span( + project_name=project_names[-1], + exporter=_grpc_span_exporter(), + attributes=attributes, + ) + spans.append(span) + stack.enter_context(use_span(span, end_on_exit=True)) + sleep(0.001) + assert len(spans) == num_traces * num_spans_per_trace + sleep(5) + res, *_ = _gql( + query="query{" + "projects(first: 1000){edges{node{name " + "sessions(first: 1000){edges{node{sessionId " + "traces(first: 1000){edges{node{" + "spans(first: 1000){edges{node{context{spanId}}}}}}}}}}}}}}" + ) + project_name = project_names[-1] + sessions_by_project = { + edge["node"]["name"]: { + session["node"]["sessionId"]: session + for session in edge["node"]["sessions"]["edges"] + } + for edge in res["data"]["projects"]["edges"] + if edge["node"]["name"] == project_name + } + sessions_by_id = sessions_by_project.get(project_name) + if not str_session_id: + assert not sessions_by_id + return + assert sessions_by_id + assert (session := sessions_by_id.get(str_session_id)) + assert (traces := [edge["node"] for edge in session["node"]["traces"]["edges"]]) + assert len(traces) == num_traces + gql_spans = [edge["node"] for trace in traces for edge in trace["spans"]["edges"]] + assert len(gql_spans) == len(spans) + expected_span_ids = { + span.get_span_context().span_id.to_bytes(8, "big").hex() for span in spans + } + assert {span["context"]["spanId"] for span in gql_spans} == expected_span_ids