Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: sqlite async session for graphql api #2784

Merged
merged 6 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ type Project implements Node {
latencyMsP50: Float
latencyMsP99: Float
trace(traceId: ID!): Trace
spans(timeRange: TimeRange, traceIds: [ID!], first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection!
spans(timeRange: TimeRange, first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection!

"""
Names of all available evaluations for traces. (The list contains no duplicates.)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dependencies = [
"openinference-instrumentation-openai>=0.1.4",
"sqlalchemy>=2, <3",
"alembic>=1.3.0, <2",
"aiosqlite",
]
dynamic = ["version"]

Expand Down
76 changes: 3 additions & 73 deletions src/phoenix/core/traces.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import weakref
from collections import defaultdict
from datetime import datetime
from queue import SimpleQueue
from threading import RLock, Thread
from types import MethodType
from typing import DefaultDict, Iterator, Optional, Protocol, Tuple, Union
from typing import DefaultDict, Iterator, Optional, Tuple, Union

from typing_extensions import assert_never

Expand All @@ -13,45 +12,16 @@
from phoenix.core.project import (
END_OF_QUEUE,
Project,
WrappedSpan,
_ProjectName,
)
from phoenix.trace.schemas import ComputedAttributes, ComputedValues, Span, TraceID
from phoenix.trace.schemas import Span

_SpanItem = Tuple[Span, _ProjectName]
_EvalItem = Tuple[pb.Evaluation, _ProjectName]


class Database(Protocol):
def insert_span(self, span: Span, project_name: str) -> None: ...

def trace_count(
self,
project_name: str,
start_time: Optional[datetime] = None,
stop_time: Optional[datetime] = None,
) -> int: ...

def span_count(
self,
project_name: str,
start_time: Optional[datetime] = None,
stop_time: Optional[datetime] = None,
) -> int: ...

def llm_token_count_total(
self,
project_name: str,
start_time: Optional[datetime] = None,
stop_time: Optional[datetime] = None,
) -> int: ...

def get_trace(self, trace_id: TraceID) -> Iterator[Tuple[Span, ComputedValues]]: ...


class Traces:
def __init__(self, database: Database) -> None:
self._database = database
def __init__(self) -> None:
self._span_queue: "SimpleQueue[Optional[_SpanItem]]" = SimpleQueue()
self._eval_queue: "SimpleQueue[Optional[_EvalItem]]" = SimpleQueue()
# Putting `None` as the sentinel value for queue termination.
Expand All @@ -64,45 +34,6 @@ def __init__(self, database: Database) -> None:
)
self._start_consumers()

def trace_count(
self,
project_name: str,
start_time: Optional[datetime] = None,
stop_time: Optional[datetime] = None,
) -> int:
return self._database.trace_count(project_name, start_time, stop_time)

def span_count(
self,
project_name: str,
start_time: Optional[datetime] = None,
stop_time: Optional[datetime] = None,
) -> int:
return self._database.span_count(project_name, start_time, stop_time)

def llm_token_count_total(
self,
project_name: str,
start_time: Optional[datetime] = None,
stop_time: Optional[datetime] = None,
) -> int:
return self._database.llm_token_count_total(project_name, start_time, stop_time)

def get_trace(self, trace_id: TraceID) -> Iterator[WrappedSpan]:
for span, computed_values in self._database.get_trace(trace_id):
wrapped_span = WrappedSpan(span)
wrapped_span[ComputedAttributes.LATENCY_MS] = computed_values.latency_ms
wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_PROMPT] = (
computed_values.cumulative_llm_token_count_prompt
)
wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_COMPLETION] = (
computed_values.cumulative_llm_token_count_completion
)
wrapped_span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL] = (
computed_values.cumulative_llm_token_count_total
)
yield wrapped_span

def get_project(self, project_name: str) -> Optional["Project"]:
with self._lock:
return self._projects.get(project_name)
Expand Down Expand Up @@ -153,7 +84,6 @@ def _start_consumers(self) -> None:
def _consume_spans(self, queue: "SimpleQueue[Optional[_SpanItem]]") -> None:
while (item := queue.get()) is not END_OF_QUEUE:
span, project_name = item
self._database.insert_span(span, project_name=project_name)
with self._lock:
project = self._projects[project_name]
project.add_span(span)
Expand Down
Loading
Loading