Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(persistence): get or delete projects using sql #2839

Merged
merged 14 commits into from
Apr 12, 2024
6 changes: 3 additions & 3 deletions src/phoenix/core/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ def __init__(self) -> None:

def get_project(self, project_name: str) -> Optional["Project"]:
with self._lock:
return self._projects.get(project_name)
return self._projects.get(project_name, Project())
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved

def get_projects(self) -> Iterator[Tuple[int, str, "Project"]]:
with self._lock:
for project_id, (project_name, project) in enumerate(self._projects.items()):
if project.is_archived:
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
continue
yield project_id, project_name, project
yield project_id + 1, project_name, project
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved

def archive_project(self, id: int) -> Optional["Project"]:
if id == 0:
if id == 1:
raise ValueError("Cannot archive the default project")
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
with self._lock:
for project_id, _, project in self.get_projects():
Expand Down
35 changes: 30 additions & 5 deletions src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def upgrade() -> None:
op.create_table(
"traces",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("project_rowid", sa.Integer, sa.ForeignKey("projects.id"), nullable=False),
sa.Column(
"project_rowid",
sa.Integer,
sa.ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
),
# TODO(mikeldking): might not be the right place for this
sa.Column("session_id", sa.String, nullable=True),
sa.Column("trace_id", sa.String, nullable=False, unique=True),
Expand All @@ -54,7 +59,12 @@ def upgrade() -> None:
op.create_table(
"spans",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("trace_rowid", sa.Integer, sa.ForeignKey("traces.id"), nullable=False),
sa.Column(
"trace_rowid",
sa.Integer,
sa.ForeignKey("traces.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("span_id", sa.String, nullable=False, unique=True),
sa.Column("parent_span_id", sa.String, nullable=True, index=True),
sa.Column("name", sa.String, nullable=False),
Expand Down Expand Up @@ -82,7 +92,12 @@ def upgrade() -> None:
op.create_table(
"span_annotations",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("span_rowid", sa.Integer, sa.ForeignKey("spans.id"), nullable=False),
sa.Column(
"span_rowid",
sa.Integer,
sa.ForeignKey("spans.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("name", sa.String, nullable=False),
sa.Column("label", sa.String, nullable=True),
sa.Column("score", sa.Float, nullable=True),
Expand Down Expand Up @@ -121,7 +136,12 @@ def upgrade() -> None:
op.create_table(
"trace_annotations",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("trace_rowid", sa.Integer, sa.ForeignKey("traces.id"), nullable=False),
sa.Column(
"trace_rowid",
sa.Integer,
sa.ForeignKey("traces.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("name", sa.String, nullable=False),
sa.Column("label", sa.String, nullable=True),
sa.Column("score", sa.Float, nullable=True),
Expand Down Expand Up @@ -160,7 +180,12 @@ def upgrade() -> None:
op.create_table(
"document_annotations",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("span_rowid", sa.Integer, sa.ForeignKey("spans.id"), nullable=False),
sa.Column(
"span_rowid",
sa.Integer,
sa.ForeignKey("spans.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("document_index", sa.Integer, nullable=False),
sa.Column("name", sa.String, nullable=False),
sa.Column("label", sa.String, nullable=True),
Expand Down
15 changes: 9 additions & 6 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ class Project(Base):
UtcTimeStamp, server_default=func.now(), onupdate=func.now()
)

traces: WriteOnlyMapped["Trace"] = relationship(
traces: WriteOnlyMapped[List["Trace"]] = relationship(
"Trace",
back_populates="project",
cascade="all, delete-orphan",
passive_deletes=True,
uselist=True,
)
__table_args__ = (
UniqueConstraint(
Expand All @@ -104,7 +106,7 @@ class Project(Base):
class Trace(Base):
__tablename__ = "traces"
id: Mapped[int] = mapped_column(primary_key=True)
project_rowid: Mapped[int] = mapped_column(ForeignKey("projects.id"))
project_rowid: Mapped[int] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE"))
session_id: Mapped[Optional[str]]
trace_id: Mapped[str]
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
Expand All @@ -119,6 +121,7 @@ class Trace(Base):
"Span",
back_populates="trace",
cascade="all, delete-orphan",
uselist=True,
)
__table_args__ = (
UniqueConstraint(
Expand All @@ -132,7 +135,7 @@ class Trace(Base):
class Span(Base):
__tablename__ = "spans"
id: Mapped[int] = mapped_column(primary_key=True)
trace_rowid: Mapped[int] = mapped_column(ForeignKey("traces.id"))
trace_rowid: Mapped[int] = mapped_column(ForeignKey("traces.id", ondelete="CASCADE"))
span_id: Mapped[str]
parent_span_id: Mapped[Optional[str]] = mapped_column(index=True)
name: Mapped[str]
Expand Down Expand Up @@ -177,7 +180,7 @@ async def init_models(engine: AsyncEngine) -> None:
class SpanAnnotation(Base):
__tablename__ = "span_annotations"
id: Mapped[int] = mapped_column(primary_key=True)
span_rowid: Mapped[int] = mapped_column(ForeignKey("spans.id"))
span_rowid: Mapped[int] = mapped_column(ForeignKey("spans.id", ondelete="CASCADE"))
name: Mapped[str]
label: Mapped[Optional[str]]
score: Mapped[Optional[float]]
Expand All @@ -203,7 +206,7 @@ class SpanAnnotation(Base):
class TraceAnnotation(Base):
__tablename__ = "trace_annotations"
id: Mapped[int] = mapped_column(primary_key=True)
trace_rowid: Mapped[int] = mapped_column(ForeignKey("traces.id"))
trace_rowid: Mapped[int] = mapped_column(ForeignKey("traces.id", ondelete="CASCADE"))
name: Mapped[str]
label: Mapped[Optional[str]]
score: Mapped[Optional[float]]
Expand All @@ -229,7 +232,7 @@ class TraceAnnotation(Base):
class DocumentAnnotation(Base):
__tablename__ = "document_annotations"
id: Mapped[int] = mapped_column(primary_key=True)
span_rowid: Mapped[int] = mapped_column(ForeignKey("spans.id"))
span_rowid: Mapped[int] = mapped_column(ForeignKey("spans.id", ondelete="CASCADE"))
document_index: Mapped[int]
name: Mapped[str]
label: Mapped[Optional[str]]
Expand Down
99 changes: 62 additions & 37 deletions src/phoenix/server/api/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,54 @@
import numpy as np
import numpy.typing as npt
import strawberry
from sqlalchemy import select
from sqlalchemy.orm import load_only
from strawberry import ID, UNSET
from strawberry.types import Info
from typing_extensions import Annotated

from phoenix.config import DEFAULT_PROJECT_NAME
from phoenix.db import models
from phoenix.pointcloud.clustering import Hdbscan
from phoenix.server.api.context import Context
from phoenix.server.api.helpers import ensure_list
from phoenix.server.api.input_types.ClusterInput import ClusterInput
from phoenix.server.api.input_types.Coordinates import (
InputCoordinate2D,
InputCoordinate3D,
)
from phoenix.server.api.types.Cluster import Cluster, to_gql_clusters
from phoenix.server.api.types.Project import Project

from .context import Context
from .types.DatasetRole import AncillaryDatasetRole, DatasetRole
from .types.Dimension import to_gql_dimension
from .types.EmbeddingDimension import (
from phoenix.server.api.types.DatasetRole import AncillaryDatasetRole, DatasetRole
from phoenix.server.api.types.Dimension import to_gql_dimension
from phoenix.server.api.types.EmbeddingDimension import (
DEFAULT_CLUSTER_SELECTION_EPSILON,
DEFAULT_MIN_CLUSTER_SIZE,
DEFAULT_MIN_SAMPLES,
to_gql_embedding_dimension,
)
from .types.Event import create_event_id, unpack_event_id
from .types.ExportEventsMutation import ExportEventsMutation
from .types.Functionality import Functionality
from .types.Model import Model
from .types.node import GlobalID, Node, from_global_id, from_global_id_with_expected_type
from .types.pagination import Connection, ConnectionArgs, Cursor, connection_from_list
from phoenix.server.api.types.Event import create_event_id, unpack_event_id
from phoenix.server.api.types.ExportEventsMutation import ExportEventsMutation
from phoenix.server.api.types.Functionality import Functionality
from phoenix.server.api.types.Model import Model
from phoenix.server.api.types.node import (
GlobalID,
Node,
from_global_id,
from_global_id_with_expected_type,
)
from phoenix.server.api.types.pagination import (
Connection,
ConnectionArgs,
Cursor,
connection_from_list,
)
from phoenix.server.api.types.Project import Project


@strawberry.type
class Query:
@strawberry.field
def projects(
async def projects(
self,
info: Info[Context, None],
first: Optional[int] = 50,
Expand All @@ -52,14 +65,16 @@ def projects(
last=last,
before=before if isinstance(before, Cursor) else None,
)
data = (
[]
if (traces := info.context.traces) is None
else [
Project(id_attr=project_id, name=project_name, project=project)
for project_id, project_name, project in traces.get_projects()
]
)
async with info.context.db() as session:
projects = await session.scalars(select(models.Project))
data = [
Project(
id_attr=project.id,
name=project.name,
project=info.context.traces.get_project(project.name), # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a follow-up, it might make sense to key the core projects by ID rather than name.

Copy link
Contributor Author

@RogerHYang RogerHYang Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the (in-memory) core projects will be deleted after migration, so this is moot.

)
for project in projects
]
return connection_from_list(data=data, args=args)

@strawberry.field
Expand All @@ -76,7 +91,7 @@ def model(self) -> Model:
return Model()

@strawberry.field
def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
type_name, node_id = from_global_id(str(id))
if type_name == "Dimension":
dimension = info.context.model.scalar_dimensions[node_id]
Expand All @@ -85,17 +100,18 @@ def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
embedding_dimension = info.context.model.embedding_dimensions[node_id]
return to_gql_embedding_dimension(node_id, embedding_dimension)
elif type_name == "Project":
if (traces := info.context.traces) is not None:
projects = {
project_id: (project_name, project)
for project_id, project_name, project in traces.get_projects()
}
if node_id in projects:
name, project = projects[node_id]
return Project(id_attr=node_id, name=name, project=project)
raise Exception(f"Unknown project: {id}")

raise Exception(f"Unknown node type: {type}")
async with info.context.db() as session:
project = await session.scalar(
select(models.Project).where(models.Project.id == node_id)
)
if project is None:
raise ValueError(f"Unknown project: {id}")
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
return Project(
id_attr=project.id,
name=project.name,
project=info.context.traces.get_project(project.name), # type: ignore
)
raise Exception(f"Unknown node type: {type_name}")

@strawberry.field
def clusters(
Expand Down Expand Up @@ -229,10 +245,19 @@ def hdbscan_clustering(
@strawberry.type
class Mutation(ExportEventsMutation):
@strawberry.mutation
def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
if (traces := info.context.traces) is not None:
node_id = from_global_id_with_expected_type(str(id), "Project")
traces.archive_project(node_id)
async def delete_project(self, info: Info[Context, None], id: GlobalID) -> Query:
node_id = from_global_id_with_expected_type(str(id), "Project")
async with info.context.db() as session:
project = await session.scalar(
select(models.Project)
.where(models.Project.id == node_id)
.options(load_only(models.Project.name))
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
)
if project is None:
raise ValueError(f"Unknown project: {id}")
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
if project.name == DEFAULT_PROJECT_NAME:
raise ValueError(f"Cannot delete the {DEFAULT_PROJECT_NAME} project")
await session.delete(project)
return Query()

@strawberry.mutation
Expand Down
20 changes: 15 additions & 5 deletions src/phoenix/server/api/types/Project.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from phoenix.server.api.types.Trace import Trace
from phoenix.server.api.types.ValidationResult import ValidationResult
from phoenix.trace.dsl import SpanFilter
from phoenix.trace.schemas import SpanID, TraceID
from phoenix.trace.schemas import SpanID


@strawberry.type
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -149,10 +149,20 @@ async def latency_ms_quantile(
)

@strawberry.field
def trace(self, trace_id: ID) -> Optional[Trace]:
if self.project.has_trace(TraceID(trace_id)):
return Trace(trace_id=trace_id, project=self.project)
return None
async def trace(self, info: Info[Context, None], trace_id: ID) -> Optional[Trace]:
async with info.context.db() as session:
if not await session.scalar(
select(1)
.join_from(models.Trace, models.Project)
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
.where(
and_(
models.Trace.trace_id == str(trace_id),
models.Project.name == self.name,
),
)
):
return None
return Trace(trace_id=trace_id, project=self.project)

@strawberry.field
async def spans(
Expand Down
Loading