Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(persistence): experimental bulk inserter for spans #2808

Merged
merged 5 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions src/phoenix/db/bulk_inserter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import asyncio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the file location of this feels off if it's specific to spans

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it could be used for bulk inserting evals too. it'll just need second queue

import logging
from itertools import islice
from time import time
from typing import Any, AsyncContextManager, Callable, Iterable, List, Optional, Tuple, cast

from openinference.semconv.trace import SpanAttributes
from sqlalchemy import func, insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession

from phoenix.db import models
from phoenix.trace.schemas import Span, SpanStatusCode

logger = logging.getLogger(__name__)


class SpansBulkInserter:
def __init__(
self,
db: Callable[[], AsyncContextManager[AsyncSession]],
initial_batch: Optional[Iterable[Tuple[Span, str]]] = None,
sleep_seconds: float = 0.5,
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
max_num_spans_per_transaction: int = 100,
RogerHYang marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self._db = db
self._running = False
self._sleep_seconds = sleep_seconds
self._max_num_spans_per_transaction = max_num_spans_per_transaction
self._batch: List[Tuple[Span, str]] = [] if initial_batch is None else list(initial_batch)
self._task: Optional[asyncio.Task[None]] = None

async def __aenter__(self) -> Callable[[Span, str], None]:
self._running = True
self._task = asyncio.create_task(self._insert_spans())
return self._queue_span

async def __aexit__(self, *args: Any) -> None:
self._running = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set self._task to None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll let the garbage collect do it later. it's harmless either way


def _queue_span(self, span: Span, project_name: str) -> None:
self._batch.append((span, project_name))

async def _insert_spans(self) -> None:
next_run_at = time() + self._sleep_seconds
while self._batch or self._running:
await asyncio.sleep(next_run_at - time())
next_run_at = time() + self._sleep_seconds
if not self._batch:
continue
batch = self._batch
self._batch = []
for i in range(0, len(batch), self._max_num_spans_per_transaction):
try:
async with self._db() as session:
for span, project_name in islice(
batch, i, i + self._max_num_spans_per_transaction
):
try:
async with session.begin_nested():
await _insert_span(session, span, project_name)
except Exception:
logger.exception(
f"Failed to insert span with span_id={span.context.span_id}"
)
except Exception:
logger.exception("Failed to insert spans")


async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> None:
if await session.scalar(select(1).where(models.Span.span_id == span.context.span_id)):
# Span already exists
return
Comment on lines +80 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there not a setting to ignore inserts if the record already exists so we don't need to hit the database an extra time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it'll raise a IntegrityError which is annoying. On the other hand, this operation here is not expensive, because the B-tree is most likely already in the buffer pool.

if not (
project_rowid := await session.scalar(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we start moving away from the rowid naming convention since we are trying to support Postgres in addition to SQLite?

Copy link
Contributor Author

@RogerHYang RogerHYang Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline: currently don't have a good alternative name, but will reconsider later

select(models.Project.id).where(models.Project.name == project_name)
)
):
project_rowid = await session.scalar(
insert(models.Project).values(name=project_name).returning(models.Project.id)
)
if trace := await session.scalar(
select(models.Trace).where(models.Trace.trace_id == span.context.trace_id)
):
trace_rowid = trace.id
# TODO(persistence): Figure out how to reliably retrieve timezone-aware
# datetime from the (sqlite) database, because all datetime in our
# programs should be timezone-aware.
if span.start_time < trace.start_time or trace.end_time < span.end_time:
trace.start_time = min(trace.start_time, span.start_time)
trace.end_time = max(trace.end_time, span.end_time)
await session.execute(
update(models.Trace)
.where(models.Trace.id == trace_rowid)
.values(
start_time=min(trace.start_time, span.start_time),
end_time=max(trace.end_time, span.end_time),
)
)
else:
trace_rowid = cast(
int,
await session.scalar(
insert(models.Trace)
.values(
project_rowid=project_rowid,
trace_id=span.context.trace_id,
start_time=span.start_time,
end_time=span.end_time,
)
.returning(models.Trace.id)
),
)
cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR)
cumulative_llm_token_count_prompt = cast(
int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0)
)
cumulative_llm_token_count_completion = cast(
int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0)
)
if accumulation := (
await session.execute(
select(
func.sum(models.Span.cumulative_error_count),
func.sum(models.Span.cumulative_llm_token_count_prompt),
func.sum(models.Span.cumulative_llm_token_count_completion),
).where(models.Span.parent_span_id == span.context.span_id)
)
).first():
cumulative_error_count += cast(int, accumulation[0] or 0)
cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0)
cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0)
latency_ms = (span.end_time - span.start_time).total_seconds() * 1000
session.add(
models.Span(
span_id=span.context.span_id,
trace_rowid=trace_rowid,
parent_span_id=span.parent_id,
kind=span.span_kind.value,
name=span.name,
start_time=span.start_time,
end_time=span.end_time,
attributes=span.attributes,
events=span.events,
status=span.status_code.value,
status_message=span.status_message,
latency_ms=latency_ms,
cumulative_error_count=cumulative_error_count,
cumulative_llm_token_count_prompt=cumulative_llm_token_count_prompt,
cumulative_llm_token_count_completion=cumulative_llm_token_count_completion,
)
)
# Propagate cumulative values to ancestors. This is usually a no-op, since
# the parent usually arrives after the child. But in the event that a
# child arrives after its parent, we need to make sure the all the
# ancestors' cumulative values are updated.
ancestors = (
select(models.Span.id, models.Span.parent_span_id)
.where(models.Span.span_id == span.parent_id)
.cte(recursive=True)
)
child = ancestors.alias()
ancestors = ancestors.union_all(
select(models.Span.id, models.Span.parent_span_id).join(
child, models.Span.span_id == child.c.parent_span_id
)
)
await session.execute(
update(models.Span)
.where(models.Span.id.in_(select(ancestors.c.id)))
.values(
cumulative_error_count=models.Span.cumulative_error_count + cumulative_error_count,
cumulative_llm_token_count_prompt=models.Span.cumulative_llm_token_count_prompt
+ cumulative_llm_token_count_prompt,
cumulative_llm_token_count_completion=models.Span.cumulative_llm_token_count_completion
+ cumulative_llm_token_count_completion,
)
)
7 changes: 6 additions & 1 deletion src/phoenix/db/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def aiosqlite_engine(
engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps)
event.listen(engine.sync_engine, "connect", set_sqlite_pragma)
if str(database) == ":memory:":
asyncio.run(init_models(engine))
try:
asyncio.get_running_loop()
except RuntimeError:
asyncio.run(init_models(engine))
else:
asyncio.create_task(init_models(engine))
else:
migrate(url)
return engine
Expand Down
17 changes: 11 additions & 6 deletions src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@ def upgrade() -> None:
# TODO does the uniqueness constraint need to be named
sa.Column("name", sa.String, nullable=False, unique=True),
sa.Column("description", sa.String, nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column(
"updated_at",
sa.DateTime(),
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
onupdate=sa.func.now(),
Expand All @@ -41,8 +46,8 @@ def upgrade() -> None:
# TODO(mikeldking): might not be the right place for this
sa.Column("session_id", sa.String, nullable=True),
sa.Column("trace_id", sa.String, nullable=False, unique=True),
sa.Column("start_time", sa.DateTime(), nullable=False, index=True),
sa.Column("end_time", sa.DateTime(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False, index=True),
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
)

op.create_table(
Expand All @@ -53,8 +58,8 @@ def upgrade() -> None:
sa.Column("parent_span_id", sa.String, nullable=True, index=True),
sa.Column("name", sa.String, nullable=False),
sa.Column("kind", sa.String, nullable=False),
sa.Column("start_time", sa.DateTime(), nullable=False),
sa.Column("end_time", sa.DateTime(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("attributes", sa.JSON, nullable=False),
sa.Column("events", sa.JSON, nullable=False),
sa.Column(
Expand Down
52 changes: 45 additions & 7 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from datetime import datetime
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional

from sqlalchemy import (
JSON,
CheckConstraint,
DateTime,
Dialect,
ForeignKey,
MetaData,
TypeDecorator,
UniqueConstraint,
func,
insert,
Expand All @@ -21,6 +23,42 @@
)


class UtcTimeStamp(TypeDecorator[datetime]):
"""TODO(persistence): Figure out how to reliably store and retrieve
timezone-aware datetime objects from the (sqlite) database. This is a
workaround to guarantee that the timestamps we fetch from the database is
always timezone-aware, in order to prevent comparisons of timezone-naive
datetime with timezone-aware datetime, because objects in the rest of our
programs are always timezone-aware.
"""

cache_ok = True
impl = DateTime
_LOCAL_TIMEZONE = datetime.now(timezone.utc).astimezone().tzinfo

def process_bind_param(
self,
value: Optional[datetime],
dialect: Dialect,
) -> Optional[datetime]:
if not value:
return None
if value.tzinfo is None:
value = value.astimezone(self._LOCAL_TIMEZONE)
return value.astimezone(timezone.utc)

def process_result_value(
self,
value: Optional[Any],
dialect: Dialect,
) -> Optional[datetime]:
if not isinstance(value, datetime):
return None
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)


class Base(DeclarativeBase):
# Enforce best practices for naming constraints
# https://alembic.sqlalchemy.org/en/latest/naming.html#integration-of-naming-conventions-into-operations-autogenerate
Expand All @@ -44,9 +82,9 @@ class Project(Base):
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]
description: Mapped[Optional[str]]
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now())
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
UtcTimeStamp, server_default=func.now(), onupdate=func.now()
)

traces: WriteOnlyMapped["Trace"] = relationship(
Expand All @@ -69,8 +107,8 @@ class Trace(Base):
project_rowid: Mapped[int] = mapped_column(ForeignKey("projects.id"))
session_id: Mapped[Optional[str]]
trace_id: Mapped[str]
start_time: Mapped[datetime] = mapped_column(DateTime(), index=True)
end_time: Mapped[datetime] = mapped_column(DateTime())
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True)
end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)

project: Mapped["Project"] = relationship(
"Project",
Expand Down Expand Up @@ -98,8 +136,8 @@ class Span(Base):
parent_span_id: Mapped[Optional[str]] = mapped_column(index=True)
name: Mapped[str]
kind: Mapped[str]
start_time: Mapped[datetime] = mapped_column(DateTime())
end_time: Mapped[datetime] = mapped_column(DateTime())
start_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
end_time: Mapped[datetime] = mapped_column(UtcTimeStamp)
attributes: Mapped[Dict[str, Any]]
events: Mapped[List[Dict[str, Any]]]
status: Mapped[str] = mapped_column(
Expand Down
Loading
Loading