diff --git a/.github/workflows/python-CI.yml b/.github/workflows/python-CI.yml index d9cc2f35f0..d3744e845a 100644 --- a/.github/workflows/python-CI.yml +++ b/.github/workflows/python-CI.yml @@ -2,7 +2,7 @@ name: Python CI on: push: - branches: [main] + branches: [main, sql] pull_request: paths: - "src/**" diff --git a/Dockerfile b/Dockerfile index 0caeedfd44..40f2b2fa31 100644 --- a/Dockerfile +++ b/Dockerfile @@ -36,7 +36,7 @@ COPY ./ /phoenix/ COPY --from=frontend-builder /phoenix/src/phoenix/server/static/ /phoenix/src/phoenix/server/static/ # Delete symbolic links used during development. RUN find src/ -xtype l -delete -RUN pip install --target ./env .[container] +RUN pip install --target ./env ".[container, pg]" # The production image is distroless, meaning that it is a minimal image that # contains only the necessary dependencies to run the application. This is diff --git a/app/schema.graphql b/app/schema.graphql index 0236fc879f..26cc2d66f5 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -543,11 +543,11 @@ type Project implements Node { endTime: DateTime recordCount(timeRange: TimeRange): Int! traceCount(timeRange: TimeRange): Int! - tokenCountTotal: Int! + tokenCountTotal(timeRange: TimeRange): Int! latencyMsP50: Float latencyMsP99: Float trace(traceId: ID!): Trace - spans(timeRange: TimeRange, traceIds: [ID!], first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection! + spans(timeRange: TimeRange, first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection! """ Names of all available evaluations for traces. (The list contains no duplicates.) diff --git a/cspell.json b/cspell.json index af4d274d9a..22d6f408bf 100644 --- a/cspell.json +++ b/cspell.json @@ -37,6 +37,8 @@ "respx", "rgba", "seafoam", + "sqlalchemy", + "Starlette", "templating", "tensorboard", "testset", diff --git a/pyproject.toml b/pyproject.toml index cab2a0a29c..e1bae67614 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ dependencies = [ "openinference-instrumentation-langchain>=0.1.12", "openinference-instrumentation-llama-index>=1.2.0", "openinference-instrumentation-openai>=0.1.4", + "sqlalchemy>=2, <3", + "alembic>=1.3.0, <2", + "aiosqlite", ] dynamic = ["version"] @@ -86,6 +89,11 @@ llama-index = [ "llama-index-callbacks-arize-phoenix>=0.1.2", "openinference-instrumentation-llama-index>=1.2.0", ] + +pg = [ + "asyncpg", +] + container = [ "prometheus-client", ] diff --git a/src/phoenix/__init__.py b/src/phoenix/__init__.py index 58319bf2f2..2dece2695c 100644 --- a/src/phoenix/__init__.py +++ b/src/phoenix/__init__.py @@ -9,7 +9,14 @@ from .datasets.schema import EmbeddingColumnNames, RetrievalEmbeddingColumnNames, Schema from .session.client import Client from .session.evaluation import log_evaluations -from .session.session import NotebookEnvironment, Session, active_session, close_app, launch_app +from .session.session import ( + NotebookEnvironment, + Session, + active_session, + close_app, + launch_app, + reset_all, +) from .trace.fixtures import load_example_traces from .trace.trace_dataset import TraceDataset from .version import __version__ @@ -41,6 +48,7 @@ "active_session", "close_app", "launch_app", + "reset_all", "Session", "load_example_traces", "TraceDataset", diff --git a/src/phoenix/config.py b/src/phoenix/config.py index b952da4d28..bd278d2a90 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -22,6 +22,10 @@ """ The project name to use when logging traces and evals. defaults to 'default'. """ +ENV_PHOENIX_SQL_DATABASE = "__DANGEROUS__PHOENIX_SQL_DATABASE" +""" +The database URL to use when logging traces and evals. +""" ENV_SPAN_STORAGE_TYPE = "__DANGEROUS__PHOENIX_SPAN_STORAGE_TYPE" """ **EXPERIMENTAL** @@ -152,6 +156,14 @@ def get_env_span_storage_type() -> Optional["SpanStorageType"]: ) +def get_env_database_connection_str() -> str: + env_url = os.getenv(ENV_PHOENIX_SQL_DATABASE) + if env_url is None: + working_dir = get_working_dir() + return f"sqlite:///{working_dir}/phoenix.db" + return env_url + + class SpanStorageType(Enum): TEXT_FILES = "text-files" diff --git a/src/phoenix/db/__init__.py b/src/phoenix/db/__init__.py new file mode 100644 index 0000000000..2848bf384b --- /dev/null +++ b/src/phoenix/db/__init__.py @@ -0,0 +1,3 @@ +from .migrate import migrate + +__all__ = ["migrate"] diff --git a/src/phoenix/db/alembic.ini b/src/phoenix/db/alembic.ini new file mode 100644 index 0000000000..402ff9e576 --- /dev/null +++ b/src/phoenix/db/alembic.ini @@ -0,0 +1,119 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Note this is overridden in .migrate during programatic migrations +script_location = migrations + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# NB: This is commented out intentionally as it is dynamic +# See migrations/env.py +# sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = WARN +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/src/phoenix/db/bulk_inserter.py b/src/phoenix/db/bulk_inserter.py new file mode 100644 index 0000000000..b863949d52 --- /dev/null +++ b/src/phoenix/db/bulk_inserter.py @@ -0,0 +1,187 @@ +import asyncio +import logging +from itertools import islice +from time import time +from typing import Any, AsyncContextManager, Callable, Iterable, List, Optional, Tuple, cast + +from openinference.semconv.trace import SpanAttributes +from sqlalchemy import func, insert, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from phoenix.db import models +from phoenix.trace.schemas import Span, SpanStatusCode + +logger = logging.getLogger(__name__) + + +class BulkInserter: + def __init__( + self, + db: Callable[[], AsyncContextManager[AsyncSession]], + initial_batch_of_spans: Optional[Iterable[Tuple[Span, str]]] = None, + run_interval_in_seconds: float = 0.5, + max_num_per_transaction: int = 100, + ) -> None: + """ + :param db: A function to initiate a new database session. + :param initial_batch_of_spans: Initial batch of spans to insert. + :param run_interval_in_seconds: The time interval between the starts of each + bulk insert. If there's nothing to insert, the inserter goes back to sleep. + :param max_num_per_transaction: The maximum number of items to insert in a single + transaction. Multiple transactions will be used if there are more items in the batch. + """ + self._db = db + self._running = False + self._run_interval_seconds = run_interval_in_seconds + self._max_num_per_transaction = max_num_per_transaction + self._spans: List[Tuple[Span, str]] = ( + [] if initial_batch_of_spans is None else list(initial_batch_of_spans) + ) + self._task: Optional[asyncio.Task[None]] = None + + async def __aenter__(self) -> Callable[[Span, str], None]: + self._running = True + self._task = asyncio.create_task(self._bulk_insert()) + return self._queue_span + + async def __aexit__(self, *args: Any) -> None: + self._running = False + + def _queue_span(self, span: Span, project_name: str) -> None: + self._spans.append((span, project_name)) + + async def _bulk_insert(self) -> None: + next_run_at = time() + self._run_interval_seconds + while self._spans or self._running: + await asyncio.sleep(next_run_at - time()) + next_run_at = time() + self._run_interval_seconds + if self._spans: + await self._insert_spans() + + async def _insert_spans(self) -> None: + spans = self._spans + self._spans = [] + for i in range(0, len(spans), self._max_num_per_transaction): + try: + async with self._db() as session: + for span, project_name in islice(spans, i, i + self._max_num_per_transaction): + try: + async with session.begin_nested(): + await _insert_span(session, span, project_name) + except Exception: + logger.exception( + f"Failed to insert span with span_id={span.context.span_id}" + ) + except Exception: + logger.exception("Failed to insert spans") + + +async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> None: + if await session.scalar(select(1).where(models.Span.span_id == span.context.span_id)): + # Span already exists + return + if not ( + project_rowid := await session.scalar( + select(models.Project.id).where(models.Project.name == project_name) + ) + ): + project_rowid = await session.scalar( + insert(models.Project).values(name=project_name).returning(models.Project.id) + ) + if trace := await session.scalar( + select(models.Trace).where(models.Trace.trace_id == span.context.trace_id) + ): + trace_rowid = trace.id + # TODO(persistence): Figure out how to reliably retrieve timezone-aware + # datetime from the (sqlite) database, because all datetime in our + # programs should be timezone-aware. + if span.start_time < trace.start_time or trace.end_time < span.end_time: + trace.start_time = min(trace.start_time, span.start_time) + trace.end_time = max(trace.end_time, span.end_time) + await session.execute( + update(models.Trace) + .where(models.Trace.id == trace_rowid) + .values( + start_time=min(trace.start_time, span.start_time), + end_time=max(trace.end_time, span.end_time), + ) + ) + else: + trace_rowid = cast( + int, + await session.scalar( + insert(models.Trace) + .values( + project_rowid=project_rowid, + trace_id=span.context.trace_id, + start_time=span.start_time, + end_time=span.end_time, + ) + .returning(models.Trace.id) + ), + ) + cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) + cumulative_llm_token_count_prompt = cast( + int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0) + ) + cumulative_llm_token_count_completion = cast( + int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0) + ) + if accumulation := ( + await session.execute( + select( + func.sum(models.Span.cumulative_error_count), + func.sum(models.Span.cumulative_llm_token_count_prompt), + func.sum(models.Span.cumulative_llm_token_count_completion), + ).where(models.Span.parent_span_id == span.context.span_id) + ) + ).first(): + cumulative_error_count += cast(int, accumulation[0] or 0) + cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0) + cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0) + latency_ms = (span.end_time - span.start_time).total_seconds() * 1000 + session.add( + models.Span( + span_id=span.context.span_id, + trace_rowid=trace_rowid, + parent_span_id=span.parent_id, + kind=span.span_kind.value, + name=span.name, + start_time=span.start_time, + end_time=span.end_time, + attributes=span.attributes, + events=span.events, + status=span.status_code.value, + status_message=span.status_message, + latency_ms=latency_ms, + cumulative_error_count=cumulative_error_count, + cumulative_llm_token_count_prompt=cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion=cumulative_llm_token_count_completion, + ) + ) + # Propagate cumulative values to ancestors. This is usually a no-op, since + # the parent usually arrives after the child. But in the event that a + # child arrives after its parent, we need to make sure the all the + # ancestors' cumulative values are updated. + ancestors = ( + select(models.Span.id, models.Span.parent_span_id) + .where(models.Span.span_id == span.parent_id) + .cte(recursive=True) + ) + child = ancestors.alias() + ancestors = ancestors.union_all( + select(models.Span.id, models.Span.parent_span_id).join( + child, models.Span.span_id == child.c.parent_span_id + ) + ) + await session.execute( + update(models.Span) + .where(models.Span.id.in_(select(ancestors.c.id))) + .values( + cumulative_error_count=models.Span.cumulative_error_count + cumulative_error_count, + cumulative_llm_token_count_prompt=models.Span.cumulative_llm_token_count_prompt + + cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion=models.Span.cumulative_llm_token_count_completion + + cumulative_llm_token_count_completion, + ) + ) diff --git a/src/phoenix/db/engines.py b/src/phoenix/db/engines.py new file mode 100644 index 0000000000..2006de3dd6 --- /dev/null +++ b/src/phoenix/db/engines.py @@ -0,0 +1,94 @@ +import asyncio +import json +from datetime import datetime +from enum import Enum +from pathlib import Path +from sqlite3 import Connection +from typing import Any, Union + +import numpy as np +from sqlalchemy import URL, event, make_url +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from phoenix.db.migrate import migrate +from phoenix.db.models import init_models + + +def set_sqlite_pragma(connection: Connection, _: Any) -> None: + cursor = connection.cursor() + cursor.execute("PRAGMA foreign_keys = ON;") + cursor.execute("PRAGMA journal_mode = WAL;") + cursor.execute("PRAGMA synchronous = OFF;") + cursor.execute("PRAGMA cache_size = -32000;") + cursor.execute("PRAGMA busy_timeout = 10000;") + cursor.close() + + +def get_db_url(driver: str = "sqlite+aiosqlite", database: Union[str, Path] = ":memory:") -> URL: + return URL.create(driver, database=str(database)) + + +def create_engine(connection_str: str, echo: bool = False) -> AsyncEngine: + """ + Factory to create a SQLAlchemy engine from a URL string. + """ + url = make_url(connection_str) + if not url.database: + raise ValueError("Failed to parse database from connection string") + if "sqlite" in url.drivername: + # Split the URL to get the database name + return aio_sqlite_engine(database=url.database, echo=echo) + if "postgresql" in url.drivername: + return aio_postgresql_engine(url=url, echo=echo) + raise ValueError(f"Unsupported driver: {url.drivername}") + + +def aio_sqlite_engine( + database: Union[str, Path] = ":memory:", + echo: bool = False, +) -> AsyncEngine: + url = get_db_url(driver="sqlite+aiosqlite", database=database) + engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps) + event.listen(engine.sync_engine, "connect", set_sqlite_pragma) + if str(database) == ":memory:": + try: + asyncio.get_running_loop() + except RuntimeError: + asyncio.run(init_models(engine)) + else: + asyncio.create_task(init_models(engine)) + else: + migrate(engine.url) + return engine + + +def aio_postgresql_engine( + url: URL, + echo: bool = False, +) -> AsyncEngine: + # Swap out the engine + async_url = url.set(drivername="postgresql+asyncpg") + engine = create_async_engine(url=async_url, echo=echo, json_serializer=_dumps) + # TODO(persistence): figure out the postgres pragma + # event.listen(engine.sync_engine, "connect", set_pragma) + migrate(engine.url) + return engine + + +def _dumps(obj: Any) -> str: + return json.dumps(obj, cls=_Encoder) + + +class _Encoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, datetime): + return obj.isoformat() + elif isinstance(obj, Enum): + return obj.value + elif isinstance(obj, np.ndarray): + return list(obj) + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + return super().default(obj) diff --git a/src/phoenix/db/migrate.py b/src/phoenix/db/migrate.py new file mode 100644 index 0000000000..990c390e14 --- /dev/null +++ b/src/phoenix/db/migrate.py @@ -0,0 +1,28 @@ +import logging +from pathlib import Path + +from alembic import command +from alembic.config import Config +from sqlalchemy import URL + +logger = logging.getLogger(__name__) + + +def migrate(url: URL) -> None: + """ + Runs migrations on the database. + NB: Migrate only works on non-memory databases. + + Args: + url: The database URL. + """ + logger.warning("Running migrations on the database") + config_path = str(Path(__file__).parent.resolve() / "alembic.ini") + alembic_cfg = Config(config_path) + + # Explicitly set the migration directory + scripts_location = str(Path(__file__).parent.resolve() / "migrations") + alembic_cfg.set_main_option("script_location", scripts_location) + alembic_cfg.set_main_option("sqlalchemy.url", str(url)) + + command.upgrade(alembic_cfg, "head") diff --git a/src/phoenix/db/migrations/README b/src/phoenix/db/migrations/README new file mode 100644 index 0000000000..98e4f9c44e --- /dev/null +++ b/src/phoenix/db/migrations/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/src/phoenix/db/migrations/env.py b/src/phoenix/db/migrations/env.py new file mode 100644 index 0000000000..601b7873a2 --- /dev/null +++ b/src/phoenix/db/migrations/env.py @@ -0,0 +1,101 @@ +import asyncio +from logging.config import fileConfig + +from alembic import context +from phoenix.db.models import Base +from sqlalchemy import Connection, engine_from_config, pool +from sqlalchemy.ext.asyncio import AsyncEngine + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support + +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = context.config.attributes.get("connection", None) + if connectable is None: + connectable = AsyncEngine( + engine_from_config( + context.config.get_section(context.config.config_ini_section) or {}, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + future=True, + ) + ) + + if isinstance(connectable, AsyncEngine): + try: + asyncio.get_running_loop() + except RuntimeError: + asyncio.run(run_async_migrations(connectable)) + else: + asyncio.create_task(run_async_migrations(connectable)) + else: + run_migrations(connectable) + + +async def run_async_migrations(engine: AsyncEngine) -> None: + async with engine.connect() as connection: + await connection.run_sync(run_migrations) + await engine.dispose() + + +def run_migrations(connection: Connection) -> None: + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, + ) + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/src/phoenix/db/migrations/script.py.mako b/src/phoenix/db/migrations/script.py.mako new file mode 100644 index 0000000000..fbc4b07dce --- /dev/null +++ b/src/phoenix/db/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py new file mode 100644 index 0000000000..de4a4fc917 --- /dev/null +++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py @@ -0,0 +1,91 @@ +"""init + +Revision ID: cf03bd6bae1d +Revises: +Create Date: 2024-04-03 19:41:48.871555 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "cf03bd6bae1d" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + projects_table = op.create_table( + "projects", + sa.Column("id", sa.Integer, primary_key=True), + # TODO does the uniqueness constraint need to be named + sa.Column("name", sa.String, nullable=False, unique=True), + sa.Column("description", sa.String, nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + onupdate=sa.func.now(), + ), + ) + op.create_table( + "traces", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("project_rowid", sa.Integer, sa.ForeignKey("projects.id"), nullable=False), + # TODO(mikeldking): might not be the right place for this + sa.Column("session_id", sa.String, nullable=True), + sa.Column("trace_id", sa.String, nullable=False, unique=True), + sa.Column("start_time", sa.DateTime(timezone=True), nullable=False, index=True), + sa.Column("end_time", sa.DateTime(timezone=True), nullable=False), + ) + + op.create_table( + "spans", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("trace_rowid", sa.Integer, sa.ForeignKey("traces.id"), nullable=False), + sa.Column("span_id", sa.String, nullable=False, unique=True), + sa.Column("parent_span_id", sa.String, nullable=True, index=True), + sa.Column("name", sa.String, nullable=False), + sa.Column("kind", sa.String, nullable=False), + sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), + sa.Column("end_time", sa.DateTime(timezone=True), nullable=False), + sa.Column("attributes", sa.JSON, nullable=False), + sa.Column("events", sa.JSON, nullable=False), + sa.Column( + "status", + sa.String, + # TODO(mikeldking): this doesn't seem to work... + sa.CheckConstraint("status IN ('OK', 'ERROR', 'UNSET')", "valid_status"), + nullable=False, + default="UNSET", + server_default="UNSET", + ), + sa.Column("status_message", sa.String, nullable=False), + sa.Column("latency_ms", sa.REAL, nullable=False), + sa.Column("cumulative_error_count", sa.Integer, nullable=False), + sa.Column("cumulative_llm_token_count_prompt", sa.Integer, nullable=False), + sa.Column("cumulative_llm_token_count_completion", sa.Integer, nullable=False), + ) + op.bulk_insert( + projects_table, + [ + {"name": "default", "description": "Default project"}, + ], + ) + + +def downgrade() -> None: + op.drop_table("projects") + op.drop_table("traces") + op.drop_table("spans") diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py new file mode 100644 index 0000000000..1ba56ef6f2 --- /dev/null +++ b/src/phoenix/db/models.py @@ -0,0 +1,173 @@ +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from sqlalchemy import ( + JSON, + CheckConstraint, + DateTime, + Dialect, + ForeignKey, + MetaData, + TypeDecorator, + UniqueConstraint, + func, + insert, +) +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + WriteOnlyMapped, + mapped_column, + relationship, +) + + +class UtcTimeStamp(TypeDecorator[datetime]): + """TODO(persistence): Figure out how to reliably store and retrieve + timezone-aware datetime objects from the (sqlite) database. Below is a + workaround to guarantee that the timestamps we fetch from the database is + always timezone-aware, in order to prevent comparisons of timezone-naive + datetime with timezone-aware datetime, because objects in the rest of our + programs are always timezone-aware. + """ + + cache_ok = True + impl = DateTime + _LOCAL_TIMEZONE = datetime.now(timezone.utc).astimezone().tzinfo + + def process_bind_param( + self, + value: Optional[datetime], + dialect: Dialect, + ) -> Optional[datetime]: + if not value: + return None + if value.tzinfo is None: + value = value.astimezone(self._LOCAL_TIMEZONE) + return value.astimezone(timezone.utc) + + def process_result_value( + self, + value: Optional[Any], + dialect: Dialect, + ) -> Optional[datetime]: + if not isinstance(value, datetime): + return None + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + +class Base(DeclarativeBase): + # Enforce best practices for naming constraints + # https://alembic.sqlalchemy.org/en/latest/naming.html#integration-of-naming-conventions-into-operations-autogenerate + metadata = MetaData( + naming_convention={ + "ix": "ix_%(column_0_label)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_`%(constraint_name)s`", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", + } + ) + type_annotation_map = { + Dict[str, Any]: JSON, + List[Dict[str, Any]]: JSON, + } + + +class Project(Base): + __tablename__ = "projects" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + description: Mapped[Optional[str]] + updated_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now()) + created_at: Mapped[datetime] = mapped_column( + UtcTimeStamp, server_default=func.now(), onupdate=func.now() + ) + + traces: WriteOnlyMapped["Trace"] = relationship( + "Trace", + back_populates="project", + cascade="all, delete-orphan", + ) + __table_args__ = ( + UniqueConstraint( + "name", + name="uq_projects_name", + sqlite_on_conflict="IGNORE", + ), + ) + + +class Trace(Base): + __tablename__ = "traces" + id: Mapped[int] = mapped_column(primary_key=True) + project_rowid: Mapped[int] = mapped_column(ForeignKey("projects.id")) + session_id: Mapped[Optional[str]] + trace_id: Mapped[str] + start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True) + end_time: Mapped[datetime] = mapped_column(UtcTimeStamp) + + project: Mapped["Project"] = relationship( + "Project", + back_populates="traces", + ) + spans: Mapped[List["Span"]] = relationship( + "Span", + back_populates="trace", + cascade="all, delete-orphan", + ) + __table_args__ = ( + UniqueConstraint( + "trace_id", + name="uq_traces_trace_id", + sqlite_on_conflict="IGNORE", + ), + ) + + +class Span(Base): + __tablename__ = "spans" + id: Mapped[int] = mapped_column(primary_key=True) + trace_rowid: Mapped[int] = mapped_column(ForeignKey("traces.id")) + span_id: Mapped[str] + parent_span_id: Mapped[Optional[str]] = mapped_column(index=True) + name: Mapped[str] + kind: Mapped[str] + start_time: Mapped[datetime] = mapped_column(UtcTimeStamp) + end_time: Mapped[datetime] = mapped_column(UtcTimeStamp) + attributes: Mapped[Dict[str, Any]] + events: Mapped[List[Dict[str, Any]]] + status: Mapped[str] = mapped_column( + CheckConstraint("status IN ('OK', 'ERROR', 'UNSET')", "valid_status") + ) + status_message: Mapped[str] + + # TODO(mikeldking): is computed columns possible here + latency_ms: Mapped[float] + cumulative_error_count: Mapped[int] + cumulative_llm_token_count_prompt: Mapped[int] + cumulative_llm_token_count_completion: Mapped[int] + + trace: Mapped["Trace"] = relationship("Trace", back_populates="spans") + + __table_args__ = ( + UniqueConstraint( + "span_id", + name="uq_spans_span_id", + sqlite_on_conflict="IGNORE", + ), + ) + + +async def init_models(engine: AsyncEngine) -> None: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + await conn.execute( + insert(Project).values( + name="default", + description="default project", + ) + ) diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index 2d6a95be7e..bb89e09f72 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -1,7 +1,8 @@ from dataclasses import dataclass from pathlib import Path -from typing import Optional, Union +from typing import AsyncContextManager, Callable, Optional, Union +from sqlalchemy.ext.asyncio import AsyncSession from starlette.requests import Request from starlette.responses import Response from starlette.websockets import WebSocket @@ -14,6 +15,7 @@ class Context: request: Union[Request, WebSocket] response: Optional[Response] + db: Callable[[], AsyncContextManager[AsyncSession]] model: Model export_path: Path corpus: Optional[Model] = None diff --git a/src/phoenix/server/api/routers/trace_handler.py b/src/phoenix/server/api/routers/trace_handler.py index 221f40ad99..c50efcae96 100644 --- a/src/phoenix/server/api/routers/trace_handler.py +++ b/src/phoenix/server/api/routers/trace_handler.py @@ -54,7 +54,15 @@ async def post(self, request: Request) -> Response: for resource_spans in req.resource_spans: project_name = get_project_name(resource_spans.resource.attributes) for scope_span in resource_spans.scope_spans: - for span in scope_span.spans: - self.traces.put(decode(span), project_name=project_name) + for otlp_span in scope_span.spans: + span = decode(otlp_span) + # TODO(persistence): Decide which one is better: delayed + # bulk-insert or insert each request immediately, i.e. one + # transaction per request. The bulk-insert is more efficient, + # but it queues data in volatile (buffer) memory (for a short + # period of time), so the 200 response is not a genuine + # confirmation of data persistence. + request.state.queue_span_for_bulk_insert(span, project_name) + self.traces.put(span, project_name=project_name) await asyncio.sleep(0) return Response() diff --git a/src/phoenix/server/api/types/MimeType.py b/src/phoenix/server/api/types/MimeType.py index 4c33572994..9dbe40f477 100644 --- a/src/phoenix/server/api/types/MimeType.py +++ b/src/phoenix/server/api/types/MimeType.py @@ -8,8 +8,8 @@ @strawberry.enum class MimeType(Enum): - text = trace_schemas.MimeType.TEXT - json = trace_schemas.MimeType.JSON + text = trace_schemas.MimeType.TEXT.value + json = trace_schemas.MimeType.JSON.value @classmethod def _missing_(cls, v: Any) -> Optional["MimeType"]: diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 926d28296b..19c152ecb4 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -1,12 +1,19 @@ from datetime import datetime -from itertools import chain from typing import List, Optional import strawberry +from openinference.semconv.trace import SpanAttributes +from sqlalchemy import Integer, and_, cast, func, select +from sqlalchemy.orm import contains_eager +from sqlalchemy.sql.functions import coalesce from strawberry import ID, UNSET +from strawberry.types import Info from phoenix.core.project import Project as CoreProject +from phoenix.datetime_utils import right_open_time_range +from phoenix.db import models from phoenix.metrics.retrieval_metrics import RetrievalMetrics +from phoenix.server.api.context import Context from phoenix.server.api.input_types.SpanSort import SpanSort from phoenix.server.api.input_types.TimeRange import TimeRange from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary @@ -31,36 +38,104 @@ class Project(Node): project: strawberry.Private[CoreProject] @strawberry.field - def start_time(self) -> Optional[datetime]: - start_time, _ = self.project.right_open_time_range + async def start_time( + self, + info: Info[Context, None], + ) -> Optional[datetime]: + stmt = ( + select(func.min(models.Trace.start_time)) + .join(models.Project) + .where(models.Project.name == self.name) + ) + async with info.context.db() as session: + start_time = await session.scalar(stmt) + start_time, _ = right_open_time_range(start_time, None) return start_time @strawberry.field - def end_time(self) -> Optional[datetime]: - _, end_time = self.project.right_open_time_range + async def end_time( + self, + info: Info[Context, None], + ) -> Optional[datetime]: + stmt = ( + select(func.max(models.Trace.end_time)) + .join(models.Project) + .where(models.Project.name == self.name) + ) + async with info.context.db() as session: + end_time = await session.scalar(stmt) + _, end_time = right_open_time_range(None, end_time) return end_time @strawberry.field - def record_count( + async def record_count( self, + info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - if not time_range: - return self.project.span_count() - return self.project.span_count(time_range.start, time_range.end) + stmt = ( + select(func.count(models.Span.id)) + .join(models.Trace) + .join(models.Project) + .where(models.Project.name == self.name) + ) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Span.start_time, + models.Span.start_time < time_range.end, + ) + ) + async with info.context.db() as session: + return (await session.scalar(stmt)) or 0 @strawberry.field - def trace_count( + async def trace_count( self, + info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - if not time_range: - return self.project.trace_count() - return self.project.trace_count(time_range.start, time_range.end) + stmt = ( + select(func.count(models.Trace.id)) + .join(models.Project) + .where(models.Project.name == self.name) + ) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Trace.start_time, + models.Trace.start_time < time_range.end, + ) + ) + async with info.context.db() as session: + return (await session.scalar(stmt)) or 0 @strawberry.field - def token_count_total(self) -> int: - return self.project.token_count_total + async def token_count_total( + self, + info: Info[Context, None], + time_range: Optional[TimeRange] = UNSET, + ) -> int: + prompt = models.Span.attributes[LLM_TOKEN_COUNT_PROMPT] + completion = models.Span.attributes[LLM_TOKEN_COUNT_COMPLETION] + stmt = ( + select( + coalesce(func.sum(cast(prompt, Integer)), 0) + + coalesce(func.sum(cast(completion, Integer)), 0) + ) + .join(models.Trace) + .join(models.Project) + .where(models.Project.name == self.name) + ) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Span.start_time, + models.Span.start_time < time_range.end, + ) + ) + async with info.context.db() as session: + return (await session.scalar(stmt)) or 0 @strawberry.field def latency_ms_p50(self) -> Optional[float]: @@ -77,10 +152,10 @@ def trace(self, trace_id: ID) -> Optional[Trace]: return None @strawberry.field - def spans( + async def spans( self, + info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, - trace_ids: Optional[List[ID]] = UNSET, first: Optional[int] = 50, last: Optional[int] = UNSET, after: Optional[Cursor] = UNSET, @@ -95,36 +170,32 @@ def spans( last=last, before=before if isinstance(before, Cursor) else None, ) - start_time = time_range.start if time_range else None - stop_time = time_range.end if time_range else None - if not (project := self.project).span_count( - start_time=start_time, - stop_time=stop_time, - ): - return connection_from_list(data=[], args=args) - predicate = ( - SpanFilter( - condition=filter_condition, - evals=project, - ) - if filter_condition - else None + stmt = ( + select(models.Span) + .join(models.Trace) + .join(models.Project) + .where(models.Project.name == self.name) + .options(contains_eager(models.Span.trace)) ) - if not trace_ids: - spans = project.get_spans( - start_time=start_time, - stop_time=stop_time, - root_spans_only=root_spans_only, - ) - else: - spans = chain.from_iterable( - project.get_trace(trace_id) for trace_id in map(TraceID, trace_ids) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Span.start_time, + models.Span.start_time < time_range.end, + ) ) - if predicate: - spans = filter(predicate, spans) - if sort: - spans = sort(spans, evals=project) - data = [to_gql_span(span, project) for span in spans] + if root_spans_only: + # A root span is any span whose parent span is missing in the + # database, even if its `parent_span_id` may not be NULL. + parent = select(models.Span.span_id).alias() + stmt = stmt.outerjoin( + parent, + models.Span.parent_span_id == parent.c.span_id, + ).where(parent.c.span_id.is_(None)) + # TODO(persistence): enable sort and filter + async with info.context.db() as session: + spans = await session.scalars(stmt) + data = [to_gql_span(span, self.project) for span in spans] return connection_from_list(data=data, args=args) @strawberry.field( @@ -286,3 +357,7 @@ def validate_span_filter_condition(self, condition: str) -> ValidationResult: is_valid=False, error_message=e.msg, ) + + +LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT +LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION diff --git a/src/phoenix/server/api/types/Span.py b/src/phoenix/server/api/types/Span.py index 52eebc7786..9025ab4ca9 100644 --- a/src/phoenix/server/api/types/Span.py +++ b/src/phoenix/server/api/types/Span.py @@ -7,17 +7,20 @@ import numpy as np import strawberry from openinference.semconv.trace import EmbeddingAttributes, SpanAttributes +from sqlalchemy import select +from sqlalchemy.orm import contains_eager from strawberry import ID, UNSET from strawberry.types import Info import phoenix.trace.schemas as trace_schema -from phoenix.core.project import Project, WrappedSpan +from phoenix.core.project import Project +from phoenix.db import models from phoenix.metrics.retrieval_metrics import RetrievalMetrics from phoenix.server.api.context import Context from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation from phoenix.server.api.types.MimeType import MimeType -from phoenix.trace.schemas import ComputedAttributes, SpanID +from phoenix.trace.schemas import SpanID EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR @@ -40,14 +43,14 @@ class SpanKind(Enum): NB: this is actively under construction """ - chain = trace_schema.SpanKind.CHAIN - tool = trace_schema.SpanKind.TOOL - llm = trace_schema.SpanKind.LLM - retriever = trace_schema.SpanKind.RETRIEVER - embedding = trace_schema.SpanKind.EMBEDDING - agent = trace_schema.SpanKind.AGENT - reranker = trace_schema.SpanKind.RERANKER - unknown = trace_schema.SpanKind.UNKNOWN + chain = "CHAIN" + tool = "TOOL" + llm = "LLM" + retriever = "RETRIEVER" + embedding = "EMBEDDING" + agent = "AGENT" + reranker = "RERANKER" + unknown = "UNKNOWN" @classmethod def _missing_(cls, v: Any) -> Optional["SpanKind"]: @@ -68,9 +71,9 @@ class SpanIOValue: @strawberry.enum class SpanStatusCode(Enum): - OK = trace_schema.SpanStatusCode.OK - ERROR = trace_schema.SpanStatusCode.ERROR - UNSET = trace_schema.SpanStatusCode.UNSET + OK = "OK" + ERROR = "ERROR" + UNSET = "UNSET" @classmethod def _missing_(cls, v: Any) -> Optional["SpanStatusCode"]: @@ -84,13 +87,13 @@ class SpanEvent: timestamp: datetime @staticmethod - def from_event( - event: trace_schema.SpanEvent, + def from_dict( + event: Mapping[str, Any], ) -> "SpanEvent": return SpanEvent( - name=event.name, - message=cast(str, event.attributes.get(trace_schema.EXCEPTION_MESSAGE) or ""), - timestamp=event.timestamp, + name=event["name"], + message=cast(str, event["attributes"].get(trace_schema.EXCEPTION_MESSAGE) or ""), + timestamp=event["timestamp"], ) @@ -202,18 +205,35 @@ def document_retrieval_metrics( @strawberry.field( description="All descendant spans (children, grandchildren, etc.)", ) # type: ignore - def descendants( + async def descendants( self, info: Info[Context, None], ) -> List["Span"]: - return [ - to_gql_span(span, self.project) - for span in self.project.get_descendant_spans(SpanID(self.context.span_id)) - ] + # TODO(persistence): add dataloader (to avoid N+1 queries) or change how this is fetched + async with info.context.db() as session: + descendant_ids = ( + select(models.Span.id, models.Span.span_id) + .filter(models.Span.parent_span_id == str(self.context.span_id)) + .cte(recursive=True) + ) + parent_ids = descendant_ids.alias() + descendant_ids = descendant_ids.union_all( + select(models.Span.id, models.Span.span_id).join( + parent_ids, + models.Span.parent_span_id == parent_ids.c.span_id, + ) + ) + spans = await session.scalars( + select(models.Span) + .join(descendant_ids, models.Span.id == descendant_ids.c.id) + .join(models.Trace) + .options(contains_eager(models.Span.trace)) + ) + return [to_gql_span(span, self.project) for span in spans] -def to_gql_span(span: WrappedSpan, project: Project) -> "Span": - events: List[SpanEvent] = list(map(SpanEvent.from_event, span.events)) +def to_gql_span(span: models.Span, project: Project) -> Span: + events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events)) input_value = cast(Optional[str], span.attributes.get(INPUT_VALUE)) output_value = cast(Optional[str], span.attributes.get(OUTPUT_VALUE)) retrieval_documents = span.attributes.get(RETRIEVAL_DOCUMENTS) @@ -221,16 +241,16 @@ def to_gql_span(span: WrappedSpan, project: Project) -> "Span": return Span( project=project, name=span.name, - status_code=SpanStatusCode(span.status_code), + status_code=SpanStatusCode(span.status), status_message=span.status_message, - parent_id=cast(Optional[ID], span.parent_id), - span_kind=SpanKind(span.span_kind), + parent_id=cast(Optional[ID], span.parent_span_id), + span_kind=SpanKind(span.kind), start_time=span.start_time, end_time=span.end_time, - latency_ms=cast(Optional[float], span[ComputedAttributes.LATENCY_MS]), + latency_ms=span.latency_ms, context=SpanContext( - trace_id=cast(ID, span.context.trace_id), - span_id=cast(ID, span.context.span_id), + trace_id=cast(ID, span.trace.trace_id), + span_id=cast(ID, span.span_id), ), attributes=json.dumps( _nested_attributes(_hide_embedding_vectors(span.attributes)), @@ -250,22 +270,12 @@ def to_gql_span(span: WrappedSpan, project: Project) -> "Span": Optional[int], span.attributes.get(LLM_TOKEN_COUNT_COMPLETION), ), - cumulative_token_count_total=cast( - Optional[int], - span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL], - ), - cumulative_token_count_prompt=cast( - Optional[int], - span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT], - ), - cumulative_token_count_completion=cast( - Optional[int], - span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION], - ), + cumulative_token_count_total=span.cumulative_llm_token_count_prompt + + span.cumulative_llm_token_count_completion, + cumulative_token_count_prompt=span.cumulative_llm_token_count_prompt, + cumulative_token_count_completion=span.cumulative_llm_token_count_completion, propagated_status_code=( - SpanStatusCode.ERROR - if span[ComputedAttributes.CUMULATIVE_ERROR_COUNT] - else SpanStatusCode(span.status_code) + SpanStatusCode.ERROR if span.cumulative_error_count else SpanStatusCode(span.status) ), events=events, input=( diff --git a/src/phoenix/server/api/types/Trace.py b/src/phoenix/server/api/types/Trace.py index 4a0b3331ba..075928ba8a 100644 --- a/src/phoenix/server/api/types/Trace.py +++ b/src/phoenix/server/api/types/Trace.py @@ -1,9 +1,14 @@ from typing import List, Optional import strawberry +from sqlalchemy import select +from sqlalchemy.orm import contains_eager from strawberry import ID, UNSET, Private +from strawberry.types import Info from phoenix.core.project import Project +from phoenix.db import models +from phoenix.server.api.context import Context from phoenix.server.api.types.Evaluation import TraceEvaluation from phoenix.server.api.types.pagination import ( Connection, @@ -21,8 +26,9 @@ class Trace: project: Private[Project] @strawberry.field - def spans( + async def spans( self, + info: Info[Context, None], first: Optional[int] = 50, last: Optional[int] = UNSET, after: Optional[Cursor] = UNSET, @@ -34,10 +40,13 @@ def spans( last=last, before=before if isinstance(before, Cursor) else None, ) - spans = sorted( - self.project.get_trace(TraceID(self.trace_id)), - key=lambda span: span.start_time, - ) + async with info.context.db() as session: + spans = await session.scalars( + select(models.Span) + .join(models.Trace) + .filter(models.Trace.trace_id == self.trace_id) + .options(contains_eager(models.Span.trace)) + ) data = [to_gql_span(span, self.project) for span in spans] return connection_from_list(data=data, args=args) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 2b814ce04f..545fab15b7 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -1,7 +1,24 @@ +import contextlib import logging from pathlib import Path -from typing import Any, NamedTuple, Optional, Union +from typing import ( + Any, + AsyncContextManager, + AsyncIterator, + Callable, + Dict, + Iterable, + NamedTuple, + Optional, + Tuple, + Union, +) +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, +) from starlette.applications import Starlette from starlette.datastructures import QueryParams from starlette.endpoints import HTTPEndpoint @@ -13,15 +30,17 @@ from starlette.routing import Mount, Route from starlette.staticfiles import StaticFiles from starlette.templating import Jinja2Templates -from starlette.types import Scope +from starlette.types import Scope, StatefulLifespan from starlette.websockets import WebSocket from strawberry.asgi import GraphQL from strawberry.schema import BaseSchema import phoenix -from phoenix.config import SERVER_DIR +from phoenix.config import DEFAULT_PROJECT_NAME, SERVER_DIR from phoenix.core.model_schema import Model from phoenix.core.traces import Traces +from phoenix.db.bulk_inserter import BulkInserter +from phoenix.db.engines import create_engine from phoenix.pointcloud.umap_parameters import UMAPParameters from phoenix.server.api.context import Context from phoenix.server.api.routers.evaluation_handler import EvaluationHandler @@ -29,6 +48,7 @@ from phoenix.server.api.routers.trace_handler import TraceHandler from phoenix.server.api.schema import schema from phoenix.storage.span_store import SpanStore +from phoenix.trace.schemas import Span logger = logging.getLogger(__name__) @@ -96,12 +116,14 @@ class GraphQLWithContext(GraphQL): # type: ignore def __init__( self, schema: BaseSchema, + db: Callable[[], AsyncContextManager[AsyncSession]], model: Model, export_path: Path, graphiql: bool = False, corpus: Optional[Model] = None, traces: Optional[Traces] = None, ) -> None: + self.db = db self.model = model self.corpus = corpus self.traces = traces @@ -116,6 +138,7 @@ async def get_context( return Context( request=request, response=response, + db=self.db, model=self.model, corpus=self.corpus, traces=self.traces, @@ -142,7 +165,31 @@ async def version(_: Request) -> PlainTextResponse: return PlainTextResponse(f"{phoenix.__version__}") +def _db(engine: AsyncEngine) -> Callable[[], AsyncContextManager[AsyncSession]]: + Session = async_sessionmaker(engine, expire_on_commit=False) + + @contextlib.asynccontextmanager + async def factory() -> AsyncIterator[AsyncSession]: + async with Session.begin() as session: + yield session + + return factory + + +def _lifespan( + db: Callable[[], AsyncContextManager[AsyncSession]], + initial_batch_of_spans: Optional[Iterable[Tuple[Span, str]]] = None, +) -> StatefulLifespan[Starlette]: + @contextlib.asynccontextmanager + async def lifespan(_: Starlette) -> AsyncIterator[Dict[str, Any]]: + async with BulkInserter(db, initial_batch_of_spans) as queue_span: + yield {"queue_span_for_bulk_insert": queue_span} + + return lifespan + + def create_app( + database: str, export_path: Path, model: Model, umap_params: UMAPParameters, @@ -152,8 +199,20 @@ def create_app( debug: bool = False, read_only: bool = False, enable_prometheus: bool = False, + initial_spans: Optional[Iterable[Union[Span, Tuple[Span, str]]]] = None, ) -> Starlette: + initial_batch_of_spans: Iterable[Tuple[Span, str]] = ( + () + if initial_spans is None + else ( + ((item, DEFAULT_PROJECT_NAME) if isinstance(item, Span) else item) + for item in initial_spans + ) + ) + engine = create_engine(database) + db = _db(engine) graphql = GraphQLWithContext( + db=db, schema=schema, model=model, corpus=corpus, @@ -168,6 +227,7 @@ def create_app( else: prometheus_middlewares = [] return Starlette( + lifespan=_lifespan(db, initial_batch_of_spans), middleware=[ Middleware(HeadersMiddleware), *prometheus_middlewares, diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 070a745f46..5b4e5d964e 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -13,9 +13,11 @@ from phoenix.config import ( EXPORT_DIR, + get_env_database_connection_str, get_env_host, get_env_port, get_pids_path, + get_working_dir, ) from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces @@ -48,7 +50,7 @@ ██████╔╝███████║██║ ██║█████╗ ██╔██╗ ██║██║ ╚███╔╝ ██╔═══╝ ██╔══██║██║ ██║██╔══╝ ██║╚██╗██║██║ ██╔██╗ ██║ ██║ ██║╚██████╔╝███████╗██║ ╚████║██║██╔╝ ██╗ -╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝╚═╝╚═╝ ╚═╝ v{0} +╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝╚═╝╚═╝ ╚═╝ v{version} | | 🌎 Join our Community 🌎 @@ -61,9 +63,9 @@ | https://docs.arize.com/phoenix | | 🚀 Phoenix Server 🚀 -| Phoenix UI: http://{1}:{2} +| Phoenix UI: http://{host}:{port} | Log traces: /v1/traces over HTTP -| +| Storage: {storage} """ @@ -189,13 +191,18 @@ def _load_items( trace_dataset_name = args.trace_fixture simulate_streaming = args.simulate_streaming + host = args.host or get_env_host() + port = args.port or get_env_port() + model = create_model_from_datasets( primary_dataset, reference_dataset, ) + traces = Traces() if span_store := get_span_store(): Thread(target=load_traces_data_from_store, args=(traces, span_store), daemon=True).start() + fixture_spans = [] if trace_dataset_name is not None: fixture_spans = list( # Apply `encode` here because legacy jsonl files contains UUIDs as strings. @@ -228,7 +235,11 @@ def _load_items( from phoenix.server.prometheus import start_prometheus start_prometheus() + + working_dir = get_working_dir().resolve() + db_connection_str = get_env_database_connection_str() app = create_app( + database=db_connection_str, export_path=export_path, model=model, umap_params=umap_params, @@ -238,17 +249,20 @@ def _load_items( read_only=read_only, span_store=span_store, enable_prometheus=enable_prometheus, + initial_spans=fixture_spans, ) - host = args.host or get_env_host() - port = args.port or get_env_port() server = Server(config=Config(app, host=host, port=port)) Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start() # Print information about the server phoenix_version = pkg_resources.get_distribution("arize-phoenix").version - print( - _WELCOME_MESSAGE.format(phoenix_version, host if host != "0.0.0.0" else "localhost", port) - ) + config = { + "version": phoenix_version, + "host": host, + "port": port, + "storage": db_connection_str, + } + print(_WELCOME_MESSAGE.format(**config)) # Start the server server.run() diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index 4a69d6da76..19239609de 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -1,6 +1,7 @@ import json import logging import os +import shutil import warnings from abc import ABC, abstractmethod from collections import UserList @@ -29,10 +30,12 @@ ENV_PHOENIX_COLLECTOR_ENDPOINT, ENV_PHOENIX_HOST, ENV_PHOENIX_PORT, + get_env_database_connection_str, get_env_host, get_env_port, get_env_project_name, get_exported_files, + get_working_dir, ) from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces @@ -118,27 +121,6 @@ def __init__( self.corpus_dataset = corpus_dataset self.trace_dataset = trace_dataset self.umap_parameters = get_umap_parameters(default_umap_parameters) - self.model = create_model_from_datasets( - primary_dataset, - reference_dataset, - ) - - self.corpus = ( - create_model_from_datasets( - corpus_dataset, - ) - if corpus_dataset is not None - else None - ) - - self.traces = Traces() - if trace_dataset: - for span in trace_dataset.to_spans(): - self.traces.put(span) - for evaluations in trace_dataset.evaluations: - for pb_evaluation in encode_evaluations(evaluations): - self.traces.put(pb_evaluation) - self.host = host or get_env_host() self.port = port or get_env_port() self.temp_dir = TemporaryDirectory() @@ -284,6 +266,7 @@ def get_evaluations( class ThreadSession(Session): def __init__( self, + database: str, primary_dataset: Dataset, reference_dataset: Optional[Dataset] = None, corpus_dataset: Optional[Dataset] = None, @@ -304,6 +287,24 @@ def __init__( port=port, notebook_env=notebook_env, ) + self.model = create_model_from_datasets( + primary_dataset, + reference_dataset, + ) + self.corpus = ( + create_model_from_datasets( + corpus_dataset, + ) + if corpus_dataset is not None + else None + ) + self.traces = Traces() + if trace_dataset: + for span in trace_dataset.to_spans(): + self.traces.put(span) + for evaluations in trace_dataset.evaluations: + for pb_evaluation in encode_evaluations(evaluations): + self.traces.put(pb_evaluation) if span_store := get_span_store(): Thread( target=load_traces_data_from_store, @@ -312,12 +313,14 @@ def __init__( ).start() # Initialize an app service that keeps the server running self.app = create_app( + database=database, export_path=self.export_path, model=self.model, corpus=self.corpus, traces=self.traces, umap_params=self.umap_parameters, span_store=span_store, + initial_spans=trace_dataset.to_spans() if trace_dataset else None, ) self.server = ThreadServer( app=self.app, @@ -423,6 +426,19 @@ def get_evaluations( return project.export_evaluations() +def reset_all(hard: Optional[bool] = False) -> None: + """ + Resets everything to the initial state. + """ + working_dir = get_working_dir() + + # See if the working directory exists + if working_dir.exists(): + if not hard: + input(f"Working directory exists at {working_dir}. Press Enter to delete the directory") + shutil.rmtree(working_dir) + + def launch_app( primary: Optional[Dataset] = None, reference: Optional[Dataset] = None, @@ -533,9 +549,11 @@ def launch_app( host = host or get_env_host() port = port or get_env_port() + database = get_env_database_connection_str() if run_in_thread: _session = ThreadSession( + database, primary, reference, corpus, @@ -568,7 +586,7 @@ def launch_app( return None print(f"🌍 To view the Phoenix app in your browser, visit {_session.url}") - print("📺 To view the Phoenix app in a notebook, run `px.active_session().view()`") + print(f"💽 Your data is being persisted to {database}") print("📖 For more information on how to use Phoenix, check out https://docs.arize.com/phoenix") return _session @@ -582,10 +600,15 @@ def active_session() -> Optional[Session]: return None -def close_app() -> None: +def close_app(reset: bool = False) -> None: """ Closes the phoenix application. The application server is shut down and will no longer be accessible. + + Parameters + ---------- + reset : bool, optional + Whether to reset the working directory. Defaults to False. """ global _session if _session is None: @@ -594,6 +617,9 @@ def close_app() -> None: _session.end() _session = None logger.info("Session closed") + if reset: + logger.info("Resetting working directory") + reset_all(hard=True) def _get_url(host: str, port: int, notebook_env: NotebookEnvironment) -> str: diff --git a/src/phoenix/trace/otel.py b/src/phoenix/trace/otel.py index 12ec0b4104..21bb417b4a 100644 --- a/src/phoenix/trace/otel.py +++ b/src/phoenix/trace/otel.py @@ -33,7 +33,6 @@ EXCEPTION_MESSAGE, EXCEPTION_STACKTRACE, EXCEPTION_TYPE, - MimeType, Span, SpanContext, SpanEvent, @@ -61,16 +60,14 @@ def decode(otlp_span: otlp.Span) -> Span: parent_id = _decode_identifier(otlp_span.parent_span_id) start_time = _decode_unix_nano(otlp_span.start_time_unix_nano) - end_time = ( - _decode_unix_nano(otlp_span.end_time_unix_nano) if otlp_span.end_time_unix_nano else None - ) + end_time = _decode_unix_nano(otlp_span.end_time_unix_nano) attributes = dict(_unflatten(_load_json_strings(_decode_key_values(otlp_span.attributes)))) span_kind = SpanKind(attributes.pop(OPENINFERENCE_SPAN_KIND, None)) for mime_type in (INPUT_MIME_TYPE, OUTPUT_MIME_TYPE): if mime_type in attributes: - attributes[mime_type] = MimeType(attributes[mime_type]) + attributes[mime_type] = attributes[mime_type] status_code, status_message = _decode_status(otlp_span.status) events = [_decode_event(event) for event in otlp_span.events] @@ -320,7 +317,7 @@ def encode(span: Span) -> otlp.Span: for mime_type in (INPUT_MIME_TYPE, OUTPUT_MIME_TYPE): if mime_type in attributes: - attributes[mime_type] = attributes[mime_type].value + attributes[mime_type] = attributes[mime_type] for key, value in span.attributes.items(): if value is None: diff --git a/src/phoenix/trace/schemas.py b/src/phoenix/trace/schemas.py index efebeea439..7caeab9bf4 100644 --- a/src/phoenix/trace/schemas.py +++ b/src/phoenix/trace/schemas.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, NamedTuple, Optional, Union from uuid import UUID EXCEPTION_TYPE = "exception.type" @@ -142,7 +142,7 @@ class Span: "If the parent_id is None, this is the root span" parent_id: Optional[SpanID] start_time: datetime - end_time: Optional[datetime] + end_time: datetime status_code: SpanStatusCode status_message: str """ @@ -202,3 +202,11 @@ class ComputedAttributes(Enum): CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION = "cumulative_token_count.completion" ERROR_COUNT = "error_count" CUMULATIVE_ERROR_COUNT = "cumulative_error_count" + + +class ComputedValues(NamedTuple): + latency_ms: float + cumulative_error_count: int + cumulative_llm_token_count_prompt: int + cumulative_llm_token_count_completion: int + cumulative_llm_token_count_total: int diff --git a/tests/server/api/types/conftest.py b/tests/server/api/types/conftest.py index c33708ae48..27028ca454 100644 --- a/tests/server/api/types/conftest.py +++ b/tests/server/api/types/conftest.py @@ -42,6 +42,7 @@ def create_context(primary_dataset: Dataset, reference_dataset: Optional[Dataset response=None, model=create_model_from_datasets(primary_dataset, reference_dataset), export_path=Path(TemporaryDirectory().name), + db=None, # TODO(persistence): add mock for db ) return create_context