From 295b432ba95fb6798d704928d497c20f297fd56c Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Thu, 17 Oct 2024 14:13:05 -0700 Subject: [PATCH] feat: add db table for sessions --- .github/workflows/python-CI.yml | 2 +- app/schema.graphql | 44 +++ requirements/integration-tests.txt | 1 + src/phoenix/db/insertion/span.py | 107 ++++-- ...4ded9e43755f_create_trace_session_table.py | 55 +++ src/phoenix/db/models.py | 27 ++ src/phoenix/server/api/queries.py | 9 + src/phoenix/server/api/types/Project.py | 41 +++ .../server/api/types/ProjectSession.py | 52 +++ src/phoenix/server/api/types/Trace.py | 8 + tests/integration/_helpers.py | 16 +- tests/integration/conftest.py | 45 ++- tests/integration/db_migrations/conftest.py | 43 ++- .../test_up_and_down_migrations.py | 319 ++++++++++++++++-- .../integration/project_sessions/__init__.py | 0 .../integration/project_sessions/conftest.py | 13 + .../project_sessions/test_project_sessions.py | 93 +++++ tests/integration/server/test_launch_app.py | 2 +- 18 files changed, 793 insertions(+), 84 deletions(-) create mode 100644 src/phoenix/db/migrations/versions/4ded9e43755f_create_trace_session_table.py create mode 100644 src/phoenix/server/api/types/ProjectSession.py create mode 100644 tests/integration/project_sessions/__init__.py create mode 100644 tests/integration/project_sessions/conftest.py create mode 100644 tests/integration/project_sessions/test_project_sessions.py diff --git a/.github/workflows/python-CI.yml b/.github/workflows/python-CI.yml index 4fbcc35ea2..a25edbd808 100644 --- a/.github/workflows/python-CI.yml +++ b/.github/workflows/python-CI.yml @@ -450,4 +450,4 @@ jobs: - name: Set up `tox` with `tox-uv` run: uv tool install tox --with tox-uv - name: Run integration tests - run: tox run -e integration_tests -- -ra -x -n auto + run: tox run -e integration_tests -- -ra -x -n 10 --reruns 5 diff --git a/app/schema.graphql b/app/schema.graphql index e42847c242..1e8b8e5967 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -1120,6 +1120,7 @@ type Project implements Node { spanLatencyMsQuantile(probability: Float!, timeRange: TimeRange, filterCondition: String): Float trace(traceId: ID!): Trace spans(timeRange: TimeRange, first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection! + sessions(timeRange: TimeRange, first: Int = 50, after: String): ProjectSessionConnection! """ Names of all available annotations for traces. (The list contains no duplicates.) @@ -1158,6 +1159,31 @@ type ProjectEdge { node: Project! } +type ProjectSession implements Node { + """The Globally Unique ID of this object""" + id: GlobalID! + sessionId: String! + traces(first: Int = 50, last: Int, after: String, before: String): TraceConnection! +} + +"""A connection to a list of items.""" +type ProjectSessionConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + + """Contains the nodes in this connection""" + edges: [ProjectSessionEdge!]! +} + +"""An edge in a connection.""" +type ProjectSessionEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: ProjectSession! +} + type PromptResponse { """The prompt submitted to the LLM""" prompt: String @@ -1537,6 +1563,24 @@ input TraceAnnotationSort { dir: SortDir! } +"""A connection to a list of items.""" +type TraceConnection { + """Pagination data for this connection""" + pageInfo: PageInfo! + + """Contains the nodes in this connection""" + edges: [TraceEdge!]! +} + +"""An edge in a connection.""" +type TraceEdge { + """A cursor for use in pagination""" + cursor: String! + + """The item at the end of the edge""" + node: Trace! +} + type UMAPPoint { id: GlobalID! diff --git a/requirements/integration-tests.txt b/requirements/integration-tests.txt index 95051ff827..a1b3207c31 100644 --- a/requirements/integration-tests.txt +++ b/requirements/integration-tests.txt @@ -8,6 +8,7 @@ portpicker psutil pyjwt pytest-randomly +pytest-rerunfailures pytest-smtpd types-beautifulsoup4 types-psutil diff --git a/src/phoenix/db/insertion/span.py b/src/phoenix/db/insertion/span.py index 21f46de1f0..1ca80b2b08 100644 --- a/src/phoenix/db/insertion/span.py +++ b/src/phoenix/db/insertion/span.py @@ -28,42 +28,99 @@ async def insert_span( dialect = SupportedSQLDialect(session.bind.dialect.name) if ( project_rowid := await session.scalar( - select(models.Project.id).where(models.Project.name == project_name) + select(models.Project.id).filter_by(name=project_name) ) ) is None: project_rowid = await session.scalar( - insert(models.Project).values(dict(name=project_name)).returning(models.Project.id) + insert(models.Project).values(name=project_name).returning(models.Project.id) ) assert project_rowid is not None - if trace := await session.scalar( - select(models.Trace).where(models.Trace.trace_id == span.context.trace_id) - ): - trace_rowid = trace.id - if span.start_time < trace.start_time or trace.end_time < span.end_time: - trace_start_time = min(trace.start_time, span.start_time) - trace_end_time = max(trace.end_time, span.end_time) + + project_session: Optional[models.ProjectSession] = None + session_id = get_attribute_value(span.attributes, SpanAttributes.SESSION_ID) + if session_id is not None and (not isinstance(session_id, str) or session_id.strip()): + session_id = str(session_id).strip() + assert isinstance(session_id, str) + project_session = await session.scalar( + select(models.ProjectSession).filter_by(session_id=session_id) + ) + if project_session: + project_session_needs_update = False + project_session_end_time = None + project_session_project_id = None + if project_session.end_time < span.end_time: + project_session_needs_update = True + project_session_end_time = span.end_time + project_session_project_id = project_rowid + project_session_start_time = None + if span.start_time < project_session.start_time: + project_session_needs_update = True + project_session_start_time = span.start_time + if project_session_needs_update: + project_session = await session.scalar( + update(models.ProjectSession) + .filter_by(id=project_session.id) + .values( + start_time=project_session_start_time or project_session.start_time, + end_time=project_session_end_time or project_session.end_time, + project_id=project_session_project_id or project_session.project_id, + ) + .returning(models.ProjectSession) + ) + else: + project_session = await session.scalar( + insert(models.ProjectSession) + .values( + project_id=project_rowid, + session_id=session_id, + start_time=span.start_time, + end_time=span.end_time, + ) + .returning(models.ProjectSession) + ) + + trace_id = span.context.trace_id + trace: Optional[models.Trace] = await session.scalar( + select(models.Trace).filter_by(trace_id=trace_id) + ) + if trace: + trace_needs_update = False + trace_end_time = None + trace_project_rowid = None + trace_project_session_id = None + if trace.end_time < span.end_time: + trace_needs_update = True + trace_end_time = span.end_time + trace_project_rowid = project_rowid + trace_project_session_id = project_session.id if project_session else None + trace_start_time = None + if span.start_time < trace.start_time: + trace_needs_update = True + trace_start_time = span.start_time + if trace_needs_update: await session.execute( update(models.Trace) - .where(models.Trace.id == trace_rowid) + .filter_by(id=trace.id) .values( - start_time=trace_start_time, - end_time=trace_end_time, + start_time=trace_start_time or trace.start_time, + end_time=trace_end_time or trace.end_time, + project_rowid=trace_project_rowid or trace.project_rowid, + project_session_id=trace_project_session_id or trace.project_session_id, ) ) else: - trace_rowid = cast( - int, - await session.scalar( - insert(models.Trace) - .values( - project_rowid=project_rowid, - trace_id=span.context.trace_id, - start_time=span.start_time, - end_time=span.end_time, - ) - .returning(models.Trace.id) - ), + trace = await session.scalar( + insert(models.Trace) + .values( + project_rowid=project_rowid, + trace_id=span.context.trace_id, + start_time=span.start_time, + end_time=span.end_time, + project_session_id=project_session.id if project_session else None, + ) + .returning(models.Trace) ) + assert trace is not None cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) cumulative_llm_token_count_prompt = cast( int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0 @@ -94,7 +151,7 @@ async def insert_span( insert_on_conflict( dict( span_id=span.context.span_id, - trace_rowid=trace_rowid, + trace_rowid=trace.id, parent_id=span.parent_id, span_kind=span.span_kind.value, name=span.name, diff --git a/src/phoenix/db/migrations/versions/4ded9e43755f_create_trace_session_table.py b/src/phoenix/db/migrations/versions/4ded9e43755f_create_trace_session_table.py new file mode 100644 index 0000000000..c49486d42c --- /dev/null +++ b/src/phoenix/db/migrations/versions/4ded9e43755f_create_trace_session_table.py @@ -0,0 +1,55 @@ +"""create project_session table + +Revision ID: 4ded9e43755f +Revises: cd164e83824f +Create Date: 2024-10-08 22:53:24.539786 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "4ded9e43755f" +down_revision: Union[str, None] = "cd164e83824f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "project_sessions", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("session_id", sa.String, unique=True, nullable=False), + sa.Column( + "project_id", + sa.Integer, + sa.ForeignKey("projects.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("start_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False), + sa.Column("end_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False), + ) + with op.batch_alter_table("traces") as batch_op: + batch_op.add_column( + sa.Column( + "project_session_id", + sa.Integer, + sa.ForeignKey("project_sessions.id", ondelete="CASCADE"), + nullable=True, + ), + ) + op.create_index( + "ix_traces_project_session_id", + "traces", + ["project_session_id"], + ) + + +def downgrade() -> None: + op.drop_index("ix_traces_project_session_id") + with op.batch_alter_table("traces") as batch_op: + batch_op.drop_column("project_session_id") + op.drop_table("project_sessions") diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 2adef3b19f..f3b29161d7 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -156,6 +156,24 @@ class Project(Base): ) +class ProjectSession(Base): + __tablename__ = "project_sessions" + id: Mapped[int] = mapped_column(primary_key=True) + session_id: Mapped[str] = mapped_column(String, nullable=False, unique=True) + project_id: Mapped[int] = mapped_column( + ForeignKey("projects.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True) + end_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True) + traces: Mapped[List["Trace"]] = relationship( + "Trace", + back_populates="project_session", + uselist=True, + ) + + class Trace(Base): __tablename__ = "traces" id: Mapped[int] = mapped_column(primary_key=True) @@ -164,6 +182,11 @@ class Trace(Base): index=True, ) trace_id: Mapped[str] + project_session_id: Mapped[int] = mapped_column( + ForeignKey("project_sessions.id", ondelete="CASCADE"), + nullable=True, + index=True, + ) start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True) end_time: Mapped[datetime] = mapped_column(UtcTimeStamp) @@ -188,6 +211,10 @@ def _latency_ms_expression(cls) -> ColumnElement[float]: cascade="all, delete-orphan", uselist=True, ) + project_session: Mapped[ProjectSession] = relationship( + "ProjectSession", + back_populates="traces", + ) experiment_runs: Mapped[List["ExperimentRun"]] = relationship( primaryjoin="foreign(ExperimentRun.trace_id) == Trace.trace_id", back_populates="trace", diff --git a/src/phoenix/server/api/queries.py b/src/phoenix/server/api/queries.py index fe09ca6f25..a29911d381 100644 --- a/src/phoenix/server/api/queries.py +++ b/src/phoenix/server/api/queries.py @@ -72,6 +72,7 @@ connection_from_list, ) from phoenix.server.api.types.Project import Project +from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session from phoenix.server.api.types.SortDir import SortDir from phoenix.server.api.types.Span import Span, to_gql_span from phoenix.server.api.types.SystemApiKey import SystemApiKey @@ -476,6 +477,14 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node: if span is None: raise NotFound(f"Unknown span: {id}") return to_gql_span(span) + elif type_name == ProjectSession.__name__: + async with info.context.db() as session: + project_session = await session.scalar( + select(models.ProjectSession).filter_by(id=node_id) + ) + if project_session is None: + raise NotFound(f"Unknown project_session: {id}") + return to_gql_project_session(project_session) elif type_name == Dataset.__name__: dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id) async with info.context.db() as session: diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 46426ba1c0..e6be2bf18e 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -30,6 +30,7 @@ CursorString, connection_from_cursors_and_nodes, ) +from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session from phoenix.server.api.types.SortDir import SortDir from phoenix.server.api.types.Span import Span, to_gql_span from phoenix.server.api.types.Trace import Trace @@ -248,6 +249,46 @@ async def spans( has_next_page=has_next_page, ) + @strawberry.field + async def sessions( + self, + info: Info[Context, None], + time_range: Optional[TimeRange] = UNSET, + first: Optional[int] = 50, + after: Optional[CursorString] = UNSET, + ) -> Connection[ProjectSession]: + table = models.ProjectSession + stmt = select(table).filter_by(project_id=self.id_attr) + if time_range: + if time_range.start: + stmt = stmt.where(time_range.start <= table.start_time) + if time_range.end: + stmt = stmt.where(table.start_time < time_range.end) + if after: + cursor = Cursor.from_string(after) + stmt = stmt.where(table.id > cursor.rowid) + if first: + stmt = stmt.limit( + first + 1 # over-fetch by one to determine whether there's a next page + ) + stmt = stmt.order_by(table.id) + cursors_and_nodes = [] + async with info.context.db() as session: + records = await session.scalars(stmt) + async for project_session in islice(records, first): + cursor = Cursor(rowid=project_session.id) + cursors_and_nodes.append((cursor, to_gql_project_session(project_session))) + has_next_page = True + try: + next(records) + except StopIteration: + has_next_page = False + return connection_from_cursors_and_nodes( + cursors_and_nodes, + has_previous_page=False, + has_next_page=has_next_page, + ) + @strawberry.field( description="Names of all available annotations for traces. " "(The list contains no duplicates.)" diff --git a/src/phoenix/server/api/types/ProjectSession.py b/src/phoenix/server/api/types/ProjectSession.py new file mode 100644 index 0000000000..436b0f494a --- /dev/null +++ b/src/phoenix/server/api/types/ProjectSession.py @@ -0,0 +1,52 @@ +from typing import Optional + +import strawberry +from sqlalchemy import desc, select +from strawberry import UNSET, Info +from strawberry.relay import Connection, Node, NodeID + +from phoenix.db import models +from phoenix.server.api.context import Context +from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list +from phoenix.server.api.types.Trace import Trace, to_gql_trace + + +@strawberry.type +class ProjectSession(Node): + id_attr: NodeID[int] + session_id: str + + @strawberry.field + async def traces( + self, + info: Info[Context, None], + first: Optional[int] = 50, + last: Optional[int] = UNSET, + after: Optional[CursorString] = UNSET, + before: Optional[CursorString] = UNSET, + ) -> Connection[Trace]: + args = ConnectionArgs( + first=first, + after=after if isinstance(after, CursorString) else None, + last=last, + before=before if isinstance(before, CursorString) else None, + ) + stmt = ( + select(models.Trace) + .filter_by(project_session_id=self.id_attr) + .order_by(desc(models.Trace.id)) + .limit(first) + ) + async with info.context.db() as session: + traces = await session.stream_scalars(stmt) + data = [to_gql_trace(trace) async for trace in traces] + return connection_from_list(data=data, args=args) + + +def to_gql_project_session( + project_session: models.ProjectSession, +) -> ProjectSession: + return ProjectSession( + id_attr=project_session.id, + session_id=project_session.session_id, + ) diff --git a/src/phoenix/server/api/types/Trace.py b/src/phoenix/server/api/types/Trace.py index a1a648cd8e..c22c9078d1 100644 --- a/src/phoenix/server/api/types/Trace.py +++ b/src/phoenix/server/api/types/Trace.py @@ -82,3 +82,11 @@ async def span_annotations( stmt = stmt.order_by(models.TraceAnnotation.created_at.desc()) annotations = await session.scalars(stmt) return [to_gql_trace_annotation(annotation) for annotation in annotations] + + +def to_gql_trace(trace: models.Trace) -> Trace: + return Trace( + id_attr=trace.id, + project_rowid=trace.project_rowid, + trace_id=trace.trace_id, + ) diff --git a/tests/integration/_helpers.py b/tests/integration/_helpers.py index 541b2356b7..b5c9b7d201 100644 --- a/tests/integration/_helpers.py +++ b/tests/integration/_helpers.py @@ -2,7 +2,6 @@ import os import re -import secrets import sys from abc import ABC, abstractmethod from contextlib import contextmanager, nullcontext @@ -12,6 +11,7 @@ from email.message import Message from functools import cached_property from io import BytesIO +from secrets import token_hex from subprocess import PIPE, STDOUT from threading import Lock, Thread from time import sleep, time @@ -423,14 +423,18 @@ def _get_tracer( def _start_span( *, - project_name: str, - span_name: str, exporter: SpanExporter, + project_name: Optional[str] = None, + span_name: Optional[str] = None, + attributes: Optional[Mapping[str, AttributeValue]] = None, ) -> Span: return _get_tracer( - project_name=project_name, + project_name=project_name or token_hex(16), exporter=exporter, - ).start_span(span_name) + ).start_span( + name=span_name or token_hex(16), + attributes=attributes, + ) class _DefaultAdminTokens(ABC): @@ -660,7 +664,7 @@ def _random_schema( engine = create_engine(url.set(drivername="postgresql+psycopg")) engine.connect().close() engine.dispose() - schema = f"_{secrets.token_hex(15)}" + schema = f"_{token_hex(15)}" yield schema time_limit = time() + 30 while time() < time_limit: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 3ab5a72c28..065a078255 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,7 +1,7 @@ import os from contextlib import ExitStack -from dataclasses import asdict from itertools import count, starmap +from secrets import token_hex from typing import Generator, Iterator, List, Optional, Tuple, cast from unittest import mock @@ -11,7 +11,6 @@ from faker import Faker from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter -from phoenix.auth import REQUIREMENTS_FOR_PHOENIX_SECRET from phoenix.config import ( ENV_PHOENIX_GRPC_PORT, ENV_PHOENIX_PORT, @@ -113,30 +112,40 @@ def _env_phoenix_sql_database_url( _fake: Faker, ) -> Iterator[None]: values = [(ENV_PHOENIX_SQL_DATABASE_URL, _sql_database_url.render_as_string())] + with mock.patch.dict(os.environ, values): + yield + + +@pytest.fixture(autouse=True, scope="module") +def _env_postgresql_schema( + _sql_database_url: URL, +) -> Iterator[None]: + if not _sql_database_url.get_backend_name().startswith("postgresql"): + yield + return with ExitStack() as stack: - if _sql_database_url.get_backend_name().startswith("postgresql"): - schema = stack.enter_context(_random_schema(_sql_database_url)) - values.append((ENV_PHOENIX_SQL_DATABASE_SCHEMA, schema)) + schema = stack.enter_context(_random_schema(_sql_database_url)) + values = [(ENV_PHOENIX_SQL_DATABASE_SCHEMA, schema)] stack.enter_context(mock.patch.dict(os.environ, values)) yield -@pytest.fixture(scope="module") -def _emails(_fake: Faker) -> Iterator[_Email]: - return (_fake.unique.email() for _ in count()) +@pytest.fixture +def _emails() -> Iterator[_Email]: + return (f"{token_hex(32)}@{token_hex(32)}.com" for _ in count()) -@pytest.fixture(scope="module") -def _passwords(_fake: Faker) -> Iterator[_Password]: - return (_fake.unique.password(**asdict(REQUIREMENTS_FOR_PHOENIX_SECRET)) for _ in count()) +@pytest.fixture +def _passwords() -> Iterator[_Password]: + return (token_hex(32) for _ in count()) -@pytest.fixture(scope="module") -def _usernames(_fake: Faker) -> Iterator[_Username]: - return (_fake.unique.pystr() for _ in count()) +@pytest.fixture +def _usernames() -> Iterator[_Username]: + return (token_hex(32) for _ in count()) -@pytest.fixture(scope="module") +@pytest.fixture def _profiles( _emails: Iterator[_Email], _passwords: Iterator[_Password], @@ -145,7 +154,7 @@ def _profiles( return starmap(_Profile, zip(_emails, _passwords, _usernames)) -@pytest.fixture(scope="module") +@pytest.fixture def _users( _profiles: Iterator[_Profile], ) -> _UserGenerator: @@ -160,7 +169,7 @@ def _() -> Generator[Optional[_User], Tuple[UserRoleInput, Optional[_Profile]], return cast(_UserGenerator, g) -@pytest.fixture(scope="module") +@pytest.fixture def _new_user( _users: _UserGenerator, ) -> _UserFactory: @@ -175,7 +184,7 @@ def _( return _ -@pytest.fixture(scope="module") +@pytest.fixture def _get_user( _new_user: _UserFactory, ) -> _GetUser: diff --git a/tests/integration/db_migrations/conftest.py b/tests/integration/db_migrations/conftest.py index 6244cf3f7d..08f2b43541 100644 --- a/tests/integration/db_migrations/conftest.py +++ b/tests/integration/db_migrations/conftest.py @@ -1,16 +1,20 @@ import os +from contextlib import ExitStack from pathlib import Path from secrets import token_hex from typing import Any, Iterator +from unittest import mock import phoenix import pytest import sqlean # type: ignore[import-untyped] from alembic.config import Config -from phoenix.config import ENV_PHOENIX_SQL_DATABASE_SCHEMA, ENV_PHOENIX_SQL_DATABASE_URL +from phoenix.config import ENV_PHOENIX_SQL_DATABASE_SCHEMA from phoenix.db.engines import set_postgresql_search_path from pytest import TempPathFactory -from sqlalchemy import Engine, NullPool, create_engine, event, make_url +from sqlalchemy import URL, Engine, NullPool, create_engine, event + +from integration._helpers import _random_schema @pytest.fixture @@ -21,33 +25,46 @@ def _alembic_config() -> Config: return cfg +@pytest.fixture(autouse=True, scope="function") +def _env_postgresql_schema( + _sql_database_url: URL, +) -> Iterator[None]: + if not _sql_database_url.get_backend_name().startswith("postgresql"): + yield + return + with ExitStack() as stack: + schema = stack.enter_context(_random_schema(_sql_database_url)) + values = [(ENV_PHOENIX_SQL_DATABASE_SCHEMA, schema)] + stack.enter_context(mock.patch.dict(os.environ, values)) + yield + + @pytest.fixture def _engine( - _env_phoenix_sql_database_url: Any, + _sql_database_url: URL, + _env_postgresql_schema: Any, tmp_path_factory: TempPathFactory, ) -> Iterator[Engine]: - url = make_url(os.environ[ENV_PHOENIX_SQL_DATABASE_URL]) - backend = url.get_backend_name() - if backend.startswith("sqlite"): + backend = _sql_database_url.get_backend_name() + if backend == "sqlite": tmp = tmp_path_factory.getbasetemp() / Path(__file__).parent.name tmp.mkdir(parents=True, exist_ok=True) file = tmp / f".{token_hex(16)}.db" - database = f"file:///{file}" engine = create_engine( - url=url.set(drivername="sqlite", database=database), - creator=lambda: sqlean.connect(database, uri=True), + url=_sql_database_url.set(database=str(file)), + creator=lambda: sqlean.connect(f"file:///{file}", uri=True), poolclass=NullPool, echo=True, ) - elif backend.startswith("postgresql"): - assert (schema := os.environ.get(ENV_PHOENIX_SQL_DATABASE_SCHEMA)) + elif backend == "postgresql": + schema = os.environ[ENV_PHOENIX_SQL_DATABASE_SCHEMA] engine = create_engine( - url=url.set(drivername="postgresql+psycopg"), + url=_sql_database_url.set(drivername="postgresql+psycopg"), poolclass=NullPool, echo=True, ) event.listen(engine, "connect", set_postgresql_search_path(schema)) else: - pytest.fail(f"Unknown database backend: {backend}") + pytest.fail(f"Unknown backend: {backend}") yield engine engine.dispose() diff --git a/tests/integration/db_migrations/test_up_and_down_migrations.py b/tests/integration/db_migrations/test_up_and_down_migrations.py index 77b1dd322f..e2178addf7 100644 --- a/tests/integration/db_migrations/test_up_and_down_migrations.py +++ b/tests/integration/db_migrations/test_up_and_down_migrations.py @@ -1,38 +1,317 @@ import os +from typing import Optional, Tuple import pytest from alembic import command from alembic.config import Config from phoenix.config import ENV_PHOENIX_SQL_DATABASE_SCHEMA -from sqlalchemy import Engine, text +from sqlalchemy import ( + INTEGER, + TIMESTAMP, + VARCHAR, + Engine, + ForeignKeyConstraint, + MetaData, + PrimaryKeyConstraint, + Row, + UniqueConstraint, + text, +) def test_up_and_down_migrations( - _alembic_config: Config, _engine: Engine, + _alembic_config: Config, ) -> None: - table = "alembic_version" - if _engine.url.get_backend_name().startswith("postgresql"): - schema = os.environ[ENV_PHOENIX_SQL_DATABASE_SCHEMA] - assert schema - table = f"{schema}.{table}" - stmt = text(f"SELECT version_num FROM {table}") - with _engine.connect() as conn: - with pytest.raises(BaseException, match=table): - conn.execute(stmt) - _engine.dispose() + with pytest.raises(BaseException, match="alembic_version"): + _version_num(_engine) + + for _ in range(2): + _up(_engine, _alembic_config, "cf03bd6bae1d") + _down(_engine, _alembic_config, "base") + _up(_engine, _alembic_config, "cf03bd6bae1d") + + for _ in range(2): + _up(_engine, _alembic_config, "10460e46d750") + _down(_engine, _alembic_config, "cf03bd6bae1d") + _up(_engine, _alembic_config, "10460e46d750") + + for _ in range(2): + _up(_engine, _alembic_config, "3be8647b87d8") + _down(_engine, _alembic_config, "10460e46d750") + _up(_engine, _alembic_config, "3be8647b87d8") + + for _ in range(2): + _up(_engine, _alembic_config, "cd164e83824f") + _down(_engine, _alembic_config, "3be8647b87d8") + _up(_engine, _alembic_config, "cd164e83824f") + + for _ in range(2): + _up(_engine, _alembic_config, "4ded9e43755f") + + metadata = MetaData() + metadata.reflect(bind=_engine) + + assert (project_sessions := metadata.tables.get("project_sessions")) is not None + + columns = {str(col.name): col for col in project_sessions.columns} + + column = columns.pop("id", None) + assert column is not None + assert column.primary_key + assert isinstance(column.type, INTEGER) + del column + + column = columns.pop("session_id", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, VARCHAR) + del column + + column = columns.pop("project_id", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, INTEGER) + del column + + column = columns.pop("start_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + + column = columns.pop("end_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + + assert not columns + del columns + + indexes = {str(idx.name): idx for idx in project_sessions.indexes} + + index = indexes.pop("ix_project_sessions_start_time", None) + assert index is not None + assert not index.unique + del index + + index = indexes.pop("ix_project_sessions_end_time", None) + assert index is not None + assert not index.unique + del index + + assert not indexes + del indexes + + constraints = {str(con.name): con for con in project_sessions.constraints} + + constraint = constraints.pop("pk_project_sessions", None) + assert constraint is not None + assert isinstance(constraint, PrimaryKeyConstraint) + del constraint + + constraint = constraints.pop("uq_project_sessions_session_id", None) + assert constraint is not None + assert isinstance(constraint, UniqueConstraint) + del constraint + + constraint = constraints.pop("fk_project_sessions_project_id_projects", None) + assert constraint is not None + assert isinstance(constraint, ForeignKeyConstraint) + assert constraint.ondelete == "CASCADE" + del constraint + + assert not constraints + del constraints + + assert (traces := metadata.tables.get("traces")) is not None + + columns = {str(col.name): col for col in traces.columns} + + column = columns.pop("id", None) + assert column is not None + assert column.primary_key + assert isinstance(column.type, INTEGER) + del column + + column = columns.pop("trace_id", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, VARCHAR) + del column + + column = columns.pop("project_rowid", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, INTEGER) + del column + + column = columns.pop("start_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + + column = columns.pop("end_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + + column = columns.pop("project_session_id", None) + assert column is not None + assert column.nullable + assert isinstance(column.type, INTEGER) + del column + + assert not columns + del columns + + indexes = {str(idx.name): idx for idx in traces.indexes} + + index = indexes.pop("ix_traces_project_rowid", None) + assert index is not None + assert not index.unique + del index + + index = indexes.pop("ix_traces_start_time", None) + assert index is not None + assert not index.unique + del index + + index = indexes.pop("ix_traces_project_session_id", None) + assert index is not None + assert not index.unique + del index + + assert not indexes + del indexes + + constraints = {str(con.name): con for con in traces.constraints} + + constraint = constraints.pop("pk_traces", None) + assert isinstance(constraint, PrimaryKeyConstraint) + del constraint + + constraint = constraints.pop("uq_traces_trace_id", None) + assert isinstance(constraint, UniqueConstraint) + del constraint + + constraint = constraints.pop("fk_traces_project_rowid_projects", None) + assert isinstance(constraint, ForeignKeyConstraint) + assert constraint.ondelete == "CASCADE" + del constraint + + constraint = constraints.pop("fk_traces_project_session_id_project_sessions", None) + assert isinstance(constraint, ForeignKeyConstraint) + assert constraint.ondelete == "CASCADE" + del constraint + + assert not constraints + del constraints + + _down(_engine, _alembic_config, "cd164e83824f") + + metadata = MetaData() + metadata.reflect(bind=_engine) + + assert metadata.tables.get("project_sessions") is None + + assert (traces := metadata.tables.get("traces")) is not None + + columns = {str(col.name): col for col in traces.columns} + + column = columns.pop("id", None) + assert column is not None + assert column.primary_key + assert isinstance(column.type, INTEGER) + del column + + column = columns.pop("trace_id", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, VARCHAR) + del column + + column = columns.pop("project_rowid", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, INTEGER) + del column + + column = columns.pop("start_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + + column = columns.pop("end_time", None) + assert column is not None + assert not column.nullable + assert isinstance(column.type, TIMESTAMP) + del column + + assert not columns + del columns + + indexes = {str(idx.name): idx for idx in traces.indexes} + + index = indexes.pop("ix_traces_project_rowid", None) + assert index is not None + assert not index.unique + del index + + index = indexes.pop("ix_traces_start_time", None) + assert index is not None + assert not index.unique + del index + + assert not indexes + del indexes + + constraints = {str(con.name): con for con in traces.constraints} + + constraint = constraints.pop("pk_traces", None) + assert isinstance(constraint, PrimaryKeyConstraint) + del constraint + + constraint = constraints.pop("uq_traces_trace_id", None) + assert isinstance(constraint, UniqueConstraint) + del constraint + + constraint = constraints.pop("fk_traces_project_rowid_projects", None) + assert isinstance(constraint, ForeignKeyConstraint) + assert constraint.ondelete == "CASCADE" + del constraint + + assert not constraints + del constraints + _up(_engine, _alembic_config, "4ded9e43755f") + + +def _up(_engine: Engine, _alembic_config: Config, revision: str) -> None: with _engine.connect() as conn: _alembic_config.attributes["connection"] = conn - command.upgrade(_alembic_config, "head") - _engine.dispose() - with _engine.connect() as conn: - version_num = conn.execute(stmt).first() - assert version_num == ("cd164e83824f",) + command.upgrade(_alembic_config, revision) _engine.dispose() + assert _version_num(_engine) == (revision,) + + +def _down(_engine: Engine, _alembic_config: Config, revision: str) -> None: with _engine.connect() as conn: _alembic_config.attributes["connection"] = conn - command.downgrade(_alembic_config, "base") + command.downgrade(_alembic_config, revision) _engine.dispose() + assert _version_num(_engine) == (None if revision == "base" else (revision,)) + + +def _version_num(_engine: Engine) -> Optional[Row[Tuple[str]]]: + schema_prefix = "" + if _engine.url.get_backend_name().startswith("postgresql"): + assert (schema := os.environ[ENV_PHOENIX_SQL_DATABASE_SCHEMA]) + schema_prefix = f"{schema}." + table, column = "alembic_version", "version_num" + stmt = text(f"SELECT {column} FROM {schema_prefix}{table}") with _engine.connect() as conn: - assert conn.execute(stmt).first() is None - _engine.dispose() + return conn.execute(stmt).first() diff --git a/tests/integration/project_sessions/__init__.py b/tests/integration/project_sessions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/project_sessions/conftest.py b/tests/integration/project_sessions/conftest.py new file mode 100644 index 0000000000..00c9e6bf14 --- /dev/null +++ b/tests/integration/project_sessions/conftest.py @@ -0,0 +1,13 @@ +from typing import Any, Iterator + +import pytest + +from .._helpers import _server + + +@pytest.fixture(autouse=True, scope="module") +def _app( + _env_phoenix_sql_database_url: Any, +) -> Iterator[None]: + with _server(): + yield diff --git a/tests/integration/project_sessions/test_project_sessions.py b/tests/integration/project_sessions/test_project_sessions.py new file mode 100644 index 0000000000..d13056dabd --- /dev/null +++ b/tests/integration/project_sessions/test_project_sessions.py @@ -0,0 +1,93 @@ +from contextlib import ExitStack +from secrets import token_hex +from time import sleep +from typing import List + +import pytest +from openinference.semconv.trace import SpanAttributes +from opentelemetry.trace import Span, use_span +from opentelemetry.util.types import AttributeValue +from pytest import param + +from .._helpers import _gql, _grpc_span_exporter, _start_span + + +class TestProjectSessions: + @pytest.mark.parametrize( + "session_id", + [ + param(0, id="integer"), + param(3.14, id="float"), + param(True, id="bool"), + param("abc", id="string"), + param(" a b c ", id="string with extra spaces"), + param(" ", id="empty string"), + param([1, 2], id="list of integers"), + param([1.1, 2.2], id="list of floats"), + param([True, False], id="list of bools"), + param(["a", "b"], id="list of strings"), + param([], id="empty list"), + ], + ) + def test_span_ingestion_with_session_id( + self, + session_id: AttributeValue, + ) -> None: + # remove extra whitespaces + str_session_id = str(session_id).strip() + num_traces, num_spans_per_trace = 2, 3 + assert num_traces > 1 and num_spans_per_trace > 2 + project_names = [token_hex(8)] + spans: List[Span] = [] + for _ in range(num_traces): + project_names.append(token_hex(8)) + with ExitStack() as stack: + for i in range(num_spans_per_trace): + if i == 0: + # Not all spans are required to have `session_id`. + attributes = None + elif i == 1: + # In case of conflict, the `Span` with later `end_time` wins. + attributes = {SpanAttributes.SESSION_ID: session_id} + elif str_session_id: + attributes = {SpanAttributes.SESSION_ID: token_hex(8)} + span = _start_span( + project_name=project_names[-1], + exporter=_grpc_span_exporter(), + attributes=attributes, + ) + spans.append(span) + stack.enter_context(use_span(span, end_on_exit=True)) + sleep(0.001) + assert len(spans) == num_traces * num_spans_per_trace + sleep(5) + res, *_ = _gql( + query="query{" + "projects(first: 1000){edges{node{name " + "sessions(first: 1000){edges{node{sessionId " + "traces(first: 1000){edges{node{" + "spans(first: 1000){edges{node{context{spanId}}}}}}}}}}}}}}" + ) + project_name = project_names[-1] + sessions_by_project = { + edge["node"]["name"]: { + session["node"]["sessionId"]: session + for session in edge["node"]["sessions"]["edges"] + } + for edge in res["data"]["projects"]["edges"] + if edge["node"]["name"] == project_name + } + sessions_by_id = sessions_by_project.get(project_name) + if not str_session_id: + assert not sessions_by_id + return + assert sessions_by_id + assert (session := sessions_by_id.get(str_session_id)) + assert (traces := [edge["node"] for edge in session["node"]["traces"]["edges"]]) + assert len(traces) == num_traces + gql_spans = [edge["node"] for trace in traces for edge in trace["spans"]["edges"]] + assert len(gql_spans) == len(spans) + expected_span_ids = { + span.get_span_context().span_id.to_bytes(8, "big").hex() for span in spans + } + assert {span["context"]["spanId"] for span in gql_spans} == expected_span_ids diff --git a/tests/integration/server/test_launch_app.py b/tests/integration/server/test_launch_app.py index b94453aa2e..065db44241 100644 --- a/tests/integration/server/test_launch_app.py +++ b/tests/integration/server/test_launch_app.py @@ -28,7 +28,7 @@ def test_send_spans(self, _fake: Faker) -> None: _start_span( project_name=project_name, span_name=span_name, - exporter=exporter(headers=None), + exporter=exporter(), ).end() sleep(2) project = _get_gql_spans(None, "name")[project_name]