Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow retries for insertions #4026

Merged
merged 14 commits into from
Jul 30, 2024
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ dev = [
"anthropic",
"prometheus_client",
"asgi-lifespan",
"Faker>=26.0.0",
]
evals = []
experimental = []
Expand Down Expand Up @@ -171,6 +172,7 @@ dependencies = [
"nest-asyncio", # for executor testing
"astunparse; python_version<'3.9'", # `ast.unparse(...)` is only available starting with Python 3.9
"asgi-lifespan",
"Faker>=26.0.0",
]

[tool.hatch.envs.type]
Expand Down
131 changes: 129 additions & 2 deletions src/phoenix/db/bulk_inserter.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,46 @@
import asyncio
import logging
from asyncio import Queue
from asyncio import Queue, as_completed
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timezone
from functools import singledispatchmethod
from itertools import islice
from time import perf_counter
from typing import (
Any,
Awaitable,
Callable,
DefaultDict,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
Type,
cast,
)

from cachetools import LRUCache
from sqlalchemy import Select, select
from typing_extensions import TypeAlias

import phoenix.trace.v1 as pb
from phoenix.db import models
from phoenix.db.insertion.constants import DEFAULT_RETRY_ALLOWANCE, DEFAULT_RETRY_DELAY_SEC
from phoenix.db.insertion.document_annotation import DocumentAnnotationQueueInserter
from phoenix.db.insertion.evaluation import (
EvaluationInsertionEvent,
InsertEvaluationError,
insert_evaluation,
)
from phoenix.db.insertion.helpers import DataManipulation, DataManipulationEvent
from phoenix.db.insertion.span import SpanInsertionEvent, insert_span
from phoenix.db.insertion.span_annotation import SpanAnnotationQueueInserter
from phoenix.db.insertion.trace_annotation import TraceAnnotationQueueInserter
from phoenix.db.insertion.types import Insertables, Precursors
from phoenix.server.api.dataloaders import CacheForDataLoaders
from phoenix.server.types import DbSessionFactory
from phoenix.trace.schemas import Span
Expand Down Expand Up @@ -55,6 +68,8 @@ def __init__(
max_ops_per_transaction: int = 1000,
max_queue_size: int = 1000,
enable_prometheus: bool = False,
retry_delay_sec: float = DEFAULT_RETRY_DELAY_SEC,
retry_allowance: int = DEFAULT_RETRY_ALLOWANCE,
) -> None:
"""
:param db: A function to initiate a new database session.
Expand All @@ -81,6 +96,9 @@ def __init__(
self._last_updated_at_by_project: LRUCache[ProjectRowId, datetime] = LRUCache(maxsize=100)
self._cache_for_dataloaders = cache_for_dataloaders
self._enable_prometheus = enable_prometheus
self._retry_delay_sec = retry_delay_sec
self._retry_allowance = retry_allowance
self._queue_inserters = _QueueInserters(db, self._retry_delay_sec, self._retry_allowance)

def last_updated_at(self, project_rowid: Optional[ProjectRowId] = None) -> Optional[datetime]:
if isinstance(project_rowid, ProjectRowId):
Expand All @@ -90,6 +108,7 @@ def last_updated_at(self, project_rowid: Optional[ProjectRowId] = None) -> Optio
async def __aenter__(
self,
) -> Tuple[
Callable[[Any], Awaitable[None]],
Callable[[Span, str], Awaitable[None]],
Callable[[pb.Evaluation], Awaitable[None]],
Callable[[DataManipulation], None],
Expand All @@ -98,6 +117,7 @@ async def __aenter__(
self._operations = Queue(maxsize=self._max_queue_size)
self._task = asyncio.create_task(self._bulk_insert())
return (
self._enqueue,
self._queue_span,
self._queue_evaluation,
self._enqueue_operation,
Expand All @@ -109,6 +129,9 @@ async def __aexit__(self, *args: Any) -> None:
self._task.cancel()
self._task = None

async def _enqueue(self, *items: Any) -> None:
await self._queue_inserters.enqueue(*items)

def _enqueue_operation(self, operation: DataManipulation) -> None:
cast("Queue[DataManipulation]", self._operations).put_nowait(operation)

Expand All @@ -124,7 +147,17 @@ async def _bulk_insert(self) -> None:
assert isinstance(self._operations, Queue)
spans_buffer, evaluations_buffer = None, None
# start first insert immediately if the inserter has not run recently
while self._running or not self._operations.empty() or self._spans or self._evaluations:
while (
self._running
or not self._queue_inserters.empty
or not self._operations.empty()
or self._spans
or self._evaluations
):
if not self._queue_inserters.empty:
if inserted_ids := await self._queue_inserters.insert():
for project_rowid in await self._get_project_rowids(inserted_ids):
self._last_updated_at_by_project[project_rowid] = datetime.now(timezone.utc)
if self._operations.empty() and not (self._spans or self._evaluations):
await asyncio.sleep(self._sleep)
continue
Expand Down Expand Up @@ -244,3 +277,97 @@ async def _insert_evaluations(self, evaluations: List[pb.Evaluation]) -> Transac
BULK_LOADER_EXCEPTIONS.inc()
logger.exception("Failed to insert evaluations")
return transaction_result

async def _get_project_rowids(
self,
inserted_ids: Mapping[Type[models.Base], List[int]],
) -> Set[int]:
ans: Set[int] = set()
if not inserted_ids:
return ans
stmt: Select[Tuple[int]]
for table, ids in inserted_ids.items():
if not ids:
continue
if issubclass(table, models.SpanAnnotation):
stmt = (
select(models.Project.id)
.join(models.Trace)
.join_from(models.Trace, models.Span)
.join_from(models.Span, models.SpanAnnotation)
.where(models.SpanAnnotation.id.in_(ids))
)
elif issubclass(table, models.DocumentAnnotation):
stmt = (
select(models.Project.id)
.join(models.Trace)
.join_from(models.Trace, models.Span)
.join_from(models.Span, models.DocumentAnnotation)
.where(models.DocumentAnnotation.id.in_(ids))
)
elif issubclass(table, models.TraceAnnotation):
stmt = (
select(models.Project.id)
.join(models.Trace)
.join_from(models.Trace, models.TraceAnnotation)
.where(models.TraceAnnotation.id.in_(ids))
)
else:
continue
async with self._db() as session:
project_rowids = [_ async for _ in await session.stream_scalars(stmt)]
ans.update(project_rowids)
return ans


class _QueueInserters:
def __init__(
self,
db: DbSessionFactory,
retry_delay_sec: float = DEFAULT_RETRY_DELAY_SEC,
retry_allowance: int = DEFAULT_RETRY_ALLOWANCE,
) -> None:
self._db = db
args = (db, retry_delay_sec, retry_allowance)
self._span_annotations = SpanAnnotationQueueInserter(*args)
self._trace_annotations = TraceAnnotationQueueInserter(*args)
self._document_annotations = DocumentAnnotationQueueInserter(*args)
self._queues = (
self._span_annotations,
self._trace_annotations,
self._document_annotations,
)

async def insert(self) -> Dict[Type[models.Base], List[int]]:
ans: DefaultDict[Type[models.Base], List[int]] = defaultdict(list)
for coro in as_completed([q.insert() for q in self._queues]):
table, inserted_ids = await coro
if inserted_ids:
ans[table].extend(inserted_ids)
return ans

@property
def empty(self) -> bool:
return all(q.empty for q in self._queues)

async def enqueue(self, *items: Any) -> None:
for item in items:
await self._enqueue(item)

@singledispatchmethod
async def _enqueue(self, item: Any) -> None: ...

@_enqueue.register(Precursors.SpanAnnotation)
@_enqueue.register(Insertables.SpanAnnotation)
async def _(self, item: Precursors.SpanAnnotation) -> None:
await self._span_annotations.enqueue(item)

@_enqueue.register(Precursors.TraceAnnotation)
@_enqueue.register(Insertables.TraceAnnotation)
async def _(self, item: Precursors.TraceAnnotation) -> None:
await self._trace_annotations.enqueue(item)

@_enqueue.register(Precursors.DocumentAnnotation)
@_enqueue.register(Insertables.DocumentAnnotation)
async def _(self, item: Precursors.DocumentAnnotation) -> None:
await self._document_annotations.enqueue(item)
24 changes: 23 additions & 1 deletion src/phoenix/db/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Optional, Tuple
from typing import Any, Callable, Hashable, Iterable, List, Optional, Set, Tuple, TypeVar

from openinference.semconv.trace import (
OpenInferenceSpanKindValues,
Expand Down Expand Up @@ -80,3 +80,25 @@ def get_project_names_for_experiments(*experiment_ids: int) -> Select[Tuple[Opti
.where(models.Experiment.id.in_(set(experiment_ids)))
.where(models.Experiment.project_name.isnot(None))
)


_AnyT = TypeVar("_AnyT")
_KeyT = TypeVar("_KeyT", bound=Hashable)


def dedup(
items: Iterable[_AnyT],
key: Callable[[_AnyT], _KeyT],
) -> List[_AnyT]:
"""
Discard subsequent duplicates after the first appearance in `items`.
"""
ans = []
seen: Set[_KeyT] = set()
for item in items:
if (k := key(item)) in seen:
continue
else:
ans.append(item)
seen.add(k)
return ans
2 changes: 2 additions & 0 deletions src/phoenix/db/insertion/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DEFAULT_RETRY_DELAY_SEC: float = 60
DEFAULT_RETRY_ALLOWANCE: int = 10
Loading
Loading