From afd368d6ba318a2fd8788158efab8f1207b01671 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Mon, 1 Apr 2024 11:18:02 -0600 Subject: [PATCH 01/30] feat(persistence): sql persistence --- .github/workflows/python-CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-CI.yml b/.github/workflows/python-CI.yml index d9cc2f35f0..d3744e845a 100644 --- a/.github/workflows/python-CI.yml +++ b/.github/workflows/python-CI.yml @@ -2,7 +2,7 @@ name: Python CI on: push: - branches: [main] + branches: [main, sql] pull_request: paths: - "src/**" From fad00090fec91d580fef1e807297f7dd3725f8ae Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Tue, 2 Apr 2024 15:11:36 -0700 Subject: [PATCH 02/30] refactor: sqlite3 proof of concept (#2740) * wip --------- Co-authored-by: Mikyo King --- app/schema.graphql | 2 +- src/phoenix/core/traces.py | 76 ++++- src/phoenix/db/__init__.py | 0 src/phoenix/db/database.py | 354 +++++++++++++++++++++++ src/phoenix/server/api/types/MimeType.py | 4 +- src/phoenix/server/api/types/Project.py | 35 ++- src/phoenix/server/api/types/Trace.py | 10 +- src/phoenix/server/main.py | 22 +- src/phoenix/session/session.py | 40 ++- src/phoenix/trace/otel.py | 9 +- src/phoenix/trace/schemas.py | 12 +- 11 files changed, 510 insertions(+), 54 deletions(-) create mode 100644 src/phoenix/db/__init__.py create mode 100644 src/phoenix/db/database.py diff --git a/app/schema.graphql b/app/schema.graphql index 0236fc879f..e5ff7dbcae 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -543,7 +543,7 @@ type Project implements Node { endTime: DateTime recordCount(timeRange: TimeRange): Int! traceCount(timeRange: TimeRange): Int! - tokenCountTotal: Int! + tokenCountTotal(timeRange: TimeRange): Int! latencyMsP50: Float latencyMsP99: Float trace(traceId: ID!): Trace diff --git a/src/phoenix/core/traces.py b/src/phoenix/core/traces.py index 0681e5479e..6fff53e469 100644 --- a/src/phoenix/core/traces.py +++ b/src/phoenix/core/traces.py @@ -1,9 +1,10 @@ import weakref from collections import defaultdict +from datetime import datetime from queue import SimpleQueue from threading import RLock, Thread from types import MethodType -from typing import DefaultDict, Iterator, Optional, Tuple, Union +from typing import DefaultDict, Iterator, Optional, Protocol, Tuple, Union from typing_extensions import assert_never @@ -12,16 +13,45 @@ from phoenix.core.project import ( END_OF_QUEUE, Project, + WrappedSpan, _ProjectName, ) -from phoenix.trace.schemas import Span +from phoenix.trace.schemas import ComputedAttributes, ComputedValues, Span, TraceID _SpanItem = Tuple[Span, _ProjectName] _EvalItem = Tuple[pb.Evaluation, _ProjectName] +class Database(Protocol): + def insert_span(self, span: Span, project_name: str) -> None: ... + + def trace_count( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: ... + + def span_count( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: ... + + def llm_token_count_total( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: ... + + def get_trace(self, trace_id: TraceID) -> Iterator[Tuple[Span, ComputedValues]]: ... + + class Traces: - def __init__(self) -> None: + def __init__(self, database: Database) -> None: + self._database = database self._span_queue: "SimpleQueue[Optional[_SpanItem]]" = SimpleQueue() self._eval_queue: "SimpleQueue[Optional[_EvalItem]]" = SimpleQueue() # Putting `None` as the sentinel value for queue termination. @@ -34,6 +64,45 @@ def __init__(self) -> None: ) self._start_consumers() + def trace_count( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: + return self._database.trace_count(project_name, start_time, stop_time) + + def span_count( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: + return self._database.span_count(project_name, start_time, stop_time) + + def llm_token_count_total( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: + return self._database.llm_token_count_total(project_name, start_time, stop_time) + + def get_trace(self, trace_id: TraceID) -> Iterator[WrappedSpan]: + for span, computed_values in self._database.get_trace(trace_id): + wrapped_span = WrappedSpan(span) + wrapped_span[ComputedAttributes.LATENCY_MS] = computed_values.latency_ms + wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT] = ( + computed_values.cumulative_llm_token_count_prompt + ) + wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION] = ( + computed_values.cumulative_llm_token_count_completion + ) + wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL] = ( + computed_values.cumulative_llm_token_count_total + ) + yield wrapped_span + def get_project(self, project_name: str) -> Optional["Project"]: with self._lock: return self._projects.get(project_name) @@ -84,6 +153,7 @@ def _start_consumers(self) -> None: def _consume_spans(self, queue: "SimpleQueue[Optional[_SpanItem]]") -> None: while (item := queue.get()) is not END_OF_QUEUE: span, project_name = item + self._database.insert_span(span, project_name=project_name) with self._lock: project = self._projects[project_name] project.add_span(span) diff --git a/src/phoenix/db/__init__.py b/src/phoenix/db/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/phoenix/db/database.py b/src/phoenix/db/database.py new file mode 100644 index 0000000000..3e3486cca8 --- /dev/null +++ b/src/phoenix/db/database.py @@ -0,0 +1,354 @@ +import json +import sqlite3 +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any, Iterator, Optional, Tuple, cast + +import numpy as np +from openinference.semconv.trace import SpanAttributes + +from phoenix.trace.schemas import ComputedValues, Span, SpanContext, SpanKind, SpanStatusCode + +_CONFIG = """ +PRAGMA foreign_keys = ON; +PRAGMA journal_mode = WAL; +PRAGMA synchronous = OFF; +PRAGMA cache_size = -32000; +PRAGMA busy_timeout = 10000; +""" + +_INIT_DB = """ +BEGIN; +CREATE TABLE projects ( + rowid INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + description TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +INSERT INTO projects(name) VALUES('default'); +CREATE TABLE traces ( + rowid INTEGER PRIMARY KEY, + trace_id TEXT UNIQUE NOT NULL, + project_rowid INTEGER NOT NULL, + session_id TEXT, + start_time TIMESTAMP NOT NULL, + end_time TIMESTAMP NOT NULL, + FOREIGN KEY(project_rowid) REFERENCES projects(rowid) +); +CREATE INDEX idx_trace_start_time ON traces(start_time); +CREATE TABLE spans ( + rowid INTEGER PRIMARY KEY, + span_id TEXT UNIQUE NOT NULL, + trace_rowid INTEGER NOT NULL, + parent_span_id TEXT, + kind TEXT NOT NULL, + name TEXT NOT NULL, + start_time TIMESTAMP NOT NULL, + end_time TIMESTAMP NOT NULL, + attributes JSON, + events JSON, + status TEXT CHECK(status IN ('UNSET','OK','ERROR')) NOT NULL DEFAULT('UNSET'), + status_message TEXT, + cumulative_error_count INTEGER NOT NULL DEFAULT 0, + cumulative_llm_token_count_prompt INTEGER NOT NULL DEFAULT 0, + cumulative_llm_token_count_completion INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY(trace_rowid) REFERENCES traces(rowid) +); +CREATE INDEX idx_parent_span_id ON spans(parent_span_id); +PRAGMA user_version = 1; +COMMIT; +""" + + +class SqliteDatabase: + def __init__(self, database: Optional[Path] = None) -> None: + """ + :param database: The path to the database file to be opened. + """ + self.con = sqlite3.connect( + database=database or ":memory:", + uri=True, + check_same_thread=False, + ) + # self.con.set_trace_callback(print) + cur = self.con.cursor() + cur.executescript(_CONFIG) + if int(cur.execute("PRAGMA user_version;").fetchone()[0]) < 1: + cur.executescript(_INIT_DB) + + def insert_span(self, span: Span, project_name: str) -> None: + cur = self.con.cursor() + cur.execute("BEGIN;") + try: + if not ( + projects := cur.execute( + "SELECT rowid FROM projects WHERE name = ?;", + (project_name,), + ).fetchone() + ): + projects = cur.execute( + "INSERT INTO projects(name) VALUES(?) RETURNING rowid;", + (project_name,), + ).fetchone() + project_rowid = projects[0] + if ( + trace_row := cur.execute( + """ + INSERT INTO traces(trace_id, project_rowid, session_id, start_time, end_time) + VALUES(?,?,?,?,?) + ON CONFLICT DO UPDATE SET + start_time = CASE WHEN excluded.start_time < start_time THEN excluded.start_time ELSE start_time END, + end_time = CASE WHEN end_time < excluded.end_time THEN excluded.end_time ELSE end_time END + WHERE excluded.start_time < start_time OR end_time < excluded.end_time + RETURNING rowid; + """, # noqa E501 + ( + span.context.trace_id, + project_rowid, + None, + span.start_time, + span.end_time, + ), + ).fetchone() + ) is None: + trace_row = cur.execute( + "SELECT rowid from traces where trace_id = ?", (span.context.trace_id,) + ).fetchone() + trace_rowid = trace_row[0] + cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) + cumulative_llm_token_count_prompt = cast( + int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0) + ) + cumulative_llm_token_count_completion = cast( + int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0) + ) + if accumulation := cur.execute( + """ + SELECT + sum(cumulative_error_count), + sum(cumulative_llm_token_count_prompt), + sum(cumulative_llm_token_count_completion) + FROM spans + WHERE parent_span_id = ? + """, # noqa E501 + (span.context.span_id,), + ).fetchone(): + cumulative_error_count += cast(int, accumulation[0] or 0) + cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0) + cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0) + cur.execute( + """ + INSERT INTO spans(span_id, trace_rowid, parent_span_id, kind, name, start_time, end_time, attributes, events, status, status_message, cumulative_error_count, cumulative_llm_token_count_prompt, cumulative_llm_token_count_completion) + VALUES(?,?,?,?,?,?,?,?,?,?,?,?,?,?) + RETURNING rowid; + """, # noqa E501 + ( + span.context.span_id, + trace_rowid, + span.parent_id, + span.span_kind.value, + span.name, + span.start_time, + span.end_time, + json.dumps(span.attributes, cls=_Encoder), + json.dumps(span.events, cls=_Encoder), + span.status_code.value, + span.status_message, + cumulative_error_count, + cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion, + ), + ) + parent_id = span.parent_id + while parent_id: + if parent_span := cur.execute( + """ + SELECT rowid, parent_span_id + FROM spans + WHERE span_id = ? + """, + (parent_id,), + ).fetchone(): + rowid, parent_id = parent_span[0], parent_span[1] + cur.execute( + """ + UPDATE spans SET + cumulative_error_count = cumulative_error_count + ?, + cumulative_llm_token_count_prompt = cumulative_llm_token_count_prompt + ?, + cumulative_llm_token_count_completion = cumulative_llm_token_count_completion + ? + WHERE rowid = ?; + """, # noqa E501 + ( + cumulative_error_count, + cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion, + rowid, + ), + ) + else: + break + except Exception: + cur.execute("ROLLBACK;") + else: + cur.execute("COMMIT;") + + def get_projects(self) -> Iterator[Tuple[int, str]]: + cur = self.con.cursor() + for project in cur.execute("SELECT rowid, name FROM projects;").fetchall(): + yield cast(Tuple[int, str], (project[0], project[1])) + + def trace_count( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: + query = """ + SELECT COUNT(*) + FROM traces + JOIN projects ON projects.rowid = traces.project_rowid + WHERE projects.name = ? + """ + cur = self.con.cursor() + if start_time and stop_time: + cur = cur.execute( + query + " AND ? <= traces.start_time AND traces.start_time < ?;", + (project_name, start_time, stop_time), + ) + elif start_time: + cur = cur.execute(query + " AND ? <= traces.start_time;", (project_name, start_time)) + elif stop_time: + cur = cur.execute(query + " AND traces.start_time < ?;", (project_name, stop_time)) + else: + cur = cur.execute(query + ";", (project_name,)) + if res := cur.fetchone(): + return cast(int, res[0] or 0) + return 0 + + def span_count( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: + query = """ + SELECT COUNT(*) + FROM spans + JOIN traces ON traces.rowid = spans.trace_rowid + JOIN projects ON projects.rowid = traces.project_rowid + WHERE projects.name = ? + """ + cur = self.con.cursor() + if start_time and stop_time: + cur = cur.execute( + query + " AND ? <= spans.start_time AND spans.start_time < ?;", + (project_name, start_time, stop_time), + ) + elif start_time: + cur = cur.execute(query + " AND ? <= spans.start_time;", (project_name, start_time)) + elif stop_time: + cur = cur.execute(query + " AND spans.start_time < ?;", (project_name, stop_time)) + else: + cur = cur.execute(query + ";", (project_name,)) + if res := cur.fetchone(): + return cast(int, res[0] or 0) + return 0 + + def llm_token_count_total( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: + query = """ + SELECT + SUM(COALESCE(json_extract(spans.attributes, '$."llm.token_count.prompt"'), 0) + + COALESCE(json_extract(spans.attributes, '$."llm.token_count.completion"'), 0)) + FROM spans + JOIN traces ON traces.rowid = spans.trace_rowid + JOIN projects ON projects.rowid = traces.project_rowid + WHERE projects.name = ? + """ # noqa E501 + cur = self.con.cursor() + if start_time and stop_time: + cur = cur.execute( + query + " AND ? <= spans.start_time AND spans.start_time < ?;", + (project_name, start_time, stop_time), + ) + elif start_time: + cur = cur.execute(query + " AND ? <= spans.start_time;", (project_name, start_time)) + elif stop_time: + cur = cur.execute(query + " AND spans.start_time < ?;", (project_name, stop_time)) + else: + cur = cur.execute(query + ";", (project_name,)) + if res := cur.fetchone(): + return cast(int, res[0] or 0) + return 0 + + def get_trace(self, trace_id: str) -> Iterator[Tuple[Span, ComputedValues]]: + cur = self.con.cursor() + for span in cur.execute( + """ + SELECT + spans.span_id, + traces.trace_id, + spans.parent_span_id, + spans.kind, + spans.name, + spans.start_time, + spans.end_time, + spans.attributes, + spans.events, + spans.status, + spans.status_message, + spans.cumulative_error_count, + spans.cumulative_llm_token_count_prompt, + spans.cumulative_llm_token_count_completion + FROM spans + JOIN traces ON traces.rowid = spans.trace_rowid + WHERE traces.trace_id = ? + """, + (trace_id,), + ).fetchall(): + start_time = datetime.fromisoformat(span[5]) + end_time = datetime.fromisoformat(span[6]) + latency_ms = (end_time - start_time).total_seconds() * 1000 + yield ( + Span( + name=span[4], + context=SpanContext(trace_id=span[1], span_id=span[0]), + parent_id=span[2], + span_kind=SpanKind(span[3]), + start_time=start_time, + end_time=end_time, + attributes=json.loads(span[7]), + events=json.loads(span[8]), + status_code=SpanStatusCode(span[9]), + status_message=span[10], + conversation=None, + ), + ComputedValues( + latency_ms=latency_ms, + cumulative_error_count=span[11], + cumulative_llm_token_count_prompt=span[12], + cumulative_llm_token_count_completion=span[13], + cumulative_llm_token_count_total=span[12] + span[13], + ), + ) + + +class _Encoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, datetime): + return obj.isoformat() + elif isinstance(obj, Enum): + return obj.value + elif isinstance(obj, np.ndarray): + return list(obj) + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + return super().default(obj) diff --git a/src/phoenix/server/api/types/MimeType.py b/src/phoenix/server/api/types/MimeType.py index 4c33572994..9dbe40f477 100644 --- a/src/phoenix/server/api/types/MimeType.py +++ b/src/phoenix/server/api/types/MimeType.py @@ -8,8 +8,8 @@ @strawberry.enum class MimeType(Enum): - text = trace_schemas.MimeType.TEXT - json = trace_schemas.MimeType.JSON + text = trace_schemas.MimeType.TEXT.value + json = trace_schemas.MimeType.JSON.value @classmethod def _missing_(cls, v: Any) -> Optional["MimeType"]: diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 926d28296b..671d74cd93 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -4,9 +4,11 @@ import strawberry from strawberry import ID, UNSET +from strawberry.types import Info from phoenix.core.project import Project as CoreProject from phoenix.metrics.retrieval_metrics import RetrievalMetrics +from phoenix.server.api.context import Context from phoenix.server.api.input_types.SpanSort import SpanSort from phoenix.server.api.input_types.TimeRange import TimeRange from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary @@ -43,24 +45,41 @@ def end_time(self) -> Optional[datetime]: @strawberry.field def record_count( self, + info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - if not time_range: - return self.project.span_count() - return self.project.span_count(time_range.start, time_range.end) + if not (traces := info.context.traces): + return 0 + start_time, stop_time = ( + (None, None) if not time_range else (time_range.start, time_range.end) + ) + return traces.span_count(self.name, start_time, stop_time) @strawberry.field def trace_count( self, + info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - if not time_range: - return self.project.trace_count() - return self.project.trace_count(time_range.start, time_range.end) + if not (traces := info.context.traces): + return 0 + start_time, stop_time = ( + (None, None) if not time_range else (time_range.start, time_range.end) + ) + return traces.trace_count(self.name, start_time, stop_time) @strawberry.field - def token_count_total(self) -> int: - return self.project.token_count_total + def token_count_total( + self, + info: Info[Context, None], + time_range: Optional[TimeRange] = UNSET, + ) -> int: + if not (traces := info.context.traces): + return 0 + start_time, stop_time = ( + (None, None) if not time_range else (time_range.start, time_range.end) + ) + return traces.llm_token_count_total(self.name, start_time, stop_time) @strawberry.field def latency_ms_p50(self) -> Optional[float]: diff --git a/src/phoenix/server/api/types/Trace.py b/src/phoenix/server/api/types/Trace.py index 4a0b3331ba..a7d440bc52 100644 --- a/src/phoenix/server/api/types/Trace.py +++ b/src/phoenix/server/api/types/Trace.py @@ -2,8 +2,10 @@ import strawberry from strawberry import ID, UNSET, Private +from strawberry.types import Info from phoenix.core.project import Project +from phoenix.server.api.context import Context from phoenix.server.api.types.Evaluation import TraceEvaluation from phoenix.server.api.types.pagination import ( Connection, @@ -23,6 +25,7 @@ class Trace: @strawberry.field def spans( self, + info: Info[Context, None], first: Optional[int] = 50, last: Optional[int] = UNSET, after: Optional[Cursor] = UNSET, @@ -34,10 +37,9 @@ def spans( last=last, before=before if isinstance(before, Cursor) else None, ) - spans = sorted( - self.project.get_trace(TraceID(self.trace_id)), - key=lambda span: span.start_time, - ) + if not (traces := info.context.traces): + return connection_from_list(data=[], args=args) + spans = traces.get_trace(TraceID(self.trace_id)) data = [to_gql_span(span, self.project) for span in spans] return connection_from_list(data=data, args=args) diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 070a745f46..1b43921264 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -16,11 +16,13 @@ get_env_host, get_env_port, get_pids_path, + get_working_dir, ) from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset from phoenix.datasets.fixtures import FIXTURES, get_datasets +from phoenix.db.database import SqliteDatabase from phoenix.pointcloud.umap_parameters import ( DEFAULT_MIN_DIST, DEFAULT_N_NEIGHBORS, @@ -48,7 +50,7 @@ ██████╔╝███████║██║ ██║█████╗ ██╔██╗ ██║██║ ╚███╔╝ ██╔═══╝ ██╔══██║██║ ██║██╔══╝ ██║╚██╗██║██║ ██╔██╗ ██║ ██║ ██║╚██████╔╝███████╗██║ ╚████║██║██╔╝ ██╗ -╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝╚═╝╚═╝ ╚═╝ v{0} +╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝╚═╝╚═╝ ╚═╝ v{version} | | 🌎 Join our Community 🌎 @@ -61,9 +63,9 @@ | https://docs.arize.com/phoenix | | 🚀 Phoenix Server 🚀 -| Phoenix UI: http://{1}:{2} +| Phoenix UI: http://{host}:{port} | Log traces: /v1/traces over HTTP -| +| Storage location: {working_dir} """ @@ -193,7 +195,9 @@ def _load_items( primary_dataset, reference_dataset, ) - traces = Traces() + working_dir = get_working_dir() + db = SqliteDatabase(working_dir / "phoenix.db") + traces = Traces(db) if span_store := get_span_store(): Thread(target=load_traces_data_from_store, args=(traces, span_store), daemon=True).start() if trace_dataset_name is not None: @@ -246,9 +250,13 @@ def _load_items( # Print information about the server phoenix_version = pkg_resources.get_distribution("arize-phoenix").version - print( - _WELCOME_MESSAGE.format(phoenix_version, host if host != "0.0.0.0" else "localhost", port) - ) + config = { + "version": phoenix_version, + "host": host, + "port": port, + "working_dir": working_dir, + } + print(_WELCOME_MESSAGE.format(**config)) # Start the server server.run() diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index 4a69d6da76..d072ff2bc9 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -37,6 +37,7 @@ from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset +from phoenix.db.database import SqliteDatabase from phoenix.pointcloud.umap_parameters import get_umap_parameters from phoenix.server.app import create_app from phoenix.server.thread_server import ThreadServer @@ -118,27 +119,6 @@ def __init__( self.corpus_dataset = corpus_dataset self.trace_dataset = trace_dataset self.umap_parameters = get_umap_parameters(default_umap_parameters) - self.model = create_model_from_datasets( - primary_dataset, - reference_dataset, - ) - - self.corpus = ( - create_model_from_datasets( - corpus_dataset, - ) - if corpus_dataset is not None - else None - ) - - self.traces = Traces() - if trace_dataset: - for span in trace_dataset.to_spans(): - self.traces.put(span) - for evaluations in trace_dataset.evaluations: - for pb_evaluation in encode_evaluations(evaluations): - self.traces.put(pb_evaluation) - self.host = host or get_env_host() self.port = port or get_env_port() self.temp_dir = TemporaryDirectory() @@ -304,6 +284,24 @@ def __init__( port=port, notebook_env=notebook_env, ) + self.model = create_model_from_datasets( + primary_dataset, + reference_dataset, + ) + self.corpus = ( + create_model_from_datasets( + corpus_dataset, + ) + if corpus_dataset is not None + else None + ) + self.traces = Traces(SqliteDatabase()) + if trace_dataset: + for span in trace_dataset.to_spans(): + self.traces.put(span) + for evaluations in trace_dataset.evaluations: + for pb_evaluation in encode_evaluations(evaluations): + self.traces.put(pb_evaluation) if span_store := get_span_store(): Thread( target=load_traces_data_from_store, diff --git a/src/phoenix/trace/otel.py b/src/phoenix/trace/otel.py index 12ec0b4104..21bb417b4a 100644 --- a/src/phoenix/trace/otel.py +++ b/src/phoenix/trace/otel.py @@ -33,7 +33,6 @@ EXCEPTION_MESSAGE, EXCEPTION_STACKTRACE, EXCEPTION_TYPE, - MimeType, Span, SpanContext, SpanEvent, @@ -61,16 +60,14 @@ def decode(otlp_span: otlp.Span) -> Span: parent_id = _decode_identifier(otlp_span.parent_span_id) start_time = _decode_unix_nano(otlp_span.start_time_unix_nano) - end_time = ( - _decode_unix_nano(otlp_span.end_time_unix_nano) if otlp_span.end_time_unix_nano else None - ) + end_time = _decode_unix_nano(otlp_span.end_time_unix_nano) attributes = dict(_unflatten(_load_json_strings(_decode_key_values(otlp_span.attributes)))) span_kind = SpanKind(attributes.pop(OPENINFERENCE_SPAN_KIND, None)) for mime_type in (INPUT_MIME_TYPE, OUTPUT_MIME_TYPE): if mime_type in attributes: - attributes[mime_type] = MimeType(attributes[mime_type]) + attributes[mime_type] = attributes[mime_type] status_code, status_message = _decode_status(otlp_span.status) events = [_decode_event(event) for event in otlp_span.events] @@ -320,7 +317,7 @@ def encode(span: Span) -> otlp.Span: for mime_type in (INPUT_MIME_TYPE, OUTPUT_MIME_TYPE): if mime_type in attributes: - attributes[mime_type] = attributes[mime_type].value + attributes[mime_type] = attributes[mime_type] for key, value in span.attributes.items(): if value is None: diff --git a/src/phoenix/trace/schemas.py b/src/phoenix/trace/schemas.py index efebeea439..7caeab9bf4 100644 --- a/src/phoenix/trace/schemas.py +++ b/src/phoenix/trace/schemas.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, NamedTuple, Optional, Union from uuid import UUID EXCEPTION_TYPE = "exception.type" @@ -142,7 +142,7 @@ class Span: "If the parent_id is None, this is the root span" parent_id: Optional[SpanID] start_time: datetime - end_time: Optional[datetime] + end_time: datetime status_code: SpanStatusCode status_message: str """ @@ -202,3 +202,11 @@ class ComputedAttributes(Enum): CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION = "cumulative_token_count.completion" ERROR_COUNT = "error_count" CUMULATIVE_ERROR_COUNT = "cumulative_error_count" + + +class ComputedValues(NamedTuple): + latency_ms: float + cumulative_error_count: int + cumulative_llm_token_count_prompt: int + cumulative_llm_token_count_completion: int + cumulative_llm_token_count_total: int From d05454b784b1f34b1fcd741e03ca0076bb0cd76f Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Wed, 3 Apr 2024 14:42:11 -0700 Subject: [PATCH 03/30] refactor: sqlite with SQLAlchemy (#2747) * wip --- pyproject.toml | 1 + src/phoenix/db/database.py | 150 ++++++++++++++++++++++--------------- src/phoenix/db/models.py | 103 +++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 60 deletions(-) create mode 100644 src/phoenix/db/models.py diff --git a/pyproject.toml b/pyproject.toml index cab2a0a29c..513b7b4438 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "openinference-instrumentation-langchain>=0.1.12", "openinference-instrumentation-llama-index>=1.2.0", "openinference-instrumentation-openai>=0.1.4", + "sqlalchemy>=2, <3", ] dynamic = ["version"] diff --git a/src/phoenix/db/database.py b/src/phoenix/db/database.py index 3e3486cca8..f97ca07b69 100644 --- a/src/phoenix/db/database.py +++ b/src/phoenix/db/database.py @@ -3,12 +3,23 @@ from datetime import datetime from enum import Enum from pathlib import Path +from sqlite3 import Connection from typing import Any, Iterator, Optional, Tuple, cast import numpy as np from openinference.semconv.trace import SpanAttributes +from sqlalchemy import Engine, create_engine, event +from sqlalchemy.orm import sessionmaker -from phoenix.trace.schemas import ComputedValues, Span, SpanContext, SpanKind, SpanStatusCode +from phoenix.db.models import Base, Trace +from phoenix.trace.schemas import ( + ComputedValues, + Span, + SpanContext, + SpanEvent, + SpanKind, + SpanStatusCode, +) _CONFIG = """ PRAGMA foreign_keys = ON; @@ -21,7 +32,7 @@ _INIT_DB = """ BEGIN; CREATE TABLE projects ( - rowid INTEGER PRIMARY KEY, + id INTEGER PRIMARY KEY, name TEXT NOT NULL UNIQUE, description TEXT, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -29,17 +40,17 @@ ); INSERT INTO projects(name) VALUES('default'); CREATE TABLE traces ( - rowid INTEGER PRIMARY KEY, + id INTEGER PRIMARY KEY, trace_id TEXT UNIQUE NOT NULL, project_rowid INTEGER NOT NULL, session_id TEXT, start_time TIMESTAMP NOT NULL, end_time TIMESTAMP NOT NULL, - FOREIGN KEY(project_rowid) REFERENCES projects(rowid) + FOREIGN KEY(project_rowid) REFERENCES projects(id) ); CREATE INDEX idx_trace_start_time ON traces(start_time); CREATE TABLE spans ( - rowid INTEGER PRIMARY KEY, + id INTEGER PRIMARY KEY, span_id TEXT UNIQUE NOT NULL, trace_rowid INTEGER NOT NULL, parent_span_id TEXT, @@ -51,10 +62,11 @@ events JSON, status TEXT CHECK(status IN ('UNSET','OK','ERROR')) NOT NULL DEFAULT('UNSET'), status_message TEXT, + latency_ms REAL, cumulative_error_count INTEGER NOT NULL DEFAULT 0, cumulative_llm_token_count_prompt INTEGER NOT NULL DEFAULT 0, cumulative_llm_token_count_completion INTEGER NOT NULL DEFAULT 0, - FOREIGN KEY(trace_rowid) REFERENCES traces(rowid) + FOREIGN KEY(trace_rowid) REFERENCES traces(id) ); CREATE INDEX idx_parent_span_id ON spans(parent_span_id); PRAGMA user_version = 1; @@ -62,13 +74,31 @@ """ +_MEM_DB_STR = "file::memory:?cache=shared" + + +def _mem_db_creator() -> Any: + return sqlite3.connect(_MEM_DB_STR, uri=True) + + +@event.listens_for(Engine, "connect") +def set_sqlite_pragma(dbapi_connection: Connection, _: Any) -> None: + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys = ON;") + cursor.execute("PRAGMA journal_mode = WAL;") + cursor.execute("PRAGMA synchronous = OFF;") + cursor.execute("PRAGMA cache_size = -32000;") + cursor.execute("PRAGMA busy_timeout = 10000;") + cursor.close() + + class SqliteDatabase: - def __init__(self, database: Optional[Path] = None) -> None: + def __init__(self, db_path: Optional[Path] = None) -> None: """ - :param database: The path to the database file to be opened. + :param db_path: The path to the database file to be opened. """ self.con = sqlite3.connect( - database=database or ":memory:", + database=db_path or _MEM_DB_STR, uri=True, check_same_thread=False, ) @@ -78,6 +108,18 @@ def __init__(self, database: Optional[Path] = None) -> None: if int(cur.execute("PRAGMA user_version;").fetchone()[0]) < 1: cur.executescript(_INIT_DB) + engine = ( + create_engine(f"sqlite:///{db_path}", echo=True) + if db_path + else create_engine( + "sqlite:///:memory:", + echo=True, + creator=_mem_db_creator, + ) + ) + Base.metadata.create_all(engine) + self.Session = sessionmaker(bind=engine) + def insert_span(self, span: Span, project_name: str) -> None: cur = self.con.cursor() cur.execute("BEGIN;") @@ -138,10 +180,11 @@ def insert_span(self, span: Span, project_name: str) -> None: cumulative_error_count += cast(int, accumulation[0] or 0) cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0) cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0) + latency_ms = (span.end_time - span.start_time).total_seconds() * 1000 cur.execute( """ - INSERT INTO spans(span_id, trace_rowid, parent_span_id, kind, name, start_time, end_time, attributes, events, status, status_message, cumulative_error_count, cumulative_llm_token_count_prompt, cumulative_llm_token_count_completion) - VALUES(?,?,?,?,?,?,?,?,?,?,?,?,?,?) + INSERT INTO spans(span_id, trace_rowid, parent_span_id, kind, name, start_time, end_time, attributes, events, status, status_message, latency_ms, cumulative_error_count, cumulative_llm_token_count_prompt, cumulative_llm_token_count_completion) + VALUES(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) RETURNING rowid; """, # noqa E501 ( @@ -156,6 +199,7 @@ def insert_span(self, span: Span, project_name: str) -> None: json.dumps(span.events, cls=_Encoder), span.status_code.value, span.status_message, + latency_ms, cumulative_error_count, cumulative_llm_token_count_prompt, cumulative_llm_token_count_completion, @@ -288,55 +332,41 @@ def llm_token_count_total( return 0 def get_trace(self, trace_id: str) -> Iterator[Tuple[Span, ComputedValues]]: - cur = self.con.cursor() - for span in cur.execute( - """ - SELECT - spans.span_id, - traces.trace_id, - spans.parent_span_id, - spans.kind, - spans.name, - spans.start_time, - spans.end_time, - spans.attributes, - spans.events, - spans.status, - spans.status_message, - spans.cumulative_error_count, - spans.cumulative_llm_token_count_prompt, - spans.cumulative_llm_token_count_completion - FROM spans - JOIN traces ON traces.rowid = spans.trace_rowid - WHERE traces.trace_id = ? - """, - (trace_id,), - ).fetchall(): - start_time = datetime.fromisoformat(span[5]) - end_time = datetime.fromisoformat(span[6]) - latency_ms = (end_time - start_time).total_seconds() * 1000 - yield ( - Span( - name=span[4], - context=SpanContext(trace_id=span[1], span_id=span[0]), - parent_id=span[2], - span_kind=SpanKind(span[3]), - start_time=start_time, - end_time=end_time, - attributes=json.loads(span[7]), - events=json.loads(span[8]), - status_code=SpanStatusCode(span[9]), - status_message=span[10], - conversation=None, - ), - ComputedValues( - latency_ms=latency_ms, - cumulative_error_count=span[11], - cumulative_llm_token_count_prompt=span[12], - cumulative_llm_token_count_completion=span[13], - cumulative_llm_token_count_total=span[12] + span[13], - ), - ) + with self.Session.begin() as session: + trace = session.query(Trace).where(Trace.trace_id == trace_id).one_or_none() + if not trace: + return + for span in trace.spans: + yield ( + Span( + name=span.name, + context=SpanContext(trace_id=span.trace.trace_id, span_id=span.span_id), + parent_id=span.parent_span_id, + span_kind=SpanKind(span.kind), + start_time=span.start_time, + end_time=span.end_time, + attributes=span.attributes, + events=[ + SpanEvent( + name=obj["name"], + attributes=obj["attributes"], + timestamp=obj["timestamp"], + ) + for obj in span.events + ], + status_code=SpanStatusCode(span.status), + status_message=span.status_message, + conversation=None, + ), + ComputedValues( + latency_ms=span.latency_ms, + cumulative_error_count=span.cumulative_error_count, + cumulative_llm_token_count_prompt=span.cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion=span.cumulative_llm_token_count_completion, + cumulative_llm_token_count_total=span.cumulative_llm_token_count_prompt + + span.cumulative_llm_token_count_completion, + ), + ) class _Encoder(json.JSONEncoder): diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py new file mode 100644 index 0000000000..62bc032347 --- /dev/null +++ b/src/phoenix/db/models.py @@ -0,0 +1,103 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from sqlalchemy import ( + JSON, + DateTime, + ForeignKey, + UniqueConstraint, +) +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + WriteOnlyMapped, + mapped_column, + relationship, +) + + +class Base(DeclarativeBase): + type_annotation_map = { + Dict[str, Any]: JSON, + List[Dict[str, Any]]: JSON, + } + + +class Project(Base): + __tablename__ = "projects" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + description: Mapped[Optional[str]] + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + + traces: WriteOnlyMapped["Trace"] = relationship( + "Trace", + back_populates="project", + cascade="all, delete-orphan", + ) + __table_args__ = ( + UniqueConstraint( + "name", + name="project_name_unique", + sqlite_on_conflict="IGNORE", + ), + ) + + +class Trace(Base): + __tablename__ = "traces" + id: Mapped[int] = mapped_column(primary_key=True) + project_rowid: Mapped[int] = mapped_column(ForeignKey("projects.id")) + session_id: Mapped[Optional[str]] + trace_id: Mapped[str] + start_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + end_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + + project: Mapped["Project"] = relationship( + "Project", + back_populates="traces", + ) + spans: Mapped[List["Span"]] = relationship( + "Span", + back_populates="trace", + cascade="all, delete-orphan", + ) + __table_args__ = ( + UniqueConstraint( + "trace_id", + name="trace_id_unique", + sqlite_on_conflict="IGNORE", + ), + ) + + +class Span(Base): + __tablename__ = "spans" + id: Mapped[int] = mapped_column(primary_key=True) + trace_rowid: Mapped[int] = mapped_column(ForeignKey("traces.id")) + span_id: Mapped[str] + parent_span_id: Mapped[Optional[str]] + name: Mapped[str] + kind: Mapped[str] + start_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + end_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + attributes: Mapped[Dict[str, Any]] + events: Mapped[List[Dict[str, Any]]] + status: Mapped[str] + status_message: Mapped[str] + + latency_ms: Mapped[float] + cumulative_error_count: Mapped[int] + cumulative_llm_token_count_prompt: Mapped[int] + cumulative_llm_token_count_completion: Mapped[int] + + trace: Mapped["Trace"] = relationship("Trace", back_populates="spans") + + __table_args__ = ( + UniqueConstraint( + "span_id", + name="trace_id_unique", + sqlite_on_conflict="IGNORE", + ), + ) From 8f79e290eb79e73101a05175f5de4b8f8e747c6f Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Wed, 3 Apr 2024 19:37:00 -0600 Subject: [PATCH 04/30] initialize alembic refactor --- pyproject.toml | 1 + src/phoenix/db/alembic.ini | 116 +++++++++++++++++++++++ src/phoenix/db/migrations/README | 1 + src/phoenix/db/migrations/env.py | 74 +++++++++++++++ src/phoenix/db/migrations/script.py.mako | 26 +++++ 5 files changed, 218 insertions(+) create mode 100644 src/phoenix/db/alembic.ini create mode 100644 src/phoenix/db/migrations/README create mode 100644 src/phoenix/db/migrations/env.py create mode 100644 src/phoenix/db/migrations/script.py.mako diff --git a/pyproject.toml b/pyproject.toml index 513b7b4438..3c7b1cea88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "openinference-instrumentation-llama-index>=1.2.0", "openinference-instrumentation-openai>=0.1.4", "sqlalchemy>=2, <3", + "alembic>=1.3.0, <2", ] dynamic = ["version"] diff --git a/src/phoenix/db/alembic.ini b/src/phoenix/db/alembic.ini new file mode 100644 index 0000000000..cd413a2606 --- /dev/null +++ b/src/phoenix/db/alembic.ini @@ -0,0 +1,116 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = migrations + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/src/phoenix/db/migrations/README b/src/phoenix/db/migrations/README new file mode 100644 index 0000000000..98e4f9c44e --- /dev/null +++ b/src/phoenix/db/migrations/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/src/phoenix/db/migrations/env.py b/src/phoenix/db/migrations/env.py new file mode 100644 index 0000000000..9dd6c6ceca --- /dev/null +++ b/src/phoenix/db/migrations/env.py @@ -0,0 +1,74 @@ +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import engine_from_config, pool + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = None + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/src/phoenix/db/migrations/script.py.mako b/src/phoenix/db/migrations/script.py.mako new file mode 100644 index 0000000000..fbc4b07dce --- /dev/null +++ b/src/phoenix/db/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} From 4272f5cd13c14e4ba276a7dbe71a8072ca66876b Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Wed, 3 Apr 2024 19:48:39 -0600 Subject: [PATCH 05/30] make alembic path to DB dynamic --- src/phoenix/db/alembic.ini | 4 +++- src/phoenix/db/migrations/env.py | 14 +++++------ .../migrations/versions/cf03bd6bae1d_init.py | 23 +++++++++++++++++++ 3 files changed, 33 insertions(+), 8 deletions(-) create mode 100644 src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py diff --git a/src/phoenix/db/alembic.ini b/src/phoenix/db/alembic.ini index cd413a2606..5aeabbcc9d 100644 --- a/src/phoenix/db/alembic.ini +++ b/src/phoenix/db/alembic.ini @@ -60,7 +60,9 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne # are written from script.py.mako # output_encoding = utf-8 -sqlalchemy.url = driver://user:pass@localhost/dbname +# NB: This is commented out intentionally as it is dynamic +# See migrations/env.py +# sqlalchemy.url = driver://user:pass@localhost/dbname [post_write_hooks] diff --git a/src/phoenix/db/migrations/env.py b/src/phoenix/db/migrations/env.py index 9dd6c6ceca..ad1f2e3887 100644 --- a/src/phoenix/db/migrations/env.py +++ b/src/phoenix/db/migrations/env.py @@ -1,7 +1,8 @@ from logging.config import fileConfig from alembic import context -from sqlalchemy import engine_from_config, pool +from phoenix.config import get_working_dir +from sqlalchemy import create_engine # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -23,6 +24,10 @@ # my_important_option = config.get_main_option("my_important_option") # ... etc. +# Since the working directory is dynamic, we need to get it from the config +working_dir = get_working_dir() +url = f"sqlite:///{working_dir}/phoenix.db" + def run_migrations_offline() -> None: """Run migrations in 'offline' mode. @@ -36,7 +41,6 @@ def run_migrations_offline() -> None: script output. """ - url = config.get_main_option("sqlalchemy.url") context.configure( url=url, target_metadata=target_metadata, @@ -55,11 +59,7 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - connectable = engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) + connectable = create_engine(url) with connectable.connect() as connection: context.configure(connection=connection, target_metadata=target_metadata) diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py new file mode 100644 index 0000000000..6c7f693dc2 --- /dev/null +++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py @@ -0,0 +1,23 @@ +"""init + +Revision ID: cf03bd6bae1d +Revises: +Create Date: 2024-04-03 19:41:48.871555 + +""" + +from typing import Sequence, Union + +# revision identifiers, used by Alembic. +revision: str = "cf03bd6bae1d" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass From 2dbe17fb6b249632a4d6d8853baf4a72b98459d6 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Wed, 3 Apr 2024 22:15:33 -0600 Subject: [PATCH 06/30] move migrations over --- cspell.json | 1 + src/phoenix/db/__init__.py | 3 + src/phoenix/db/alembic.ini | 4 +- src/phoenix/db/database.py | 2 - src/phoenix/db/migrate.py | 22 +++++++ .../migrations/versions/cf03bd6bae1d_init.py | 59 ++++++++++++++++++- src/phoenix/db/models.py | 2 + src/phoenix/server/main.py | 4 ++ 8 files changed, 91 insertions(+), 6 deletions(-) create mode 100644 src/phoenix/db/migrate.py diff --git a/cspell.json b/cspell.json index af4d274d9a..67bedffd61 100644 --- a/cspell.json +++ b/cspell.json @@ -37,6 +37,7 @@ "respx", "rgba", "seafoam", + "sqlalchemy", "templating", "tensorboard", "testset", diff --git a/src/phoenix/db/__init__.py b/src/phoenix/db/__init__.py index e69de29bb2..2848bf384b 100644 --- a/src/phoenix/db/__init__.py +++ b/src/phoenix/db/__init__.py @@ -0,0 +1,3 @@ +from .migrate import migrate + +__all__ = ["migrate"] diff --git a/src/phoenix/db/alembic.ini b/src/phoenix/db/alembic.ini index 5aeabbcc9d..70e3c84046 100644 --- a/src/phoenix/db/alembic.ini +++ b/src/phoenix/db/alembic.ini @@ -93,12 +93,12 @@ keys = console keys = generic [logger_root] -level = WARN +level = DEBUG handlers = console qualname = [logger_sqlalchemy] -level = WARN +level = DEBUG handlers = qualname = sqlalchemy.engine diff --git a/src/phoenix/db/database.py b/src/phoenix/db/database.py index f97ca07b69..2269a31538 100644 --- a/src/phoenix/db/database.py +++ b/src/phoenix/db/database.py @@ -105,8 +105,6 @@ def __init__(self, db_path: Optional[Path] = None) -> None: # self.con.set_trace_callback(print) cur = self.con.cursor() cur.executescript(_CONFIG) - if int(cur.execute("PRAGMA user_version;").fetchone()[0]) < 1: - cur.executescript(_INIT_DB) engine = ( create_engine(f"sqlite:///{db_path}", echo=True) diff --git a/src/phoenix/db/migrate.py b/src/phoenix/db/migrate.py new file mode 100644 index 0000000000..9f189f8d75 --- /dev/null +++ b/src/phoenix/db/migrate.py @@ -0,0 +1,22 @@ +import logging +import os +from pathlib import Path + +import alembic.config + +logger = logging.getLogger(__name__) + + +def migrate() -> None: + """ + Runs migrations on the database. + """ + config_path = os.path.normpath(str(Path(__file__).parent.resolve()) + os.sep + "alembic.ini") + alembicArgs = [ + "--config", + config_path, + "--raiseerr", + "upgrade", + "head", + ] + alembic.config.main(argv=alembicArgs) diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py index 6c7f693dc2..26163ae0c6 100644 --- a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py +++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py @@ -8,6 +8,9 @@ from typing import Sequence, Union +import sqlalchemy as sa +from alembic import op + # revision identifiers, used by Alembic. revision: str = "cf03bd6bae1d" down_revision: Union[str, None] = None @@ -16,8 +19,60 @@ def upgrade() -> None: - pass + projects_table = op.create_table( + "projects", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("name", sa.String, nullable=False), + sa.Column("description", sa.String, nullable=True), + # TODO(mikeldking): is timezone=True necessary? + sa.Column( + "created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now() + ), + sa.Column( + "updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now() + ), + ) + op.create_table( + "traces", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("project_rowid", sa.Integer, sa.ForeignKey("projects.id"), nullable=False), + sa.Column("session_id", sa.String, nullable=True), + sa.Column("trace_id", sa.String, nullable=False), + ) + + op.create_table( + "spans", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("trace_rowid", sa.Integer, sa.ForeignKey("traces.id"), nullable=False), + sa.Column("span_id", sa.String, nullable=False), + sa.Column("parent_span_id", sa.String, nullable=True), + sa.Column("name", sa.String, nullable=False), + sa.Column("kind", sa.String, nullable=False), + sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), + sa.Column("end_time", sa.DateTime(timezone=True), nullable=False), + sa.Column("attributes", sa.JSON, nullable=False), + sa.Column("events", sa.JSON, nullable=False), + sa.Column("status", sa.String, nullable=False), + sa.Column("latency_ms", sa.Float, nullable=False), + sa.Column("cumulative_error_count", sa.Integer, nullable=False), + sa.Column("cumulative_llm_token_count_prompt", sa.Integer, nullable=False), + sa.Column("cumulative_llm_token_count_completion", sa.Integer, nullable=False), + ) + + op.create_index("idx_trace_start_time", "traces", ["start_time"]) + op.create_index("idx_parent_span_id", "spans", ["parent_span_id"]) + op.bulk_insert( + projects_table, + [ + {"name": "default", "description": "Default project"}, + ], + ) def downgrade() -> None: - pass + op.drop_index("idx_trace_start_time") + op.drop_index("idx_parent_span_id", "spans") + + op.drop_table("projects") + op.drop_table("traces") + op.drop_table("spans") diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 62bc032347..44b98f3b08 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -51,6 +51,7 @@ class Trace(Base): project_rowid: Mapped[int] = mapped_column(ForeignKey("projects.id")) session_id: Mapped[Optional[str]] trace_id: Mapped[str] + # TODO(mikeldking): why is the start and end time necessary? just filtering? start_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) end_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) @@ -87,6 +88,7 @@ class Span(Base): status: Mapped[str] status_message: Mapped[str] + # TODO(mikeldking): is computed columns possible here latency_ms: Mapped[float] cumulative_error_count: Mapped[int] cumulative_llm_token_count_prompt: Mapped[int] diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 1b43921264..b5add55a0a 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -22,6 +22,7 @@ from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset from phoenix.datasets.fixtures import FIXTURES, get_datasets +from phoenix.db import migrate from phoenix.db.database import SqliteDatabase from phoenix.pointcloud.umap_parameters import ( DEFAULT_MIN_DIST, @@ -197,6 +198,9 @@ def _load_items( ) working_dir = get_working_dir() db = SqliteDatabase(working_dir / "phoenix.db") + # Run migrations + migrate() + traces = Traces(db) if span_store := get_span_store(): Thread(target=load_traces_data_from_store, args=(traces, span_store), daemon=True).start() From 595400f981679c46edb2a1d08e4da865e9b3bdcc Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Wed, 3 Apr 2024 22:50:15 -0600 Subject: [PATCH 07/30] remove init_db --- src/phoenix/db/database.py | 45 -------------------------------------- 1 file changed, 45 deletions(-) diff --git a/src/phoenix/db/database.py b/src/phoenix/db/database.py index 2269a31538..89f9a1a97b 100644 --- a/src/phoenix/db/database.py +++ b/src/phoenix/db/database.py @@ -29,51 +29,6 @@ PRAGMA busy_timeout = 10000; """ -_INIT_DB = """ -BEGIN; -CREATE TABLE projects ( - id INTEGER PRIMARY KEY, - name TEXT NOT NULL UNIQUE, - description TEXT, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -); -INSERT INTO projects(name) VALUES('default'); -CREATE TABLE traces ( - id INTEGER PRIMARY KEY, - trace_id TEXT UNIQUE NOT NULL, - project_rowid INTEGER NOT NULL, - session_id TEXT, - start_time TIMESTAMP NOT NULL, - end_time TIMESTAMP NOT NULL, - FOREIGN KEY(project_rowid) REFERENCES projects(id) -); -CREATE INDEX idx_trace_start_time ON traces(start_time); -CREATE TABLE spans ( - id INTEGER PRIMARY KEY, - span_id TEXT UNIQUE NOT NULL, - trace_rowid INTEGER NOT NULL, - parent_span_id TEXT, - kind TEXT NOT NULL, - name TEXT NOT NULL, - start_time TIMESTAMP NOT NULL, - end_time TIMESTAMP NOT NULL, - attributes JSON, - events JSON, - status TEXT CHECK(status IN ('UNSET','OK','ERROR')) NOT NULL DEFAULT('UNSET'), - status_message TEXT, - latency_ms REAL, - cumulative_error_count INTEGER NOT NULL DEFAULT 0, - cumulative_llm_token_count_prompt INTEGER NOT NULL DEFAULT 0, - cumulative_llm_token_count_completion INTEGER NOT NULL DEFAULT 0, - FOREIGN KEY(trace_rowid) REFERENCES traces(id) -); -CREATE INDEX idx_parent_span_id ON spans(parent_span_id); -PRAGMA user_version = 1; -COMMIT; -""" - - _MEM_DB_STR = "file::memory:?cache=shared" From 3ef525de76119c03e929a4f35b09e54b29d4397a Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Thu, 4 Apr 2024 00:47:09 -0600 Subject: [PATCH 08/30] finish out the migration --- src/phoenix/db/alembic.ini | 1 + src/phoenix/db/database.py | 56 ++++++++++++++++++- src/phoenix/db/migrate.py | 20 ++++--- src/phoenix/db/migrations/env.py | 7 ++- .../migrations/versions/cf03bd6bae1d_init.py | 18 +++++- src/phoenix/server/main.py | 3 - 6 files changed, 86 insertions(+), 19 deletions(-) diff --git a/src/phoenix/db/alembic.ini b/src/phoenix/db/alembic.ini index 70e3c84046..f584199f26 100644 --- a/src/phoenix/db/alembic.ini +++ b/src/phoenix/db/alembic.ini @@ -2,6 +2,7 @@ [alembic] # path to migration scripts +# Note this is overridden in migrations/env.py for programmatic use script_location = migrations # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s diff --git a/src/phoenix/db/database.py b/src/phoenix/db/database.py index 89f9a1a97b..49291090af 100644 --- a/src/phoenix/db/database.py +++ b/src/phoenix/db/database.py @@ -21,6 +21,8 @@ SpanStatusCode, ) +from .migrate import migrate + _CONFIG = """ PRAGMA foreign_keys = ON; PRAGMA journal_mode = WAL; @@ -29,6 +31,51 @@ PRAGMA busy_timeout = 10000; """ +_INIT_DB = """ +BEGIN; +CREATE TABLE projects ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + description TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +INSERT INTO projects(name) VALUES('default'); +CREATE TABLE traces ( + id INTEGER PRIMARY KEY, + trace_id TEXT UNIQUE NOT NULL, + project_rowid INTEGER NOT NULL, + session_id TEXT, + start_time TIMESTAMP NOT NULL, + end_time TIMESTAMP NOT NULL, + FOREIGN KEY(project_rowid) REFERENCES projects(id) +); +CREATE INDEX idx_trace_start_time ON traces(start_time); +CREATE TABLE spans ( + id INTEGER PRIMARY KEY, + span_id TEXT UNIQUE NOT NULL, + trace_rowid INTEGER NOT NULL, + parent_span_id TEXT, + kind TEXT NOT NULL, + name TEXT NOT NULL, + start_time TIMESTAMP NOT NULL, + end_time TIMESTAMP NOT NULL, + attributes JSON, + events JSON, + status TEXT CHECK(status IN ('UNSET','OK','ERROR')) NOT NULL DEFAULT('UNSET'), + status_message TEXT, + latency_ms REAL, + cumulative_error_count INTEGER NOT NULL DEFAULT 0, + cumulative_llm_token_count_prompt INTEGER NOT NULL DEFAULT 0, + cumulative_llm_token_count_completion INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY(trace_rowid) REFERENCES traces(id) +); +CREATE INDEX idx_parent_span_id ON spans(parent_span_id); +PRAGMA user_version = 1; +COMMIT; +""" + + _MEM_DB_STR = "file::memory:?cache=shared" @@ -70,7 +117,14 @@ def __init__(self, db_path: Optional[Path] = None) -> None: creator=_mem_db_creator, ) ) - Base.metadata.create_all(engine) + + # TODO this should be moved out + # Migrate the database if a path is provided + if db_path: + migrate() + else: + Base.metadata.create_all(engine) + self.Session = sessionmaker(bind=engine) def insert_span(self, span: Span, project_name: str) -> None: diff --git a/src/phoenix/db/migrate.py b/src/phoenix/db/migrate.py index 9f189f8d75..d9a25fc10d 100644 --- a/src/phoenix/db/migrate.py +++ b/src/phoenix/db/migrate.py @@ -2,7 +2,8 @@ import os from pathlib import Path -import alembic.config +from alembic import command +from alembic.config import Config logger = logging.getLogger(__name__) @@ -12,11 +13,12 @@ def migrate() -> None: Runs migrations on the database. """ config_path = os.path.normpath(str(Path(__file__).parent.resolve()) + os.sep + "alembic.ini") - alembicArgs = [ - "--config", - config_path, - "--raiseerr", - "upgrade", - "head", - ] - alembic.config.main(argv=alembicArgs) + alembic_cfg = Config(config_path) + + # Explicitly set the migration directory + scripts_location = os.path.normpath( + str(Path(__file__).parent.resolve()) + os.sep + "migrations" + ) + alembic_cfg.set_main_option("script_location", scripts_location) + + command.upgrade(alembic_cfg, "head") diff --git a/src/phoenix/db/migrations/env.py b/src/phoenix/db/migrations/env.py index ad1f2e3887..3b929bdcd8 100644 --- a/src/phoenix/db/migrations/env.py +++ b/src/phoenix/db/migrations/env.py @@ -2,6 +2,7 @@ from alembic import context from phoenix.config import get_working_dir +from phoenix.db.models import Base from sqlalchemy import create_engine # this is the Alembic Config object, which provides @@ -15,9 +16,8 @@ # add your model's MetaData object here # for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata -target_metadata = None + +target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, # can be acquired: @@ -59,6 +59,7 @@ def run_migrations_online() -> None: and associate a connection with the context. """ + print(config) connectable = create_engine(url) with connectable.connect() as connection: diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py index 26163ae0c6..cf3a7740d7 100644 --- a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py +++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py @@ -36,8 +36,11 @@ def upgrade() -> None: "traces", sa.Column("id", sa.Integer, primary_key=True), sa.Column("project_rowid", sa.Integer, sa.ForeignKey("projects.id"), nullable=False), + # TODO(mikeldking): might not be the right place for this sa.Column("session_id", sa.String, nullable=True), - sa.Column("trace_id", sa.String, nullable=False), + sa.Column("trace_id", sa.String, nullable=False, unique=True), + sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), + sa.Column("end_time", sa.DateTime(timezone=True), nullable=False), ) op.create_table( @@ -52,8 +55,17 @@ def upgrade() -> None: sa.Column("end_time", sa.DateTime(timezone=True), nullable=False), sa.Column("attributes", sa.JSON, nullable=False), sa.Column("events", sa.JSON, nullable=False), - sa.Column("status", sa.String, nullable=False), - sa.Column("latency_ms", sa.Float, nullable=False), + sa.Column( + "status", + sa.String, + # TODO(mikeldking): this doesn't seem to work... + sa.CheckConstraint("status IN ('OK', 'ERROR', 'UNSET')"), + nullable=False, + default="UNSET", + server_default="UNSET", + ), + sa.Column("status_message", sa.String, nullable=False), + sa.Column("latency_ms", sa.REAL, nullable=False), sa.Column("cumulative_error_count", sa.Integer, nullable=False), sa.Column("cumulative_llm_token_count_prompt", sa.Integer, nullable=False), sa.Column("cumulative_llm_token_count_completion", sa.Integer, nullable=False), diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index b5add55a0a..ae78c49e15 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -22,7 +22,6 @@ from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset from phoenix.datasets.fixtures import FIXTURES, get_datasets -from phoenix.db import migrate from phoenix.db.database import SqliteDatabase from phoenix.pointcloud.umap_parameters import ( DEFAULT_MIN_DIST, @@ -198,8 +197,6 @@ def _load_items( ) working_dir = get_working_dir() db = SqliteDatabase(working_dir / "phoenix.db") - # Run migrations - migrate() traces = Traces(db) if span_store := get_span_store(): From 8c10e4677ad244a75a8e266631584e6076133e14 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Thu, 4 Apr 2024 00:49:31 -0600 Subject: [PATCH 09/30] WIP' --- src/phoenix/db/alembic.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/phoenix/db/alembic.ini b/src/phoenix/db/alembic.ini index f584199f26..07950f4c0a 100644 --- a/src/phoenix/db/alembic.ini +++ b/src/phoenix/db/alembic.ini @@ -2,7 +2,7 @@ [alembic] # path to migration scripts -# Note this is overridden in migrations/env.py for programmatic use +# Note this is overridden in .migrate during programatic migrations script_location = migrations # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s From ca5a23d904a771210e5e586f94a0a0d765f8e690 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Thu, 4 Apr 2024 02:15:46 -0600 Subject: [PATCH 10/30] make it work --- src/phoenix/db/alembic.ini | 6 +- src/phoenix/db/database.py | 60 ++++--------------- src/phoenix/db/migrate.py | 8 ++- src/phoenix/db/migrations/env.py | 15 +++-- .../migrations/versions/cf03bd6bae1d_init.py | 6 +- src/phoenix/db/models.py | 7 ++- 6 files changed, 37 insertions(+), 65 deletions(-) diff --git a/src/phoenix/db/alembic.ini b/src/phoenix/db/alembic.ini index 07950f4c0a..27aff4ddf9 100644 --- a/src/phoenix/db/alembic.ini +++ b/src/phoenix/db/alembic.ini @@ -94,17 +94,17 @@ keys = console keys = generic [logger_root] -level = DEBUG +level = WARN handlers = console qualname = [logger_sqlalchemy] -level = DEBUG +level = WARN handlers = qualname = sqlalchemy.engine [logger_alembic] -level = INFO +level = DEBUG handlers = qualname = alembic diff --git a/src/phoenix/db/database.py b/src/phoenix/db/database.py index 49291090af..86d9f761dc 100644 --- a/src/phoenix/db/database.py +++ b/src/phoenix/db/database.py @@ -8,10 +8,10 @@ import numpy as np from openinference.semconv.trace import SpanAttributes -from sqlalchemy import Engine, create_engine, event +from sqlalchemy import Engine, create_engine, event, insert from sqlalchemy.orm import sessionmaker -from phoenix.db.models import Base, Trace +from phoenix.db.models import Base, Project, Trace from phoenix.trace.schemas import ( ComputedValues, Span, @@ -31,50 +31,6 @@ PRAGMA busy_timeout = 10000; """ -_INIT_DB = """ -BEGIN; -CREATE TABLE projects ( - id INTEGER PRIMARY KEY, - name TEXT NOT NULL UNIQUE, - description TEXT, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -); -INSERT INTO projects(name) VALUES('default'); -CREATE TABLE traces ( - id INTEGER PRIMARY KEY, - trace_id TEXT UNIQUE NOT NULL, - project_rowid INTEGER NOT NULL, - session_id TEXT, - start_time TIMESTAMP NOT NULL, - end_time TIMESTAMP NOT NULL, - FOREIGN KEY(project_rowid) REFERENCES projects(id) -); -CREATE INDEX idx_trace_start_time ON traces(start_time); -CREATE TABLE spans ( - id INTEGER PRIMARY KEY, - span_id TEXT UNIQUE NOT NULL, - trace_rowid INTEGER NOT NULL, - parent_span_id TEXT, - kind TEXT NOT NULL, - name TEXT NOT NULL, - start_time TIMESTAMP NOT NULL, - end_time TIMESTAMP NOT NULL, - attributes JSON, - events JSON, - status TEXT CHECK(status IN ('UNSET','OK','ERROR')) NOT NULL DEFAULT('UNSET'), - status_message TEXT, - latency_ms REAL, - cumulative_error_count INTEGER NOT NULL DEFAULT 0, - cumulative_llm_token_count_prompt INTEGER NOT NULL DEFAULT 0, - cumulative_llm_token_count_completion INTEGER NOT NULL DEFAULT 0, - FOREIGN KEY(trace_rowid) REFERENCES traces(id) -); -CREATE INDEX idx_parent_span_id ON spans(parent_span_id); -PRAGMA user_version = 1; -COMMIT; -""" - _MEM_DB_STR = "file::memory:?cache=shared" @@ -108,22 +64,26 @@ def __init__(self, db_path: Optional[Path] = None) -> None: cur = self.con.cursor() cur.executescript(_CONFIG) + db_url = f"sqlite:///{db_path}" if db_path else "sqlite:///:memory:" engine = ( - create_engine(f"sqlite:///{db_path}", echo=True) + create_engine(db_url, echo=True) if db_path else create_engine( - "sqlite:///:memory:", + db_url, echo=True, creator=_mem_db_creator, ) ) # TODO this should be moved out - # Migrate the database if a path is provided if db_path: - migrate() + migrate(db_url) else: Base.metadata.create_all(engine) + # Create the default project + with engine.connect() as conn: + conn.execute(insert(Project).values(name="default", description="default project")) + conn.commit() self.Session = sessionmaker(bind=engine) diff --git a/src/phoenix/db/migrate.py b/src/phoenix/db/migrate.py index d9a25fc10d..4a1bf52450 100644 --- a/src/phoenix/db/migrate.py +++ b/src/phoenix/db/migrate.py @@ -8,10 +8,15 @@ logger = logging.getLogger(__name__) -def migrate() -> None: +def migrate(url: str) -> None: """ Runs migrations on the database. + NB: Migrate only works on non-memory databases. + + Args: + url: The database URL. """ + logger.warning("Running migrations on the database") config_path = os.path.normpath(str(Path(__file__).parent.resolve()) + os.sep + "alembic.ini") alembic_cfg = Config(config_path) @@ -20,5 +25,6 @@ def migrate() -> None: str(Path(__file__).parent.resolve()) + os.sep + "migrations" ) alembic_cfg.set_main_option("script_location", scripts_location) + alembic_cfg.set_main_option("sqlalchemy.url", url) command.upgrade(alembic_cfg, "head") diff --git a/src/phoenix/db/migrations/env.py b/src/phoenix/db/migrations/env.py index 3b929bdcd8..6e3c4d0d3c 100644 --- a/src/phoenix/db/migrations/env.py +++ b/src/phoenix/db/migrations/env.py @@ -1,9 +1,8 @@ from logging.config import fileConfig from alembic import context -from phoenix.config import get_working_dir from phoenix.db.models import Base -from sqlalchemy import create_engine +from sqlalchemy import engine_from_config, pool # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -24,10 +23,6 @@ # my_important_option = config.get_main_option("my_important_option") # ... etc. -# Since the working directory is dynamic, we need to get it from the config -working_dir = get_working_dir() -url = f"sqlite:///{working_dir}/phoenix.db" - def run_migrations_offline() -> None: """Run migrations in 'offline' mode. @@ -41,6 +36,7 @@ def run_migrations_offline() -> None: script output. """ + url = config.get_main_option("sqlalchemy.url") context.configure( url=url, target_metadata=target_metadata, @@ -59,8 +55,11 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - print(config) - connectable = create_engine(url) + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) with connectable.connect() as connection: context.configure(connection=connection, target_metadata=target_metadata) diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py index cf3a7740d7..181d791306 100644 --- a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py +++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py @@ -29,7 +29,11 @@ def upgrade() -> None: "created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now() ), sa.Column( - "updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now() + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + onupdate=sa.func.now(), ), ) op.create_table( diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 44b98f3b08..e457406782 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -6,6 +6,7 @@ DateTime, ForeignKey, UniqueConstraint, + func, ) from sqlalchemy.orm import ( DeclarativeBase, @@ -28,8 +29,10 @@ class Project(Base): id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] description: Mapped[Optional[str]] - updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True)) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) traces: WriteOnlyMapped["Trace"] = relationship( "Trace", From fdf2507e1935f940b9d549bc2672d84f280b2f15 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Thu, 4 Apr 2024 09:25:24 -0600 Subject: [PATCH 11/30] add contraints naming conventions --- src/phoenix/db/alembic.ini | 4 ++-- .../migrations/versions/cf03bd6bae1d_init.py | 18 +++++++-------- src/phoenix/db/models.py | 23 +++++++++++++++---- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/phoenix/db/alembic.ini b/src/phoenix/db/alembic.ini index 27aff4ddf9..c8b09b0273 100644 --- a/src/phoenix/db/alembic.ini +++ b/src/phoenix/db/alembic.ini @@ -94,12 +94,12 @@ keys = console keys = generic [logger_root] -level = WARN +level = DEBUG handlers = console qualname = [logger_sqlalchemy] -level = WARN +level = DEBUG handlers = qualname = sqlalchemy.engine diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py index 181d791306..217bfaf2f6 100644 --- a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py +++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py @@ -22,15 +22,13 @@ def upgrade() -> None: projects_table = op.create_table( "projects", sa.Column("id", sa.Integer, primary_key=True), - sa.Column("name", sa.String, nullable=False), + # TODO does the uniqueness constraint need to be named + sa.Column("name", sa.String, nullable=False, unique=True), sa.Column("description", sa.String, nullable=True), - # TODO(mikeldking): is timezone=True necessary? - sa.Column( - "created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now() - ), + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()), sa.Column( "updated_at", - sa.DateTime(timezone=True), + sa.DateTime(), nullable=False, server_default=sa.func.now(), onupdate=sa.func.now(), @@ -43,15 +41,15 @@ def upgrade() -> None: # TODO(mikeldking): might not be the right place for this sa.Column("session_id", sa.String, nullable=True), sa.Column("trace_id", sa.String, nullable=False, unique=True), - sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), - sa.Column("end_time", sa.DateTime(timezone=True), nullable=False), + sa.Column("start_time", sa.DateTime(), nullable=False), + sa.Column("end_time", sa.DateTime(), nullable=False), ) op.create_table( "spans", sa.Column("id", sa.Integer, primary_key=True), sa.Column("trace_rowid", sa.Integer, sa.ForeignKey("traces.id"), nullable=False), - sa.Column("span_id", sa.String, nullable=False), + sa.Column("span_id", sa.String, nullable=False, unique=True), sa.Column("parent_span_id", sa.String, nullable=True), sa.Column("name", sa.String, nullable=False), sa.Column("kind", sa.String, nullable=False), @@ -63,7 +61,7 @@ def upgrade() -> None: "status", sa.String, # TODO(mikeldking): this doesn't seem to work... - sa.CheckConstraint("status IN ('OK', 'ERROR', 'UNSET')"), + sa.CheckConstraint("status IN ('OK', 'ERROR', 'UNSET')", "valid_status"), nullable=False, default="UNSET", server_default="UNSET", diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index e457406782..8971b7d940 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -3,8 +3,10 @@ from sqlalchemy import ( JSON, + CheckConstraint, DateTime, ForeignKey, + MetaData, UniqueConstraint, func, ) @@ -18,6 +20,17 @@ class Base(DeclarativeBase): + # Enforce best practices for naming constraints + # https://alembic.sqlalchemy.org/en/latest/naming.html#integration-of-naming-conventions-into-operations-autogenerate + metadata = MetaData( + naming_convention={ + "ix": "ix_%(column_0_label)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_`%(constraint_name)s`", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", + } + ) type_annotation_map = { Dict[str, Any]: JSON, List[Dict[str, Any]]: JSON, @@ -42,7 +55,7 @@ class Project(Base): __table_args__ = ( UniqueConstraint( "name", - name="project_name_unique", + name="uq_projects_name", sqlite_on_conflict="IGNORE", ), ) @@ -70,7 +83,7 @@ class Trace(Base): __table_args__ = ( UniqueConstraint( "trace_id", - name="trace_id_unique", + name="uq_traces_trace_id", sqlite_on_conflict="IGNORE", ), ) @@ -88,7 +101,9 @@ class Span(Base): end_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) attributes: Mapped[Dict[str, Any]] events: Mapped[List[Dict[str, Any]]] - status: Mapped[str] + status: Mapped[str] = mapped_column( + CheckConstraint("status IN ('OK', 'ERROR', 'UNSET')", "valid_status") + ) status_message: Mapped[str] # TODO(mikeldking): is computed columns possible here @@ -102,7 +117,7 @@ class Span(Base): __table_args__ = ( UniqueConstraint( "span_id", - name="trace_id_unique", + name="uq_spans_trace_id", sqlite_on_conflict="IGNORE", ), ) From 27a87e8d00b84145f69bcdb6e590f11429750b58 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Thu, 4 Apr 2024 09:26:39 -0600 Subject: [PATCH 12/30] Update src/phoenix/db/migrate.py Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com> --- src/phoenix/db/migrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/phoenix/db/migrate.py b/src/phoenix/db/migrate.py index 4a1bf52450..8746cfb694 100644 --- a/src/phoenix/db/migrate.py +++ b/src/phoenix/db/migrate.py @@ -17,7 +17,7 @@ def migrate(url: str) -> None: url: The database URL. """ logger.warning("Running migrations on the database") - config_path = os.path.normpath(str(Path(__file__).parent.resolve()) + os.sep + "alembic.ini") + config_path = str(Path(__file__).parent.resolve() / "alembic.ini") alembic_cfg = Config(config_path) # Explicitly set the migration directory From 5f2630240cddfcf783aa476863dbb34dd972b5a7 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Thu, 4 Apr 2024 09:26:44 -0600 Subject: [PATCH 13/30] Update src/phoenix/db/migrate.py Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com> --- src/phoenix/db/migrate.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/phoenix/db/migrate.py b/src/phoenix/db/migrate.py index 8746cfb694..7e1f942ee4 100644 --- a/src/phoenix/db/migrate.py +++ b/src/phoenix/db/migrate.py @@ -21,9 +21,7 @@ def migrate(url: str) -> None: alembic_cfg = Config(config_path) # Explicitly set the migration directory - scripts_location = os.path.normpath( - str(Path(__file__).parent.resolve()) + os.sep + "migrations" - ) + scripts_location = str(Path(__file__).parent.resolve()) / "migrations") alembic_cfg.set_main_option("script_location", scripts_location) alembic_cfg.set_main_option("sqlalchemy.url", url) From a4d4c8827250d04074f1b42885ab813455a4df96 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Thu, 4 Apr 2024 09:46:29 -0600 Subject: [PATCH 14/30] restore log level --- src/phoenix/db/alembic.ini | 6 +++--- src/phoenix/db/database.py | 2 +- src/phoenix/db/migrate.py | 3 +-- .../db/migrations/versions/cf03bd6bae1d_init.py | 14 ++++---------- src/phoenix/db/models.py | 11 +++++------ 5 files changed, 14 insertions(+), 22 deletions(-) diff --git a/src/phoenix/db/alembic.ini b/src/phoenix/db/alembic.ini index c8b09b0273..402ff9e576 100644 --- a/src/phoenix/db/alembic.ini +++ b/src/phoenix/db/alembic.ini @@ -94,17 +94,17 @@ keys = console keys = generic [logger_root] -level = DEBUG +level = WARN handlers = console qualname = [logger_sqlalchemy] -level = DEBUG +level = WARN handlers = qualname = sqlalchemy.engine [logger_alembic] -level = DEBUG +level = WARN handlers = qualname = alembic diff --git a/src/phoenix/db/database.py b/src/phoenix/db/database.py index 86d9f761dc..be23975d6e 100644 --- a/src/phoenix/db/database.py +++ b/src/phoenix/db/database.py @@ -80,7 +80,7 @@ def __init__(self, db_path: Optional[Path] = None) -> None: migrate(db_url) else: Base.metadata.create_all(engine) - # Create the default project + # Create the default project and setup indexes with engine.connect() as conn: conn.execute(insert(Project).values(name="default", description="default project")) conn.commit() diff --git a/src/phoenix/db/migrate.py b/src/phoenix/db/migrate.py index 7e1f942ee4..35d640300a 100644 --- a/src/phoenix/db/migrate.py +++ b/src/phoenix/db/migrate.py @@ -1,5 +1,4 @@ import logging -import os from pathlib import Path from alembic import command @@ -21,7 +20,7 @@ def migrate(url: str) -> None: alembic_cfg = Config(config_path) # Explicitly set the migration directory - scripts_location = str(Path(__file__).parent.resolve()) / "migrations") + scripts_location = str(Path(__file__).parent.resolve() / "migrations") alembic_cfg.set_main_option("script_location", scripts_location) alembic_cfg.set_main_option("sqlalchemy.url", url) diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py index 217bfaf2f6..e0a997531c 100644 --- a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py +++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py @@ -41,7 +41,7 @@ def upgrade() -> None: # TODO(mikeldking): might not be the right place for this sa.Column("session_id", sa.String, nullable=True), sa.Column("trace_id", sa.String, nullable=False, unique=True), - sa.Column("start_time", sa.DateTime(), nullable=False), + sa.Column("start_time", sa.DateTime(), nullable=False, index=True), sa.Column("end_time", sa.DateTime(), nullable=False), ) @@ -50,11 +50,11 @@ def upgrade() -> None: sa.Column("id", sa.Integer, primary_key=True), sa.Column("trace_rowid", sa.Integer, sa.ForeignKey("traces.id"), nullable=False), sa.Column("span_id", sa.String, nullable=False, unique=True), - sa.Column("parent_span_id", sa.String, nullable=True), + sa.Column("parent_span_id", sa.String, nullable=True, index=True), sa.Column("name", sa.String, nullable=False), sa.Column("kind", sa.String, nullable=False), - sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), - sa.Column("end_time", sa.DateTime(timezone=True), nullable=False), + sa.Column("start_time", sa.DateTime(), nullable=False), + sa.Column("end_time", sa.DateTime(), nullable=False), sa.Column("attributes", sa.JSON, nullable=False), sa.Column("events", sa.JSON, nullable=False), sa.Column( @@ -72,9 +72,6 @@ def upgrade() -> None: sa.Column("cumulative_llm_token_count_prompt", sa.Integer, nullable=False), sa.Column("cumulative_llm_token_count_completion", sa.Integer, nullable=False), ) - - op.create_index("idx_trace_start_time", "traces", ["start_time"]) - op.create_index("idx_parent_span_id", "spans", ["parent_span_id"]) op.bulk_insert( projects_table, [ @@ -84,9 +81,6 @@ def upgrade() -> None: def downgrade() -> None: - op.drop_index("idx_trace_start_time") - op.drop_index("idx_parent_span_id", "spans") - op.drop_table("projects") op.drop_table("traces") op.drop_table("spans") diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 8971b7d940..b52428f1f8 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -67,9 +67,8 @@ class Trace(Base): project_rowid: Mapped[int] = mapped_column(ForeignKey("projects.id")) session_id: Mapped[Optional[str]] trace_id: Mapped[str] - # TODO(mikeldking): why is the start and end time necessary? just filtering? - start_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) - end_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + start_time: Mapped[datetime] = mapped_column(DateTime(), index=True) + end_time: Mapped[datetime] = mapped_column(DateTime()) project: Mapped["Project"] = relationship( "Project", @@ -94,11 +93,11 @@ class Span(Base): id: Mapped[int] = mapped_column(primary_key=True) trace_rowid: Mapped[int] = mapped_column(ForeignKey("traces.id")) span_id: Mapped[str] - parent_span_id: Mapped[Optional[str]] + parent_span_id: Mapped[Optional[str]] = mapped_column(index=True) name: Mapped[str] kind: Mapped[str] - start_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) - end_time: Mapped[datetime] = mapped_column(DateTime(timezone=True)) + start_time: Mapped[datetime] = mapped_column(DateTime()) + end_time: Mapped[datetime] = mapped_column(DateTime()) attributes: Mapped[Dict[str, Any]] events: Mapped[List[Dict[str, Any]]] status: Mapped[str] = mapped_column( From 30b2ab0b7d568e5656425ef43d4b94557e55394f Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Fri, 5 Apr 2024 16:26:36 -0700 Subject: [PATCH 15/30] refactor: sqlite async session for graphql api (#2784) --- app/schema.graphql | 2 +- pyproject.toml | 1 + src/phoenix/core/traces.py | 76 +--- src/phoenix/db/database.py | 351 ------------------ src/phoenix/db/engines.py | 58 +++ src/phoenix/db/migrate.py | 5 +- src/phoenix/db/migrations/env.py | 44 ++- src/phoenix/db/models.py | 15 +- src/phoenix/server/api/context.py | 4 +- .../server/api/routers/trace_handler.py | 116 +++++- src/phoenix/server/api/types/Project.py | 162 +++++--- src/phoenix/server/api/types/Span.py | 102 ++--- src/phoenix/server/api/types/Trace.py | 15 +- src/phoenix/server/app.py | 39 +- src/phoenix/server/main.py | 47 ++- src/phoenix/session/session.py | 5 +- tests/server/api/types/conftest.py | 1 + 17 files changed, 487 insertions(+), 556 deletions(-) delete mode 100644 src/phoenix/db/database.py create mode 100644 src/phoenix/db/engines.py diff --git a/app/schema.graphql b/app/schema.graphql index e5ff7dbcae..26cc2d66f5 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -547,7 +547,7 @@ type Project implements Node { latencyMsP50: Float latencyMsP99: Float trace(traceId: ID!): Trace - spans(timeRange: TimeRange, traceIds: [ID!], first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection! + spans(timeRange: TimeRange, first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection! """ Names of all available evaluations for traces. (The list contains no duplicates.) diff --git a/pyproject.toml b/pyproject.toml index 3c7b1cea88..6726f90194 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "openinference-instrumentation-openai>=0.1.4", "sqlalchemy>=2, <3", "alembic>=1.3.0, <2", + "aiosqlite", ] dynamic = ["version"] diff --git a/src/phoenix/core/traces.py b/src/phoenix/core/traces.py index 6fff53e469..0681e5479e 100644 --- a/src/phoenix/core/traces.py +++ b/src/phoenix/core/traces.py @@ -1,10 +1,9 @@ import weakref from collections import defaultdict -from datetime import datetime from queue import SimpleQueue from threading import RLock, Thread from types import MethodType -from typing import DefaultDict, Iterator, Optional, Protocol, Tuple, Union +from typing import DefaultDict, Iterator, Optional, Tuple, Union from typing_extensions import assert_never @@ -13,45 +12,16 @@ from phoenix.core.project import ( END_OF_QUEUE, Project, - WrappedSpan, _ProjectName, ) -from phoenix.trace.schemas import ComputedAttributes, ComputedValues, Span, TraceID +from phoenix.trace.schemas import Span _SpanItem = Tuple[Span, _ProjectName] _EvalItem = Tuple[pb.Evaluation, _ProjectName] -class Database(Protocol): - def insert_span(self, span: Span, project_name: str) -> None: ... - - def trace_count( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: ... - - def span_count( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: ... - - def llm_token_count_total( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: ... - - def get_trace(self, trace_id: TraceID) -> Iterator[Tuple[Span, ComputedValues]]: ... - - class Traces: - def __init__(self, database: Database) -> None: - self._database = database + def __init__(self) -> None: self._span_queue: "SimpleQueue[Optional[_SpanItem]]" = SimpleQueue() self._eval_queue: "SimpleQueue[Optional[_EvalItem]]" = SimpleQueue() # Putting `None` as the sentinel value for queue termination. @@ -64,45 +34,6 @@ def __init__(self, database: Database) -> None: ) self._start_consumers() - def trace_count( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: - return self._database.trace_count(project_name, start_time, stop_time) - - def span_count( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: - return self._database.span_count(project_name, start_time, stop_time) - - def llm_token_count_total( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: - return self._database.llm_token_count_total(project_name, start_time, stop_time) - - def get_trace(self, trace_id: TraceID) -> Iterator[WrappedSpan]: - for span, computed_values in self._database.get_trace(trace_id): - wrapped_span = WrappedSpan(span) - wrapped_span[ComputedAttributes.LATENCY_MS] = computed_values.latency_ms - wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT] = ( - computed_values.cumulative_llm_token_count_prompt - ) - wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION] = ( - computed_values.cumulative_llm_token_count_completion - ) - wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL] = ( - computed_values.cumulative_llm_token_count_total - ) - yield wrapped_span - def get_project(self, project_name: str) -> Optional["Project"]: with self._lock: return self._projects.get(project_name) @@ -153,7 +84,6 @@ def _start_consumers(self) -> None: def _consume_spans(self, queue: "SimpleQueue[Optional[_SpanItem]]") -> None: while (item := queue.get()) is not END_OF_QUEUE: span, project_name = item - self._database.insert_span(span, project_name=project_name) with self._lock: project = self._projects[project_name] project.add_span(span) diff --git a/src/phoenix/db/database.py b/src/phoenix/db/database.py deleted file mode 100644 index be23975d6e..0000000000 --- a/src/phoenix/db/database.py +++ /dev/null @@ -1,351 +0,0 @@ -import json -import sqlite3 -from datetime import datetime -from enum import Enum -from pathlib import Path -from sqlite3 import Connection -from typing import Any, Iterator, Optional, Tuple, cast - -import numpy as np -from openinference.semconv.trace import SpanAttributes -from sqlalchemy import Engine, create_engine, event, insert -from sqlalchemy.orm import sessionmaker - -from phoenix.db.models import Base, Project, Trace -from phoenix.trace.schemas import ( - ComputedValues, - Span, - SpanContext, - SpanEvent, - SpanKind, - SpanStatusCode, -) - -from .migrate import migrate - -_CONFIG = """ -PRAGMA foreign_keys = ON; -PRAGMA journal_mode = WAL; -PRAGMA synchronous = OFF; -PRAGMA cache_size = -32000; -PRAGMA busy_timeout = 10000; -""" - - -_MEM_DB_STR = "file::memory:?cache=shared" - - -def _mem_db_creator() -> Any: - return sqlite3.connect(_MEM_DB_STR, uri=True) - - -@event.listens_for(Engine, "connect") -def set_sqlite_pragma(dbapi_connection: Connection, _: Any) -> None: - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys = ON;") - cursor.execute("PRAGMA journal_mode = WAL;") - cursor.execute("PRAGMA synchronous = OFF;") - cursor.execute("PRAGMA cache_size = -32000;") - cursor.execute("PRAGMA busy_timeout = 10000;") - cursor.close() - - -class SqliteDatabase: - def __init__(self, db_path: Optional[Path] = None) -> None: - """ - :param db_path: The path to the database file to be opened. - """ - self.con = sqlite3.connect( - database=db_path or _MEM_DB_STR, - uri=True, - check_same_thread=False, - ) - # self.con.set_trace_callback(print) - cur = self.con.cursor() - cur.executescript(_CONFIG) - - db_url = f"sqlite:///{db_path}" if db_path else "sqlite:///:memory:" - engine = ( - create_engine(db_url, echo=True) - if db_path - else create_engine( - db_url, - echo=True, - creator=_mem_db_creator, - ) - ) - - # TODO this should be moved out - if db_path: - migrate(db_url) - else: - Base.metadata.create_all(engine) - # Create the default project and setup indexes - with engine.connect() as conn: - conn.execute(insert(Project).values(name="default", description="default project")) - conn.commit() - - self.Session = sessionmaker(bind=engine) - - def insert_span(self, span: Span, project_name: str) -> None: - cur = self.con.cursor() - cur.execute("BEGIN;") - try: - if not ( - projects := cur.execute( - "SELECT rowid FROM projects WHERE name = ?;", - (project_name,), - ).fetchone() - ): - projects = cur.execute( - "INSERT INTO projects(name) VALUES(?) RETURNING rowid;", - (project_name,), - ).fetchone() - project_rowid = projects[0] - if ( - trace_row := cur.execute( - """ - INSERT INTO traces(trace_id, project_rowid, session_id, start_time, end_time) - VALUES(?,?,?,?,?) - ON CONFLICT DO UPDATE SET - start_time = CASE WHEN excluded.start_time < start_time THEN excluded.start_time ELSE start_time END, - end_time = CASE WHEN end_time < excluded.end_time THEN excluded.end_time ELSE end_time END - WHERE excluded.start_time < start_time OR end_time < excluded.end_time - RETURNING rowid; - """, # noqa E501 - ( - span.context.trace_id, - project_rowid, - None, - span.start_time, - span.end_time, - ), - ).fetchone() - ) is None: - trace_row = cur.execute( - "SELECT rowid from traces where trace_id = ?", (span.context.trace_id,) - ).fetchone() - trace_rowid = trace_row[0] - cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) - cumulative_llm_token_count_prompt = cast( - int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0) - ) - cumulative_llm_token_count_completion = cast( - int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0) - ) - if accumulation := cur.execute( - """ - SELECT - sum(cumulative_error_count), - sum(cumulative_llm_token_count_prompt), - sum(cumulative_llm_token_count_completion) - FROM spans - WHERE parent_span_id = ? - """, # noqa E501 - (span.context.span_id,), - ).fetchone(): - cumulative_error_count += cast(int, accumulation[0] or 0) - cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0) - cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0) - latency_ms = (span.end_time - span.start_time).total_seconds() * 1000 - cur.execute( - """ - INSERT INTO spans(span_id, trace_rowid, parent_span_id, kind, name, start_time, end_time, attributes, events, status, status_message, latency_ms, cumulative_error_count, cumulative_llm_token_count_prompt, cumulative_llm_token_count_completion) - VALUES(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?) - RETURNING rowid; - """, # noqa E501 - ( - span.context.span_id, - trace_rowid, - span.parent_id, - span.span_kind.value, - span.name, - span.start_time, - span.end_time, - json.dumps(span.attributes, cls=_Encoder), - json.dumps(span.events, cls=_Encoder), - span.status_code.value, - span.status_message, - latency_ms, - cumulative_error_count, - cumulative_llm_token_count_prompt, - cumulative_llm_token_count_completion, - ), - ) - parent_id = span.parent_id - while parent_id: - if parent_span := cur.execute( - """ - SELECT rowid, parent_span_id - FROM spans - WHERE span_id = ? - """, - (parent_id,), - ).fetchone(): - rowid, parent_id = parent_span[0], parent_span[1] - cur.execute( - """ - UPDATE spans SET - cumulative_error_count = cumulative_error_count + ?, - cumulative_llm_token_count_prompt = cumulative_llm_token_count_prompt + ?, - cumulative_llm_token_count_completion = cumulative_llm_token_count_completion + ? - WHERE rowid = ?; - """, # noqa E501 - ( - cumulative_error_count, - cumulative_llm_token_count_prompt, - cumulative_llm_token_count_completion, - rowid, - ), - ) - else: - break - except Exception: - cur.execute("ROLLBACK;") - else: - cur.execute("COMMIT;") - - def get_projects(self) -> Iterator[Tuple[int, str]]: - cur = self.con.cursor() - for project in cur.execute("SELECT rowid, name FROM projects;").fetchall(): - yield cast(Tuple[int, str], (project[0], project[1])) - - def trace_count( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: - query = """ - SELECT COUNT(*) - FROM traces - JOIN projects ON projects.rowid = traces.project_rowid - WHERE projects.name = ? - """ - cur = self.con.cursor() - if start_time and stop_time: - cur = cur.execute( - query + " AND ? <= traces.start_time AND traces.start_time < ?;", - (project_name, start_time, stop_time), - ) - elif start_time: - cur = cur.execute(query + " AND ? <= traces.start_time;", (project_name, start_time)) - elif stop_time: - cur = cur.execute(query + " AND traces.start_time < ?;", (project_name, stop_time)) - else: - cur = cur.execute(query + ";", (project_name,)) - if res := cur.fetchone(): - return cast(int, res[0] or 0) - return 0 - - def span_count( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: - query = """ - SELECT COUNT(*) - FROM spans - JOIN traces ON traces.rowid = spans.trace_rowid - JOIN projects ON projects.rowid = traces.project_rowid - WHERE projects.name = ? - """ - cur = self.con.cursor() - if start_time and stop_time: - cur = cur.execute( - query + " AND ? <= spans.start_time AND spans.start_time < ?;", - (project_name, start_time, stop_time), - ) - elif start_time: - cur = cur.execute(query + " AND ? <= spans.start_time;", (project_name, start_time)) - elif stop_time: - cur = cur.execute(query + " AND spans.start_time < ?;", (project_name, stop_time)) - else: - cur = cur.execute(query + ";", (project_name,)) - if res := cur.fetchone(): - return cast(int, res[0] or 0) - return 0 - - def llm_token_count_total( - self, - project_name: str, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - ) -> int: - query = """ - SELECT - SUM(COALESCE(json_extract(spans.attributes, '$."llm.token_count.prompt"'), 0) + - COALESCE(json_extract(spans.attributes, '$."llm.token_count.completion"'), 0)) - FROM spans - JOIN traces ON traces.rowid = spans.trace_rowid - JOIN projects ON projects.rowid = traces.project_rowid - WHERE projects.name = ? - """ # noqa E501 - cur = self.con.cursor() - if start_time and stop_time: - cur = cur.execute( - query + " AND ? <= spans.start_time AND spans.start_time < ?;", - (project_name, start_time, stop_time), - ) - elif start_time: - cur = cur.execute(query + " AND ? <= spans.start_time;", (project_name, start_time)) - elif stop_time: - cur = cur.execute(query + " AND spans.start_time < ?;", (project_name, stop_time)) - else: - cur = cur.execute(query + ";", (project_name,)) - if res := cur.fetchone(): - return cast(int, res[0] or 0) - return 0 - - def get_trace(self, trace_id: str) -> Iterator[Tuple[Span, ComputedValues]]: - with self.Session.begin() as session: - trace = session.query(Trace).where(Trace.trace_id == trace_id).one_or_none() - if not trace: - return - for span in trace.spans: - yield ( - Span( - name=span.name, - context=SpanContext(trace_id=span.trace.trace_id, span_id=span.span_id), - parent_id=span.parent_span_id, - span_kind=SpanKind(span.kind), - start_time=span.start_time, - end_time=span.end_time, - attributes=span.attributes, - events=[ - SpanEvent( - name=obj["name"], - attributes=obj["attributes"], - timestamp=obj["timestamp"], - ) - for obj in span.events - ], - status_code=SpanStatusCode(span.status), - status_message=span.status_message, - conversation=None, - ), - ComputedValues( - latency_ms=span.latency_ms, - cumulative_error_count=span.cumulative_error_count, - cumulative_llm_token_count_prompt=span.cumulative_llm_token_count_prompt, - cumulative_llm_token_count_completion=span.cumulative_llm_token_count_completion, - cumulative_llm_token_count_total=span.cumulative_llm_token_count_prompt - + span.cumulative_llm_token_count_completion, - ), - ) - - -class _Encoder(json.JSONEncoder): - def default(self, obj: Any) -> Any: - if isinstance(obj, datetime): - return obj.isoformat() - elif isinstance(obj, Enum): - return obj.value - elif isinstance(obj, np.ndarray): - return list(obj) - elif isinstance(obj, np.integer): - return int(obj) - elif isinstance(obj, np.floating): - return float(obj) - return super().default(obj) diff --git a/src/phoenix/db/engines.py b/src/phoenix/db/engines.py new file mode 100644 index 0000000000..001aa8a352 --- /dev/null +++ b/src/phoenix/db/engines.py @@ -0,0 +1,58 @@ +import asyncio +import json +from datetime import datetime +from enum import Enum +from pathlib import Path +from sqlite3 import Connection +from typing import Any, Union + +import numpy as np +from sqlalchemy import URL, event +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine + +from phoenix.db.migrate import migrate +from phoenix.db.models import init_models + + +def set_sqlite_pragma(connection: Connection, _: Any) -> None: + cursor = connection.cursor() + cursor.execute("PRAGMA foreign_keys = ON;") + cursor.execute("PRAGMA journal_mode = WAL;") + cursor.execute("PRAGMA synchronous = OFF;") + cursor.execute("PRAGMA cache_size = -32000;") + cursor.execute("PRAGMA busy_timeout = 10000;") + cursor.close() + + +def aiosqlite_engine( + database: Union[str, Path] = ":memory:", + echo: bool = False, +) -> AsyncEngine: + driver_name = "sqlite+aiosqlite" + url = URL.create(driver_name, database=str(database)) + engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps) + event.listen(engine.sync_engine, "connect", set_sqlite_pragma) + if str(database) == ":memory:": + asyncio.run(init_models(engine)) + else: + migrate(url) + return engine + + +def _dumps(obj: Any) -> str: + return json.dumps(obj, cls=_Encoder) + + +class _Encoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, datetime): + return obj.isoformat() + elif isinstance(obj, Enum): + return obj.value + elif isinstance(obj, np.ndarray): + return list(obj) + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + return super().default(obj) diff --git a/src/phoenix/db/migrate.py b/src/phoenix/db/migrate.py index 35d640300a..990c390e14 100644 --- a/src/phoenix/db/migrate.py +++ b/src/phoenix/db/migrate.py @@ -3,11 +3,12 @@ from alembic import command from alembic.config import Config +from sqlalchemy import URL logger = logging.getLogger(__name__) -def migrate(url: str) -> None: +def migrate(url: URL) -> None: """ Runs migrations on the database. NB: Migrate only works on non-memory databases. @@ -22,6 +23,6 @@ def migrate(url: str) -> None: # Explicitly set the migration directory scripts_location = str(Path(__file__).parent.resolve() / "migrations") alembic_cfg.set_main_option("script_location", scripts_location) - alembic_cfg.set_main_option("sqlalchemy.url", url) + alembic_cfg.set_main_option("sqlalchemy.url", str(url)) command.upgrade(alembic_cfg, "head") diff --git a/src/phoenix/db/migrations/env.py b/src/phoenix/db/migrations/env.py index 6e3c4d0d3c..2cc51224e6 100644 --- a/src/phoenix/db/migrations/env.py +++ b/src/phoenix/db/migrations/env.py @@ -1,8 +1,10 @@ +import asyncio from logging.config import fileConfig from alembic import context from phoenix.db.models import Base -from sqlalchemy import engine_from_config, pool +from sqlalchemy import Connection, engine_from_config, pool +from sqlalchemy.ext.asyncio import AsyncEngine # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -55,17 +57,37 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - connectable = engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, + connectable = context.config.attributes.get("connection", None) + if connectable is None: + connectable = AsyncEngine( + engine_from_config( + context.config.get_section(context.config.config_ini_section) or {}, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + future=True, + ) + ) + + if isinstance(connectable, AsyncEngine): + asyncio.run(run_async_migrations(connectable)) + else: + run_migrations(connectable) + + +async def run_async_migrations(engine: AsyncEngine) -> None: + async with engine.connect() as connection: + await connection.run_sync(run_migrations) + await engine.dispose() + + +def run_migrations(connection: Connection) -> None: + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, ) - - with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) - - with context.begin_transaction(): - context.run_migrations() + with context.begin_transaction(): + context.run_migrations() if context.is_offline_mode(): diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index b52428f1f8..5f10f21931 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -9,7 +9,9 @@ MetaData, UniqueConstraint, func, + insert, ) +from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.orm import ( DeclarativeBase, Mapped, @@ -116,7 +118,18 @@ class Span(Base): __table_args__ = ( UniqueConstraint( "span_id", - name="uq_spans_trace_id", + name="uq_spans_span_id", sqlite_on_conflict="IGNORE", ), ) + + +async def init_models(engine: AsyncEngine) -> None: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + await conn.execute( + insert(Project).values( + name="default", + description="default project", + ) + ) diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index 2d6a95be7e..bb89e09f72 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -1,7 +1,8 @@ from dataclasses import dataclass from pathlib import Path -from typing import Optional, Union +from typing import AsyncContextManager, Callable, Optional, Union +from sqlalchemy.ext.asyncio import AsyncSession from starlette.requests import Request from starlette.responses import Response from starlette.websockets import WebSocket @@ -14,6 +15,7 @@ class Context: request: Union[Request, WebSocket] response: Optional[Response] + db: Callable[[], AsyncContextManager[AsyncSession]] model: Model export_path: Path corpus: Optional[Model] = None diff --git a/src/phoenix/server/api/routers/trace_handler.py b/src/phoenix/server/api/routers/trace_handler.py index 221f40ad99..507bb0df40 100644 --- a/src/phoenix/server/api/routers/trace_handler.py +++ b/src/phoenix/server/api/routers/trace_handler.py @@ -1,25 +1,31 @@ import asyncio import gzip import zlib -from typing import Optional +from typing import AsyncContextManager, Callable, Optional, cast from google.protobuf.message import DecodeError +from openinference.semconv.trace import SpanAttributes from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( ExportTraceServiceRequest, ) from opentelemetry.proto.trace.v1.trace_pb2 import TracesData +from sqlalchemy import func, insert, select, text, update +from sqlalchemy.ext.asyncio import AsyncSession from starlette.endpoints import HTTPEndpoint from starlette.requests import Request from starlette.responses import Response from starlette.status import HTTP_415_UNSUPPORTED_MEDIA_TYPE, HTTP_422_UNPROCESSABLE_ENTITY from phoenix.core.traces import Traces +from phoenix.db import models from phoenix.storage.span_store import SpanStore from phoenix.trace.otel import decode +from phoenix.trace.schemas import Span, SpanStatusCode from phoenix.utilities.project import get_project_name class TraceHandler(HTTPEndpoint): + db: Callable[[], AsyncContextManager[AsyncSession]] traces: Traces store: Optional[SpanStore] @@ -54,7 +60,111 @@ async def post(self, request: Request) -> Response: for resource_spans in req.resource_spans: project_name = get_project_name(resource_spans.resource.attributes) for scope_span in resource_spans.scope_spans: - for span in scope_span.spans: - self.traces.put(decode(span), project_name=project_name) + for otlp_span in scope_span.spans: + span = decode(otlp_span) + async with self.db() as session: + await _insert_span(session, span, project_name) + self.traces.put(span, project_name=project_name) await asyncio.sleep(0) return Response() + + +async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> None: + if not ( + project_rowid := await session.scalar( + select(models.Project.id).filter(models.Project.name == project_name) + ) + ): + project_rowid = await session.scalar( + insert(models.Project).values(name=project_name).returning(models.Project.id) + ) + if ( + trace_rowid := await session.scalar( + text( + """ + INSERT INTO traces(trace_id, project_rowid, session_id, start_time, end_time) + VALUES(:trace_id, :project_rowid, :session_id, :start_time, :end_time) + ON CONFLICT DO UPDATE SET + start_time = CASE WHEN excluded.start_time < start_time THEN excluded.start_time ELSE start_time END, + end_time = CASE WHEN end_time < excluded.end_time THEN excluded.end_time ELSE end_time END + WHERE excluded.start_time < start_time OR end_time < excluded.end_time + RETURNING rowid; + """ # noqa E501 + ), + { + "trace_id": span.context.trace_id, + "project_rowid": project_rowid, + "session_id": None, + "start_time": span.start_time, + "end_time": span.end_time, + }, + ) + ) is None: + trace_rowid = await session.scalar( + select(models.Trace.id).filter(models.Trace.trace_id == span.context.trace_id) + ) + cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) + cumulative_llm_token_count_prompt = cast( + int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0) + ) + cumulative_llm_token_count_completion = cast( + int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0) + ) + if accumulation := ( + await session.execute( + select( + func.sum(models.Span.cumulative_error_count), + func.sum(models.Span.cumulative_llm_token_count_prompt), + func.sum(models.Span.cumulative_llm_token_count_completion), + ).where(models.Span.parent_span_id == span.context.span_id) + ) + ).first(): + cumulative_error_count += cast(int, accumulation[0] or 0) + cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0) + cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0) + latency_ms = (span.end_time - span.start_time).total_seconds() * 1000 + session.add( + models.Span( + span_id=span.context.span_id, + trace_rowid=trace_rowid, + parent_span_id=span.parent_id, + kind=span.span_kind.value, + name=span.name, + start_time=span.start_time, + end_time=span.end_time, + attributes=span.attributes, + events=span.events, + status=span.status_code.value, + status_message=span.status_message, + latency_ms=latency_ms, + cumulative_error_count=cumulative_error_count, + cumulative_llm_token_count_prompt=cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion=cumulative_llm_token_count_completion, + ) + ) + # Propagate cumulative values to ancestors. This is usually a no-op, since + # the parent usually arrives after the child. But in the event that a + # child arrives after its parent, we need to make sure the all the + # ancestors' cumulative values are updated. + ancestors = ( + select(models.Span.id, models.Span.parent_span_id) + .where(models.Span.span_id == span.parent_id) + .cte(recursive=True) + ) + child = ancestors.alias() + ancestors = ancestors.union_all( + select(models.Span.id, models.Span.parent_span_id).join( + child, models.Span.span_id == child.c.parent_span_id + ) + ) + await session.execute( + update(models.Span) + .where(models.Span.id.in_(select(ancestors.c.id))) + .values( + cumulative_error_count=models.Span.cumulative_error_count + cumulative_error_count, + cumulative_llm_token_count_prompt=models.Span.cumulative_llm_token_count_prompt + + cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion=models.Span.cumulative_llm_token_count_completion + + cumulative_llm_token_count_completion, + ) + ) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 671d74cd93..19c152ecb4 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -1,12 +1,17 @@ from datetime import datetime -from itertools import chain from typing import List, Optional import strawberry +from openinference.semconv.trace import SpanAttributes +from sqlalchemy import Integer, and_, cast, func, select +from sqlalchemy.orm import contains_eager +from sqlalchemy.sql.functions import coalesce from strawberry import ID, UNSET from strawberry.types import Info from phoenix.core.project import Project as CoreProject +from phoenix.datetime_utils import right_open_time_range +from phoenix.db import models from phoenix.metrics.retrieval_metrics import RetrievalMetrics from phoenix.server.api.context import Context from phoenix.server.api.input_types.SpanSort import SpanSort @@ -33,53 +38,104 @@ class Project(Node): project: strawberry.Private[CoreProject] @strawberry.field - def start_time(self) -> Optional[datetime]: - start_time, _ = self.project.right_open_time_range + async def start_time( + self, + info: Info[Context, None], + ) -> Optional[datetime]: + stmt = ( + select(func.min(models.Trace.start_time)) + .join(models.Project) + .where(models.Project.name == self.name) + ) + async with info.context.db() as session: + start_time = await session.scalar(stmt) + start_time, _ = right_open_time_range(start_time, None) return start_time @strawberry.field - def end_time(self) -> Optional[datetime]: - _, end_time = self.project.right_open_time_range + async def end_time( + self, + info: Info[Context, None], + ) -> Optional[datetime]: + stmt = ( + select(func.max(models.Trace.end_time)) + .join(models.Project) + .where(models.Project.name == self.name) + ) + async with info.context.db() as session: + end_time = await session.scalar(stmt) + _, end_time = right_open_time_range(None, end_time) return end_time @strawberry.field - def record_count( + async def record_count( self, info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - if not (traces := info.context.traces): - return 0 - start_time, stop_time = ( - (None, None) if not time_range else (time_range.start, time_range.end) + stmt = ( + select(func.count(models.Span.id)) + .join(models.Trace) + .join(models.Project) + .where(models.Project.name == self.name) ) - return traces.span_count(self.name, start_time, stop_time) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Span.start_time, + models.Span.start_time < time_range.end, + ) + ) + async with info.context.db() as session: + return (await session.scalar(stmt)) or 0 @strawberry.field - def trace_count( + async def trace_count( self, info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - if not (traces := info.context.traces): - return 0 - start_time, stop_time = ( - (None, None) if not time_range else (time_range.start, time_range.end) + stmt = ( + select(func.count(models.Trace.id)) + .join(models.Project) + .where(models.Project.name == self.name) ) - return traces.trace_count(self.name, start_time, stop_time) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Trace.start_time, + models.Trace.start_time < time_range.end, + ) + ) + async with info.context.db() as session: + return (await session.scalar(stmt)) or 0 @strawberry.field - def token_count_total( + async def token_count_total( self, info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - if not (traces := info.context.traces): - return 0 - start_time, stop_time = ( - (None, None) if not time_range else (time_range.start, time_range.end) + prompt = models.Span.attributes[LLM_TOKEN_COUNT_PROMPT] + completion = models.Span.attributes[LLM_TOKEN_COUNT_COMPLETION] + stmt = ( + select( + coalesce(func.sum(cast(prompt, Integer)), 0) + + coalesce(func.sum(cast(completion, Integer)), 0) + ) + .join(models.Trace) + .join(models.Project) + .where(models.Project.name == self.name) ) - return traces.llm_token_count_total(self.name, start_time, stop_time) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Span.start_time, + models.Span.start_time < time_range.end, + ) + ) + async with info.context.db() as session: + return (await session.scalar(stmt)) or 0 @strawberry.field def latency_ms_p50(self) -> Optional[float]: @@ -96,10 +152,10 @@ def trace(self, trace_id: ID) -> Optional[Trace]: return None @strawberry.field - def spans( + async def spans( self, + info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, - trace_ids: Optional[List[ID]] = UNSET, first: Optional[int] = 50, last: Optional[int] = UNSET, after: Optional[Cursor] = UNSET, @@ -114,36 +170,32 @@ def spans( last=last, before=before if isinstance(before, Cursor) else None, ) - start_time = time_range.start if time_range else None - stop_time = time_range.end if time_range else None - if not (project := self.project).span_count( - start_time=start_time, - stop_time=stop_time, - ): - return connection_from_list(data=[], args=args) - predicate = ( - SpanFilter( - condition=filter_condition, - evals=project, - ) - if filter_condition - else None + stmt = ( + select(models.Span) + .join(models.Trace) + .join(models.Project) + .where(models.Project.name == self.name) + .options(contains_eager(models.Span.trace)) ) - if not trace_ids: - spans = project.get_spans( - start_time=start_time, - stop_time=stop_time, - root_spans_only=root_spans_only, - ) - else: - spans = chain.from_iterable( - project.get_trace(trace_id) for trace_id in map(TraceID, trace_ids) + if time_range: + stmt = stmt.where( + and_( + time_range.start <= models.Span.start_time, + models.Span.start_time < time_range.end, + ) ) - if predicate: - spans = filter(predicate, spans) - if sort: - spans = sort(spans, evals=project) - data = [to_gql_span(span, project) for span in spans] + if root_spans_only: + # A root span is any span whose parent span is missing in the + # database, even if its `parent_span_id` may not be NULL. + parent = select(models.Span.span_id).alias() + stmt = stmt.outerjoin( + parent, + models.Span.parent_span_id == parent.c.span_id, + ).where(parent.c.span_id.is_(None)) + # TODO(persistence): enable sort and filter + async with info.context.db() as session: + spans = await session.scalars(stmt) + data = [to_gql_span(span, self.project) for span in spans] return connection_from_list(data=data, args=args) @strawberry.field( @@ -305,3 +357,7 @@ def validate_span_filter_condition(self, condition: str) -> ValidationResult: is_valid=False, error_message=e.msg, ) + + +LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT +LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION diff --git a/src/phoenix/server/api/types/Span.py b/src/phoenix/server/api/types/Span.py index 52eebc7786..9025ab4ca9 100644 --- a/src/phoenix/server/api/types/Span.py +++ b/src/phoenix/server/api/types/Span.py @@ -7,17 +7,20 @@ import numpy as np import strawberry from openinference.semconv.trace import EmbeddingAttributes, SpanAttributes +from sqlalchemy import select +from sqlalchemy.orm import contains_eager from strawberry import ID, UNSET from strawberry.types import Info import phoenix.trace.schemas as trace_schema -from phoenix.core.project import Project, WrappedSpan +from phoenix.core.project import Project +from phoenix.db import models from phoenix.metrics.retrieval_metrics import RetrievalMetrics from phoenix.server.api.context import Context from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation from phoenix.server.api.types.MimeType import MimeType -from phoenix.trace.schemas import ComputedAttributes, SpanID +from phoenix.trace.schemas import SpanID EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR @@ -40,14 +43,14 @@ class SpanKind(Enum): NB: this is actively under construction """ - chain = trace_schema.SpanKind.CHAIN - tool = trace_schema.SpanKind.TOOL - llm = trace_schema.SpanKind.LLM - retriever = trace_schema.SpanKind.RETRIEVER - embedding = trace_schema.SpanKind.EMBEDDING - agent = trace_schema.SpanKind.AGENT - reranker = trace_schema.SpanKind.RERANKER - unknown = trace_schema.SpanKind.UNKNOWN + chain = "CHAIN" + tool = "TOOL" + llm = "LLM" + retriever = "RETRIEVER" + embedding = "EMBEDDING" + agent = "AGENT" + reranker = "RERANKER" + unknown = "UNKNOWN" @classmethod def _missing_(cls, v: Any) -> Optional["SpanKind"]: @@ -68,9 +71,9 @@ class SpanIOValue: @strawberry.enum class SpanStatusCode(Enum): - OK = trace_schema.SpanStatusCode.OK - ERROR = trace_schema.SpanStatusCode.ERROR - UNSET = trace_schema.SpanStatusCode.UNSET + OK = "OK" + ERROR = "ERROR" + UNSET = "UNSET" @classmethod def _missing_(cls, v: Any) -> Optional["SpanStatusCode"]: @@ -84,13 +87,13 @@ class SpanEvent: timestamp: datetime @staticmethod - def from_event( - event: trace_schema.SpanEvent, + def from_dict( + event: Mapping[str, Any], ) -> "SpanEvent": return SpanEvent( - name=event.name, - message=cast(str, event.attributes.get(trace_schema.EXCEPTION_MESSAGE) or ""), - timestamp=event.timestamp, + name=event["name"], + message=cast(str, event["attributes"].get(trace_schema.EXCEPTION_MESSAGE) or ""), + timestamp=event["timestamp"], ) @@ -202,18 +205,35 @@ def document_retrieval_metrics( @strawberry.field( description="All descendant spans (children, grandchildren, etc.)", ) # type: ignore - def descendants( + async def descendants( self, info: Info[Context, None], ) -> List["Span"]: - return [ - to_gql_span(span, self.project) - for span in self.project.get_descendant_spans(SpanID(self.context.span_id)) - ] + # TODO(persistence): add dataloader (to avoid N+1 queries) or change how this is fetched + async with info.context.db() as session: + descendant_ids = ( + select(models.Span.id, models.Span.span_id) + .filter(models.Span.parent_span_id == str(self.context.span_id)) + .cte(recursive=True) + ) + parent_ids = descendant_ids.alias() + descendant_ids = descendant_ids.union_all( + select(models.Span.id, models.Span.span_id).join( + parent_ids, + models.Span.parent_span_id == parent_ids.c.span_id, + ) + ) + spans = await session.scalars( + select(models.Span) + .join(descendant_ids, models.Span.id == descendant_ids.c.id) + .join(models.Trace) + .options(contains_eager(models.Span.trace)) + ) + return [to_gql_span(span, self.project) for span in spans] -def to_gql_span(span: WrappedSpan, project: Project) -> "Span": - events: List[SpanEvent] = list(map(SpanEvent.from_event, span.events)) +def to_gql_span(span: models.Span, project: Project) -> Span: + events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events)) input_value = cast(Optional[str], span.attributes.get(INPUT_VALUE)) output_value = cast(Optional[str], span.attributes.get(OUTPUT_VALUE)) retrieval_documents = span.attributes.get(RETRIEVAL_DOCUMENTS) @@ -221,16 +241,16 @@ def to_gql_span(span: WrappedSpan, project: Project) -> "Span": return Span( project=project, name=span.name, - status_code=SpanStatusCode(span.status_code), + status_code=SpanStatusCode(span.status), status_message=span.status_message, - parent_id=cast(Optional[ID], span.parent_id), - span_kind=SpanKind(span.span_kind), + parent_id=cast(Optional[ID], span.parent_span_id), + span_kind=SpanKind(span.kind), start_time=span.start_time, end_time=span.end_time, - latency_ms=cast(Optional[float], span[ComputedAttributes.LATENCY_MS]), + latency_ms=span.latency_ms, context=SpanContext( - trace_id=cast(ID, span.context.trace_id), - span_id=cast(ID, span.context.span_id), + trace_id=cast(ID, span.trace.trace_id), + span_id=cast(ID, span.span_id), ), attributes=json.dumps( _nested_attributes(_hide_embedding_vectors(span.attributes)), @@ -250,22 +270,12 @@ def to_gql_span(span: WrappedSpan, project: Project) -> "Span": Optional[int], span.attributes.get(LLM_TOKEN_COUNT_COMPLETION), ), - cumulative_token_count_total=cast( - Optional[int], - span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL], - ), - cumulative_token_count_prompt=cast( - Optional[int], - span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT], - ), - cumulative_token_count_completion=cast( - Optional[int], - span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION], - ), + cumulative_token_count_total=span.cumulative_llm_token_count_prompt + + span.cumulative_llm_token_count_completion, + cumulative_token_count_prompt=span.cumulative_llm_token_count_prompt, + cumulative_token_count_completion=span.cumulative_llm_token_count_completion, propagated_status_code=( - SpanStatusCode.ERROR - if span[ComputedAttributes.CUMULATIVE_ERROR_COUNT] - else SpanStatusCode(span.status_code) + SpanStatusCode.ERROR if span.cumulative_error_count else SpanStatusCode(span.status) ), events=events, input=( diff --git a/src/phoenix/server/api/types/Trace.py b/src/phoenix/server/api/types/Trace.py index a7d440bc52..075928ba8a 100644 --- a/src/phoenix/server/api/types/Trace.py +++ b/src/phoenix/server/api/types/Trace.py @@ -1,10 +1,13 @@ from typing import List, Optional import strawberry +from sqlalchemy import select +from sqlalchemy.orm import contains_eager from strawberry import ID, UNSET, Private from strawberry.types import Info from phoenix.core.project import Project +from phoenix.db import models from phoenix.server.api.context import Context from phoenix.server.api.types.Evaluation import TraceEvaluation from phoenix.server.api.types.pagination import ( @@ -23,7 +26,7 @@ class Trace: project: Private[Project] @strawberry.field - def spans( + async def spans( self, info: Info[Context, None], first: Optional[int] = 50, @@ -37,9 +40,13 @@ def spans( last=last, before=before if isinstance(before, Cursor) else None, ) - if not (traces := info.context.traces): - return connection_from_list(data=[], args=args) - spans = traces.get_trace(TraceID(self.trace_id)) + async with info.context.db() as session: + spans = await session.scalars( + select(models.Span) + .join(models.Trace) + .filter(models.Trace.trace_id == self.trace_id) + .options(contains_eager(models.Span.trace)) + ) data = [to_gql_span(span, self.project) for span in spans] return connection_from_list(data=data, args=args) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 2b814ce04f..04e9c17421 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -1,7 +1,21 @@ +import contextlib import logging from pathlib import Path -from typing import Any, NamedTuple, Optional, Union +from typing import ( + Any, + AsyncContextManager, + AsyncIterator, + Callable, + NamedTuple, + Optional, + Union, +) +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, +) from starlette.applications import Starlette from starlette.datastructures import QueryParams from starlette.endpoints import HTTPEndpoint @@ -96,12 +110,14 @@ class GraphQLWithContext(GraphQL): # type: ignore def __init__( self, schema: BaseSchema, + db: Callable[[], AsyncContextManager[AsyncSession]], model: Model, export_path: Path, graphiql: bool = False, corpus: Optional[Model] = None, traces: Optional[Traces] = None, ) -> None: + self.db = db self.model = model self.corpus = corpus self.traces = traces @@ -116,6 +132,7 @@ async def get_context( return Context( request=request, response=response, + db=self.db, model=self.model, corpus=self.corpus, traces=self.traces, @@ -142,7 +159,19 @@ async def version(_: Request) -> PlainTextResponse: return PlainTextResponse(f"{phoenix.__version__}") +def _db(engine: AsyncEngine) -> Callable[[], AsyncContextManager[AsyncSession]]: + Session = async_sessionmaker(engine, expire_on_commit=False) + + @contextlib.asynccontextmanager + async def factory() -> AsyncIterator[AsyncSession]: + async with Session.begin() as session: + yield session + + return factory + + def create_app( + engine: AsyncEngine, export_path: Path, model: Model, umap_params: UMAPParameters, @@ -153,7 +182,9 @@ def create_app( read_only: bool = False, enable_prometheus: bool = False, ) -> Starlette: + db = _db(engine) graphql = GraphQLWithContext( + db=db, schema=schema, model=model, corpus=corpus, @@ -183,7 +214,11 @@ def create_app( ), Route( "/v1/traces", - type("TraceEndpoint", (TraceHandler,), {"traces": traces, "store": span_store}), + type( + "TraceEndpoint", + (TraceHandler,), + {"db": staticmethod(db), "traces": traces, "store": span_store}, + ), ), Route( "/v1/evaluations", diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index ae78c49e15..14cda43050 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -1,4 +1,5 @@ import atexit +import gzip import logging import os from argparse import ArgumentParser @@ -9,6 +10,9 @@ from typing import Iterable, Optional, Protocol, TypeVar import pkg_resources +import requests +from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest +from opentelemetry.proto.trace.v1.trace_pb2 import ResourceSpans, ScopeSpans from uvicorn import Config, Server from phoenix.config import ( @@ -22,7 +26,7 @@ from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset from phoenix.datasets.fixtures import FIXTURES, get_datasets -from phoenix.db.database import SqliteDatabase +from phoenix.db.engines import aiosqlite_engine from phoenix.pointcloud.umap_parameters import ( DEFAULT_MIN_DIST, DEFAULT_N_NEIGHBORS, @@ -38,6 +42,7 @@ get_evals_from_fixture, ) from phoenix.trace.otel import decode, encode +from phoenix.trace.schemas import Span from phoenix.trace.span_json_decoder import json_string_to_span from phoenix.utilities.span_store import get_span_store, load_traces_data_from_store @@ -108,6 +113,28 @@ def _load_items( queue.put(item) +def _send_spans(spans: Iterable[Span], url: str) -> None: + # TODO(persistence): Ingest fixtures without networking for read-only deployments + sleep(5) # Wait for the server to start + session = requests.session() + for span in spans: + req = ExportTraceServiceRequest( + resource_spans=[ResourceSpans(scope_spans=[ScopeSpans(spans=[encode(span)])])] + ) + session.post( + url=url, + headers={ + "content-type": "application/x-protobuf", + "content-encoding": "gzip", + }, + data=gzip.compress(req.SerializeToString()), + ) + # TODO(persistence): If ingestion rate is too high it can crash the UI, because + # sqlite is not designed for high concurrency, especially for disk + # persistence. + sleep(0.2) + + DEFAULT_UMAP_PARAMS_STR = f"{DEFAULT_MIN_DIST},{DEFAULT_N_NEIGHBORS},{DEFAULT_N_SAMPLES}" if __name__ == "__main__": @@ -191,14 +218,15 @@ def _load_items( trace_dataset_name = args.trace_fixture simulate_streaming = args.simulate_streaming + host = args.host or get_env_host() + port = args.port or get_env_port() + model = create_model_from_datasets( primary_dataset, reference_dataset, ) - working_dir = get_working_dir() - db = SqliteDatabase(working_dir / "phoenix.db") - traces = Traces(db) + traces = Traces() if span_store := get_span_store(): Thread(target=load_traces_data_from_store, args=(traces, span_store), daemon=True).start() if trace_dataset_name is not None: @@ -215,6 +243,11 @@ def _load_items( args=(traces, fixture_spans, simulate_streaming), daemon=True, ).start() + Thread( + target=_send_spans, + args=(fixture_spans, f"http://{host}:{port}/v1/traces"), + daemon=True, + ).start() fixture_evals = list(get_evals_from_fixture(trace_dataset_name)) Thread( target=_load_items, @@ -233,7 +266,11 @@ def _load_items( from phoenix.server.prometheus import start_prometheus start_prometheus() + + working_dir = get_working_dir().resolve() + engine = aiosqlite_engine(working_dir / "phoenix.db") app = create_app( + engine=engine, export_path=export_path, model=model, umap_params=umap_params, @@ -244,8 +281,6 @@ def _load_items( span_store=span_store, enable_prometheus=enable_prometheus, ) - host = args.host or get_env_host() - port = args.port or get_env_port() server = Server(config=Config(app, host=host, port=port)) Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start() diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index d072ff2bc9..c402cbac23 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -37,7 +37,7 @@ from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset -from phoenix.db.database import SqliteDatabase +from phoenix.db.engines import aiosqlite_engine from phoenix.pointcloud.umap_parameters import get_umap_parameters from phoenix.server.app import create_app from phoenix.server.thread_server import ThreadServer @@ -295,7 +295,7 @@ def __init__( if corpus_dataset is not None else None ) - self.traces = Traces(SqliteDatabase()) + self.traces = Traces() if trace_dataset: for span in trace_dataset.to_spans(): self.traces.put(span) @@ -310,6 +310,7 @@ def __init__( ).start() # Initialize an app service that keeps the server running self.app = create_app( + engine=aiosqlite_engine(), export_path=self.export_path, model=self.model, corpus=self.corpus, diff --git a/tests/server/api/types/conftest.py b/tests/server/api/types/conftest.py index c33708ae48..27028ca454 100644 --- a/tests/server/api/types/conftest.py +++ b/tests/server/api/types/conftest.py @@ -42,6 +42,7 @@ def create_context(primary_dataset: Dataset, reference_dataset: Optional[Dataset response=None, model=create_model_from_datasets(primary_dataset, reference_dataset), export_path=Path(TemporaryDirectory().name), + db=None, # TODO(persistence): add mock for db ) return create_context From a2657d4a99f89aa9beb9b2529c624d88c1727ae7 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Sat, 6 Apr 2024 11:25:28 -0600 Subject: [PATCH 16/30] feat(experimental): postgres support --- src/phoenix/db/engines.py | 7 +++++-- src/phoenix/server/main.py | 7 ++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/phoenix/db/engines.py b/src/phoenix/db/engines.py index 001aa8a352..df90e67bc2 100644 --- a/src/phoenix/db/engines.py +++ b/src/phoenix/db/engines.py @@ -24,12 +24,15 @@ def set_sqlite_pragma(connection: Connection, _: Any) -> None: cursor.close() +def get_db_url(driver: str = "sqlite+aiosqlite", database: Union[str, Path] = ":memory:") -> URL: + return URL.create(driver, database=str(database)) + + def aiosqlite_engine( database: Union[str, Path] = ":memory:", echo: bool = False, ) -> AsyncEngine: - driver_name = "sqlite+aiosqlite" - url = URL.create(driver_name, database=str(database)) + url = get_db_url(driver="sqlite+aiosqlite", database=database) engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps) event.listen(engine.sync_engine, "connect", set_sqlite_pragma) if str(database) == ":memory:": diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 14cda43050..119494a396 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -26,7 +26,7 @@ from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset from phoenix.datasets.fixtures import FIXTURES, get_datasets -from phoenix.db.engines import aiosqlite_engine +from phoenix.db.engines import aiosqlite_engine, get_db_url from phoenix.pointcloud.umap_parameters import ( DEFAULT_MIN_DIST, DEFAULT_N_NEIGHBORS, @@ -70,7 +70,7 @@ | 🚀 Phoenix Server 🚀 | Phoenix UI: http://{host}:{port} | Log traces: /v1/traces over HTTP -| Storage location: {working_dir} +| Storage: {storage} """ @@ -268,6 +268,7 @@ def _send_spans(spans: Iterable[Span], url: str) -> None: start_prometheus() working_dir = get_working_dir().resolve() + sql_url = get_db_url(database=working_dir / "phoenix.db") engine = aiosqlite_engine(working_dir / "phoenix.db") app = create_app( engine=engine, @@ -290,7 +291,7 @@ def _send_spans(spans: Iterable[Span], url: str) -> None: "version": phoenix_version, "host": host, "port": port, - "working_dir": working_dir, + "storage": sql_url, } print(_WELCOME_MESSAGE.format(**config)) From 529963447a18568bf44d1722b193a32f944e8abe Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Mon, 8 Apr 2024 11:18:10 -0600 Subject: [PATCH 17/30] WIP --- src/phoenix/config.py | 12 ++++++++++++ src/phoenix/db/engines.py | 10 ++++++++++ src/phoenix/server/main.py | 8 +++++--- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/phoenix/config.py b/src/phoenix/config.py index b952da4d28..3af2c3ffa7 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -22,6 +22,10 @@ """ The project name to use when logging traces and evals. defaults to 'default'. """ +ENV_PHOENIX_DATABASE_URL = "PHOENIX_DATABASE_URL" +""" +The database URL to use when logging traces and evals. +""" ENV_SPAN_STORAGE_TYPE = "__DANGEROUS__PHOENIX_SPAN_STORAGE_TYPE" """ **EXPERIMENTAL** @@ -152,6 +156,14 @@ def get_env_span_storage_type() -> Optional["SpanStorageType"]: ) +def get_env_database_url_str() -> str: + env_url = os.getenv(ENV_PHOENIX_DATABASE_URL) + if env_url is None: + working_dir = get_working_dir() + return f"sqlite:///{working_dir}/phoenix.db" + return env_url + + class SpanStorageType(Enum): TEXT_FILES = "text-files" diff --git a/src/phoenix/db/engines.py b/src/phoenix/db/engines.py index df90e67bc2..cb7113cd5a 100644 --- a/src/phoenix/db/engines.py +++ b/src/phoenix/db/engines.py @@ -14,6 +14,12 @@ from phoenix.db.models import init_models +# Enum for the the different sql drivers +class SQLDriver(Enum): + SQLITE = "sqlite" + POSTGRES = "postgres" + + def set_sqlite_pragma(connection: Connection, _: Any) -> None: cursor = connection.cursor() cursor.execute("PRAGMA foreign_keys = ON;") @@ -28,6 +34,10 @@ def get_db_url(driver: str = "sqlite+aiosqlite", database: Union[str, Path] = ": return URL.create(driver, database=str(database)) +def db_url_from_str(url_str: str) -> URL: + return URL.create(url_str) + + def aiosqlite_engine( database: Union[str, Path] = ":memory:", echo: bool = False, diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 119494a396..1fed9c95d3 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -26,7 +26,7 @@ from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset from phoenix.datasets.fixtures import FIXTURES, get_datasets -from phoenix.db.engines import aiosqlite_engine, get_db_url +from phoenix.db.engines import aiosqlite_engine from phoenix.pointcloud.umap_parameters import ( DEFAULT_MIN_DIST, DEFAULT_N_NEIGHBORS, @@ -268,7 +268,9 @@ def _send_spans(spans: Iterable[Span], url: str) -> None: start_prometheus() working_dir = get_working_dir().resolve() - sql_url = get_db_url(database=working_dir / "phoenix.db") + # db_url_str = get_env_database_url_str() + # Run postgres + db_url_str = "postgresql://localhost:5432/postgres" engine = aiosqlite_engine(working_dir / "phoenix.db") app = create_app( engine=engine, @@ -291,7 +293,7 @@ def _send_spans(spans: Iterable[Span], url: str) -> None: "version": phoenix_version, "host": host, "port": port, - "storage": sql_url, + "storage": db_url_str, } print(_WELCOME_MESSAGE.format(**config)) From 058b1fe3f9fdc451820331c87cd7f4c0646ca22c Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Mon, 8 Apr 2024 12:26:58 -0600 Subject: [PATCH 18/30] WIP --- src/phoenix/db/engines.py | 9 +++++++++ src/phoenix/server/main.py | 9 +++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/phoenix/db/engines.py b/src/phoenix/db/engines.py index cb7113cd5a..d20ba34804 100644 --- a/src/phoenix/db/engines.py +++ b/src/phoenix/db/engines.py @@ -38,6 +38,15 @@ def db_url_from_str(url_str: str) -> URL: return URL.create(url_str) +def create_engine(url_str: str, echo: bool = False) -> AsyncEngine: + url = db_url_from_str(url_str) + if "sqlite" in url.drivername: + return aiosqlite_engine(database=url.database, echo=echo) + if "postgres" in url.drivername: + return create_async_engine(url=url, echo=echo) + raise ValueError(f"Unsupported driver: {url.drivername}") + + def aiosqlite_engine( database: Union[str, Path] = ":memory:", echo: bool = False, diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 1fed9c95d3..911a01f74b 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -17,6 +17,7 @@ from phoenix.config import ( EXPORT_DIR, + get_env_database_url_str, get_env_host, get_env_port, get_pids_path, @@ -26,7 +27,7 @@ from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset from phoenix.datasets.fixtures import FIXTURES, get_datasets -from phoenix.db.engines import aiosqlite_engine +from phoenix.db.engines import create_engine from phoenix.pointcloud.umap_parameters import ( DEFAULT_MIN_DIST, DEFAULT_N_NEIGHBORS, @@ -268,10 +269,10 @@ def _send_spans(spans: Iterable[Span], url: str) -> None: start_prometheus() working_dir = get_working_dir().resolve() - # db_url_str = get_env_database_url_str() + db_url_str = get_env_database_url_str() # Run postgres - db_url_str = "postgresql://localhost:5432/postgres" - engine = aiosqlite_engine(working_dir / "phoenix.db") + # db_url_str = "postgresql://localhost:5432/postgres" + engine = create_engine(db_url_str) app = create_app( engine=engine, export_path=export_path, From 3b364736673f6e067b342fa9d1040c342d8627ed Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Mon, 8 Apr 2024 13:47:24 -0600 Subject: [PATCH 19/30] postgres support --- Dockerfile | 2 +- pyproject.toml | 5 ++++ src/phoenix/config.py | 6 ++--- src/phoenix/db/engines.py | 48 ++++++++++++++++++++++++---------- src/phoenix/server/main.py | 10 +++---- src/phoenix/session/session.py | 4 +-- 6 files changed, 50 insertions(+), 25 deletions(-) diff --git a/Dockerfile b/Dockerfile index 0caeedfd44..109eeaf887 100644 --- a/Dockerfile +++ b/Dockerfile @@ -36,7 +36,7 @@ COPY ./ /phoenix/ COPY --from=frontend-builder /phoenix/src/phoenix/server/static/ /phoenix/src/phoenix/server/static/ # Delete symbolic links used during development. RUN find src/ -xtype l -delete -RUN pip install --target ./env .[container] +RUN pip install --target ./env .[container, pg] # The production image is distroless, meaning that it is a minimal image that # contains only the necessary dependencies to run the application. This is diff --git a/pyproject.toml b/pyproject.toml index 6726f90194..e1bae67614 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,11 @@ llama-index = [ "llama-index-callbacks-arize-phoenix>=0.1.2", "openinference-instrumentation-llama-index>=1.2.0", ] + +pg = [ + "asyncpg", +] + container = [ "prometheus-client", ] diff --git a/src/phoenix/config.py b/src/phoenix/config.py index 3af2c3ffa7..b3275c1c80 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -22,7 +22,7 @@ """ The project name to use when logging traces and evals. defaults to 'default'. """ -ENV_PHOENIX_DATABASE_URL = "PHOENIX_DATABASE_URL" +ENV_PHOENIX_SQL_DATABASE = "PHOENIX_SQL_DATABASE" """ The database URL to use when logging traces and evals. """ @@ -156,8 +156,8 @@ def get_env_span_storage_type() -> Optional["SpanStorageType"]: ) -def get_env_database_url_str() -> str: - env_url = os.getenv(ENV_PHOENIX_DATABASE_URL) +def get_env_database_connection_str() -> str: + env_url = os.getenv(ENV_PHOENIX_SQL_DATABASE) if env_url is None: working_dir = get_working_dir() return f"sqlite:///{working_dir}/phoenix.db" diff --git a/src/phoenix/db/engines.py b/src/phoenix/db/engines.py index d20ba34804..303a9312c1 100644 --- a/src/phoenix/db/engines.py +++ b/src/phoenix/db/engines.py @@ -7,7 +7,7 @@ from typing import Any, Union import numpy as np -from sqlalchemy import URL, event +from sqlalchemy import URL, event, make_url from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from phoenix.db.migrate import migrate @@ -20,7 +20,7 @@ class SQLDriver(Enum): POSTGRES = "postgres" -def set_sqlite_pragma(connection: Connection, _: Any) -> None: +def set_pragma(connection: Connection, _: Any) -> None: cursor = connection.cursor() cursor.execute("PRAGMA foreign_keys = ON;") cursor.execute("PRAGMA journal_mode = WAL;") @@ -34,30 +34,50 @@ def get_db_url(driver: str = "sqlite+aiosqlite", database: Union[str, Path] = ": return URL.create(driver, database=str(database)) -def db_url_from_str(url_str: str) -> URL: - return URL.create(url_str) - - -def create_engine(url_str: str, echo: bool = False) -> AsyncEngine: - url = db_url_from_str(url_str) +def create_engine(connection_str: str, echo: bool = False) -> AsyncEngine: + """ + Factory to create a SQLAlchemy engine from a URL string. + """ + print("connection_str: " + connection_str) + url = make_url(connection_str) + if not url.database: + raise ValueError("Failed to parse database from connection string") if "sqlite" in url.drivername: - return aiosqlite_engine(database=url.database, echo=echo) - if "postgres" in url.drivername: - return create_async_engine(url=url, echo=echo) + # Split the URL to get the database name + database = url.database + + if not database: + raise ValueError("Database is required for SQLite") + print("Creating sqlite engine: " + database) + return aio_sqlite_engine(database=database, echo=echo) + if "postgresql" in url.drivername: + print("Creating postgres engine") + return aio_postgresql_engine(database=url.database, echo=echo) raise ValueError(f"Unsupported driver: {url.drivername}") -def aiosqlite_engine( +def aio_sqlite_engine( database: Union[str, Path] = ":memory:", echo: bool = False, ) -> AsyncEngine: url = get_db_url(driver="sqlite+aiosqlite", database=database) engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps) - event.listen(engine.sync_engine, "connect", set_sqlite_pragma) + event.listen(engine.sync_engine, "connect", set_pragma) if str(database) == ":memory:": asyncio.run(init_models(engine)) else: - migrate(url) + migrate(engine.url) + return engine + + +def aio_postgresql_engine( + database: Union[str, Path], + echo: bool = False, +) -> AsyncEngine: + url = get_db_url(driver="postgresql+asyncpg", database=database) + engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps) + event.listen(engine.sync_engine, "connect", set_pragma) + migrate(engine.url) return engine diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 911a01f74b..ba6af0f636 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -17,7 +17,7 @@ from phoenix.config import ( EXPORT_DIR, - get_env_database_url_str, + get_env_database_connection_str, get_env_host, get_env_port, get_pids_path, @@ -269,10 +269,10 @@ def _send_spans(spans: Iterable[Span], url: str) -> None: start_prometheus() working_dir = get_working_dir().resolve() - db_url_str = get_env_database_url_str() + db_connection_str = get_env_database_connection_str() # Run postgres - # db_url_str = "postgresql://localhost:5432/postgres" - engine = create_engine(db_url_str) + db_connection_str = "postgresql://localhost:5432/postgres" + engine = create_engine(db_connection_str) app = create_app( engine=engine, export_path=export_path, @@ -294,7 +294,7 @@ def _send_spans(spans: Iterable[Span], url: str) -> None: "version": phoenix_version, "host": host, "port": port, - "storage": db_url_str, + "storage": db_connection_str, } print(_WELCOME_MESSAGE.format(**config)) diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index c402cbac23..c137ba0f01 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -37,7 +37,7 @@ from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset -from phoenix.db.engines import aiosqlite_engine +from phoenix.db.engines import aio_sqlite_engine from phoenix.pointcloud.umap_parameters import get_umap_parameters from phoenix.server.app import create_app from phoenix.server.thread_server import ThreadServer @@ -310,7 +310,7 @@ def __init__( ).start() # Initialize an app service that keeps the server running self.app = create_app( - engine=aiosqlite_engine(), + engine=aio_sqlite_engine(), export_path=self.export_path, model=self.model, corpus=self.corpus, From c4b70898bb5412295770a4743834b790b704e302 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Mon, 8 Apr 2024 14:01:56 -0600 Subject: [PATCH 20/30] add pg support --- src/phoenix/config.py | 2 +- src/phoenix/db/alembic.ini | 6 +++--- src/phoenix/db/engines.py | 6 +++--- src/phoenix/server/main.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/phoenix/config.py b/src/phoenix/config.py index b3275c1c80..bd278d2a90 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -22,7 +22,7 @@ """ The project name to use when logging traces and evals. defaults to 'default'. """ -ENV_PHOENIX_SQL_DATABASE = "PHOENIX_SQL_DATABASE" +ENV_PHOENIX_SQL_DATABASE = "__DANGEROUS__PHOENIX_SQL_DATABASE" """ The database URL to use when logging traces and evals. """ diff --git a/src/phoenix/db/alembic.ini b/src/phoenix/db/alembic.ini index 402ff9e576..c8b09b0273 100644 --- a/src/phoenix/db/alembic.ini +++ b/src/phoenix/db/alembic.ini @@ -94,17 +94,17 @@ keys = console keys = generic [logger_root] -level = WARN +level = DEBUG handlers = console qualname = [logger_sqlalchemy] -level = WARN +level = DEBUG handlers = qualname = sqlalchemy.engine [logger_alembic] -level = WARN +level = DEBUG handlers = qualname = alembic diff --git a/src/phoenix/db/engines.py b/src/phoenix/db/engines.py index 303a9312c1..369843b68a 100644 --- a/src/phoenix/db/engines.py +++ b/src/phoenix/db/engines.py @@ -20,7 +20,7 @@ class SQLDriver(Enum): POSTGRES = "postgres" -def set_pragma(connection: Connection, _: Any) -> None: +def set_sqlite_pragma(connection: Connection, _: Any) -> None: cursor = connection.cursor() cursor.execute("PRAGMA foreign_keys = ON;") cursor.execute("PRAGMA journal_mode = WAL;") @@ -62,7 +62,7 @@ def aio_sqlite_engine( ) -> AsyncEngine: url = get_db_url(driver="sqlite+aiosqlite", database=database) engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps) - event.listen(engine.sync_engine, "connect", set_pragma) + event.listen(engine.sync_engine, "connect", set_sqlite_pragma) if str(database) == ":memory:": asyncio.run(init_models(engine)) else: @@ -76,7 +76,7 @@ def aio_postgresql_engine( ) -> AsyncEngine: url = get_db_url(driver="postgresql+asyncpg", database=database) engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps) - event.listen(engine.sync_engine, "connect", set_pragma) + # event.listen(engine.sync_engine, "connect", set_pragma) migrate(engine.url) return engine diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index ba6af0f636..859a9246ce 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -271,7 +271,7 @@ def _send_spans(spans: Iterable[Span], url: str) -> None: working_dir = get_working_dir().resolve() db_connection_str = get_env_database_connection_str() # Run postgres - db_connection_str = "postgresql://localhost:5432/postgres" + db_connection_str = "postgresql://localhost:5432/mikeldking" engine = create_engine(db_connection_str) app = create_app( engine=engine, From f977bf53c4a295f479f5fe61a6dd0b17e52d1133 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Mon, 8 Apr 2024 14:04:34 -0600 Subject: [PATCH 21/30] Cleanup --- src/phoenix/db/engines.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/phoenix/db/engines.py b/src/phoenix/db/engines.py index 369843b68a..4905218dce 100644 --- a/src/phoenix/db/engines.py +++ b/src/phoenix/db/engines.py @@ -14,12 +14,6 @@ from phoenix.db.models import init_models -# Enum for the the different sql drivers -class SQLDriver(Enum): - SQLITE = "sqlite" - POSTGRES = "postgres" - - def set_sqlite_pragma(connection: Connection, _: Any) -> None: cursor = connection.cursor() cursor.execute("PRAGMA foreign_keys = ON;") @@ -38,20 +32,13 @@ def create_engine(connection_str: str, echo: bool = False) -> AsyncEngine: """ Factory to create a SQLAlchemy engine from a URL string. """ - print("connection_str: " + connection_str) url = make_url(connection_str) if not url.database: raise ValueError("Failed to parse database from connection string") if "sqlite" in url.drivername: # Split the URL to get the database name - database = url.database - - if not database: - raise ValueError("Database is required for SQLite") - print("Creating sqlite engine: " + database) - return aio_sqlite_engine(database=database, echo=echo) + return aio_sqlite_engine(database=url.database, echo=echo) if "postgresql" in url.drivername: - print("Creating postgres engine") return aio_postgresql_engine(database=url.database, echo=echo) raise ValueError(f"Unsupported driver: {url.drivername}") From b8b01f8a7155124fff786cf362de9db23973ca04 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Mon, 8 Apr 2024 14:06:37 -0600 Subject: [PATCH 22/30] remove hardcoded db --- src/phoenix/server/main.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 859a9246ce..23f250c8bb 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -270,8 +270,6 @@ def _send_spans(spans: Iterable[Span], url: str) -> None: working_dir = get_working_dir().resolve() db_connection_str = get_env_database_connection_str() - # Run postgres - db_connection_str = "postgresql://localhost:5432/mikeldking" engine = create_engine(db_connection_str) app = create_app( engine=engine, From 46521efcdb95e06d650e297aca92113ab0bec118 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Mon, 8 Apr 2024 14:28:05 -0600 Subject: [PATCH 23/30] make url preserve password etc. --- src/phoenix/db/engines.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/phoenix/db/engines.py b/src/phoenix/db/engines.py index 4905218dce..d8354724d6 100644 --- a/src/phoenix/db/engines.py +++ b/src/phoenix/db/engines.py @@ -39,7 +39,7 @@ def create_engine(connection_str: str, echo: bool = False) -> AsyncEngine: # Split the URL to get the database name return aio_sqlite_engine(database=url.database, echo=echo) if "postgresql" in url.drivername: - return aio_postgresql_engine(database=url.database, echo=echo) + return aio_postgresql_engine(url=url, echo=echo) raise ValueError(f"Unsupported driver: {url.drivername}") @@ -58,11 +58,13 @@ def aio_sqlite_engine( def aio_postgresql_engine( - database: Union[str, Path], + url: URL, echo: bool = False, ) -> AsyncEngine: - url = get_db_url(driver="postgresql+asyncpg", database=database) - engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps) + # Swap out the engine + async_url = url.set(drivername="postgresql+asyncpg") + engine = create_async_engine(url=async_url, echo=echo, json_serializer=_dumps) + # TODO(persistence): figure out the postgres pragma # event.listen(engine.sync_engine, "connect", set_pragma) migrate(engine.url) return engine From b6d508d5aa286768e6fc87b58ed901b3c2f8222c Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Mon, 8 Apr 2024 15:04:59 -0600 Subject: [PATCH 24/30] fix: fix docker build for sql --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 109eeaf887..40f2b2fa31 100644 --- a/Dockerfile +++ b/Dockerfile @@ -36,7 +36,7 @@ COPY ./ /phoenix/ COPY --from=frontend-builder /phoenix/src/phoenix/server/static/ /phoenix/src/phoenix/server/static/ # Delete symbolic links used during development. RUN find src/ -xtype l -delete -RUN pip install --target ./env .[container, pg] +RUN pip install --target ./env ".[container, pg]" # The production image is distroless, meaning that it is a minimal image that # contains only the necessary dependencies to run the application. This is From 9ce841eb1c9d4f248cae482992ab67447ae53fee Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Mon, 8 Apr 2024 14:50:14 -0700 Subject: [PATCH 25/30] feat(persistence): bulk inserter for spans (#2808) --- src/phoenix/db/alembic.ini | 6 +- src/phoenix/db/bulk_inserter.py | 187 ++++++++++++++++++ src/phoenix/db/engines.py | 7 +- .../migrations/versions/cf03bd6bae1d_init.py | 17 +- src/phoenix/db/models.py | 52 ++++- .../server/api/routers/trace_handler.py | 118 +---------- src/phoenix/server/app.py | 37 +++- src/phoenix/server/main.py | 34 +--- src/phoenix/session/session.py | 1 + 9 files changed, 293 insertions(+), 166 deletions(-) create mode 100644 src/phoenix/db/bulk_inserter.py diff --git a/src/phoenix/db/alembic.ini b/src/phoenix/db/alembic.ini index c8b09b0273..402ff9e576 100644 --- a/src/phoenix/db/alembic.ini +++ b/src/phoenix/db/alembic.ini @@ -94,17 +94,17 @@ keys = console keys = generic [logger_root] -level = DEBUG +level = WARN handlers = console qualname = [logger_sqlalchemy] -level = DEBUG +level = WARN handlers = qualname = sqlalchemy.engine [logger_alembic] -level = DEBUG +level = WARN handlers = qualname = alembic diff --git a/src/phoenix/db/bulk_inserter.py b/src/phoenix/db/bulk_inserter.py new file mode 100644 index 0000000000..b863949d52 --- /dev/null +++ b/src/phoenix/db/bulk_inserter.py @@ -0,0 +1,187 @@ +import asyncio +import logging +from itertools import islice +from time import time +from typing import Any, AsyncContextManager, Callable, Iterable, List, Optional, Tuple, cast + +from openinference.semconv.trace import SpanAttributes +from sqlalchemy import func, insert, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from phoenix.db import models +from phoenix.trace.schemas import Span, SpanStatusCode + +logger = logging.getLogger(__name__) + + +class BulkInserter: + def __init__( + self, + db: Callable[[], AsyncContextManager[AsyncSession]], + initial_batch_of_spans: Optional[Iterable[Tuple[Span, str]]] = None, + run_interval_in_seconds: float = 0.5, + max_num_per_transaction: int = 100, + ) -> None: + """ + :param db: A function to initiate a new database session. + :param initial_batch_of_spans: Initial batch of spans to insert. + :param run_interval_in_seconds: The time interval between the starts of each + bulk insert. If there's nothing to insert, the inserter goes back to sleep. + :param max_num_per_transaction: The maximum number of items to insert in a single + transaction. Multiple transactions will be used if there are more items in the batch. + """ + self._db = db + self._running = False + self._run_interval_seconds = run_interval_in_seconds + self._max_num_per_transaction = max_num_per_transaction + self._spans: List[Tuple[Span, str]] = ( + [] if initial_batch_of_spans is None else list(initial_batch_of_spans) + ) + self._task: Optional[asyncio.Task[None]] = None + + async def __aenter__(self) -> Callable[[Span, str], None]: + self._running = True + self._task = asyncio.create_task(self._bulk_insert()) + return self._queue_span + + async def __aexit__(self, *args: Any) -> None: + self._running = False + + def _queue_span(self, span: Span, project_name: str) -> None: + self._spans.append((span, project_name)) + + async def _bulk_insert(self) -> None: + next_run_at = time() + self._run_interval_seconds + while self._spans or self._running: + await asyncio.sleep(next_run_at - time()) + next_run_at = time() + self._run_interval_seconds + if self._spans: + await self._insert_spans() + + async def _insert_spans(self) -> None: + spans = self._spans + self._spans = [] + for i in range(0, len(spans), self._max_num_per_transaction): + try: + async with self._db() as session: + for span, project_name in islice(spans, i, i + self._max_num_per_transaction): + try: + async with session.begin_nested(): + await _insert_span(session, span, project_name) + except Exception: + logger.exception( + f"Failed to insert span with span_id={span.context.span_id}" + ) + except Exception: + logger.exception("Failed to insert spans") + + +async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> None: + if await session.scalar(select(1).where(models.Span.span_id == span.context.span_id)): + # Span already exists + return + if not ( + project_rowid := await session.scalar( + select(models.Project.id).where(models.Project.name == project_name) + ) + ): + project_rowid = await session.scalar( + insert(models.Project).values(name=project_name).returning(models.Project.id) + ) + if trace := await session.scalar( + select(models.Trace).where(models.Trace.trace_id == span.context.trace_id) + ): + trace_rowid = trace.id + # TODO(persistence): Figure out how to reliably retrieve timezone-aware + # datetime from the (sqlite) database, because all datetime in our + # programs should be timezone-aware. + if span.start_time < trace.start_time or trace.end_time < span.end_time: + trace.start_time = min(trace.start_time, span.start_time) + trace.end_time = max(trace.end_time, span.end_time) + await session.execute( + update(models.Trace) + .where(models.Trace.id == trace_rowid) + .values( + start_time=min(trace.start_time, span.start_time), + end_time=max(trace.end_time, span.end_time), + ) + ) + else: + trace_rowid = cast( + int, + await session.scalar( + insert(models.Trace) + .values( + project_rowid=project_rowid, + trace_id=span.context.trace_id, + start_time=span.start_time, + end_time=span.end_time, + ) + .returning(models.Trace.id) + ), + ) + cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) + cumulative_llm_token_count_prompt = cast( + int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0) + ) + cumulative_llm_token_count_completion = cast( + int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0) + ) + if accumulation := ( + await session.execute( + select( + func.sum(models.Span.cumulative_error_count), + func.sum(models.Span.cumulative_llm_token_count_prompt), + func.sum(models.Span.cumulative_llm_token_count_completion), + ).where(models.Span.parent_span_id == span.context.span_id) + ) + ).first(): + cumulative_error_count += cast(int, accumulation[0] or 0) + cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0) + cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0) + latency_ms = (span.end_time - span.start_time).total_seconds() * 1000 + session.add( + models.Span( + span_id=span.context.span_id, + trace_rowid=trace_rowid, + parent_span_id=span.parent_id, + kind=span.span_kind.value, + name=span.name, + start_time=span.start_time, + end_time=span.end_time, + attributes=span.attributes, + events=span.events, + status=span.status_code.value, + status_message=span.status_message, + latency_ms=latency_ms, + cumulative_error_count=cumulative_error_count, + cumulative_llm_token_count_prompt=cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion=cumulative_llm_token_count_completion, + ) + ) + # Propagate cumulative values to ancestors. This is usually a no-op, since + # the parent usually arrives after the child. But in the event that a + # child arrives after its parent, we need to make sure the all the + # ancestors' cumulative values are updated. + ancestors = ( + select(models.Span.id, models.Span.parent_span_id) + .where(models.Span.span_id == span.parent_id) + .cte(recursive=True) + ) + child = ancestors.alias() + ancestors = ancestors.union_all( + select(models.Span.id, models.Span.parent_span_id).join( + child, models.Span.span_id == child.c.parent_span_id + ) + ) + await session.execute( + update(models.Span) + .where(models.Span.id.in_(select(ancestors.c.id))) + .values( + cumulative_error_count=models.Span.cumulative_error_count + cumulative_error_count, + cumulative_llm_token_count_prompt=models.Span.cumulative_llm_token_count_prompt + + cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion=models.Span.cumulative_llm_token_count_completion + + cumulative_llm_token_count_completion, + ) + ) diff --git a/src/phoenix/db/engines.py b/src/phoenix/db/engines.py index d8354724d6..2006de3dd6 100644 --- a/src/phoenix/db/engines.py +++ b/src/phoenix/db/engines.py @@ -51,7 +51,12 @@ def aio_sqlite_engine( engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps) event.listen(engine.sync_engine, "connect", set_sqlite_pragma) if str(database) == ":memory:": - asyncio.run(init_models(engine)) + try: + asyncio.get_running_loop() + except RuntimeError: + asyncio.run(init_models(engine)) + else: + asyncio.create_task(init_models(engine)) else: migrate(engine.url) return engine diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py index e0a997531c..de4a4fc917 100644 --- a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py +++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py @@ -25,10 +25,15 @@ def upgrade() -> None: # TODO does the uniqueness constraint need to be named sa.Column("name", sa.String, nullable=False, unique=True), sa.Column("description", sa.String, nullable=True), - sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), sa.Column( "updated_at", - sa.DateTime(), + sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now(), onupdate=sa.func.now(), @@ -41,8 +46,8 @@ def upgrade() -> None: # TODO(mikeldking): might not be the right place for this sa.Column("session_id", sa.String, nullable=True), sa.Column("trace_id", sa.String, nullable=False, unique=True), - sa.Column("start_time", sa.DateTime(), nullable=False, index=True), - sa.Column("end_time", sa.DateTime(), nullable=False), + sa.Column("start_time", sa.DateTime(timezone=True), nullable=False, index=True), + sa.Column("end_time", sa.DateTime(timezone=True), nullable=False), ) op.create_table( @@ -53,8 +58,8 @@ def upgrade() -> None: sa.Column("parent_span_id", sa.String, nullable=True, index=True), sa.Column("name", sa.String, nullable=False), sa.Column("kind", sa.String, nullable=False), - sa.Column("start_time", sa.DateTime(), nullable=False), - sa.Column("end_time", sa.DateTime(), nullable=False), + sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), + sa.Column("end_time", sa.DateTime(timezone=True), nullable=False), sa.Column("attributes", sa.JSON, nullable=False), sa.Column("events", sa.JSON, nullable=False), sa.Column( diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 5f10f21931..1ba56ef6f2 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -1,12 +1,14 @@ -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict, List, Optional from sqlalchemy import ( JSON, CheckConstraint, DateTime, + Dialect, ForeignKey, MetaData, + TypeDecorator, UniqueConstraint, func, insert, @@ -21,6 +23,42 @@ ) +class UtcTimeStamp(TypeDecorator[datetime]): + """TODO(persistence): Figure out how to reliably store and retrieve + timezone-aware datetime objects from the (sqlite) database. Below is a + workaround to guarantee that the timestamps we fetch from the database is + always timezone-aware, in order to prevent comparisons of timezone-naive + datetime with timezone-aware datetime, because objects in the rest of our + programs are always timezone-aware. + """ + + cache_ok = True + impl = DateTime + _LOCAL_TIMEZONE = datetime.now(timezone.utc).astimezone().tzinfo + + def process_bind_param( + self, + value: Optional[datetime], + dialect: Dialect, + ) -> Optional[datetime]: + if not value: + return None + if value.tzinfo is None: + value = value.astimezone(self._LOCAL_TIMEZONE) + return value.astimezone(timezone.utc) + + def process_result_value( + self, + value: Optional[Any], + dialect: Dialect, + ) -> Optional[datetime]: + if not isinstance(value, datetime): + return None + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + class Base(DeclarativeBase): # Enforce best practices for naming constraints # https://alembic.sqlalchemy.org/en/latest/naming.html#integration-of-naming-conventions-into-operations-autogenerate @@ -44,9 +82,9 @@ class Project(Base): id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] description: Mapped[Optional[str]] - updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now()) created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + UtcTimeStamp, server_default=func.now(), onupdate=func.now() ) traces: WriteOnlyMapped["Trace"] = relationship( @@ -69,8 +107,8 @@ class Trace(Base): project_rowid: Mapped[int] = mapped_column(ForeignKey("projects.id")) session_id: Mapped[Optional[str]] trace_id: Mapped[str] - start_time: Mapped[datetime] = mapped_column(DateTime(), index=True) - end_time: Mapped[datetime] = mapped_column(DateTime()) + start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True) + end_time: Mapped[datetime] = mapped_column(UtcTimeStamp) project: Mapped["Project"] = relationship( "Project", @@ -98,8 +136,8 @@ class Span(Base): parent_span_id: Mapped[Optional[str]] = mapped_column(index=True) name: Mapped[str] kind: Mapped[str] - start_time: Mapped[datetime] = mapped_column(DateTime()) - end_time: Mapped[datetime] = mapped_column(DateTime()) + start_time: Mapped[datetime] = mapped_column(UtcTimeStamp) + end_time: Mapped[datetime] = mapped_column(UtcTimeStamp) attributes: Mapped[Dict[str, Any]] events: Mapped[List[Dict[str, Any]]] status: Mapped[str] = mapped_column( diff --git a/src/phoenix/server/api/routers/trace_handler.py b/src/phoenix/server/api/routers/trace_handler.py index 507bb0df40..c50efcae96 100644 --- a/src/phoenix/server/api/routers/trace_handler.py +++ b/src/phoenix/server/api/routers/trace_handler.py @@ -1,31 +1,25 @@ import asyncio import gzip import zlib -from typing import AsyncContextManager, Callable, Optional, cast +from typing import Optional from google.protobuf.message import DecodeError -from openinference.semconv.trace import SpanAttributes from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( ExportTraceServiceRequest, ) from opentelemetry.proto.trace.v1.trace_pb2 import TracesData -from sqlalchemy import func, insert, select, text, update -from sqlalchemy.ext.asyncio import AsyncSession from starlette.endpoints import HTTPEndpoint from starlette.requests import Request from starlette.responses import Response from starlette.status import HTTP_415_UNSUPPORTED_MEDIA_TYPE, HTTP_422_UNPROCESSABLE_ENTITY from phoenix.core.traces import Traces -from phoenix.db import models from phoenix.storage.span_store import SpanStore from phoenix.trace.otel import decode -from phoenix.trace.schemas import Span, SpanStatusCode from phoenix.utilities.project import get_project_name class TraceHandler(HTTPEndpoint): - db: Callable[[], AsyncContextManager[AsyncSession]] traces: Traces store: Optional[SpanStore] @@ -62,109 +56,13 @@ async def post(self, request: Request) -> Response: for scope_span in resource_spans.scope_spans: for otlp_span in scope_span.spans: span = decode(otlp_span) - async with self.db() as session: - await _insert_span(session, span, project_name) + # TODO(persistence): Decide which one is better: delayed + # bulk-insert or insert each request immediately, i.e. one + # transaction per request. The bulk-insert is more efficient, + # but it queues data in volatile (buffer) memory (for a short + # period of time), so the 200 response is not a genuine + # confirmation of data persistence. + request.state.queue_span_for_bulk_insert(span, project_name) self.traces.put(span, project_name=project_name) await asyncio.sleep(0) return Response() - - -async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> None: - if not ( - project_rowid := await session.scalar( - select(models.Project.id).filter(models.Project.name == project_name) - ) - ): - project_rowid = await session.scalar( - insert(models.Project).values(name=project_name).returning(models.Project.id) - ) - if ( - trace_rowid := await session.scalar( - text( - """ - INSERT INTO traces(trace_id, project_rowid, session_id, start_time, end_time) - VALUES(:trace_id, :project_rowid, :session_id, :start_time, :end_time) - ON CONFLICT DO UPDATE SET - start_time = CASE WHEN excluded.start_time < start_time THEN excluded.start_time ELSE start_time END, - end_time = CASE WHEN end_time < excluded.end_time THEN excluded.end_time ELSE end_time END - WHERE excluded.start_time < start_time OR end_time < excluded.end_time - RETURNING rowid; - """ # noqa E501 - ), - { - "trace_id": span.context.trace_id, - "project_rowid": project_rowid, - "session_id": None, - "start_time": span.start_time, - "end_time": span.end_time, - }, - ) - ) is None: - trace_rowid = await session.scalar( - select(models.Trace.id).filter(models.Trace.trace_id == span.context.trace_id) - ) - cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) - cumulative_llm_token_count_prompt = cast( - int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0) - ) - cumulative_llm_token_count_completion = cast( - int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0) - ) - if accumulation := ( - await session.execute( - select( - func.sum(models.Span.cumulative_error_count), - func.sum(models.Span.cumulative_llm_token_count_prompt), - func.sum(models.Span.cumulative_llm_token_count_completion), - ).where(models.Span.parent_span_id == span.context.span_id) - ) - ).first(): - cumulative_error_count += cast(int, accumulation[0] or 0) - cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0) - cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0) - latency_ms = (span.end_time - span.start_time).total_seconds() * 1000 - session.add( - models.Span( - span_id=span.context.span_id, - trace_rowid=trace_rowid, - parent_span_id=span.parent_id, - kind=span.span_kind.value, - name=span.name, - start_time=span.start_time, - end_time=span.end_time, - attributes=span.attributes, - events=span.events, - status=span.status_code.value, - status_message=span.status_message, - latency_ms=latency_ms, - cumulative_error_count=cumulative_error_count, - cumulative_llm_token_count_prompt=cumulative_llm_token_count_prompt, - cumulative_llm_token_count_completion=cumulative_llm_token_count_completion, - ) - ) - # Propagate cumulative values to ancestors. This is usually a no-op, since - # the parent usually arrives after the child. But in the event that a - # child arrives after its parent, we need to make sure the all the - # ancestors' cumulative values are updated. - ancestors = ( - select(models.Span.id, models.Span.parent_span_id) - .where(models.Span.span_id == span.parent_id) - .cte(recursive=True) - ) - child = ancestors.alias() - ancestors = ancestors.union_all( - select(models.Span.id, models.Span.parent_span_id).join( - child, models.Span.span_id == child.c.parent_span_id - ) - ) - await session.execute( - update(models.Span) - .where(models.Span.id.in_(select(ancestors.c.id))) - .values( - cumulative_error_count=models.Span.cumulative_error_count + cumulative_error_count, - cumulative_llm_token_count_prompt=models.Span.cumulative_llm_token_count_prompt - + cumulative_llm_token_count_prompt, - cumulative_llm_token_count_completion=models.Span.cumulative_llm_token_count_completion - + cumulative_llm_token_count_completion, - ) - ) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 04e9c17421..e14f0e538d 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -6,8 +6,11 @@ AsyncContextManager, AsyncIterator, Callable, + Dict, + Iterable, NamedTuple, Optional, + Tuple, Union, ) @@ -27,15 +30,16 @@ from starlette.routing import Mount, Route from starlette.staticfiles import StaticFiles from starlette.templating import Jinja2Templates -from starlette.types import Scope +from starlette.types import Scope, StatefulLifespan from starlette.websockets import WebSocket from strawberry.asgi import GraphQL from strawberry.schema import BaseSchema import phoenix -from phoenix.config import SERVER_DIR +from phoenix.config import DEFAULT_PROJECT_NAME, SERVER_DIR from phoenix.core.model_schema import Model from phoenix.core.traces import Traces +from phoenix.db.bulk_inserter import BulkInserter from phoenix.pointcloud.umap_parameters import UMAPParameters from phoenix.server.api.context import Context from phoenix.server.api.routers.evaluation_handler import EvaluationHandler @@ -43,6 +47,7 @@ from phoenix.server.api.routers.trace_handler import TraceHandler from phoenix.server.api.schema import schema from phoenix.storage.span_store import SpanStore +from phoenix.trace.schemas import Span logger = logging.getLogger(__name__) @@ -170,6 +175,18 @@ async def factory() -> AsyncIterator[AsyncSession]: return factory +def _lifespan( + db: Callable[[], AsyncContextManager[AsyncSession]], + initial_batch_of_spans: Optional[Iterable[Tuple[Span, str]]] = None, +) -> StatefulLifespan[Starlette]: + @contextlib.asynccontextmanager + async def lifespan(_: Starlette) -> AsyncIterator[Dict[str, Any]]: + async with BulkInserter(db, initial_batch_of_spans) as queue_span: + yield {"queue_span_for_bulk_insert": queue_span} + + return lifespan + + def create_app( engine: AsyncEngine, export_path: Path, @@ -181,7 +198,16 @@ def create_app( debug: bool = False, read_only: bool = False, enable_prometheus: bool = False, + initial_spans: Optional[Iterable[Union[Span, Tuple[Span, str]]]] = None, ) -> Starlette: + initial_batch_of_spans: Iterable[Tuple[Span, str]] = ( + () + if initial_spans is None + else ( + ((item, DEFAULT_PROJECT_NAME) if isinstance(item, Span) else item) + for item in initial_spans + ) + ) db = _db(engine) graphql = GraphQLWithContext( db=db, @@ -199,6 +225,7 @@ def create_app( else: prometheus_middlewares = [] return Starlette( + lifespan=_lifespan(db, initial_batch_of_spans), middleware=[ Middleware(HeadersMiddleware), *prometheus_middlewares, @@ -214,11 +241,7 @@ def create_app( ), Route( "/v1/traces", - type( - "TraceEndpoint", - (TraceHandler,), - {"db": staticmethod(db), "traces": traces, "store": span_store}, - ), + type("TraceEndpoint", (TraceHandler,), {"traces": traces, "store": span_store}), ), Route( "/v1/evaluations", diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 23f250c8bb..bc9674a015 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -1,5 +1,4 @@ import atexit -import gzip import logging import os from argparse import ArgumentParser @@ -10,9 +9,6 @@ from typing import Iterable, Optional, Protocol, TypeVar import pkg_resources -import requests -from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest -from opentelemetry.proto.trace.v1.trace_pb2 import ResourceSpans, ScopeSpans from uvicorn import Config, Server from phoenix.config import ( @@ -43,7 +39,6 @@ get_evals_from_fixture, ) from phoenix.trace.otel import decode, encode -from phoenix.trace.schemas import Span from phoenix.trace.span_json_decoder import json_string_to_span from phoenix.utilities.span_store import get_span_store, load_traces_data_from_store @@ -114,28 +109,6 @@ def _load_items( queue.put(item) -def _send_spans(spans: Iterable[Span], url: str) -> None: - # TODO(persistence): Ingest fixtures without networking for read-only deployments - sleep(5) # Wait for the server to start - session = requests.session() - for span in spans: - req = ExportTraceServiceRequest( - resource_spans=[ResourceSpans(scope_spans=[ScopeSpans(spans=[encode(span)])])] - ) - session.post( - url=url, - headers={ - "content-type": "application/x-protobuf", - "content-encoding": "gzip", - }, - data=gzip.compress(req.SerializeToString()), - ) - # TODO(persistence): If ingestion rate is too high it can crash the UI, because - # sqlite is not designed for high concurrency, especially for disk - # persistence. - sleep(0.2) - - DEFAULT_UMAP_PARAMS_STR = f"{DEFAULT_MIN_DIST},{DEFAULT_N_NEIGHBORS},{DEFAULT_N_SAMPLES}" if __name__ == "__main__": @@ -230,6 +203,7 @@ def _send_spans(spans: Iterable[Span], url: str) -> None: traces = Traces() if span_store := get_span_store(): Thread(target=load_traces_data_from_store, args=(traces, span_store), daemon=True).start() + fixture_spans = [] if trace_dataset_name is not None: fixture_spans = list( # Apply `encode` here because legacy jsonl files contains UUIDs as strings. @@ -244,11 +218,6 @@ def _send_spans(spans: Iterable[Span], url: str) -> None: args=(traces, fixture_spans, simulate_streaming), daemon=True, ).start() - Thread( - target=_send_spans, - args=(fixture_spans, f"http://{host}:{port}/v1/traces"), - daemon=True, - ).start() fixture_evals = list(get_evals_from_fixture(trace_dataset_name)) Thread( target=_load_items, @@ -282,6 +251,7 @@ def _send_spans(spans: Iterable[Span], url: str) -> None: read_only=read_only, span_store=span_store, enable_prometheus=enable_prometheus, + initial_spans=fixture_spans, ) server = Server(config=Config(app, host=host, port=port)) Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start() diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index c137ba0f01..cfd8885406 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -317,6 +317,7 @@ def __init__( traces=self.traces, umap_params=self.umap_parameters, span_store=span_store, + initial_spans=trace_dataset.to_spans() if trace_dataset else None, ) self.server = ThreadServer( app=self.app, From 7f907f006eefde2f3c3429ea4dd13f5106814601 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Mon, 8 Apr 2024 16:33:49 -0600 Subject: [PATCH 26/30] add a reset --- src/phoenix/__init__.py | 10 +++++++++- src/phoenix/session/session.py | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/phoenix/__init__.py b/src/phoenix/__init__.py index 58319bf2f2..14ca8627a4 100644 --- a/src/phoenix/__init__.py +++ b/src/phoenix/__init__.py @@ -9,7 +9,14 @@ from .datasets.schema import EmbeddingColumnNames, RetrievalEmbeddingColumnNames, Schema from .session.client import Client from .session.evaluation import log_evaluations -from .session.session import NotebookEnvironment, Session, active_session, close_app, launch_app +from .session.session import ( + NotebookEnvironment, + Session, + active_session, + close_app, + launch_app, + reset, +) from .trace.fixtures import load_example_traces from .trace.trace_dataset import TraceDataset from .version import __version__ @@ -41,6 +48,7 @@ "active_session", "close_app", "launch_app", + "reset", "Session", "load_example_traces", "TraceDataset", diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index cfd8885406..8f97289be0 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -1,6 +1,7 @@ import json import logging import os +import shutil import warnings from abc import ABC, abstractmethod from collections import UserList @@ -33,6 +34,7 @@ get_env_port, get_env_project_name, get_exported_files, + get_working_dir, ) from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces @@ -423,6 +425,24 @@ def get_evaluations( return project.export_evaluations() +def reset(hard: Optional[bool] = False) -> None: + """ + Resets everything to the initial state. + """ + global _session + if _session is not None: + if not hard: + input("Active session detected. Press Enter to close the session") + close_app() + working_dir = get_working_dir() + + # See if the working directory exists + if working_dir.exists(): + if not hard: + input(f"Working directory exists at {working_dir}. Press Enter to delete the directory") + shutil.rmtree(working_dir) + + def launch_app( primary: Optional[Dataset] = None, reference: Optional[Dataset] = None, From 9d53d1261fffb680386f17565e3afe6b6613192f Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Mon, 8 Apr 2024 16:55:07 -0600 Subject: [PATCH 27/30] mount sqlite --- src/phoenix/session/session.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index 8f97289be0..078c8eb3a7 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -30,6 +30,7 @@ ENV_PHOENIX_COLLECTOR_ENDPOINT, ENV_PHOENIX_HOST, ENV_PHOENIX_PORT, + get_env_database_connection_str, get_env_host, get_env_port, get_env_project_name, @@ -39,7 +40,7 @@ from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset -from phoenix.db.engines import aio_sqlite_engine +from phoenix.db.engines import create_engine from phoenix.pointcloud.umap_parameters import get_umap_parameters from phoenix.server.app import create_app from phoenix.server.thread_server import ThreadServer @@ -266,6 +267,7 @@ def get_evaluations( class ThreadSession(Session): def __init__( self, + database: str, primary_dataset: Dataset, reference_dataset: Optional[Dataset] = None, corpus_dataset: Optional[Dataset] = None, @@ -312,7 +314,7 @@ def __init__( ).start() # Initialize an app service that keeps the server running self.app = create_app( - engine=aio_sqlite_engine(), + engine=create_engine(database), export_path=self.export_path, model=self.model, corpus=self.corpus, @@ -553,9 +555,11 @@ def launch_app( host = host or get_env_host() port = port or get_env_port() + database = get_env_database_connection_str() if run_in_thread: _session = ThreadSession( + database, primary, reference, corpus, @@ -588,7 +592,7 @@ def launch_app( return None print(f"🌍 To view the Phoenix app in your browser, visit {_session.url}") - print("📺 To view the Phoenix app in a notebook, run `px.active_session().view()`") + print(f"💽 Your data is being persisted to {database}") print("📖 For more information on how to use Phoenix, check out https://docs.arize.com/phoenix") return _session From d0cfc8eb3a1b13dd48d5c20b4c254eb117b35895 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Mon, 8 Apr 2024 17:18:51 -0600 Subject: [PATCH 28/30] WIP --- cspell.json | 1 + src/phoenix/server/app.py | 4 +++- src/phoenix/server/main.py | 4 +--- src/phoenix/session/session.py | 3 +-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cspell.json b/cspell.json index 67bedffd61..22d6f408bf 100644 --- a/cspell.json +++ b/cspell.json @@ -38,6 +38,7 @@ "rgba", "seafoam", "sqlalchemy", + "Starlette", "templating", "tensorboard", "testset", diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index e14f0e538d..545fab15b7 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -40,6 +40,7 @@ from phoenix.core.model_schema import Model from phoenix.core.traces import Traces from phoenix.db.bulk_inserter import BulkInserter +from phoenix.db.engines import create_engine from phoenix.pointcloud.umap_parameters import UMAPParameters from phoenix.server.api.context import Context from phoenix.server.api.routers.evaluation_handler import EvaluationHandler @@ -188,7 +189,7 @@ async def lifespan(_: Starlette) -> AsyncIterator[Dict[str, Any]]: def create_app( - engine: AsyncEngine, + database: str, export_path: Path, model: Model, umap_params: UMAPParameters, @@ -208,6 +209,7 @@ def create_app( for item in initial_spans ) ) + engine = create_engine(database) db = _db(engine) graphql = GraphQLWithContext( db=db, diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index bc9674a015..5b4e5d964e 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -23,7 +23,6 @@ from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset from phoenix.datasets.fixtures import FIXTURES, get_datasets -from phoenix.db.engines import create_engine from phoenix.pointcloud.umap_parameters import ( DEFAULT_MIN_DIST, DEFAULT_N_NEIGHBORS, @@ -239,9 +238,8 @@ def _load_items( working_dir = get_working_dir().resolve() db_connection_str = get_env_database_connection_str() - engine = create_engine(db_connection_str) app = create_app( - engine=engine, + database=db_connection_str, export_path=export_path, model=model, umap_params=umap_params, diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index 078c8eb3a7..f1ce06604f 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -40,7 +40,6 @@ from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset -from phoenix.db.engines import create_engine from phoenix.pointcloud.umap_parameters import get_umap_parameters from phoenix.server.app import create_app from phoenix.server.thread_server import ThreadServer @@ -314,7 +313,7 @@ def __init__( ).start() # Initialize an app service that keeps the server running self.app = create_app( - engine=create_engine(database), + database=database, export_path=self.export_path, model=self.model, corpus=self.corpus, From 6a8d40ac9f9a1088cc5d430a2173e7da1d446954 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Tue, 9 Apr 2024 11:19:31 -0600 Subject: [PATCH 29/30] run async as task --- src/phoenix/db/migrations/env.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/phoenix/db/migrations/env.py b/src/phoenix/db/migrations/env.py index 2cc51224e6..601b7873a2 100644 --- a/src/phoenix/db/migrations/env.py +++ b/src/phoenix/db/migrations/env.py @@ -69,7 +69,12 @@ def run_migrations_online() -> None: ) if isinstance(connectable, AsyncEngine): - asyncio.run(run_async_migrations(connectable)) + try: + asyncio.get_running_loop() + except RuntimeError: + asyncio.run(run_async_migrations(connectable)) + else: + asyncio.create_task(run_async_migrations(connectable)) else: run_migrations(connectable) From 1c60e75f36a5fd5b54239501ac43ea7428f22b96 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Tue, 9 Apr 2024 12:14:44 -0600 Subject: [PATCH 30/30] feat: use non-memory sqlite in the notebook --- src/phoenix/__init__.py | 4 ++-- src/phoenix/session/session.py | 17 ++++++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/phoenix/__init__.py b/src/phoenix/__init__.py index 14ca8627a4..2dece2695c 100644 --- a/src/phoenix/__init__.py +++ b/src/phoenix/__init__.py @@ -15,7 +15,7 @@ active_session, close_app, launch_app, - reset, + reset_all, ) from .trace.fixtures import load_example_traces from .trace.trace_dataset import TraceDataset @@ -48,7 +48,7 @@ "active_session", "close_app", "launch_app", - "reset", + "reset_all", "Session", "load_example_traces", "TraceDataset", diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index f1ce06604f..19239609de 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -426,15 +426,10 @@ def get_evaluations( return project.export_evaluations() -def reset(hard: Optional[bool] = False) -> None: +def reset_all(hard: Optional[bool] = False) -> None: """ Resets everything to the initial state. """ - global _session - if _session is not None: - if not hard: - input("Active session detected. Press Enter to close the session") - close_app() working_dir = get_working_dir() # See if the working directory exists @@ -605,10 +600,15 @@ def active_session() -> Optional[Session]: return None -def close_app() -> None: +def close_app(reset: bool = False) -> None: """ Closes the phoenix application. The application server is shut down and will no longer be accessible. + + Parameters + ---------- + reset : bool, optional + Whether to reset the working directory. Defaults to False. """ global _session if _session is None: @@ -617,6 +617,9 @@ def close_app() -> None: _session.end() _session = None logger.info("Session closed") + if reset: + logger.info("Resetting working directory") + reset_all(hard=True) def _get_url(host: str, port: int, notebook_env: NotebookEnvironment) -> str: