From fad00090fec91d580fef1e807297f7dd3725f8ae Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Tue, 2 Apr 2024 15:11:36 -0700 Subject: [PATCH] refactor: sqlite3 proof of concept (#2740) * wip --------- Co-authored-by: Mikyo King --- app/schema.graphql | 2 +- src/phoenix/core/traces.py | 76 ++++- src/phoenix/db/__init__.py | 0 src/phoenix/db/database.py | 354 +++++++++++++++++++++++ src/phoenix/server/api/types/MimeType.py | 4 +- src/phoenix/server/api/types/Project.py | 35 ++- src/phoenix/server/api/types/Trace.py | 10 +- src/phoenix/server/main.py | 22 +- src/phoenix/session/session.py | 40 ++- src/phoenix/trace/otel.py | 9 +- src/phoenix/trace/schemas.py | 12 +- 11 files changed, 510 insertions(+), 54 deletions(-) create mode 100644 src/phoenix/db/__init__.py create mode 100644 src/phoenix/db/database.py diff --git a/app/schema.graphql b/app/schema.graphql index 0236fc879f..e5ff7dbcae 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -543,7 +543,7 @@ type Project implements Node { endTime: DateTime recordCount(timeRange: TimeRange): Int! traceCount(timeRange: TimeRange): Int! - tokenCountTotal: Int! + tokenCountTotal(timeRange: TimeRange): Int! latencyMsP50: Float latencyMsP99: Float trace(traceId: ID!): Trace diff --git a/src/phoenix/core/traces.py b/src/phoenix/core/traces.py index 0681e5479e..6fff53e469 100644 --- a/src/phoenix/core/traces.py +++ b/src/phoenix/core/traces.py @@ -1,9 +1,10 @@ import weakref from collections import defaultdict +from datetime import datetime from queue import SimpleQueue from threading import RLock, Thread from types import MethodType -from typing import DefaultDict, Iterator, Optional, Tuple, Union +from typing import DefaultDict, Iterator, Optional, Protocol, Tuple, Union from typing_extensions import assert_never @@ -12,16 +13,45 @@ from phoenix.core.project import ( END_OF_QUEUE, Project, + WrappedSpan, _ProjectName, ) -from phoenix.trace.schemas import Span +from phoenix.trace.schemas import ComputedAttributes, ComputedValues, Span, TraceID _SpanItem = Tuple[Span, _ProjectName] _EvalItem = Tuple[pb.Evaluation, _ProjectName] +class Database(Protocol): + def insert_span(self, span: Span, project_name: str) -> None: ... + + def trace_count( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: ... + + def span_count( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: ... + + def llm_token_count_total( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: ... + + def get_trace(self, trace_id: TraceID) -> Iterator[Tuple[Span, ComputedValues]]: ... + + class Traces: - def __init__(self) -> None: + def __init__(self, database: Database) -> None: + self._database = database self._span_queue: "SimpleQueue[Optional[_SpanItem]]" = SimpleQueue() self._eval_queue: "SimpleQueue[Optional[_EvalItem]]" = SimpleQueue() # Putting `None` as the sentinel value for queue termination. @@ -34,6 +64,45 @@ def __init__(self) -> None: ) self._start_consumers() + def trace_count( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: + return self._database.trace_count(project_name, start_time, stop_time) + + def span_count( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: + return self._database.span_count(project_name, start_time, stop_time) + + def llm_token_count_total( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: + return self._database.llm_token_count_total(project_name, start_time, stop_time) + + def get_trace(self, trace_id: TraceID) -> Iterator[WrappedSpan]: + for span, computed_values in self._database.get_trace(trace_id): + wrapped_span = WrappedSpan(span) + wrapped_span[ComputedAttributes.LATENCY_MS] = computed_values.latency_ms + wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT] = ( + computed_values.cumulative_llm_token_count_prompt + ) + wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION] = ( + computed_values.cumulative_llm_token_count_completion + ) + wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL] = ( + computed_values.cumulative_llm_token_count_total + ) + yield wrapped_span + def get_project(self, project_name: str) -> Optional["Project"]: with self._lock: return self._projects.get(project_name) @@ -84,6 +153,7 @@ def _start_consumers(self) -> None: def _consume_spans(self, queue: "SimpleQueue[Optional[_SpanItem]]") -> None: while (item := queue.get()) is not END_OF_QUEUE: span, project_name = item + self._database.insert_span(span, project_name=project_name) with self._lock: project = self._projects[project_name] project.add_span(span) diff --git a/src/phoenix/db/__init__.py b/src/phoenix/db/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/phoenix/db/database.py b/src/phoenix/db/database.py new file mode 100644 index 0000000000..3e3486cca8 --- /dev/null +++ b/src/phoenix/db/database.py @@ -0,0 +1,354 @@ +import json +import sqlite3 +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any, Iterator, Optional, Tuple, cast + +import numpy as np +from openinference.semconv.trace import SpanAttributes + +from phoenix.trace.schemas import ComputedValues, Span, SpanContext, SpanKind, SpanStatusCode + +_CONFIG = """ +PRAGMA foreign_keys = ON; +PRAGMA journal_mode = WAL; +PRAGMA synchronous = OFF; +PRAGMA cache_size = -32000; +PRAGMA busy_timeout = 10000; +""" + +_INIT_DB = """ +BEGIN; +CREATE TABLE projects ( + rowid INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + description TEXT, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); +INSERT INTO projects(name) VALUES('default'); +CREATE TABLE traces ( + rowid INTEGER PRIMARY KEY, + trace_id TEXT UNIQUE NOT NULL, + project_rowid INTEGER NOT NULL, + session_id TEXT, + start_time TIMESTAMP NOT NULL, + end_time TIMESTAMP NOT NULL, + FOREIGN KEY(project_rowid) REFERENCES projects(rowid) +); +CREATE INDEX idx_trace_start_time ON traces(start_time); +CREATE TABLE spans ( + rowid INTEGER PRIMARY KEY, + span_id TEXT UNIQUE NOT NULL, + trace_rowid INTEGER NOT NULL, + parent_span_id TEXT, + kind TEXT NOT NULL, + name TEXT NOT NULL, + start_time TIMESTAMP NOT NULL, + end_time TIMESTAMP NOT NULL, + attributes JSON, + events JSON, + status TEXT CHECK(status IN ('UNSET','OK','ERROR')) NOT NULL DEFAULT('UNSET'), + status_message TEXT, + cumulative_error_count INTEGER NOT NULL DEFAULT 0, + cumulative_llm_token_count_prompt INTEGER NOT NULL DEFAULT 0, + cumulative_llm_token_count_completion INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY(trace_rowid) REFERENCES traces(rowid) +); +CREATE INDEX idx_parent_span_id ON spans(parent_span_id); +PRAGMA user_version = 1; +COMMIT; +""" + + +class SqliteDatabase: + def __init__(self, database: Optional[Path] = None) -> None: + """ + :param database: The path to the database file to be opened. + """ + self.con = sqlite3.connect( + database=database or ":memory:", + uri=True, + check_same_thread=False, + ) + # self.con.set_trace_callback(print) + cur = self.con.cursor() + cur.executescript(_CONFIG) + if int(cur.execute("PRAGMA user_version;").fetchone()[0]) < 1: + cur.executescript(_INIT_DB) + + def insert_span(self, span: Span, project_name: str) -> None: + cur = self.con.cursor() + cur.execute("BEGIN;") + try: + if not ( + projects := cur.execute( + "SELECT rowid FROM projects WHERE name = ?;", + (project_name,), + ).fetchone() + ): + projects = cur.execute( + "INSERT INTO projects(name) VALUES(?) RETURNING rowid;", + (project_name,), + ).fetchone() + project_rowid = projects[0] + if ( + trace_row := cur.execute( + """ + INSERT INTO traces(trace_id, project_rowid, session_id, start_time, end_time) + VALUES(?,?,?,?,?) + ON CONFLICT DO UPDATE SET + start_time = CASE WHEN excluded.start_time < start_time THEN excluded.start_time ELSE start_time END, + end_time = CASE WHEN end_time < excluded.end_time THEN excluded.end_time ELSE end_time END + WHERE excluded.start_time < start_time OR end_time < excluded.end_time + RETURNING rowid; + """, # noqa E501 + ( + span.context.trace_id, + project_rowid, + None, + span.start_time, + span.end_time, + ), + ).fetchone() + ) is None: + trace_row = cur.execute( + "SELECT rowid from traces where trace_id = ?", (span.context.trace_id,) + ).fetchone() + trace_rowid = trace_row[0] + cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) + cumulative_llm_token_count_prompt = cast( + int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0) + ) + cumulative_llm_token_count_completion = cast( + int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0) + ) + if accumulation := cur.execute( + """ + SELECT + sum(cumulative_error_count), + sum(cumulative_llm_token_count_prompt), + sum(cumulative_llm_token_count_completion) + FROM spans + WHERE parent_span_id = ? + """, # noqa E501 + (span.context.span_id,), + ).fetchone(): + cumulative_error_count += cast(int, accumulation[0] or 0) + cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0) + cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0) + cur.execute( + """ + INSERT INTO spans(span_id, trace_rowid, parent_span_id, kind, name, start_time, end_time, attributes, events, status, status_message, cumulative_error_count, cumulative_llm_token_count_prompt, cumulative_llm_token_count_completion) + VALUES(?,?,?,?,?,?,?,?,?,?,?,?,?,?) + RETURNING rowid; + """, # noqa E501 + ( + span.context.span_id, + trace_rowid, + span.parent_id, + span.span_kind.value, + span.name, + span.start_time, + span.end_time, + json.dumps(span.attributes, cls=_Encoder), + json.dumps(span.events, cls=_Encoder), + span.status_code.value, + span.status_message, + cumulative_error_count, + cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion, + ), + ) + parent_id = span.parent_id + while parent_id: + if parent_span := cur.execute( + """ + SELECT rowid, parent_span_id + FROM spans + WHERE span_id = ? + """, + (parent_id,), + ).fetchone(): + rowid, parent_id = parent_span[0], parent_span[1] + cur.execute( + """ + UPDATE spans SET + cumulative_error_count = cumulative_error_count + ?, + cumulative_llm_token_count_prompt = cumulative_llm_token_count_prompt + ?, + cumulative_llm_token_count_completion = cumulative_llm_token_count_completion + ? + WHERE rowid = ?; + """, # noqa E501 + ( + cumulative_error_count, + cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion, + rowid, + ), + ) + else: + break + except Exception: + cur.execute("ROLLBACK;") + else: + cur.execute("COMMIT;") + + def get_projects(self) -> Iterator[Tuple[int, str]]: + cur = self.con.cursor() + for project in cur.execute("SELECT rowid, name FROM projects;").fetchall(): + yield cast(Tuple[int, str], (project[0], project[1])) + + def trace_count( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: + query = """ + SELECT COUNT(*) + FROM traces + JOIN projects ON projects.rowid = traces.project_rowid + WHERE projects.name = ? + """ + cur = self.con.cursor() + if start_time and stop_time: + cur = cur.execute( + query + " AND ? <= traces.start_time AND traces.start_time < ?;", + (project_name, start_time, stop_time), + ) + elif start_time: + cur = cur.execute(query + " AND ? <= traces.start_time;", (project_name, start_time)) + elif stop_time: + cur = cur.execute(query + " AND traces.start_time < ?;", (project_name, stop_time)) + else: + cur = cur.execute(query + ";", (project_name,)) + if res := cur.fetchone(): + return cast(int, res[0] or 0) + return 0 + + def span_count( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: + query = """ + SELECT COUNT(*) + FROM spans + JOIN traces ON traces.rowid = spans.trace_rowid + JOIN projects ON projects.rowid = traces.project_rowid + WHERE projects.name = ? + """ + cur = self.con.cursor() + if start_time and stop_time: + cur = cur.execute( + query + " AND ? <= spans.start_time AND spans.start_time < ?;", + (project_name, start_time, stop_time), + ) + elif start_time: + cur = cur.execute(query + " AND ? <= spans.start_time;", (project_name, start_time)) + elif stop_time: + cur = cur.execute(query + " AND spans.start_time < ?;", (project_name, stop_time)) + else: + cur = cur.execute(query + ";", (project_name,)) + if res := cur.fetchone(): + return cast(int, res[0] or 0) + return 0 + + def llm_token_count_total( + self, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + ) -> int: + query = """ + SELECT + SUM(COALESCE(json_extract(spans.attributes, '$."llm.token_count.prompt"'), 0) + + COALESCE(json_extract(spans.attributes, '$."llm.token_count.completion"'), 0)) + FROM spans + JOIN traces ON traces.rowid = spans.trace_rowid + JOIN projects ON projects.rowid = traces.project_rowid + WHERE projects.name = ? + """ # noqa E501 + cur = self.con.cursor() + if start_time and stop_time: + cur = cur.execute( + query + " AND ? <= spans.start_time AND spans.start_time < ?;", + (project_name, start_time, stop_time), + ) + elif start_time: + cur = cur.execute(query + " AND ? <= spans.start_time;", (project_name, start_time)) + elif stop_time: + cur = cur.execute(query + " AND spans.start_time < ?;", (project_name, stop_time)) + else: + cur = cur.execute(query + ";", (project_name,)) + if res := cur.fetchone(): + return cast(int, res[0] or 0) + return 0 + + def get_trace(self, trace_id: str) -> Iterator[Tuple[Span, ComputedValues]]: + cur = self.con.cursor() + for span in cur.execute( + """ + SELECT + spans.span_id, + traces.trace_id, + spans.parent_span_id, + spans.kind, + spans.name, + spans.start_time, + spans.end_time, + spans.attributes, + spans.events, + spans.status, + spans.status_message, + spans.cumulative_error_count, + spans.cumulative_llm_token_count_prompt, + spans.cumulative_llm_token_count_completion + FROM spans + JOIN traces ON traces.rowid = spans.trace_rowid + WHERE traces.trace_id = ? + """, + (trace_id,), + ).fetchall(): + start_time = datetime.fromisoformat(span[5]) + end_time = datetime.fromisoformat(span[6]) + latency_ms = (end_time - start_time).total_seconds() * 1000 + yield ( + Span( + name=span[4], + context=SpanContext(trace_id=span[1], span_id=span[0]), + parent_id=span[2], + span_kind=SpanKind(span[3]), + start_time=start_time, + end_time=end_time, + attributes=json.loads(span[7]), + events=json.loads(span[8]), + status_code=SpanStatusCode(span[9]), + status_message=span[10], + conversation=None, + ), + ComputedValues( + latency_ms=latency_ms, + cumulative_error_count=span[11], + cumulative_llm_token_count_prompt=span[12], + cumulative_llm_token_count_completion=span[13], + cumulative_llm_token_count_total=span[12] + span[13], + ), + ) + + +class _Encoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, datetime): + return obj.isoformat() + elif isinstance(obj, Enum): + return obj.value + elif isinstance(obj, np.ndarray): + return list(obj) + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + return super().default(obj) diff --git a/src/phoenix/server/api/types/MimeType.py b/src/phoenix/server/api/types/MimeType.py index 4c33572994..9dbe40f477 100644 --- a/src/phoenix/server/api/types/MimeType.py +++ b/src/phoenix/server/api/types/MimeType.py @@ -8,8 +8,8 @@ @strawberry.enum class MimeType(Enum): - text = trace_schemas.MimeType.TEXT - json = trace_schemas.MimeType.JSON + text = trace_schemas.MimeType.TEXT.value + json = trace_schemas.MimeType.JSON.value @classmethod def _missing_(cls, v: Any) -> Optional["MimeType"]: diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 926d28296b..671d74cd93 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -4,9 +4,11 @@ import strawberry from strawberry import ID, UNSET +from strawberry.types import Info from phoenix.core.project import Project as CoreProject from phoenix.metrics.retrieval_metrics import RetrievalMetrics +from phoenix.server.api.context import Context from phoenix.server.api.input_types.SpanSort import SpanSort from phoenix.server.api.input_types.TimeRange import TimeRange from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary @@ -43,24 +45,41 @@ def end_time(self) -> Optional[datetime]: @strawberry.field def record_count( self, + info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - if not time_range: - return self.project.span_count() - return self.project.span_count(time_range.start, time_range.end) + if not (traces := info.context.traces): + return 0 + start_time, stop_time = ( + (None, None) if not time_range else (time_range.start, time_range.end) + ) + return traces.span_count(self.name, start_time, stop_time) @strawberry.field def trace_count( self, + info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - if not time_range: - return self.project.trace_count() - return self.project.trace_count(time_range.start, time_range.end) + if not (traces := info.context.traces): + return 0 + start_time, stop_time = ( + (None, None) if not time_range else (time_range.start, time_range.end) + ) + return traces.trace_count(self.name, start_time, stop_time) @strawberry.field - def token_count_total(self) -> int: - return self.project.token_count_total + def token_count_total( + self, + info: Info[Context, None], + time_range: Optional[TimeRange] = UNSET, + ) -> int: + if not (traces := info.context.traces): + return 0 + start_time, stop_time = ( + (None, None) if not time_range else (time_range.start, time_range.end) + ) + return traces.llm_token_count_total(self.name, start_time, stop_time) @strawberry.field def latency_ms_p50(self) -> Optional[float]: diff --git a/src/phoenix/server/api/types/Trace.py b/src/phoenix/server/api/types/Trace.py index 4a0b3331ba..a7d440bc52 100644 --- a/src/phoenix/server/api/types/Trace.py +++ b/src/phoenix/server/api/types/Trace.py @@ -2,8 +2,10 @@ import strawberry from strawberry import ID, UNSET, Private +from strawberry.types import Info from phoenix.core.project import Project +from phoenix.server.api.context import Context from phoenix.server.api.types.Evaluation import TraceEvaluation from phoenix.server.api.types.pagination import ( Connection, @@ -23,6 +25,7 @@ class Trace: @strawberry.field def spans( self, + info: Info[Context, None], first: Optional[int] = 50, last: Optional[int] = UNSET, after: Optional[Cursor] = UNSET, @@ -34,10 +37,9 @@ def spans( last=last, before=before if isinstance(before, Cursor) else None, ) - spans = sorted( - self.project.get_trace(TraceID(self.trace_id)), - key=lambda span: span.start_time, - ) + if not (traces := info.context.traces): + return connection_from_list(data=[], args=args) + spans = traces.get_trace(TraceID(self.trace_id)) data = [to_gql_span(span, self.project) for span in spans] return connection_from_list(data=data, args=args) diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 070a745f46..1b43921264 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -16,11 +16,13 @@ get_env_host, get_env_port, get_pids_path, + get_working_dir, ) from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset from phoenix.datasets.fixtures import FIXTURES, get_datasets +from phoenix.db.database import SqliteDatabase from phoenix.pointcloud.umap_parameters import ( DEFAULT_MIN_DIST, DEFAULT_N_NEIGHBORS, @@ -48,7 +50,7 @@ ██████╔╝███████║██║ ██║█████╗ ██╔██╗ ██║██║ ╚███╔╝ ██╔═══╝ ██╔══██║██║ ██║██╔══╝ ██║╚██╗██║██║ ██╔██╗ ██║ ██║ ██║╚██████╔╝███████╗██║ ╚████║██║██╔╝ ██╗ -╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝╚═╝╚═╝ ╚═╝ v{0} +╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝╚═╝╚═╝ ╚═╝ v{version} | | 🌎 Join our Community 🌎 @@ -61,9 +63,9 @@ | https://docs.arize.com/phoenix | | 🚀 Phoenix Server 🚀 -| Phoenix UI: http://{1}:{2} +| Phoenix UI: http://{host}:{port} | Log traces: /v1/traces over HTTP -| +| Storage location: {working_dir} """ @@ -193,7 +195,9 @@ def _load_items( primary_dataset, reference_dataset, ) - traces = Traces() + working_dir = get_working_dir() + db = SqliteDatabase(working_dir / "phoenix.db") + traces = Traces(db) if span_store := get_span_store(): Thread(target=load_traces_data_from_store, args=(traces, span_store), daemon=True).start() if trace_dataset_name is not None: @@ -246,9 +250,13 @@ def _load_items( # Print information about the server phoenix_version = pkg_resources.get_distribution("arize-phoenix").version - print( - _WELCOME_MESSAGE.format(phoenix_version, host if host != "0.0.0.0" else "localhost", port) - ) + config = { + "version": phoenix_version, + "host": host, + "port": port, + "working_dir": working_dir, + } + print(_WELCOME_MESSAGE.format(**config)) # Start the server server.run() diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index 4a69d6da76..d072ff2bc9 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -37,6 +37,7 @@ from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset +from phoenix.db.database import SqliteDatabase from phoenix.pointcloud.umap_parameters import get_umap_parameters from phoenix.server.app import create_app from phoenix.server.thread_server import ThreadServer @@ -118,27 +119,6 @@ def __init__( self.corpus_dataset = corpus_dataset self.trace_dataset = trace_dataset self.umap_parameters = get_umap_parameters(default_umap_parameters) - self.model = create_model_from_datasets( - primary_dataset, - reference_dataset, - ) - - self.corpus = ( - create_model_from_datasets( - corpus_dataset, - ) - if corpus_dataset is not None - else None - ) - - self.traces = Traces() - if trace_dataset: - for span in trace_dataset.to_spans(): - self.traces.put(span) - for evaluations in trace_dataset.evaluations: - for pb_evaluation in encode_evaluations(evaluations): - self.traces.put(pb_evaluation) - self.host = host or get_env_host() self.port = port or get_env_port() self.temp_dir = TemporaryDirectory() @@ -304,6 +284,24 @@ def __init__( port=port, notebook_env=notebook_env, ) + self.model = create_model_from_datasets( + primary_dataset, + reference_dataset, + ) + self.corpus = ( + create_model_from_datasets( + corpus_dataset, + ) + if corpus_dataset is not None + else None + ) + self.traces = Traces(SqliteDatabase()) + if trace_dataset: + for span in trace_dataset.to_spans(): + self.traces.put(span) + for evaluations in trace_dataset.evaluations: + for pb_evaluation in encode_evaluations(evaluations): + self.traces.put(pb_evaluation) if span_store := get_span_store(): Thread( target=load_traces_data_from_store, diff --git a/src/phoenix/trace/otel.py b/src/phoenix/trace/otel.py index 12ec0b4104..21bb417b4a 100644 --- a/src/phoenix/trace/otel.py +++ b/src/phoenix/trace/otel.py @@ -33,7 +33,6 @@ EXCEPTION_MESSAGE, EXCEPTION_STACKTRACE, EXCEPTION_TYPE, - MimeType, Span, SpanContext, SpanEvent, @@ -61,16 +60,14 @@ def decode(otlp_span: otlp.Span) -> Span: parent_id = _decode_identifier(otlp_span.parent_span_id) start_time = _decode_unix_nano(otlp_span.start_time_unix_nano) - end_time = ( - _decode_unix_nano(otlp_span.end_time_unix_nano) if otlp_span.end_time_unix_nano else None - ) + end_time = _decode_unix_nano(otlp_span.end_time_unix_nano) attributes = dict(_unflatten(_load_json_strings(_decode_key_values(otlp_span.attributes)))) span_kind = SpanKind(attributes.pop(OPENINFERENCE_SPAN_KIND, None)) for mime_type in (INPUT_MIME_TYPE, OUTPUT_MIME_TYPE): if mime_type in attributes: - attributes[mime_type] = MimeType(attributes[mime_type]) + attributes[mime_type] = attributes[mime_type] status_code, status_message = _decode_status(otlp_span.status) events = [_decode_event(event) for event in otlp_span.events] @@ -320,7 +317,7 @@ def encode(span: Span) -> otlp.Span: for mime_type in (INPUT_MIME_TYPE, OUTPUT_MIME_TYPE): if mime_type in attributes: - attributes[mime_type] = attributes[mime_type].value + attributes[mime_type] = attributes[mime_type] for key, value in span.attributes.items(): if value is None: diff --git a/src/phoenix/trace/schemas.py b/src/phoenix/trace/schemas.py index efebeea439..7caeab9bf4 100644 --- a/src/phoenix/trace/schemas.py +++ b/src/phoenix/trace/schemas.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, NamedTuple, Optional, Union from uuid import UUID EXCEPTION_TYPE = "exception.type" @@ -142,7 +142,7 @@ class Span: "If the parent_id is None, this is the root span" parent_id: Optional[SpanID] start_time: datetime - end_time: Optional[datetime] + end_time: datetime status_code: SpanStatusCode status_message: str """ @@ -202,3 +202,11 @@ class ComputedAttributes(Enum): CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION = "cumulative_token_count.completion" ERROR_COUNT = "error_count" CUMULATIVE_ERROR_COUNT = "cumulative_error_count" + + +class ComputedValues(NamedTuple): + latency_ms: float + cumulative_error_count: int + cumulative_llm_token_count_prompt: int + cumulative_llm_token_count_completion: int + cumulative_llm_token_count_total: int