diff --git a/app/schema.graphql b/app/schema.graphql index e5ff7dbcae..26cc2d66f5 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -547,7 +547,7 @@ type Project implements Node { latencyMsP50: Float latencyMsP99: Float trace(traceId: ID!): Trace - spans(timeRange: TimeRange, traceIds: [ID!], first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection! + spans(timeRange: TimeRange, first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection! """ Names of all available evaluations for traces. (The list contains no duplicates.) diff --git a/pyproject.toml b/pyproject.toml index 3c7b1cea88..6726f90194 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "openinference-instrumentation-openai>=0.1.4", "sqlalchemy>=2, <3", "alembic>=1.3.0, <2", + "aiosqlite", ] dynamic = ["version"] diff --git a/src/phoenix/core/traces.py b/src/phoenix/core/traces.py index 6fff53e469..0681e5479e 100644 --- a/src/phoenix/core/traces.py +++ b/src/phoenix/core/traces.py @@ -1,10 +1,9 @@ import weakref from collections import defaultdict -from datetime import datetime from queue import SimpleQueue from threading import RLock, Thread from types import MethodType -from typing import DefaultDict, Iterator, Optional, Protocol, Tuple, Union +from typing import DefaultDict, Iterator, Optional, Tuple, Union from typing_extensions import assert_never @@ -13,45 +12,16 @@ from phoenix.core.project import ( END_OF_QUEUE, Project, - WrappedSpan, _ProjectName, ) -from phoenix.trace.schemas import ComputedAttributes, ComputedValues, Span, TraceID +from phoenix.trace.schemas import Span _SpanItem = Tuple[Span, _ProjectName] _EvalItem = Tuple[pb.Evaluation, _ProjectName] -class Database(Protocol): - def insert_span(self, span: Span, project_name: str) -> None: ... - - def trace_count( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: ... - - def span_count( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: ... - - def llm_token_count_total( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: ... - - def get_trace(self, trace_id: TraceID) -> Iterator[Tuple[Span, ComputedValues]]: ... - - class Traces: - def __init__(self, database: Database) -> None: - self._database = database + def __init__(self) -> None: self._span_queue: "SimpleQueue[Optional[_SpanItem]]" = SimpleQueue() self._eval_queue: "SimpleQueue[Optional[_EvalItem]]" = SimpleQueue() # Putting `None` as the sentinel value for queue termination. @@ -64,45 +34,6 @@ def __init__(self, database: Database) -> None: ) self._start_consumers() - def trace_count( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: - return self._database.trace_count(project_name, start_time, stop_time) - - def span_count( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: - return self._database.span_count(project_name, start_time, stop_time) - - def llm_token_count_total( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: - return self._database.llm_token_count_total(project_name, start_time, stop_time) - - def get_trace(self, trace_id: TraceID) -> Iterator[WrappedSpan]: - for span, computed_values in self._database.get_trace(trace_id): - wrapped_span = WrappedSpan(span) - wrapped_span[ComputedAttributes.LATENCY_MS] = computed_values.latency_ms - wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT] = ( - computed_values.cumulative_llm_token_count_prompt - ) - wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION] = ( - computed_values.cumulative_llm_token_count_completion - ) - wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL] = ( - computed_values.cumulative_llm_token_count_total - ) - yield wrapped_span - def get_project(self, project_name: str) -> Optional["Project"]: with self._lock: return self._projects.get(project_name) @@ -153,7 +84,6 @@ def _start_consumers(self) -> None: def _consume_spans(self, queue: "SimpleQueue[Optional[_SpanItem]]") -> None: while (item := queue.get()) is not END_OF_QUEUE: span, project_name = item - self._database.insert_span(span, project_name=project_name) with self._lock: project = self._projects[project_name] project.add_span(span) diff --git a/src/phoenix/db/database.py b/src/phoenix/db/database.py deleted file mode 100644 index be23975d6e..0000000000 --- a/src/phoenix/db/database.py +++ /dev/null @@ -1,351 +0,0 @@ -import json -import sqlite3 -from datetime import datetime -from enum import Enum -from pathlib import Path -from sqlite3 import Connection -from typing import Any, Iterator, Optional, Tuple, cast - -import numpy as np -from openinference.semconv.trace import SpanAttributes -from sqlalchemy import Engine, create_engine, event, insert -from sqlalchemy.orm import sessionmaker - -from phoenix.db.models import Base, Project, Trace -from phoenix.trace.schemas import ( - ComputedValues, - Span, - SpanContext, - SpanEvent, - SpanKind, - SpanStatusCode, -) - -from .migrate import migrate - -_CONFIG = """ -PRAGMA foreign_keys = ON; -PRAGMA journal_mode = WAL; -PRAGMA synchronous = OFF; -PRAGMA cache_size = -32000; -PRAGMA busy_timeout = 10000; -""" - - -_MEM_DB_STR = "file::memory:?cache=shared" - - -def _mem_db_creator() -> Any: - return sqlite3.connect(_MEM_DB_STR, uri=True) - - -@event.listens_for(Engine, "connect") -def set_sqlite_pragma(dbapi_connection: Connection, _: Any) -> None: - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys = ON;") - cursor.execute("PRAGMA journal_mode = WAL;") - cursor.execute("PRAGMA synchronous = OFF;") - cursor.execute("PRAGMA cache_size = -32000;") - cursor.execute("PRAGMA busy_timeout = 10000;") - cursor.close() - - -class SqliteDatabase: - def __init__(self, db_path: Optional[Path] = None) -> None: - """ - :param db_path: The path to the database file to be opened. - """ - self.con = sqlite3.connect( - database=db_path or _MEM_DB_STR, - uri=True, - check_same_thread=False, - ) - # self.con.set_trace_callback(print) - cur = self.con.cursor() - cur.executescript(_CONFIG) - - db_url = f"sqlite:///{db_path}" if db_path else "sqlite:///:memory:" - engine = ( - create_engine(db_url, echo=True) - if db_path - else create_engine( - db_url, - echo=True, - creator=_mem_db_creator, - ) - ) - - # TODO this should be moved out - if db_path: - migrate(db_url) - else: - Base.metadata.create_all(engine) - # Create the default project and setup indexes - with engine.connect() as conn: - conn.execute(insert(Project).values(name="default", description="default project")) - conn.commit() - - self.Session = sessionmaker(bind=engine) - - def insert_span(self, span: Span, project_name: str) -> None: - cur = self.con.cursor() - cur.execute("BEGIN;") - try: - if not ( - projects := cur.execute( - "SELECT rowid FROM projects WHERE name = ?;", - (project_name,), - ).fetchone() - ): - projects = cur.execute( - "INSERT INTO projects(name) VALUES(?) RETURNING rowid;", - (project_name,), - ).fetchone() - project_rowid = projects[0] - if ( - trace_row := cur.execute( - """ - INSERT INTO traces(trace_id, project_rowid, session_id, start_time, end_time) - VALUES(?,?,?,?,?) - ON CONFLICT DO UPDATE SET - start_time = CASE WHEN excluded.start_time < start_time THEN excluded.start_time ELSE start_time END, - end_time = CASE WHEN end_time < excluded.end_time THEN excluded.end_time ELSE end_time END - WHERE excluded.start_time < start_time OR end_time < excluded.end_time - RETURNING rowid; - """, # noqa E501 - ( - span.context.trace_id, - project_rowid, - None, - span.start_time, - span.end_time, - ), - ).fetchone() - ) is None: - trace_row = cur.execute( - "SELECT rowid from traces where trace_id = ?", (span.context.trace_id,) - ).fetchone() - trace_rowid = trace_row[0] - cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) - cumulative_llm_token_count_prompt = cast( - int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0) - ) - cumulative_llm_token_count_completion = cast( - int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0) - ) - if accumulation := cur.execute( - """ - SELECT - sum(cumulative_error_count), - sum(cumulative_llm_token_count_prompt), - sum(cumulative_llm_token_count_completion) - FROM spans - WHERE parent_span_id = ? - """, # noqa E501 - (span.context.span_id,), - ).fetchone(): - cumulative_error_count += cast(int, accumulation[0] or 0) - cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0) - cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0) - latency_ms = (span.end_time - span.start_time).total_seconds() * 1000 - cur.execute( - """ - INSERT INTO spans(span_id, trace_rowid, parent_span_id, kind, name, start_time, end_time, attributes, events, status, status_message, latency_ms, cumulative_error_count, cumulative_llm_token_count_prompt, cumulative_llm_token_count_completion) - VALUES(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) - RETURNING rowid; - """, # noqa E501 - ( - span.context.span_id, - trace_rowid, - span.parent_id, - span.span_kind.value, - span.name, - span.start_time, - span.end_time, - json.dumps(span.attributes, cls=_Encoder), - json.dumps(span.events, cls=_Encoder), - span.status_code.value, - span.status_message, - latency_ms, - cumulative_error_count, - cumulative_llm_token_count_prompt, - cumulative_llm_token_count_completion, - ), - ) - parent_id = span.parent_id - while parent_id: - if parent_span := cur.execute( - """ - SELECT rowid, parent_span_id - FROM spans - WHERE span_id = ? - """, - (parent_id,), - ).fetchone(): - rowid, parent_id = parent_span[0], parent_span[1] - cur.execute( - """ - UPDATE spans SET - cumulative_error_count = cumulative_error_count + ?, - cumulative_llm_token_count_prompt = cumulative_llm_token_count_prompt + ?, - cumulative_llm_token_count_completion = cumulative_llm_token_count_completion + ? - WHERE rowid = ?; - """, # noqa E501 - ( - cumulative_error_count, - cumulative_llm_token_count_prompt, - cumulative_llm_token_count_completion, - rowid, - ), - ) - else: - break - except Exception: - cur.execute("ROLLBACK;") - else: - cur.execute("COMMIT;") - - def get_projects(self) -> Iterator[Tuple[int, str]]: - cur = self.con.cursor() - for project in cur.execute("SELECT rowid, name FROM projects;").fetchall(): - yield cast(Tuple[int, str], (project[0], project[1])) - - def trace_count( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: - query = """ - SELECT COUNT(*) - FROM traces - JOIN projects ON projects.rowid = traces.project_rowid - WHERE projects.name = ? - """ - cur = self.con.cursor() - if start_time and stop_time: - cur = cur.execute( - query + " AND ? <= traces.start_time AND traces.start_time < ?;", - (project_name, start_time, stop_time), - ) - elif start_time: - cur = cur.execute(query + " AND ? <= traces.start_time;", (project_name, start_time)) - elif stop_time: - cur = cur.execute(query + " AND traces.start_time < ?;", (project_name, stop_time)) - else: - cur = cur.execute(query + ";", (project_name,)) - if res := cur.fetchone(): - return cast(int, res[0] or 0) - return 0 - - def span_count( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: - query = """ - SELECT COUNT(*) - FROM spans - JOIN traces ON traces.rowid = spans.trace_rowid - JOIN projects ON projects.rowid = traces.project_rowid - WHERE projects.name = ? - """ - cur = self.con.cursor() - if start_time and stop_time: - cur = cur.execute( - query + " AND ? <= spans.start_time AND spans.start_time < ?;", - (project_name, start_time, stop_time), - ) - elif start_time: - cur = cur.execute(query + " AND ? <= spans.start_time;", (project_name, start_time)) - elif stop_time: - cur = cur.execute(query + " AND spans.start_time < ?;", (project_name, stop_time)) - else: - cur = cur.execute(query + ";", (project_name,)) - if res := cur.fetchone(): - return cast(int, res[0] or 0) - return 0 - - def llm_token_count_total( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: - query = """ - SELECT - SUM(COALESCE(json_extract(spans.attributes, '$."llm.token_count.prompt"'), 0) + - COALESCE(json_extract(spans.attributes, '$."llm.token_count.completion"'), 0)) - FROM spans - JOIN traces ON traces.rowid = spans.trace_rowid - JOIN projects ON projects.rowid = traces.project_rowid - WHERE projects.name = ? - """ # noqa E501 - cur = self.con.cursor() - if start_time and stop_time: - cur = cur.execute( - query + " AND ? <= spans.start_time AND spans.start_time < ?;", - (project_name, start_time, stop_time), - ) - elif start_time: - cur = cur.execute(query + " AND ? <= spans.start_time;", (project_name, start_time)) - elif stop_time: - cur = cur.execute(query + " AND spans.start_time < ?;", (project_name, stop_time)) - else: - cur = cur.execute(query + ";", (project_name,)) - if res := cur.fetchone(): - return cast(int, res[0] or 0) - return 0 - - def get_trace(self, trace_id: str) -> Iterator[Tuple[Span, ComputedValues]]: - with self.Session.begin() as session: - trace = session.query(Trace).where(Trace.trace_id == trace_id).one_or_none() - if not trace: - return - for span in trace.spans: - yield ( - Span( - name=span.name, - context=SpanContext(trace_id=span.trace.trace_id, span_id=span.span_id), - parent_id=span.parent_span_id, - span_kind=SpanKind(span.kind), - start_time=span.start_time, - end_time=span.end_time, - attributes=span.attributes, - events=[ - SpanEvent( - name=obj["name"], - attributes=obj["attributes"], - timestamp=obj["timestamp"], - ) - for obj in span.events - ], - status_code=SpanStatusCode(span.status), - status_message=span.status_message, - conversation=None, - ), - ComputedValues( - latency_ms=span.latency_ms, - cumulative_error_count=span.cumulative_error_count, - cumulative_llm_token_count_prompt=span.cumulative_llm_token_count_prompt, - cumulative_llm_token_count_completion=span.cumulative_llm_token_count_completion, - cumulative_llm_token_count_total=span.cumulative_llm_token_count_prompt - + span.cumulative_llm_token_count_completion, - ), - ) - - -class _Encoder(json.JSONEncoder): - def default(self, obj: Any) -> Any: - if isinstance(obj, datetime): - return obj.isoformat() - elif isinstance(obj, Enum): - return obj.value - elif isinstance(obj, np.ndarray): - return list(obj) - elif isinstance(obj, np.integer): - return int(obj) - elif isinstance(obj, np.floating): - return float(obj) - return super().default(obj) diff --git a/src/phoenix/db/engines.py b/src/phoenix/db/engines.py new file mode 100644 index 0000000000..001aa8a352 --- /dev/null +++ b/src/phoenix/db/engines.py @@ -0,0 +1,58 @@ +import asyncio +import json +from datetime import datetime +from enum import Enum +from pathlib import Path +from sqlite3 import Connection +from typing import Any, Union + +import numpy as np +from sqlalchemy import URL, event +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from phoenix.db.migrate import migrate +from phoenix.db.models import init_models + + +def set_sqlite_pragma(connection: Connection, _: Any) -> None: + cursor = connection.cursor() + cursor.execute("PRAGMA foreign_keys = ON;") + cursor.execute("PRAGMA journal_mode = WAL;") + cursor.execute("PRAGMA synchronous = OFF;") + cursor.execute("PRAGMA cache_size = -32000;") + cursor.execute("PRAGMA busy_timeout = 10000;") + cursor.close() + + +def aiosqlite_engine( + database: Union[str, Path] = ":memory:", + echo: bool = False, +) -> AsyncEngine: + driver_name = "sqlite+aiosqlite" + url = URL.create(driver_name, database=str(database)) + engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps) + event.listen(engine.sync_engine, "connect", set_sqlite_pragma) + if str(database) == ":memory:": + asyncio.run(init_models(engine)) + else: + migrate(url) + return engine + + +def _dumps(obj: Any) -> str: + return json.dumps(obj, cls=_Encoder) + + +class _Encoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, datetime): + return obj.isoformat() + elif isinstance(obj, Enum): + return obj.value + elif isinstance(obj, np.ndarray): + return list(obj) + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + return super().default(obj) diff --git a/src/phoenix/db/migrate.py b/src/phoenix/db/migrate.py index 35d640300a..990c390e14 100644 --- a/src/phoenix/db/migrate.py +++ b/src/phoenix/db/migrate.py @@ -3,11 +3,12 @@ from alembic import command from alembic.config import Config +from sqlalchemy import URL logger = logging.getLogger(__name__) -def migrate(url: str) -> None: +def migrate(url: URL) -> None: """ Runs migrations on the database. NB: Migrate only works on non-memory databases. @@ -22,6 +23,6 @@ def migrate(url: str) -> None: # Explicitly set the migration directory scripts_location = str(Path(__file__).parent.resolve() / "migrations") alembic_cfg.set_main_option("script_location", scripts_location) - alembic_cfg.set_main_option("sqlalchemy.url", url) + alembic_cfg.set_main_option("sqlalchemy.url", str(url)) command.upgrade(alembic_cfg, "head") diff --git a/src/phoenix/db/migrations/env.py b/src/phoenix/db/migrations/env.py index 6e3c4d0d3c..2cc51224e6 100644 --- a/src/phoenix/db/migrations/env.py +++ b/src/phoenix/db/migrations/env.py @@ -1,8 +1,10 @@ +import asyncio from logging.config import fileConfig from alembic import context from phoenix.db.models import Base -from sqlalchemy import engine_from_config, pool +from sqlalchemy import Connection, engine_from_config, pool +from sqlalchemy.ext.asyncio import AsyncEngine # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -55,17 +57,37 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - connectable = engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, + connectable = context.config.attributes.get("connection", None) + if connectable is None: + connectable = AsyncEngine( + engine_from_config( + context.config.get_section(context.config.config_ini_section) or {}, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + future=True, + ) + ) + + if isinstance(connectable, AsyncEngine): + asyncio.run(run_async_migrations(connectable)) + else: + run_migrations(connectable) + + +async def run_async_migrations(engine: AsyncEngine) -> None: + async with engine.connect() as connection: + await connection.run_sync(run_migrations) + await engine.dispose() + + +def run_migrations(connection: Connection) -> None: + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, ) - - with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) - - with context.begin_transaction(): - context.run_migrations() + with context.begin_transaction(): + context.run_migrations() if context.is_offline_mode(): diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index b52428f1f8..5f10f21931 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -9,7 +9,9 @@ MetaData, UniqueConstraint, func, + insert, ) +from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.orm import ( DeclarativeBase, Mapped, @@ -116,7 +118,18 @@ class Span(Base): __table_args__ = ( UniqueConstraint( "span_id", - name="uq_spans_trace_id", + name="uq_spans_span_id", sqlite_on_conflict="IGNORE", ), ) + + +async def init_models(engine: AsyncEngine) -> None: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + await conn.execute( + insert(Project).values( + name="default", + description="default project", + ) + ) diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index 2d6a95be7e..bb89e09f72 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -1,7 +1,8 @@ from dataclasses import dataclass from pathlib import Path -from typing import Optional, Union +from typing import AsyncContextManager, Callable, Optional, Union +from sqlalchemy.ext.asyncio import AsyncSession from starlette.requests import Request from starlette.responses import Response from starlette.websockets import WebSocket @@ -14,6 +15,7 @@ class Context: request: Union[Request, WebSocket] response: Optional[Response] + db: Callable[[], AsyncContextManager[AsyncSession]] model: Model export_path: Path corpus: Optional[Model] = None diff --git a/src/phoenix/server/api/routers/trace_handler.py b/src/phoenix/server/api/routers/trace_handler.py index 221f40ad99..507bb0df40 100644 --- a/src/phoenix/server/api/routers/trace_handler.py +++ b/src/phoenix/server/api/routers/trace_handler.py @@ -1,25 +1,31 @@ import asyncio import gzip import zlib -from typing import Optional +from typing import AsyncContextManager, Callable, Optional, cast from google.protobuf.message import DecodeError +from openinference.semconv.trace import SpanAttributes from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( ExportTraceServiceRequest, ) from opentelemetry.proto.trace.v1.trace_pb2 import TracesData +from sqlalchemy import func, insert, select, text, update +from sqlalchemy.ext.asyncio import AsyncSession from starlette.endpoints import HTTPEndpoint from starlette.requests import Request from starlette.responses import Response from starlette.status import HTTP_415_UNSUPPORTED_MEDIA_TYPE, HTTP_422_UNPROCESSABLE_ENTITY from phoenix.core.traces import Traces +from phoenix.db import models from phoenix.storage.span_store import SpanStore from phoenix.trace.otel import decode +from phoenix.trace.schemas import Span, SpanStatusCode from phoenix.utilities.project import get_project_name class TraceHandler(HTTPEndpoint): + db: Callable[[], AsyncContextManager[AsyncSession]] traces: Traces store: Optional[SpanStore] @@ -54,7 +60,111 @@ async def post(self, request: Request) -> Response: for resource_spans in req.resource_spans: project_name = get_project_name(resource_spans.resource.attributes) for scope_span in resource_spans.scope_spans: - for span in scope_span.spans: - self.traces.put(decode(span), project_name=project_name) + for otlp_span in scope_span.spans: + span = decode(otlp_span) + async with self.db() as session: + await _insert_span(session, span, project_name) + self.traces.put(span, project_name=project_name) await asyncio.sleep(0) return Response() + + +async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> None: + if not ( + project_rowid := await session.scalar( + select(models.Project.id).filter(models.Project.name == project_name) + ) + ): + project_rowid = await session.scalar( + insert(models.Project).values(name=project_name).returning(models.Project.id) + ) + if ( + trace_rowid := await session.scalar( + text( + """ + INSERT INTO traces(trace_id, project_rowid, session_id, start_time, end_time) + VALUES(:trace_id, :project_rowid, :session_id, :start_time, :end_time) + ON CONFLICT DO UPDATE SET + start_time = CASE WHEN excluded.start_time < start_time THEN excluded.start_time ELSE start_time END, + end_time = CASE WHEN end_time < excluded.end_time THEN excluded.end_time ELSE end_time END + WHERE excluded.start_time < start_time OR end_time < excluded.end_time + RETURNING rowid; + """ # noqa E501 + ), + { + "trace_id": span.context.trace_id, + "project_rowid": project_rowid, + "session_id": None, + "start_time": span.start_time, + "end_time": span.end_time, + }, + ) + ) is None: + trace_rowid = await session.scalar( + select(models.Trace.id).filter(models.Trace.trace_id == span.context.trace_id) + ) + cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) + cumulative_llm_token_count_prompt = cast( + int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0) + ) + cumulative_llm_token_count_completion = cast( + int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0) + ) + if accumulation := ( + await session.execute( + select( + func.sum(models.Span.cumulative_error_count), + func.sum(models.Span.cumulative_llm_token_count_prompt), + func.sum(models.Span.cumulative_llm_token_count_completion), + ).where(models.Span.parent_span_id == span.context.span_id) + ) + ).first(): + cumulative_error_count += cast(int, accumulation[0] or 0) + cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0) + cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0) + latency_ms = (span.end_time - span.start_time).total_seconds() * 1000 + session.add( + models.Span( + span_id=span.context.span_id, + trace_rowid=trace_rowid, + parent_span_id=span.parent_id, + kind=span.span_kind.value, + name=span.name, + start_time=span.start_time, + end_time=span.end_time, + attributes=span.attributes, + events=span.events, + status=span.status_code.value, + status_message=span.status_message, + latency_ms=latency_ms, + cumulative_error_count=cumulative_error_count, + cumulative_llm_token_count_prompt=cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion=cumulative_llm_token_count_completion, + ) + ) + # Propagate cumulative values to ancestors. This is usually a no-op, since + # the parent usually arrives after the child. But in the event that a + # child arrives after its parent, we need to make sure the all the + # ancestors' cumulative values are updated. + ancestors = ( + select(models.Span.id, models.Span.parent_span_id) + .where(models.Span.span_id == span.parent_id) + .cte(recursive=True) + ) + child = ancestors.alias() + ancestors = ancestors.union_all( + select(models.Span.id, models.Span.parent_span_id).join( + child, models.Span.span_id == child.c.parent_span_id + ) + ) + await session.execute( + update(models.Span) + .where(models.Span.id.in_(select(ancestors.c.id))) + .values( + cumulative_error_count=models.Span.cumulative_error_count + cumulative_error_count, + cumulative_llm_token_count_prompt=models.Span.cumulative_llm_token_count_prompt + + cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion=models.Span.cumulative_llm_token_count_completion + + cumulative_llm_token_count_completion, + ) + ) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 671d74cd93..19c152ecb4 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -1,12 +1,17 @@ from datetime import datetime -from itertools import chain from typing import List, Optional import strawberry +from openinference.semconv.trace import SpanAttributes +from sqlalchemy import Integer, and_, cast, func, select +from sqlalchemy.orm import contains_eager +from sqlalchemy.sql.functions import coalesce from strawberry import ID, UNSET from strawberry.types import Info from phoenix.core.project import Project as CoreProject +from phoenix.datetime_utils import right_open_time_range +from phoenix.db import models from phoenix.metrics.retrieval_metrics import RetrievalMetrics from phoenix.server.api.context import Context from phoenix.server.api.input_types.SpanSort import SpanSort @@ -33,53 +38,104 @@ class Project(Node): project: strawberry.Private[CoreProject] @strawberry.field - def start_time(self) -> Optional[datetime]: - start_time, _ = self.project.right_open_time_range + async def start_time( + self, + info: Info[Context, None], + ) -> Optional[datetime]: + stmt = ( + select(func.min(models.Trace.start_time)) + .join(models.Project) + .where(models.Project.name == self.name) + ) + async with info.context.db() as session: + start_time = await session.scalar(stmt) + start_time, _ = right_open_time_range(start_time, None) return start_time @strawberry.field - def end_time(self) -> Optional[datetime]: - _, end_time = self.project.right_open_time_range + async def end_time( + self, + info: Info[Context, None], + ) -> Optional[datetime]: + stmt = ( + select(func.max(models.Trace.end_time)) + .join(models.Project) + .where(models.Project.name == self.name) + ) + async with info.context.db() as session: + end_time = await session.scalar(stmt) + _, end_time = right_open_time_range(None, end_time) return end_time @strawberry.field - def record_count( + async def record_count( self, info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - if not (traces := info.context.traces): - return 0 - start_time, stop_time = ( - (None, None) if not time_range else (time_range.start, time_range.end) + stmt = ( + select(func.count(models.Span.id)) + .join(models.Trace) + .join(models.Project) + .where(models.Project.name == self.name) ) - return traces.span_count(self.name, start_time, stop_time) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Span.start_time, + models.Span.start_time < time_range.end, + ) + ) + async with info.context.db() as session: + return (await session.scalar(stmt)) or 0 @strawberry.field - def trace_count( + async def trace_count( self, info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - if not (traces := info.context.traces): - return 0 - start_time, stop_time = ( - (None, None) if not time_range else (time_range.start, time_range.end) + stmt = ( + select(func.count(models.Trace.id)) + .join(models.Project) + .where(models.Project.name == self.name) ) - return traces.trace_count(self.name, start_time, stop_time) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Trace.start_time, + models.Trace.start_time < time_range.end, + ) + ) + async with info.context.db() as session: + return (await session.scalar(stmt)) or 0 @strawberry.field - def token_count_total( + async def token_count_total( self, info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - if not (traces := info.context.traces): - return 0 - start_time, stop_time = ( - (None, None) if not time_range else (time_range.start, time_range.end) + prompt = models.Span.attributes[LLM_TOKEN_COUNT_PROMPT] + completion = models.Span.attributes[LLM_TOKEN_COUNT_COMPLETION] + stmt = ( + select( + coalesce(func.sum(cast(prompt, Integer)), 0) + + coalesce(func.sum(cast(completion, Integer)), 0) + ) + .join(models.Trace) + .join(models.Project) + .where(models.Project.name == self.name) ) - return traces.llm_token_count_total(self.name, start_time, stop_time) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Span.start_time, + models.Span.start_time < time_range.end, + ) + ) + async with info.context.db() as session: + return (await session.scalar(stmt)) or 0 @strawberry.field def latency_ms_p50(self) -> Optional[float]: @@ -96,10 +152,10 @@ def trace(self, trace_id: ID) -> Optional[Trace]: return None @strawberry.field - def spans( + async def spans( self, + info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, - trace_ids: Optional[List[ID]] = UNSET, first: Optional[int] = 50, last: Optional[int] = UNSET, after: Optional[Cursor] = UNSET, @@ -114,36 +170,32 @@ def spans( last=last, before=before if isinstance(before, Cursor) else None, ) - start_time = time_range.start if time_range else None - stop_time = time_range.end if time_range else None - if not (project := self.project).span_count( - start_time=start_time, - stop_time=stop_time, - ): - return connection_from_list(data=[], args=args) - predicate = ( - SpanFilter( - condition=filter_condition, - evals=project, - ) - if filter_condition - else None + stmt = ( + select(models.Span) + .join(models.Trace) + .join(models.Project) + .where(models.Project.name == self.name) + .options(contains_eager(models.Span.trace)) ) - if not trace_ids: - spans = project.get_spans( - start_time=start_time, - stop_time=stop_time, - root_spans_only=root_spans_only, - ) - else: - spans = chain.from_iterable( - project.get_trace(trace_id) for trace_id in map(TraceID, trace_ids) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Span.start_time, + models.Span.start_time < time_range.end, + ) ) - if predicate: - spans = filter(predicate, spans) - if sort: - spans = sort(spans, evals=project) - data = [to_gql_span(span, project) for span in spans] + if root_spans_only: + # A root span is any span whose parent span is missing in the + # database, even if its `parent_span_id` may not be NULL. + parent = select(models.Span.span_id).alias() + stmt = stmt.outerjoin( + parent, + models.Span.parent_span_id == parent.c.span_id, + ).where(parent.c.span_id.is_(None)) + # TODO(persistence): enable sort and filter + async with info.context.db() as session: + spans = await session.scalars(stmt) + data = [to_gql_span(span, self.project) for span in spans] return connection_from_list(data=data, args=args) @strawberry.field( @@ -305,3 +357,7 @@ def validate_span_filter_condition(self, condition: str) -> ValidationResult: is_valid=False, error_message=e.msg, ) + + +LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT +LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION diff --git a/src/phoenix/server/api/types/Span.py b/src/phoenix/server/api/types/Span.py index 52eebc7786..9025ab4ca9 100644 --- a/src/phoenix/server/api/types/Span.py +++ b/src/phoenix/server/api/types/Span.py @@ -7,17 +7,20 @@ import numpy as np import strawberry from openinference.semconv.trace import EmbeddingAttributes, SpanAttributes +from sqlalchemy import select +from sqlalchemy.orm import contains_eager from strawberry import ID, UNSET from strawberry.types import Info import phoenix.trace.schemas as trace_schema -from phoenix.core.project import Project, WrappedSpan +from phoenix.core.project import Project +from phoenix.db import models from phoenix.metrics.retrieval_metrics import RetrievalMetrics from phoenix.server.api.context import Context from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation from phoenix.server.api.types.MimeType import MimeType -from phoenix.trace.schemas import ComputedAttributes, SpanID +from phoenix.trace.schemas import SpanID EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR @@ -40,14 +43,14 @@ class SpanKind(Enum): NB: this is actively under construction """ - chain = trace_schema.SpanKind.CHAIN - tool = trace_schema.SpanKind.TOOL - llm = trace_schema.SpanKind.LLM - retriever = trace_schema.SpanKind.RETRIEVER - embedding = trace_schema.SpanKind.EMBEDDING - agent = trace_schema.SpanKind.AGENT - reranker = trace_schema.SpanKind.RERANKER - unknown = trace_schema.SpanKind.UNKNOWN + chain = "CHAIN" + tool = "TOOL" + llm = "LLM" + retriever = "RETRIEVER" + embedding = "EMBEDDING" + agent = "AGENT" + reranker = "RERANKER" + unknown = "UNKNOWN" @classmethod def _missing_(cls, v: Any) -> Optional["SpanKind"]: @@ -68,9 +71,9 @@ class SpanIOValue: @strawberry.enum class SpanStatusCode(Enum): - OK = trace_schema.SpanStatusCode.OK - ERROR = trace_schema.SpanStatusCode.ERROR - UNSET = trace_schema.SpanStatusCode.UNSET + OK = "OK" + ERROR = "ERROR" + UNSET = "UNSET" @classmethod def _missing_(cls, v: Any) -> Optional["SpanStatusCode"]: @@ -84,13 +87,13 @@ class SpanEvent: timestamp: datetime @staticmethod - def from_event( - event: trace_schema.SpanEvent, + def from_dict( + event: Mapping[str, Any], ) -> "SpanEvent": return SpanEvent( - name=event.name, - message=cast(str, event.attributes.get(trace_schema.EXCEPTION_MESSAGE) or ""), - timestamp=event.timestamp, + name=event["name"], + message=cast(str, event["attributes"].get(trace_schema.EXCEPTION_MESSAGE) or ""), + timestamp=event["timestamp"], ) @@ -202,18 +205,35 @@ def document_retrieval_metrics( @strawberry.field( description="All descendant spans (children, grandchildren, etc.)", ) # type: ignore - def descendants( + async def descendants( self, info: Info[Context, None], ) -> List["Span"]: - return [ - to_gql_span(span, self.project) - for span in self.project.get_descendant_spans(SpanID(self.context.span_id)) - ] + # TODO(persistence): add dataloader (to avoid N+1 queries) or change how this is fetched + async with info.context.db() as session: + descendant_ids = ( + select(models.Span.id, models.Span.span_id) + .filter(models.Span.parent_span_id == str(self.context.span_id)) + .cte(recursive=True) + ) + parent_ids = descendant_ids.alias() + descendant_ids = descendant_ids.union_all( + select(models.Span.id, models.Span.span_id).join( + parent_ids, + models.Span.parent_span_id == parent_ids.c.span_id, + ) + ) + spans = await session.scalars( + select(models.Span) + .join(descendant_ids, models.Span.id == descendant_ids.c.id) + .join(models.Trace) + .options(contains_eager(models.Span.trace)) + ) + return [to_gql_span(span, self.project) for span in spans] -def to_gql_span(span: WrappedSpan, project: Project) -> "Span": - events: List[SpanEvent] = list(map(SpanEvent.from_event, span.events)) +def to_gql_span(span: models.Span, project: Project) -> Span: + events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events)) input_value = cast(Optional[str], span.attributes.get(INPUT_VALUE)) output_value = cast(Optional[str], span.attributes.get(OUTPUT_VALUE)) retrieval_documents = span.attributes.get(RETRIEVAL_DOCUMENTS) @@ -221,16 +241,16 @@ def to_gql_span(span: WrappedSpan, project: Project) -> "Span": return Span( project=project, name=span.name, - status_code=SpanStatusCode(span.status_code), + status_code=SpanStatusCode(span.status), status_message=span.status_message, - parent_id=cast(Optional[ID], span.parent_id), - span_kind=SpanKind(span.span_kind), + parent_id=cast(Optional[ID], span.parent_span_id), + span_kind=SpanKind(span.kind), start_time=span.start_time, end_time=span.end_time, - latency_ms=cast(Optional[float], span[ComputedAttributes.LATENCY_MS]), + latency_ms=span.latency_ms, context=SpanContext( - trace_id=cast(ID, span.context.trace_id), - span_id=cast(ID, span.context.span_id), + trace_id=cast(ID, span.trace.trace_id), + span_id=cast(ID, span.span_id), ), attributes=json.dumps( _nested_attributes(_hide_embedding_vectors(span.attributes)), @@ -250,22 +270,12 @@ def to_gql_span(span: WrappedSpan, project: Project) -> "Span": Optional[int], span.attributes.get(LLM_TOKEN_COUNT_COMPLETION), ), - cumulative_token_count_total=cast( - Optional[int], - span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL], - ), - cumulative_token_count_prompt=cast( - Optional[int], - span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT], - ), - cumulative_token_count_completion=cast( - Optional[int], - span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION], - ), + cumulative_token_count_total=span.cumulative_llm_token_count_prompt + + span.cumulative_llm_token_count_completion, + cumulative_token_count_prompt=span.cumulative_llm_token_count_prompt, + cumulative_token_count_completion=span.cumulative_llm_token_count_completion, propagated_status_code=( - SpanStatusCode.ERROR - if span[ComputedAttributes.CUMULATIVE_ERROR_COUNT] - else SpanStatusCode(span.status_code) + SpanStatusCode.ERROR if span.cumulative_error_count else SpanStatusCode(span.status) ), events=events, input=( diff --git a/src/phoenix/server/api/types/Trace.py b/src/phoenix/server/api/types/Trace.py index a7d440bc52..075928ba8a 100644 --- a/src/phoenix/server/api/types/Trace.py +++ b/src/phoenix/server/api/types/Trace.py @@ -1,10 +1,13 @@ from typing import List, Optional import strawberry +from sqlalchemy import select +from sqlalchemy.orm import contains_eager from strawberry import ID, UNSET, Private from strawberry.types import Info from phoenix.core.project import Project +from phoenix.db import models from phoenix.server.api.context import Context from phoenix.server.api.types.Evaluation import TraceEvaluation from phoenix.server.api.types.pagination import ( @@ -23,7 +26,7 @@ class Trace: project: Private[Project] @strawberry.field - def spans( + async def spans( self, info: Info[Context, None], first: Optional[int] = 50, @@ -37,9 +40,13 @@ def spans( last=last, before=before if isinstance(before, Cursor) else None, ) - if not (traces := info.context.traces): - return connection_from_list(data=[], args=args) - spans = traces.get_trace(TraceID(self.trace_id)) + async with info.context.db() as session: + spans = await session.scalars( + select(models.Span) + .join(models.Trace) + .filter(models.Trace.trace_id == self.trace_id) + .options(contains_eager(models.Span.trace)) + ) data = [to_gql_span(span, self.project) for span in spans] return connection_from_list(data=data, args=args) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 2b814ce04f..04e9c17421 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -1,7 +1,21 @@ +import contextlib import logging from pathlib import Path -from typing import Any, NamedTuple, Optional, Union +from typing import ( + Any, + AsyncContextManager, + AsyncIterator, + Callable, + NamedTuple, + Optional, + Union, +) +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, +) from starlette.applications import Starlette from starlette.datastructures import QueryParams from starlette.endpoints import HTTPEndpoint @@ -96,12 +110,14 @@ class GraphQLWithContext(GraphQL): # type: ignore def __init__( self, schema: BaseSchema, + db: Callable[[], AsyncContextManager[AsyncSession]], model: Model, export_path: Path, graphiql: bool = False, corpus: Optional[Model] = None, traces: Optional[Traces] = None, ) -> None: + self.db = db self.model = model self.corpus = corpus self.traces = traces @@ -116,6 +132,7 @@ async def get_context( return Context( request=request, response=response, + db=self.db, model=self.model, corpus=self.corpus, traces=self.traces, @@ -142,7 +159,19 @@ async def version(_: Request) -> PlainTextResponse: return PlainTextResponse(f"{phoenix.__version__}") +def _db(engine: AsyncEngine) -> Callable[[], AsyncContextManager[AsyncSession]]: + Session = async_sessionmaker(engine, expire_on_commit=False) + + @contextlib.asynccontextmanager + async def factory() -> AsyncIterator[AsyncSession]: + async with Session.begin() as session: + yield session + + return factory + + def create_app( + engine: AsyncEngine, export_path: Path, model: Model, umap_params: UMAPParameters, @@ -153,7 +182,9 @@ def create_app( read_only: bool = False, enable_prometheus: bool = False, ) -> Starlette: + db = _db(engine) graphql = GraphQLWithContext( + db=db, schema=schema, model=model, corpus=corpus, @@ -183,7 +214,11 @@ def create_app( ), Route( "/v1/traces", - type("TraceEndpoint", (TraceHandler,), {"traces": traces, "store": span_store}), + type( + "TraceEndpoint", + (TraceHandler,), + {"db": staticmethod(db), "traces": traces, "store": span_store}, + ), ), Route( "/v1/evaluations", diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index ae78c49e15..14cda43050 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -1,4 +1,5 @@ import atexit +import gzip import logging import os from argparse import ArgumentParser @@ -9,6 +10,9 @@ from typing import Iterable, Optional, Protocol, TypeVar import pkg_resources +import requests +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest +from opentelemetry.proto.trace.v1.trace_pb2 import ResourceSpans, ScopeSpans from uvicorn import Config, Server from phoenix.config import ( @@ -22,7 +26,7 @@ from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset from phoenix.datasets.fixtures import FIXTURES, get_datasets -from phoenix.db.database import SqliteDatabase +from phoenix.db.engines import aiosqlite_engine from phoenix.pointcloud.umap_parameters import ( DEFAULT_MIN_DIST, DEFAULT_N_NEIGHBORS, @@ -38,6 +42,7 @@ get_evals_from_fixture, ) from phoenix.trace.otel import decode, encode +from phoenix.trace.schemas import Span from phoenix.trace.span_json_decoder import json_string_to_span from phoenix.utilities.span_store import get_span_store, load_traces_data_from_store @@ -108,6 +113,28 @@ def _load_items( queue.put(item) +def _send_spans(spans: Iterable[Span], url: str) -> None: + # TODO(persistence): Ingest fixtures without networking for read-only deployments + sleep(5) # Wait for the server to start + session = requests.session() + for span in spans: + req = ExportTraceServiceRequest( + resource_spans=[ResourceSpans(scope_spans=[ScopeSpans(spans=[encode(span)])])] + ) + session.post( + url=url, + headers={ + "content-type": "application/x-protobuf", + "content-encoding": "gzip", + }, + data=gzip.compress(req.SerializeToString()), + ) + # TODO(persistence): If ingestion rate is too high it can crash the UI, because + # sqlite is not designed for high concurrency, especially for disk + # persistence. + sleep(0.2) + + DEFAULT_UMAP_PARAMS_STR = f"{DEFAULT_MIN_DIST},{DEFAULT_N_NEIGHBORS},{DEFAULT_N_SAMPLES}" if __name__ == "__main__": @@ -191,14 +218,15 @@ def _load_items( trace_dataset_name = args.trace_fixture simulate_streaming = args.simulate_streaming + host = args.host or get_env_host() + port = args.port or get_env_port() + model = create_model_from_datasets( primary_dataset, reference_dataset, ) - working_dir = get_working_dir() - db = SqliteDatabase(working_dir / "phoenix.db") - traces = Traces(db) + traces = Traces() if span_store := get_span_store(): Thread(target=load_traces_data_from_store, args=(traces, span_store), daemon=True).start() if trace_dataset_name is not None: @@ -215,6 +243,11 @@ def _load_items( args=(traces, fixture_spans, simulate_streaming), daemon=True, ).start() + Thread( + target=_send_spans, + args=(fixture_spans, f"http://{host}:{port}/v1/traces"), + daemon=True, + ).start() fixture_evals = list(get_evals_from_fixture(trace_dataset_name)) Thread( target=_load_items, @@ -233,7 +266,11 @@ def _load_items( from phoenix.server.prometheus import start_prometheus start_prometheus() + + working_dir = get_working_dir().resolve() + engine = aiosqlite_engine(working_dir / "phoenix.db") app = create_app( + engine=engine, export_path=export_path, model=model, umap_params=umap_params, @@ -244,8 +281,6 @@ def _load_items( span_store=span_store, enable_prometheus=enable_prometheus, ) - host = args.host or get_env_host() - port = args.port or get_env_port() server = Server(config=Config(app, host=host, port=port)) Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start() diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index d072ff2bc9..c402cbac23 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -37,7 +37,7 @@ from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset -from phoenix.db.database import SqliteDatabase +from phoenix.db.engines import aiosqlite_engine from phoenix.pointcloud.umap_parameters import get_umap_parameters from phoenix.server.app import create_app from phoenix.server.thread_server import ThreadServer @@ -295,7 +295,7 @@ def __init__( if corpus_dataset is not None else None ) - self.traces = Traces(SqliteDatabase()) + self.traces = Traces() if trace_dataset: for span in trace_dataset.to_spans(): self.traces.put(span) @@ -310,6 +310,7 @@ def __init__( ).start() # Initialize an app service that keeps the server running self.app = create_app( + engine=aiosqlite_engine(), export_path=self.export_path, model=self.model, corpus=self.corpus, diff --git a/tests/server/api/types/conftest.py b/tests/server/api/types/conftest.py index c33708ae48..27028ca454 100644 --- a/tests/server/api/types/conftest.py +++ b/tests/server/api/types/conftest.py @@ -42,6 +42,7 @@ def create_context(primary_dataset: Dataset, reference_dataset: Optional[Dataset response=None, model=create_model_from_datasets(primary_dataset, reference_dataset), export_path=Path(TemporaryDirectory().name), + db=None, # TODO(persistence): add mock for db ) return create_context