diff --git a/app/schema.graphql b/app/schema.graphql index f2c2595cd6..02a2dc1c36 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -486,7 +486,6 @@ type Mutation { """ exportClusters(clusters: [ClusterInput!]!, fileName: String): ExportedFile! deleteProject(id: GlobalID!): Query! - archiveProject(id: GlobalID!): Query! } """A node in the graph with a globally unique ID""" diff --git a/src/phoenix/core/traces.py b/src/phoenix/core/traces.py index 0681e5479e..540c527e69 100644 --- a/src/phoenix/core/traces.py +++ b/src/phoenix/core/traces.py @@ -3,7 +3,7 @@ from queue import SimpleQueue from threading import RLock, Thread from types import MethodType -from typing import DefaultDict, Iterator, Optional, Tuple, Union +from typing import DefaultDict, Optional, Tuple, Union from typing_extensions import assert_never @@ -38,23 +38,6 @@ def get_project(self, project_name: str) -> Optional["Project"]: with self._lock: return self._projects.get(project_name) - def get_projects(self) -> Iterator[Tuple[int, str, "Project"]]: - with self._lock: - for project_id, (project_name, project) in enumerate(self._projects.items()): - if project.is_archived: - continue - yield project_id, project_name, project - - def archive_project(self, id: int) -> Optional["Project"]: - if id == 0: - raise ValueError("Cannot archive the default project") - with self._lock: - for project_id, _, project in self.get_projects(): - if id == project_id: - project.archive() - return project - return None - def put( self, item: Union[Span, pb.Evaluation], diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py index 023e852bfd..1f59b17e71 100644 --- a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py +++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py @@ -49,7 +49,13 @@ def upgrade() -> None: op.create_table( "traces", sa.Column("id", sa.Integer, primary_key=True), - sa.Column("project_rowid", sa.Integer, sa.ForeignKey("projects.id"), nullable=False), + sa.Column( + "project_rowid", + sa.Integer, + sa.ForeignKey("projects.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), # TODO(mikeldking): might not be the right place for this sa.Column("session_id", sa.String, nullable=True), sa.Column("trace_id", sa.String, nullable=False, unique=True), @@ -61,7 +67,13 @@ def upgrade() -> None: op.create_table( "spans", sa.Column("id", sa.Integer, primary_key=True), - sa.Column("trace_rowid", sa.Integer, sa.ForeignKey("traces.id"), nullable=False), + sa.Column( + "trace_rowid", + sa.Integer, + sa.ForeignKey("traces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), sa.Column("span_id", sa.String, nullable=False, unique=True), sa.Column("parent_span_id", sa.String, nullable=True, index=True), sa.Column("name", sa.String, nullable=False), @@ -89,7 +101,13 @@ def upgrade() -> None: op.create_table( "span_annotations", sa.Column("id", sa.Integer, primary_key=True), - sa.Column("span_rowid", sa.Integer, sa.ForeignKey("spans.id"), nullable=False), + sa.Column( + "span_rowid", + sa.Integer, + sa.ForeignKey("spans.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), sa.Column("name", sa.String, nullable=False), sa.Column("label", sa.String, nullable=True), sa.Column("score", sa.Float, nullable=True), @@ -128,7 +146,13 @@ def upgrade() -> None: op.create_table( "trace_annotations", sa.Column("id", sa.Integer, primary_key=True), - sa.Column("trace_rowid", sa.Integer, sa.ForeignKey("traces.id"), nullable=False), + sa.Column( + "trace_rowid", + sa.Integer, + sa.ForeignKey("traces.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), sa.Column("name", sa.String, nullable=False), sa.Column("label", sa.String, nullable=True), sa.Column("score", sa.Float, nullable=True), @@ -167,7 +191,13 @@ def upgrade() -> None: op.create_table( "document_annotations", sa.Column("id", sa.Integer, primary_key=True), - sa.Column("span_rowid", sa.Integer, sa.ForeignKey("spans.id"), nullable=False), + sa.Column( + "span_rowid", + sa.Integer, + sa.ForeignKey("spans.id", ondelete="CASCADE"), + nullable=False, + index=True, + ), sa.Column("document_index", sa.Integer, nullable=False), sa.Column("name", sa.String, nullable=False), sa.Column("label", sa.String, nullable=True), diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 41f159d0b0..ba06240934 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -93,10 +93,12 @@ class Project(Base): UtcTimeStamp, server_default=func.now(), onupdate=func.now() ) - traces: WriteOnlyMapped["Trace"] = relationship( + traces: WriteOnlyMapped[List["Trace"]] = relationship( "Trace", back_populates="project", cascade="all, delete-orphan", + passive_deletes=True, + uselist=True, ) __table_args__ = ( UniqueConstraint( @@ -110,7 +112,10 @@ class Project(Base): class Trace(Base): __tablename__ = "traces" id: Mapped[int] = mapped_column(primary_key=True) - project_rowid: Mapped[int] = mapped_column(ForeignKey("projects.id")) + project_rowid: Mapped[int] = mapped_column( + ForeignKey("projects.id", ondelete="CASCADE"), + index=True, + ) session_id: Mapped[Optional[str]] trace_id: Mapped[str] start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True) @@ -125,6 +130,7 @@ class Trace(Base): "Span", back_populates="trace", cascade="all, delete-orphan", + uselist=True, ) __table_args__ = ( UniqueConstraint( @@ -138,7 +144,10 @@ class Trace(Base): class Span(Base): __tablename__ = "spans" id: Mapped[int] = mapped_column(primary_key=True) - trace_rowid: Mapped[int] = mapped_column(ForeignKey("traces.id")) + trace_rowid: Mapped[int] = mapped_column( + ForeignKey("traces.id", ondelete="CASCADE"), + index=True, + ) span_id: Mapped[str] parent_span_id: Mapped[Optional[str]] = mapped_column(index=True) name: Mapped[str] @@ -183,7 +192,10 @@ async def init_models(engine: AsyncEngine) -> None: class SpanAnnotation(Base): __tablename__ = "span_annotations" id: Mapped[int] = mapped_column(primary_key=True) - span_rowid: Mapped[int] = mapped_column(ForeignKey("spans.id")) + span_rowid: Mapped[int] = mapped_column( + ForeignKey("spans.id", ondelete="CASCADE"), + index=True, + ) name: Mapped[str] label: Mapped[Optional[str]] score: Mapped[Optional[float]] @@ -209,7 +221,10 @@ class SpanAnnotation(Base): class TraceAnnotation(Base): __tablename__ = "trace_annotations" id: Mapped[int] = mapped_column(primary_key=True) - trace_rowid: Mapped[int] = mapped_column(ForeignKey("traces.id")) + trace_rowid: Mapped[int] = mapped_column( + ForeignKey("traces.id", ondelete="CASCADE"), + index=True, + ) name: Mapped[str] label: Mapped[Optional[str]] score: Mapped[Optional[float]] @@ -235,7 +250,10 @@ class TraceAnnotation(Base): class DocumentAnnotation(Base): __tablename__ = "document_annotations" id: Mapped[int] = mapped_column(primary_key=True) - span_rowid: Mapped[int] = mapped_column(ForeignKey("spans.id")) + span_rowid: Mapped[int] = mapped_column( + ForeignKey("spans.id", ondelete="CASCADE"), + index=True, + ) document_index: Mapped[int] name: Mapped[str] label: Mapped[Optional[str]] diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index e78d59eec5..a11835ffdc 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -16,7 +16,7 @@ @dataclass class DataLoaders: - latency_ms_quantile: DataLoader[Tuple[str, Optional[TimeRange], float], Optional[float]] + latency_ms_quantile: DataLoader[Tuple[int, Optional[TimeRange], float], Optional[float]] span_evaluations: DataLoader[int, List[SpanEvaluation]] document_evaluations: DataLoader[int, List[DocumentEvaluation]] diff --git a/src/phoenix/server/api/dataloaders/latency_ms_quantile.py b/src/phoenix/server/api/dataloaders/latency_ms_quantile.py index f8076187e2..c733ab516d 100644 --- a/src/phoenix/server/api/dataloaders/latency_ms_quantile.py +++ b/src/phoenix/server/api/dataloaders/latency_ms_quantile.py @@ -19,11 +19,11 @@ from phoenix.db import models from phoenix.server.api.input_types.TimeRange import TimeRange -ProjectName: TypeAlias = str +ProjectId: TypeAlias = int TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]] -Segment: TypeAlias = Tuple[ProjectName, TimeInterval] +Segment: TypeAlias = Tuple[ProjectId, TimeInterval] Probability: TypeAlias = float -Key: TypeAlias = Tuple[ProjectName, Optional[TimeRange], Probability] +Key: TypeAlias = Tuple[ProjectId, Optional[TimeRange], Probability] ResultPosition: TypeAlias = int QuantileValue: TypeAlias = float OrmExpression: TypeAlias = Any @@ -69,21 +69,21 @@ async def _load_fn(self, keys: List[Key]) -> List[Optional[QuantileValue]]: def _get_filter_condition(segment: Segment) -> OrmExpression: - name, (start_time, stop_time) = segment + id_, (start_time, stop_time) = segment if start_time and stop_time: return and_( - models.Project.name == name, + models.Project.id == id_, start_time <= models.Trace.start_time, models.Trace.start_time < stop_time, ) if start_time: return and_( - models.Project.name == name, + models.Project.id == id_, start_time <= models.Trace.start_time, ) if stop_time: return and_( - models.Project.name == name, + models.Project.id == id_, models.Trace.start_time < stop_time, ) - return models.Project.name == name + return models.Project.id == id_ diff --git a/src/phoenix/server/api/schema.py b/src/phoenix/server/api/schema.py index d3f524bef4..c75880ec98 100644 --- a/src/phoenix/server/api/schema.py +++ b/src/phoenix/server/api/schema.py @@ -4,11 +4,16 @@ import numpy as np import numpy.typing as npt import strawberry +from sqlalchemy import select +from sqlalchemy.orm import load_only from strawberry import ID, UNSET from strawberry.types import Info from typing_extensions import Annotated +from phoenix.config import DEFAULT_PROJECT_NAME +from phoenix.db import models from phoenix.pointcloud.clustering import Hdbscan +from phoenix.server.api.context import Context from phoenix.server.api.helpers import ensure_list from phoenix.server.api.input_types.ClusterInput import ClusterInput from phoenix.server.api.input_types.Coordinates import ( @@ -16,29 +21,37 @@ InputCoordinate3D, ) from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters -from phoenix.server.api.types.Project import Project - -from .context import Context -from .types.DatasetRole import AncillaryDatasetRole, DatasetRole -from .types.Dimension import to_gql_dimension -from .types.EmbeddingDimension import ( +from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole +from phoenix.server.api.types.Dimension import to_gql_dimension +from phoenix.server.api.types.EmbeddingDimension import ( DEFAULT_CLUSTER_SELECTION_EPSILON, DEFAULT_MIN_CLUSTER_SIZE, DEFAULT_MIN_SAMPLES, to_gql_embedding_dimension, ) -from .types.Event import create_event_id, unpack_event_id -from .types.ExportEventsMutation import ExportEventsMutation -from .types.Functionality import Functionality -from .types.Model import Model -from .types.node import GlobalID, Node, from_global_id, from_global_id_with_expected_type -from .types.pagination import Connection, ConnectionArgs, Cursor, connection_from_list +from phoenix.server.api.types.Event import create_event_id, unpack_event_id +from phoenix.server.api.types.ExportEventsMutation import ExportEventsMutation +from phoenix.server.api.types.Functionality import Functionality +from phoenix.server.api.types.Model import Model +from phoenix.server.api.types.node import ( + GlobalID, + Node, + from_global_id, + from_global_id_with_expected_type, +) +from phoenix.server.api.types.pagination import ( + Connection, + ConnectionArgs, + Cursor, + connection_from_list, +) +from phoenix.server.api.types.Project import Project @strawberry.type class Query: @strawberry.field - def projects( + async def projects( self, info: Info[Context, None], first: Optional[int] = 50, @@ -52,14 +65,16 @@ def projects( last=last, before=before if isinstance(before, Cursor) else None, ) - data = ( - [] - if (traces := info.context.traces) is None - else [ - Project(id_attr=project_id, name=project_name, project=project) - for project_id, project_name, project in traces.get_projects() - ] - ) + async with info.context.db() as session: + projects = await session.scalars(select(models.Project)) + data = [ + Project( + id_attr=project.id, + name=project.name, + project=info.context.traces.get_project(project.name), # type: ignore + ) + for project in projects + ] return connection_from_list(data=data, args=args) @strawberry.field @@ -76,7 +91,7 @@ def model(self) -> Model: return Model() @strawberry.field - def node(self, id: GlobalID, info: Info[Context, None]) -> Node: + async def node(self, id: GlobalID, info: Info[Context, None]) -> Node: type_name, node_id = from_global_id(str(id)) if type_name == "Dimension": dimension = info.context.model.scalar_dimensions[node_id] @@ -85,17 +100,18 @@ def node(self, id: GlobalID, info: Info[Context, None]) -> Node: embedding_dimension = info.context.model.embedding_dimensions[node_id] return to_gql_embedding_dimension(node_id, embedding_dimension) elif type_name == "Project": - if (traces := info.context.traces) is not None: - projects = { - project_id: (project_name, project) - for project_id, project_name, project in traces.get_projects() - } - if node_id in projects: - name, project = projects[node_id] - return Project(id_attr=node_id, name=name, project=project) - raise Exception(f"Unknown project: {id}") - - raise Exception(f"Unknown node type: {type}") + async with info.context.db() as session: + project = await session.scalar( + select(models.Project).where(models.Project.id == node_id) + ) + if project is None: + raise ValueError(f"Unknown project: {id}") + return Project( + id_attr=project.id, + name=project.name, + project=info.context.traces.get_project(project.name), # type: ignore + ) + raise Exception(f"Unknown node type: {type_name}") @strawberry.field def clusters( @@ -229,17 +245,19 @@ def hdbscan_clustering( @strawberry.type class Mutation(ExportEventsMutation): @strawberry.mutation - def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query: - if (traces := info.context.traces) is not None: - node_id = from_global_id_with_expected_type(str(id), "Project") - traces.archive_project(node_id) - return Query() - - @strawberry.mutation - def archive_project(self, info: Info[Context, None], id: GlobalID) -> Query: - if (traces := info.context.traces) is not None: - node_id = from_global_id_with_expected_type(str(id), "Project") - traces.archive_project(node_id) + async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query: + node_id = from_global_id_with_expected_type(str(id), "Project") + async with info.context.db() as session: + project = await session.scalar( + select(models.Project) + .where(models.Project.id == node_id) + .options(load_only(models.Project.name)) + ) + if project is None: + raise ValueError(f"Unknown project: {id}") + if project.name == DEFAULT_PROJECT_NAME: + raise ValueError(f"Cannot delete the {DEFAULT_PROJECT_NAME} project") + await session.delete(project) return Query() diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index c59a387569..c2be1090dd 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -29,7 +29,7 @@ from phoenix.server.api.types.Trace import Trace from phoenix.server.api.types.ValidationResult import ValidationResult from phoenix.trace.dsl import SpanFilter -from phoenix.trace.schemas import SpanID, TraceID +from phoenix.trace.schemas import SpanID @strawberry.type @@ -42,10 +42,8 @@ async def start_time( self, info: Info[Context, None], ) -> Optional[datetime]: - stmt = ( - select(func.min(models.Trace.start_time)) - .join(models.Project) - .where(models.Project.name == self.name) + stmt = select(func.min(models.Trace.start_time)).where( + models.Trace.project_rowid == self.id_attr ) async with info.context.db() as session: start_time = await session.scalar(stmt) @@ -57,10 +55,8 @@ async def end_time( self, info: Info[Context, None], ) -> Optional[datetime]: - stmt = ( - select(func.max(models.Trace.end_time)) - .join(models.Project) - .where(models.Project.name == self.name) + stmt = select(func.max(models.Trace.end_time)).where( + models.Trace.project_rowid == self.id_attr ) async with info.context.db() as session: end_time = await session.scalar(stmt) @@ -76,8 +72,7 @@ async def record_count( stmt = ( select(func.count(models.Span.id)) .join(models.Trace) - .join(models.Project) - .where(models.Project.name == self.name) + .where(models.Trace.project_rowid == self.id_attr) ) if time_range: stmt = stmt.where( @@ -95,11 +90,7 @@ async def trace_count( info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - stmt = ( - select(func.count(models.Trace.id)) - .join(models.Project) - .where(models.Project.name == self.name) - ) + stmt = select(func.count(models.Trace.id)).where(models.Trace.project_rowid == self.id_attr) if time_range: stmt = stmt.where( and_( @@ -121,8 +112,7 @@ async def token_count_total( stmt = ( select(coalesce(func.sum(prompt), 0) + coalesce(func.sum(completion), 0)) .join(models.Trace) - .join(models.Project) - .where(models.Project.name == self.name) + .where(models.Trace.project_rowid == self.id_attr) ) if time_range: stmt = stmt.where( @@ -142,14 +132,19 @@ async def latency_ms_quantile( time_range: Optional[TimeRange] = UNSET, ) -> Optional[float]: return await info.context.data_loaders.latency_ms_quantile.load( - (self.name, time_range, probability) + (self.id_attr, time_range, probability) ) @strawberry.field - def trace(self, trace_id: ID) -> Optional[Trace]: - if self.project.has_trace(TraceID(trace_id)): - return Trace(trace_id=trace_id, project=self.project) - return None + async def trace(self, info: Info[Context, None], trace_id: ID) -> Optional[Trace]: + async with info.context.db() as session: + if not await session.scalar( + select(models.Trace.id) + .where(models.Trace.trace_id == str(trace_id)) + .where(models.Trace.project_rowid == self.id_attr), + ): + return None + return Trace(trace_id=trace_id, project=self.project) @strawberry.field async def spans( @@ -173,8 +168,7 @@ async def spans( stmt = ( select(models.Span) .join(models.Trace) - .join(models.Project) - .where(models.Project.name == self.name) + .where(models.Trace.project_rowid == self.id_attr) .options(contains_eager(models.Span.trace)) ) if time_range: