-
Notifications
You must be signed in to change notification settings - Fork 331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(persistence): experimental bulk inserter for spans #2808
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import asyncio | ||
import logging | ||
from itertools import islice | ||
from time import time | ||
from typing import Any, AsyncContextManager, Callable, Iterable, List, Optional, Tuple, cast | ||
|
||
from openinference.semconv.trace import SpanAttributes | ||
from sqlalchemy import func, insert, select, update | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
|
||
from phoenix.db import models | ||
from phoenix.trace.schemas import Span, SpanStatusCode | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SpansBulkInserter: | ||
def __init__( | ||
self, | ||
db: Callable[[], AsyncContextManager[AsyncSession]], | ||
initial_batch: Optional[Iterable[Tuple[Span, str]]] = None, | ||
sleep_seconds: float = 0.5, | ||
RogerHYang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
max_num_spans_per_transaction: int = 100, | ||
RogerHYang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> None: | ||
self._db = db | ||
self._running = False | ||
self._sleep_seconds = sleep_seconds | ||
self._max_num_spans_per_transaction = max_num_spans_per_transaction | ||
self._batch: List[Tuple[Span, str]] = [] if initial_batch is None else list(initial_batch) | ||
self._task: Optional[asyncio.Task[None]] = None | ||
|
||
async def __aenter__(self) -> Callable[[Span, str], None]: | ||
self._running = True | ||
self._task = asyncio.create_task(self._insert_spans()) | ||
return self._queue_span | ||
|
||
async def __aexit__(self, *args: Any) -> None: | ||
self._running = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'll let the garbage collect do it later. it's harmless either way |
||
|
||
def _queue_span(self, span: Span, project_name: str) -> None: | ||
self._batch.append((span, project_name)) | ||
|
||
async def _insert_spans(self) -> None: | ||
next_run_at = time() + self._sleep_seconds | ||
while self._batch or self._running: | ||
await asyncio.sleep(next_run_at - time()) | ||
next_run_at = time() + self._sleep_seconds | ||
if not self._batch: | ||
continue | ||
batch = self._batch | ||
self._batch = [] | ||
for i in range(0, len(batch), self._max_num_spans_per_transaction): | ||
try: | ||
async with self._db() as session: | ||
for span, project_name in islice( | ||
batch, i, i + self._max_num_spans_per_transaction | ||
): | ||
try: | ||
async with session.begin_nested(): | ||
await _insert_span(session, span, project_name) | ||
except Exception: | ||
logger.exception( | ||
f"Failed to insert span with span_id={span.context.span_id}" | ||
) | ||
except Exception: | ||
logger.exception("Failed to insert spans") | ||
|
||
|
||
async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> None: | ||
if await session.scalar(select(1).where(models.Span.span_id == span.context.span_id)): | ||
# Span already exists | ||
return | ||
Comment on lines
+80
to
+82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there not a setting to ignore inserts if the record already exists so we don't need to hit the database an extra time? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but it'll raise a IntegrityError which is annoying. On the other hand, this operation here is not expensive, because the B-tree is most likely already in the buffer pool. |
||
if not ( | ||
project_rowid := await session.scalar( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we start moving away from the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. discussed offline: currently don't have a good alternative name, but will reconsider later |
||
select(models.Project.id).where(models.Project.name == project_name) | ||
) | ||
): | ||
project_rowid = await session.scalar( | ||
insert(models.Project).values(name=project_name).returning(models.Project.id) | ||
) | ||
if trace := await session.scalar( | ||
select(models.Trace).where(models.Trace.trace_id == span.context.trace_id) | ||
): | ||
trace_rowid = trace.id | ||
# TODO(persistence): Figure out how to reliably retrieve timezone-aware | ||
# datetime from the (sqlite) database, because all datetime in our | ||
# programs should be timezone-aware. | ||
if span.start_time < trace.start_time or trace.end_time < span.end_time: | ||
trace.start_time = min(trace.start_time, span.start_time) | ||
trace.end_time = max(trace.end_time, span.end_time) | ||
await session.execute( | ||
update(models.Trace) | ||
.where(models.Trace.id == trace_rowid) | ||
.values( | ||
start_time=min(trace.start_time, span.start_time), | ||
end_time=max(trace.end_time, span.end_time), | ||
) | ||
) | ||
else: | ||
trace_rowid = cast( | ||
int, | ||
await session.scalar( | ||
insert(models.Trace) | ||
.values( | ||
project_rowid=project_rowid, | ||
trace_id=span.context.trace_id, | ||
start_time=span.start_time, | ||
end_time=span.end_time, | ||
) | ||
.returning(models.Trace.id) | ||
), | ||
) | ||
cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) | ||
cumulative_llm_token_count_prompt = cast( | ||
int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0) | ||
) | ||
cumulative_llm_token_count_completion = cast( | ||
int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0) | ||
) | ||
if accumulation := ( | ||
await session.execute( | ||
select( | ||
func.sum(models.Span.cumulative_error_count), | ||
func.sum(models.Span.cumulative_llm_token_count_prompt), | ||
func.sum(models.Span.cumulative_llm_token_count_completion), | ||
).where(models.Span.parent_span_id == span.context.span_id) | ||
) | ||
).first(): | ||
cumulative_error_count += cast(int, accumulation[0] or 0) | ||
cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0) | ||
cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0) | ||
latency_ms = (span.end_time - span.start_time).total_seconds() * 1000 | ||
session.add( | ||
models.Span( | ||
span_id=span.context.span_id, | ||
trace_rowid=trace_rowid, | ||
parent_span_id=span.parent_id, | ||
kind=span.span_kind.value, | ||
name=span.name, | ||
start_time=span.start_time, | ||
end_time=span.end_time, | ||
attributes=span.attributes, | ||
events=span.events, | ||
status=span.status_code.value, | ||
status_message=span.status_message, | ||
latency_ms=latency_ms, | ||
cumulative_error_count=cumulative_error_count, | ||
cumulative_llm_token_count_prompt=cumulative_llm_token_count_prompt, | ||
cumulative_llm_token_count_completion=cumulative_llm_token_count_completion, | ||
) | ||
) | ||
# Propagate cumulative values to ancestors. This is usually a no-op, since | ||
# the parent usually arrives after the child. But in the event that a | ||
# child arrives after its parent, we need to make sure the all the | ||
# ancestors' cumulative values are updated. | ||
ancestors = ( | ||
select(models.Span.id, models.Span.parent_span_id) | ||
.where(models.Span.span_id == span.parent_id) | ||
.cte(recursive=True) | ||
) | ||
child = ancestors.alias() | ||
ancestors = ancestors.union_all( | ||
select(models.Span.id, models.Span.parent_span_id).join( | ||
child, models.Span.span_id == child.c.parent_span_id | ||
) | ||
) | ||
await session.execute( | ||
update(models.Span) | ||
.where(models.Span.id.in_(select(ancestors.c.id))) | ||
.values( | ||
cumulative_error_count=models.Span.cumulative_error_count + cumulative_error_count, | ||
cumulative_llm_token_count_prompt=models.Span.cumulative_llm_token_count_prompt | ||
+ cumulative_llm_token_count_prompt, | ||
cumulative_llm_token_count_completion=models.Span.cumulative_llm_token_count_completion | ||
+ cumulative_llm_token_count_completion, | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the file location of this feels off if it's specific to spans
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it could be used for bulk inserting evals too. it'll just need second queue