diff --git a/pyproject.toml b/pyproject.toml index 64f35f13d6..384e9008f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ dev = [ "anthropic", "prometheus_client", "asgi-lifespan", + "Faker>=26.0.0", ] evals = [] experimental = [] @@ -171,6 +172,7 @@ dependencies = [ "nest-asyncio", # for executor testing "astunparse; python_version<'3.9'", # `ast.unparse(...)` is only available starting with Python 3.9 "asgi-lifespan", + "Faker>=26.0.0", ] [tool.hatch.envs.type] diff --git a/src/phoenix/db/bulk_inserter.py b/src/phoenix/db/bulk_inserter.py index 11de610be9..d0f914c59e 100644 --- a/src/phoenix/db/bulk_inserter.py +++ b/src/phoenix/db/bulk_inserter.py @@ -1,26 +1,36 @@ import asyncio import logging -from asyncio import Queue +from asyncio import Queue, as_completed +from collections import defaultdict from dataclasses import dataclass, field from datetime import datetime, timezone +from functools import singledispatchmethod from itertools import islice from time import perf_counter from typing import ( Any, Awaitable, Callable, + DefaultDict, + Dict, Iterable, List, + Mapping, Optional, Set, Tuple, + Type, cast, ) from cachetools import LRUCache +from sqlalchemy import Select, select from typing_extensions import TypeAlias import phoenix.trace.v1 as pb +from phoenix.db import models +from phoenix.db.insertion.constants import DEFAULT_RETRY_ALLOWANCE, DEFAULT_RETRY_DELAY_SEC +from phoenix.db.insertion.document_annotation import DocumentAnnotationQueueInserter from phoenix.db.insertion.evaluation import ( EvaluationInsertionEvent, InsertEvaluationError, @@ -28,6 +38,9 @@ ) from phoenix.db.insertion.helpers import DataManipulation, DataManipulationEvent from phoenix.db.insertion.span import SpanInsertionEvent, insert_span +from phoenix.db.insertion.span_annotation import SpanAnnotationQueueInserter +from phoenix.db.insertion.trace_annotation import TraceAnnotationQueueInserter +from phoenix.db.insertion.types import Insertables, Precursors from phoenix.server.api.dataloaders import CacheForDataLoaders from phoenix.server.types import DbSessionFactory from phoenix.trace.schemas import Span @@ -55,6 +68,8 @@ def __init__( max_ops_per_transaction: int = 1000, max_queue_size: int = 1000, enable_prometheus: bool = False, + retry_delay_sec: float = DEFAULT_RETRY_DELAY_SEC, + retry_allowance: int = DEFAULT_RETRY_ALLOWANCE, ) -> None: """ :param db: A function to initiate a new database session. @@ -81,6 +96,9 @@ def __init__( self._last_updated_at_by_project: LRUCache[ProjectRowId, datetime] = LRUCache(maxsize=100) self._cache_for_dataloaders = cache_for_dataloaders self._enable_prometheus = enable_prometheus + self._retry_delay_sec = retry_delay_sec + self._retry_allowance = retry_allowance + self._queue_inserters = _QueueInserters(db, self._retry_delay_sec, self._retry_allowance) def last_updated_at(self, project_rowid: Optional[ProjectRowId] = None) -> Optional[datetime]: if isinstance(project_rowid, ProjectRowId): @@ -90,6 +108,7 @@ def last_updated_at(self, project_rowid: Optional[ProjectRowId] = None) -> Optio async def __aenter__( self, ) -> Tuple[ + Callable[[Any], Awaitable[None]], Callable[[Span, str], Awaitable[None]], Callable[[pb.Evaluation], Awaitable[None]], Callable[[DataManipulation], None], @@ -98,6 +117,7 @@ async def __aenter__( self._operations = Queue(maxsize=self._max_queue_size) self._task = asyncio.create_task(self._bulk_insert()) return ( + self._enqueue, self._queue_span, self._queue_evaluation, self._enqueue_operation, @@ -109,6 +129,9 @@ async def __aexit__(self, *args: Any) -> None: self._task.cancel() self._task = None + async def _enqueue(self, *items: Any) -> None: + await self._queue_inserters.enqueue(*items) + def _enqueue_operation(self, operation: DataManipulation) -> None: cast("Queue[DataManipulation]", self._operations).put_nowait(operation) @@ -124,7 +147,17 @@ async def _bulk_insert(self) -> None: assert isinstance(self._operations, Queue) spans_buffer, evaluations_buffer = None, None # start first insert immediately if the inserter has not run recently - while self._running or not self._operations.empty() or self._spans or self._evaluations: + while ( + self._running + or not self._queue_inserters.empty + or not self._operations.empty() + or self._spans + or self._evaluations + ): + if not self._queue_inserters.empty: + if inserted_ids := await self._queue_inserters.insert(): + for project_rowid in await self._get_project_rowids(inserted_ids): + self._last_updated_at_by_project[project_rowid] = datetime.now(timezone.utc) if self._operations.empty() and not (self._spans or self._evaluations): await asyncio.sleep(self._sleep) continue @@ -244,3 +277,97 @@ async def _insert_evaluations(self, evaluations: List[pb.Evaluation]) -> Transac BULK_LOADER_EXCEPTIONS.inc() logger.exception("Failed to insert evaluations") return transaction_result + + async def _get_project_rowids( + self, + inserted_ids: Mapping[Type[models.Base], List[int]], + ) -> Set[int]: + ans: Set[int] = set() + if not inserted_ids: + return ans + stmt: Select[Tuple[int]] + for table, ids in inserted_ids.items(): + if not ids: + continue + if issubclass(table, models.SpanAnnotation): + stmt = ( + select(models.Project.id) + .join(models.Trace) + .join_from(models.Trace, models.Span) + .join_from(models.Span, models.SpanAnnotation) + .where(models.SpanAnnotation.id.in_(ids)) + ) + elif issubclass(table, models.DocumentAnnotation): + stmt = ( + select(models.Project.id) + .join(models.Trace) + .join_from(models.Trace, models.Span) + .join_from(models.Span, models.DocumentAnnotation) + .where(models.DocumentAnnotation.id.in_(ids)) + ) + elif issubclass(table, models.TraceAnnotation): + stmt = ( + select(models.Project.id) + .join(models.Trace) + .join_from(models.Trace, models.TraceAnnotation) + .where(models.TraceAnnotation.id.in_(ids)) + ) + else: + continue + async with self._db() as session: + project_rowids = [_ async for _ in await session.stream_scalars(stmt)] + ans.update(project_rowids) + return ans + + +class _QueueInserters: + def __init__( + self, + db: DbSessionFactory, + retry_delay_sec: float = DEFAULT_RETRY_DELAY_SEC, + retry_allowance: int = DEFAULT_RETRY_ALLOWANCE, + ) -> None: + self._db = db + args = (db, retry_delay_sec, retry_allowance) + self._span_annotations = SpanAnnotationQueueInserter(*args) + self._trace_annotations = TraceAnnotationQueueInserter(*args) + self._document_annotations = DocumentAnnotationQueueInserter(*args) + self._queues = ( + self._span_annotations, + self._trace_annotations, + self._document_annotations, + ) + + async def insert(self) -> Dict[Type[models.Base], List[int]]: + ans: DefaultDict[Type[models.Base], List[int]] = defaultdict(list) + for coro in as_completed([q.insert() for q in self._queues]): + table, inserted_ids = await coro + if inserted_ids: + ans[table].extend(inserted_ids) + return ans + + @property + def empty(self) -> bool: + return all(q.empty for q in self._queues) + + async def enqueue(self, *items: Any) -> None: + for item in items: + await self._enqueue(item) + + @singledispatchmethod + async def _enqueue(self, item: Any) -> None: ... + + @_enqueue.register(Precursors.SpanAnnotation) + @_enqueue.register(Insertables.SpanAnnotation) + async def _(self, item: Precursors.SpanAnnotation) -> None: + await self._span_annotations.enqueue(item) + + @_enqueue.register(Precursors.TraceAnnotation) + @_enqueue.register(Insertables.TraceAnnotation) + async def _(self, item: Precursors.TraceAnnotation) -> None: + await self._trace_annotations.enqueue(item) + + @_enqueue.register(Precursors.DocumentAnnotation) + @_enqueue.register(Insertables.DocumentAnnotation) + async def _(self, item: Precursors.DocumentAnnotation) -> None: + await self._document_annotations.enqueue(item) diff --git a/src/phoenix/db/helpers.py b/src/phoenix/db/helpers.py index 14a77e7c9e..a5ce45f0fd 100644 --- a/src/phoenix/db/helpers.py +++ b/src/phoenix/db/helpers.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Optional, Tuple +from typing import Any, Callable, Hashable, Iterable, List, Optional, Set, Tuple, TypeVar from openinference.semconv.trace import ( OpenInferenceSpanKindValues, @@ -80,3 +80,25 @@ def get_project_names_for_experiments(*experiment_ids: int) -> Select[Tuple[Opti .where(models.Experiment.id.in_(set(experiment_ids))) .where(models.Experiment.project_name.isnot(None)) ) + + +_AnyT = TypeVar("_AnyT") +_KeyT = TypeVar("_KeyT", bound=Hashable) + + +def dedup( + items: Iterable[_AnyT], + key: Callable[[_AnyT], _KeyT], +) -> List[_AnyT]: + """ + Discard subsequent duplicates after the first appearance in `items`. + """ + ans = [] + seen: Set[_KeyT] = set() + for item in items: + if (k := key(item)) in seen: + continue + else: + ans.append(item) + seen.add(k) + return ans diff --git a/src/phoenix/db/insertion/constants.py b/src/phoenix/db/insertion/constants.py new file mode 100644 index 0000000000..4462df5d3f --- /dev/null +++ b/src/phoenix/db/insertion/constants.py @@ -0,0 +1,2 @@ +DEFAULT_RETRY_DELAY_SEC: float = 60 +DEFAULT_RETRY_ALLOWANCE: int = 10 diff --git a/src/phoenix/db/insertion/document_annotation.py b/src/phoenix/db/insertion/document_annotation.py new file mode 100644 index 0000000000..b83cb73df0 --- /dev/null +++ b/src/phoenix/db/insertion/document_annotation.py @@ -0,0 +1,157 @@ +from datetime import datetime +from typing import Any, List, Mapping, NamedTuple, Optional, Tuple + +from sqlalchemy import Row, Select, and_, select, tuple_ +from sqlalchemy.ext.asyncio import AsyncSession +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.db.helpers import dedup, num_docs_col +from phoenix.db.insertion.types import ( + Insertables, + Postponed, + Precursors, + QueueInserter, + Received, +) + +_Name: TypeAlias = str +_SpanId: TypeAlias = str +_SpanRowId: TypeAlias = int +_DocumentPosition: TypeAlias = int +_AnnoRowId: TypeAlias = int +_NumDocs: TypeAlias = int + +_Key: TypeAlias = Tuple[_Name, _SpanId, _DocumentPosition] +_UniqueBy: TypeAlias = Tuple[_Name, _SpanRowId, _DocumentPosition] +_Existing: TypeAlias = Tuple[ + _SpanRowId, + _SpanId, + _NumDocs, + Optional[_AnnoRowId], + Optional[_Name], + Optional[_DocumentPosition], + Optional[datetime], +] + + +class DocumentAnnotationQueueInserter( + QueueInserter[ + Precursors.DocumentAnnotation, + Insertables.DocumentAnnotation, + models.DocumentAnnotation, + ], + table=models.DocumentAnnotation, + unique_by=("name", "span_rowid", "document_position"), +): + async def _partition( + self, + session: AsyncSession, + *parcels: Received[Precursors.DocumentAnnotation], + ) -> Tuple[ + List[Received[Insertables.DocumentAnnotation]], + List[Postponed[Precursors.DocumentAnnotation]], + List[Received[Precursors.DocumentAnnotation]], + ]: + to_insert: List[Received[Insertables.DocumentAnnotation]] = [] + to_postpone: List[Postponed[Precursors.DocumentAnnotation]] = [] + to_discard: List[Received[Precursors.DocumentAnnotation]] = [] + + stmt = self._select_existing(*map(_key, parcels)) + existing: List[Row[_Existing]] = [_ async for _ in await session.stream(stmt)] + existing_spans: Mapping[str, _SpanAttr] = { + e.span_id: _SpanAttr(e.span_rowid, e.num_docs) for e in existing + } + existing_annos: Mapping[_Key, _AnnoAttr] = { + (e.name, e.span_id, e.document_position): _AnnoAttr(e.span_rowid, e.id, e.updated_at) + for e in existing + if e.id is not None + and e.name is not None + and e.document_position is not None + and e.updated_at is not None + } + + for p in parcels: + if (anno := existing_annos.get(_key(p))) is not None: + if p.received_at <= anno.updated_at: + to_discard.append(p) + else: + to_insert.append( + Received( + received_at=p.received_at, + item=p.item.as_insertable( + span_rowid=anno.span_rowid, + id_=anno.id_, + ), + ) + ) + elif (span := existing_spans.get(p.item.span_id)) is not None: + if 0 <= p.item.document_position < span.num_docs: + to_insert.append( + Received( + received_at=p.received_at, + item=p.item.as_insertable( + span_rowid=span.span_rowid, + ), + ) + ) + else: + to_discard.append(p) + elif isinstance(p, Postponed): + if p.retries_left > 1: + to_postpone.append(p.postpone(p.retries_left - 1)) + else: + to_discard.append(p) + elif isinstance(p, Received): + to_postpone.append(p.postpone(self._retry_allowance)) + else: + to_discard.append(p) + + assert len(to_insert) + len(to_postpone) + len(to_discard) == len(parcels) + to_insert = dedup(sorted(to_insert, key=_time, reverse=True), _unique_by)[::-1] + return to_insert, to_postpone, to_discard + + def _select_existing(self, *keys: _Key) -> Select[_Existing]: + anno = self.table + span = ( + select(models.Span.id, models.Span.span_id, num_docs_col(self._db.dialect)) + .where(models.Span.span_id.in_({span_id for _, span_id, *_ in keys})) + .cte() + ) + onclause = and_( + span.c.id == anno.span_rowid, + anno.name.in_({name for name, *_ in keys}), + tuple_(anno.name, span.c.span_id, anno.document_position).in_(keys), + ) + return select( + span.c.id.label("span_rowid"), + span.c.span_id, + span.c.num_docs, + anno.id, + anno.name, + anno.document_position, + anno.updated_at, + ).outerjoin_from(span, anno, onclause) + + +class _SpanAttr(NamedTuple): + span_rowid: _SpanRowId + num_docs: _NumDocs + + +class _AnnoAttr(NamedTuple): + span_rowid: _SpanRowId + id_: _AnnoRowId + updated_at: datetime + + +def _key(p: Received[Precursors.DocumentAnnotation]) -> _Key: + return p.item.obj.name, p.item.span_id, p.item.document_position + + +def _unique_by(p: Received[Insertables.DocumentAnnotation]) -> _UniqueBy: + return p.item.obj.name, p.item.span_rowid, p.item.document_position + + +def _time(p: Received[Any]) -> datetime: + return p.received_at diff --git a/src/phoenix/db/insertion/helpers.py b/src/phoenix/db/insertion/helpers.py index f0f88795b1..fda232a3be 100644 --- a/src/phoenix/db/insertion/helpers.py +++ b/src/phoenix/db/insertion/helpers.py @@ -20,6 +20,7 @@ from sqlalchemy.sql.elements import KeyedColumnElement from typing_extensions import TypeAlias, assert_never +from phoenix.db import models from phoenix.db.helpers import SupportedSQLDialect from phoenix.db.models import Base @@ -93,3 +94,15 @@ def _clean( yield "metadata", v else: yield k, v + + +def as_kv(obj: models.Base) -> Iterator[Tuple[str, Any]]: + for k, c in obj.__table__.c.items(): + if k in ["created_at", "updated_at"]: + continue + k = "metadata_" if k == "metadata" else k + v = getattr(obj, k, None) + if c.primary_key and v is None: + # postgresql disallows None for primary key + continue + yield k, v diff --git a/src/phoenix/db/insertion/span_annotation.py b/src/phoenix/db/insertion/span_annotation.py new file mode 100644 index 0000000000..f3a0af9ace --- /dev/null +++ b/src/phoenix/db/insertion/span_annotation.py @@ -0,0 +1,144 @@ +from datetime import datetime +from typing import Any, List, Mapping, NamedTuple, Optional, Tuple + +from sqlalchemy import Row, Select, and_, select, tuple_ +from sqlalchemy.ext.asyncio import AsyncSession +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.db.helpers import dedup +from phoenix.db.insertion.types import ( + Insertables, + Postponed, + Precursors, + QueueInserter, + Received, +) + +_Name: TypeAlias = str +_SpanId: TypeAlias = str +_SpanRowId: TypeAlias = int +_AnnoRowId: TypeAlias = int + +_Key: TypeAlias = Tuple[_Name, _SpanId] +_UniqueBy: TypeAlias = Tuple[_Name, _SpanRowId] +_Existing: TypeAlias = Tuple[ + _SpanRowId, + _SpanId, + Optional[_AnnoRowId], + Optional[_Name], + Optional[datetime], +] + + +class SpanAnnotationQueueInserter( + QueueInserter[ + Precursors.SpanAnnotation, + Insertables.SpanAnnotation, + models.SpanAnnotation, + ], + table=models.SpanAnnotation, + unique_by=("name", "span_rowid"), +): + async def _partition( + self, + session: AsyncSession, + *parcels: Received[Precursors.SpanAnnotation], + ) -> Tuple[ + List[Received[Insertables.SpanAnnotation]], + List[Postponed[Precursors.SpanAnnotation]], + List[Received[Precursors.SpanAnnotation]], + ]: + to_insert: List[Received[Insertables.SpanAnnotation]] = [] + to_postpone: List[Postponed[Precursors.SpanAnnotation]] = [] + to_discard: List[Received[Precursors.SpanAnnotation]] = [] + + stmt = self._select_existing(*map(_key, parcels)) + existing: List[Row[_Existing]] = [_ async for _ in await session.stream(stmt)] + existing_spans: Mapping[str, _SpanAttr] = { + e.span_id: _SpanAttr(e.span_rowid) for e in existing + } + existing_annos: Mapping[_Key, _AnnoAttr] = { + (e.name, e.span_id): _AnnoAttr(e.span_rowid, e.id, e.updated_at) + for e in existing + if e.id is not None and e.name is not None and e.updated_at is not None + } + + for p in parcels: + if (anno := existing_annos.get(_key(p))) is not None: + if p.received_at <= anno.updated_at: + to_discard.append(p) + else: + to_insert.append( + Received( + received_at=p.received_at, + item=p.item.as_insertable( + span_rowid=anno.span_rowid, + id_=anno.id_, + ), + ) + ) + elif (span := existing_spans.get(p.item.span_id)) is not None: + to_insert.append( + Received( + received_at=p.received_at, + item=p.item.as_insertable( + span_rowid=span.span_rowid, + ), + ) + ) + elif isinstance(p, Postponed): + if p.retries_left > 1: + to_postpone.append(p.postpone(p.retries_left - 1)) + else: + to_discard.append(p) + elif isinstance(p, Received): + to_postpone.append(p.postpone(self._retry_allowance)) + else: + to_discard.append(p) + + assert len(to_insert) + len(to_postpone) + len(to_discard) == len(parcels) + to_insert = dedup(sorted(to_insert, key=_time, reverse=True), _unique_by)[::-1] + return to_insert, to_postpone, to_discard + + def _select_existing(self, *keys: _Key) -> Select[_Existing]: + anno = self.table + span = ( + select(models.Span.id, models.Span.span_id) + .where(models.Span.span_id.in_({span_id for _, span_id in keys})) + .cte() + ) + onclause = and_( + span.c.id == anno.span_rowid, + anno.name.in_({name for name, _ in keys}), + tuple_(anno.name, span.c.span_id).in_(keys), + ) + return select( + span.c.id.label("span_rowid"), + span.c.span_id, + anno.id, + anno.name, + anno.updated_at, + ).outerjoin_from(span, anno, onclause) + + +class _SpanAttr(NamedTuple): + span_rowid: _SpanRowId + + +class _AnnoAttr(NamedTuple): + span_rowid: _SpanRowId + id_: _AnnoRowId + updated_at: datetime + + +def _key(p: Received[Precursors.SpanAnnotation]) -> _Key: + return p.item.obj.name, p.item.span_id + + +def _unique_by(p: Received[Insertables.SpanAnnotation]) -> _UniqueBy: + return p.item.obj.name, p.item.span_rowid + + +def _time(p: Received[Any]) -> datetime: + return p.received_at diff --git a/src/phoenix/db/insertion/trace_annotation.py b/src/phoenix/db/insertion/trace_annotation.py new file mode 100644 index 0000000000..d7199171e8 --- /dev/null +++ b/src/phoenix/db/insertion/trace_annotation.py @@ -0,0 +1,144 @@ +from datetime import datetime +from typing import Any, List, Mapping, NamedTuple, Optional, Tuple + +from sqlalchemy import Row, Select, and_, select, tuple_ +from sqlalchemy.ext.asyncio import AsyncSession +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.db.helpers import dedup +from phoenix.db.insertion.types import ( + Insertables, + Postponed, + Precursors, + QueueInserter, + Received, +) + +_Name: TypeAlias = str +_TraceId: TypeAlias = str +_TraceRowId: TypeAlias = int +_AnnoRowId: TypeAlias = int + +_Key: TypeAlias = Tuple[_Name, _TraceId] +_UniqueBy: TypeAlias = Tuple[_Name, _TraceRowId] +_Existing: TypeAlias = Tuple[ + _TraceRowId, + _TraceId, + Optional[_AnnoRowId], + Optional[_Name], + Optional[datetime], +] + + +class TraceAnnotationQueueInserter( + QueueInserter[ + Precursors.TraceAnnotation, + Insertables.TraceAnnotation, + models.TraceAnnotation, + ], + table=models.TraceAnnotation, + unique_by=("name", "trace_rowid"), +): + async def _partition( + self, + session: AsyncSession, + *parcels: Received[Precursors.TraceAnnotation], + ) -> Tuple[ + List[Received[Insertables.TraceAnnotation]], + List[Postponed[Precursors.TraceAnnotation]], + List[Received[Precursors.TraceAnnotation]], + ]: + to_insert: List[Received[Insertables.TraceAnnotation]] = [] + to_postpone: List[Postponed[Precursors.TraceAnnotation]] = [] + to_discard: List[Received[Precursors.TraceAnnotation]] = [] + + stmt = self._select_existing(*map(_key, parcels)) + existing: List[Row[_Existing]] = [_ async for _ in await session.stream(stmt)] + existing_traces: Mapping[str, _TraceAttr] = { + e.trace_id: _TraceAttr(e.trace_rowid) for e in existing + } + existing_annos: Mapping[_Key, _AnnoAttr] = { + (e.name, e.trace_id): _AnnoAttr(e.trace_rowid, e.id, e.updated_at) + for e in existing + if e.id is not None and e.name is not None and e.updated_at is not None + } + + for p in parcels: + if (anno := existing_annos.get(_key(p))) is not None: + if p.received_at <= anno.updated_at: + to_discard.append(p) + else: + to_insert.append( + Received( + received_at=p.received_at, + item=p.item.as_insertable( + trace_rowid=anno.trace_rowid, + id_=anno.id_, + ), + ) + ) + elif (trace := existing_traces.get(p.item.trace_id)) is not None: + to_insert.append( + Received( + received_at=p.received_at, + item=p.item.as_insertable( + trace_rowid=trace.trace_rowid, + ), + ) + ) + elif isinstance(p, Postponed): + if p.retries_left > 1: + to_postpone.append(p.postpone(p.retries_left - 1)) + else: + to_discard.append(p) + elif isinstance(p, Received): + to_postpone.append(p.postpone(self._retry_allowance)) + else: + to_discard.append(p) + + assert len(to_insert) + len(to_postpone) + len(to_discard) == len(parcels) + to_insert = dedup(sorted(to_insert, key=_time, reverse=True), _unique_by)[::-1] + return to_insert, to_postpone, to_discard + + def _select_existing(self, *keys: _Key) -> Select[_Existing]: + anno = self.table + trace = ( + select(models.Trace.id, models.Trace.trace_id) + .where(models.Trace.trace_id.in_({trace_id for _, trace_id in keys})) + .cte() + ) + onclause = and_( + trace.c.id == anno.trace_rowid, + anno.name.in_({name for name, _ in keys}), + tuple_(anno.name, trace.c.trace_id).in_(keys), + ) + return select( + trace.c.id.label("trace_rowid"), + trace.c.trace_id, + anno.id, + anno.name, + anno.updated_at, + ).outerjoin_from(trace, anno, onclause) + + +class _TraceAttr(NamedTuple): + trace_rowid: _TraceRowId + + +class _AnnoAttr(NamedTuple): + trace_rowid: _TraceRowId + id_: _AnnoRowId + updated_at: datetime + + +def _key(p: Received[Precursors.TraceAnnotation]) -> _Key: + return p.item.obj.name, p.item.trace_id + + +def _unique_by(p: Received[Insertables.TraceAnnotation]) -> _UniqueBy: + return p.item.obj.name, p.item.trace_rowid + + +def _time(p: Received[Any]) -> datetime: + return p.received_at diff --git a/src/phoenix/db/insertion/types.py b/src/phoenix/db/insertion/types.py new file mode 100644 index 0000000000..5cd35dbeda --- /dev/null +++ b/src/phoenix/db/insertion/types.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +import asyncio +import logging +from abc import ABC, abstractmethod +from copy import copy +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import ( + Any, + Generic, + List, + Mapping, + Optional, + Protocol, + Sequence, + Tuple, + Type, + TypeVar, + cast, +) + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.dml import ReturningInsert + +from phoenix.db import models +from phoenix.db.insertion.constants import DEFAULT_RETRY_ALLOWANCE, DEFAULT_RETRY_DELAY_SEC +from phoenix.db.insertion.helpers import as_kv, insert_on_conflict +from phoenix.server.types import DbSessionFactory + +logger = logging.getLogger("__name__") + + +class Insertable(Protocol): + @property + def row(self) -> models.Base: ... + + +_AnyT = TypeVar("_AnyT") +_PrecursorT = TypeVar("_PrecursorT") +_InsertableT = TypeVar("_InsertableT", bound=Insertable) +_RowT = TypeVar("_RowT", bound=models.Base) + + +@dataclass(frozen=True) +class Received(Generic[_AnyT]): + item: _AnyT + received_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + def postpone(self, retries_left: int = DEFAULT_RETRY_ALLOWANCE) -> Postponed[_AnyT]: + return Postponed(item=self.item, received_at=self.received_at, retries_left=retries_left) + + +@dataclass(frozen=True) +class Postponed(Received[_AnyT]): + retries_left: int = field(default=DEFAULT_RETRY_ALLOWANCE) + + +class QueueInserter(ABC, Generic[_PrecursorT, _InsertableT, _RowT]): + table: Type[_RowT] + unique_by: Sequence[str] + + def __init_subclass__( + cls, + table: Type[_RowT], + unique_by: Sequence[str], + ) -> None: + cls.table = table + cls.unique_by = unique_by + + def __init__( + self, + db: DbSessionFactory, + retry_delay_sec: float = DEFAULT_RETRY_DELAY_SEC, + retry_allowance: int = DEFAULT_RETRY_ALLOWANCE, + ) -> None: + self._queue: List[Received[_PrecursorT]] = [] + self._db = db + self._retry_delay_sec = retry_delay_sec + self._retry_allowance = retry_allowance + + @property + def empty(self) -> bool: + return not bool(self._queue) + + async def enqueue(self, *items: _PrecursorT) -> None: + self._queue.extend([Received(item) for item in items]) + + @abstractmethod + async def _partition( + self, + session: AsyncSession, + *parcels: Received[_PrecursorT], + ) -> Tuple[ + List[Received[_InsertableT]], + List[Postponed[_PrecursorT]], + List[Received[_PrecursorT]], + ]: ... + + async def insert(self) -> Tuple[Type[_RowT], List[int]]: + if not self._queue: + return self.table, [] + parcels = self._queue + self._queue = [] + inserted_ids: List[int] = [] + async with self._db() as session: + to_insert, to_postpone, _ = await self._partition(session, *parcels) + if to_insert: + inserted_ids, to_retry, _ = await self._insert(session, *to_insert) + to_postpone.extend(to_retry) + if to_postpone: + loop = asyncio.get_running_loop() + loop.call_later(self._retry_delay_sec, self._queue.extend, to_postpone) + return self.table, inserted_ids + + def _stmt(self, *records: Mapping[str, Any]) -> ReturningInsert[Tuple[int]]: + pk = next(c for c in self.table.__table__.c if c.primary_key) + return insert_on_conflict( + *records, + table=self.table, + unique_by=self.unique_by, + dialect=self._db.dialect, + ).returning(pk) + + async def _insert( + self, + session: AsyncSession, + *insertions: Received[_InsertableT], + ) -> Tuple[List[int], List[Postponed[_PrecursorT]], List[Received[_InsertableT]]]: + records = [dict(as_kv(ins.item.row)) for ins in insertions] + inserted_ids: List[int] = [] + to_retry: List[Postponed[_PrecursorT]] = [] + failures: List[Received[_InsertableT]] = [] + stmt = self._stmt(*records) + try: + async with session.begin_nested(): + ids = [id_ async for id_ in await session.stream_scalars(stmt)] + inserted_ids.extend(ids) + except BaseException: + logger.exception( + f"Failed to bulk insert for {self.table.__name__}. " + f"Will try to insert ({len(records)} records) individually instead." + ) + for i, record in enumerate(records): + stmt = self._stmt(record) + try: + async with session.begin_nested(): + ids = [id_ async for id_ in await session.stream_scalars(stmt)] + inserted_ids.extend(ids) + except BaseException: + logger.exception(f"Failed to insert for {self.table.__name__}.") + p = insertions[i] + if isinstance(p, Postponed) and p.retries_left == 1: + failures.append(p) + else: + to_retry.append( + Postponed( + item=cast(_PrecursorT, p.item), + received_at=p.received_at, + retries_left=(p.retries_left - 1) + if isinstance(p, Postponed) + else self._retry_allowance, + ) + ) + return inserted_ids, to_retry, failures + + +class Precursors(ABC): + @dataclass(frozen=True) + class SpanAnnotation: + span_id: str + obj: models.SpanAnnotation + + def as_insertable( + self, + span_rowid: int, + id_: Optional[int] = None, + ) -> Insertables.SpanAnnotation: + return Insertables.SpanAnnotation( + span_id=self.span_id, + obj=self.obj, + span_rowid=span_rowid, + id_=id_, + ) + + @dataclass(frozen=True) + class TraceAnnotation: + trace_id: str + obj: models.TraceAnnotation + + def as_insertable( + self, + trace_rowid: int, + id_: Optional[int] = None, + ) -> Insertables.TraceAnnotation: + return Insertables.TraceAnnotation( + trace_id=self.trace_id, + obj=self.obj, + trace_rowid=trace_rowid, + id_=id_, + ) + + @dataclass(frozen=True) + class DocumentAnnotation: + span_id: str + document_position: int + obj: models.DocumentAnnotation + + def as_insertable( + self, + span_rowid: int, + id_: Optional[int] = None, + ) -> Insertables.DocumentAnnotation: + return Insertables.DocumentAnnotation( + span_id=self.span_id, + document_position=self.document_position, + obj=self.obj, + span_rowid=span_rowid, + id_=id_, + ) + + +class Insertables(ABC): + @dataclass(frozen=True) + class SpanAnnotation(Precursors.SpanAnnotation): + span_rowid: int + id_: Optional[int] = None + + @property + def row(self) -> models.SpanAnnotation: + obj = copy(self.obj) + obj.span_rowid = self.span_rowid + if self.id_ is not None: + obj.id = self.id_ + return obj + + @dataclass(frozen=True) + class TraceAnnotation(Precursors.TraceAnnotation): + trace_rowid: int + id_: Optional[int] = None + + @property + def row(self) -> models.TraceAnnotation: + obj = copy(self.obj) + obj.trace_rowid = self.trace_rowid + if self.id_ is not None: + obj.id = self.id_ + return obj + + @dataclass(frozen=True) + class DocumentAnnotation(Precursors.DocumentAnnotation): + span_rowid: int + id_: Optional[int] = None + + @property + def row(self) -> models.DocumentAnnotation: + obj = copy(self.obj) + obj.span_rowid = self.span_rowid + if self.id_ is not None: + obj.id = self.id_ + return obj diff --git a/src/phoenix/server/api/routers/v1/evaluations.py b/src/phoenix/server/api/routers/v1/evaluations.py index 3d616dcddd..c4e79ffc1a 100644 --- a/src/phoenix/server/api/routers/v1/evaluations.py +++ b/src/phoenix/server/api/routers/v1/evaluations.py @@ -1,6 +1,6 @@ import gzip from itertools import chain -from typing import Iterator, Optional, Tuple +from typing import Any, Callable, Iterator, Optional, Tuple, Union, cast import pandas as pd import pyarrow as pa @@ -24,10 +24,10 @@ import phoenix.trace.v1 as pb from phoenix.config import DEFAULT_PROJECT_NAME from phoenix.db import models +from phoenix.db.insertion.types import Precursors from phoenix.exceptions import PhoenixEvaluationNameIsMissing from phoenix.server.api.routers.utils import table_to_bytes from phoenix.server.types import DbSessionFactory -from phoenix.session.evaluation import encode_evaluations from phoenix.trace.span_evaluations import ( DocumentEvaluations, Evaluations, @@ -194,8 +194,94 @@ async def _process_pyarrow(request: Request) -> Response: async def _add_evaluations(state: State, evaluations: Evaluations) -> None: - for evaluation in encode_evaluations(evaluations): - await state.queue_evaluation_for_bulk_insert(evaluation) + dataframe = evaluations.dataframe + eval_name = evaluations.eval_name + names = dataframe.index.names + if ( + len(names) == 2 + and "document_position" in names + and ("context.span_id" in names or "span_id" in names) + ): + cls = _document_annotation_factory( + names.index("span_id") if "span_id" in names else names.index("context.span_id"), + names.index("document_position"), + ) + for index, row in dataframe.iterrows(): + score, label, explanation = _get_annotation_result(row) + document_annotation = cls(cast(Union[Tuple[str, int], Tuple[int, str]], index))( + name=eval_name, + annotator_kind="LLM", + score=score, + label=label, + explanation=explanation, + metadata_={}, + ) + await state.enqueue(document_annotation) + elif len(names) == 1 and names[0] in ("context.span_id", "span_id"): + for index, row in dataframe.iterrows(): + score, label, explanation = _get_annotation_result(row) + span_annotation = _span_annotation_factory(cast(str, index))( + name=eval_name, + annotator_kind="LLM", + score=score, + label=label, + explanation=explanation, + metadata_={}, + ) + await state.enqueue(span_annotation) + elif len(names) == 1 and names[0] in ("context.trace_id", "trace_id"): + for index, row in dataframe.iterrows(): + score, label, explanation = _get_annotation_result(row) + trace_annotation = _trace_annotation_factory(cast(str, index))( + name=eval_name, + annotator_kind="LLM", + score=score, + label=label, + explanation=explanation, + metadata_={}, + ) + await state.enqueue(trace_annotation) + + +def _get_annotation_result( + row: "pd.Series[Any]", +) -> Tuple[Optional[float], Optional[str], Optional[str]]: + return ( + cast(Optional[float], row.get("score")), + cast(Optional[str], row.get("label")), + cast(Optional[str], row.get("explanation")), + ) + + +def _document_annotation_factory( + span_id_idx: int, + document_position_idx: int, +) -> Callable[ + [Union[Tuple[str, int], Tuple[int, str]]], + Callable[..., Precursors.DocumentAnnotation], +]: + return lambda index: lambda **kwargs: Precursors.DocumentAnnotation( + span_id=str(index[span_id_idx]), + document_position=int(index[document_position_idx]), + obj=models.DocumentAnnotation( + document_position=int(index[document_position_idx]), + **kwargs, + ), + ) + + +def _span_annotation_factory(span_id: str) -> Callable[..., Precursors.SpanAnnotation]: + return lambda **kwargs: Precursors.SpanAnnotation( + span_id=str(span_id), + obj=models.SpanAnnotation(**kwargs), + ) + + +def _trace_annotation_factory(trace_id: str) -> Callable[..., Precursors.TraceAnnotation]: + return lambda **kwargs: Precursors.TraceAnnotation( + trace_id=str(trace_id), + obj=models.TraceAnnotation(**kwargs), + ) def _read_sql_trace_evaluations_into_dataframe( diff --git a/src/phoenix/server/api/routers/v1/spans.py b/src/phoenix/server/api/routers/v1/spans.py index a98cb41225..63c0e29a91 100644 --- a/src/phoenix/server/api/routers/v1/spans.py +++ b/src/phoenix/server/api/routers/v1/spans.py @@ -13,9 +13,9 @@ from phoenix.datetime_utils import normalize_datetime from phoenix.db import models from phoenix.db.helpers import SupportedSQLDialect -from phoenix.db.insertion.helpers import insert_on_conflict +from phoenix.db.insertion.helpers import as_kv, insert_on_conflict +from phoenix.db.insertion.types import Precursors from phoenix.server.api.routers.utils import df_to_bytes -from phoenix.server.api.types.node import from_global_id_with_expected_type from phoenix.trace.dsl import SpanQuery as SpanQuery_ from .pydantic_compat import V1RoutesBaseModel @@ -143,7 +143,7 @@ class SpanAnnotationResult(V1RoutesBaseModel): class SpanAnnotation(V1RoutesBaseModel): - span_id: str = Field(description="The ID of the span being annotated") + span_id: str = Field(description="OpenTelemetry Span ID (hex format w/o 0x prefix)") name: str = Field(description="The name of the annotation") annotator_kind: Literal["LLM", "HUMAN"] = Field( description="The kind of annotator used for the annotation" @@ -155,6 +155,19 @@ class SpanAnnotation(V1RoutesBaseModel): default=None, description="Metadata for the annotation" ) + def as_precursor(self) -> Precursors.SpanAnnotation: + return Precursors.SpanAnnotation( + self.span_id, + models.SpanAnnotation( + name=self.name, + annotator_kind=self.annotator_kind, + score=self.result.score if self.result else None, + label=self.result.label if self.result else None, + explanation=self.result.explanation if self.result else None, + metadata_=self.metadata or {}, + ), + ) + class AnnotateSpansRequestBody(RequestBody[List[SpanAnnotation]]): data: List[SpanAnnotation] @@ -178,59 +191,36 @@ class AnnotateSpansResponseBody(ResponseBody[List[InsertedSpanAnnotation]]): response_description="Span annotations inserted successfully", ) async def annotate_spans( - request: Request, request_body: AnnotateSpansRequestBody + request: Request, + request_body: AnnotateSpansRequestBody, + sync: bool = Query(default=True, description="If true, fulfill request synchronously."), ) -> AnnotateSpansResponseBody: - span_annotations = request_body.data - span_gids = [GlobalID.from_id(annotation.span_id) for annotation in span_annotations] - - resolved_span_ids = [] - for span_gid in span_gids: - try: - resolved_span_ids.append(from_global_id_with_expected_type(span_gid, "Span")) - except ValueError: - raise HTTPException( - detail="Span with ID {span_gid} does not exist", - status_code=HTTP_404_NOT_FOUND, - ) + precursors = [d.as_precursor() for d in request_body.data] + if not sync: + await request.state.enqueue(*precursors) + return AnnotateSpansResponseBody(data=[]) + span_ids = {p.span_id for p in precursors} async with request.app.state.db() as session: - spans = await session.execute( - select(models.Span).filter(models.Span.id.in_(resolved_span_ids)) - ) - existing_span_ids = {span.id for span in spans.scalars()} + existing_spans = { + span.span_id: span.id + async for span in await session.stream_scalars( + select(models.Span).filter(models.Span.span_id.in_(span_ids)) + ) + } - missing_span_ids = set(resolved_span_ids) - existing_span_ids + missing_span_ids = span_ids - set(existing_spans.keys()) if missing_span_ids: - missing_span_gids = [ - str(GlobalID("Span", str(span_gid))) for span_gid in missing_span_ids - ] raise HTTPException( - detail=f"Spans with IDs {', '.join(missing_span_gids)} do not exist.", + detail=f"Spans with IDs {', '.join(missing_span_ids)} do not exist.", status_code=HTTP_404_NOT_FOUND, ) inserted_annotations = [] - for annotation in span_annotations: - span_gid = GlobalID.from_id(annotation.span_id) - span_id = from_global_id_with_expected_type(span_gid, "Span") - name = annotation.name - annotator_kind = annotation.annotator_kind - result = annotation.result - label = result.label if result else None - score = result.score if result else None - explanation = result.explanation if result else None - metadata = annotation.metadata or {} - - values = dict( - span_rowid=span_id, - name=name, - label=label, - score=score, - explanation=explanation, - annotator_kind=annotator_kind, - metadata_=metadata, - ) - dialect = SupportedSQLDialect(session.bind.dialect.name) + + dialect = SupportedSQLDialect(session.bind.dialect.name) + for p in precursors: + values = dict(as_kv(p.as_insertable(existing_spans[p.span_id]).row)) span_annotation_id = await session.scalar( insert_on_conflict( values, diff --git a/src/phoenix/server/api/routers/v1/traces.py b/src/phoenix/server/api/routers/v1/traces.py index 78d4f029c8..b63ddcfe2c 100644 --- a/src/phoenix/server/api/routers/v1/traces.py +++ b/src/phoenix/server/api/routers/v1/traces.py @@ -2,7 +2,7 @@ import zlib from typing import Any, Dict, List, Literal, Optional -from fastapi import APIRouter, BackgroundTasks, Header, HTTPException +from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Query from google.protobuf.message import DecodeError from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( ExportTraceServiceRequest, @@ -22,8 +22,8 @@ from phoenix.db import models from phoenix.db.helpers import SupportedSQLDialect -from phoenix.db.insertion.helpers import insert_on_conflict -from phoenix.server.api.types.node import from_global_id_with_expected_type +from phoenix.db.insertion.helpers import as_kv, insert_on_conflict +from phoenix.db.insertion.types import Precursors from phoenix.trace.otel import decode_otlp_span from phoenix.utilities.project import get_project_name @@ -100,7 +100,7 @@ class TraceAnnotationResult(V1RoutesBaseModel): class TraceAnnotation(V1RoutesBaseModel): - trace_id: str = Field(description="The ID of the trace being annotated") + trace_id: str = Field(description="OpenTelemetry Trace ID (hex format w/o 0x prefix)") name: str = Field(description="The name of the annotation") annotator_kind: Literal["LLM", "HUMAN"] = Field( description="The kind of annotator used for the annotation" @@ -112,6 +112,19 @@ class TraceAnnotation(V1RoutesBaseModel): default=None, description="Metadata for the annotation" ) + def as_precursor(self) -> Precursors.TraceAnnotation: + return Precursors.TraceAnnotation( + self.trace_id, + models.TraceAnnotation( + name=self.name, + annotator_kind=self.annotator_kind, + score=self.result.score if self.result else None, + label=self.result.label if self.result else None, + explanation=self.result.explanation if self.result else None, + metadata_=self.metadata or {}, + ), + ) + class AnnotateTracesRequestBody(RequestBody[List[TraceAnnotation]]): data: List[TraceAnnotation] = Field(description="The trace annotations to be upserted") @@ -134,61 +147,36 @@ class AnnotateTracesResponseBody(ResponseBody[List[InsertedTraceAnnotation]]): ), ) async def annotate_traces( - request: Request, request_body: AnnotateTracesRequestBody + request: Request, + request_body: AnnotateTracesRequestBody, + sync: bool = Query(default=True, description="If true, fulfill request synchronously."), ) -> AnnotateTracesResponseBody: - trace_annotations = request_body.data - trace_gids = [GlobalID.from_id(annotation.trace_id) for annotation in trace_annotations] - - resolved_trace_ids = [] - for trace_gid in trace_gids: - try: - resolved_trace_ids.append(from_global_id_with_expected_type(trace_gid, "Trace")) - except ValueError: - raise HTTPException( - detail="Trace with ID {trace_gid} does not exist", - status_code=HTTP_404_NOT_FOUND, - ) + precursors = [d.as_precursor() for d in request_body.data] + if not sync: + await request.state.enqueue(*precursors) + return AnnotateTracesResponseBody(data=[]) + trace_ids = {p.trace_id for p in precursors} async with request.app.state.db() as session: - traces = await session.execute( - select(models.Trace).filter(models.Trace.id.in_(resolved_trace_ids)) - ) - existing_trace_ids = {trace.id for trace in traces.scalars()} + existing_traces = { + trace.trace_id: trace.id + async for trace in await session.stream_scalars( + select(models.Trace).filter(models.Trace.trace_id.in_(trace_ids)) + ) + } - missing_trace_ids = set(resolved_trace_ids) - existing_trace_ids + missing_trace_ids = trace_ids - set(existing_traces.keys()) if missing_trace_ids: - missing_trace_gids = [ - str(GlobalID("Trace", str(trace_gid))) for trace_gid in missing_trace_ids - ] raise HTTPException( - detail=f"Traces with IDs {', '.join(missing_trace_gids)} do not exist.", + detail=f"Traces with IDs {', '.join(missing_trace_ids)} do not exist.", status_code=HTTP_404_NOT_FOUND, ) inserted_annotations = [] - for annotation in trace_annotations: - trace_gid = GlobalID.from_id(annotation.trace_id) - trace_id = from_global_id_with_expected_type(trace_gid, "Trace") - - name = annotation.name - annotator_kind = annotation.annotator_kind - result = annotation.result - label = result.label if result else None - score = result.score if result else None - explanation = result.explanation if result else None - metadata = annotation.metadata or {} - - values = dict( - trace_rowid=trace_id, - name=name, - label=label, - score=score, - explanation=explanation, - annotator_kind=annotator_kind, - metadata_=metadata, - ) - dialect = SupportedSQLDialect(session.bind.dialect.name) + dialect = SupportedSQLDialect(session.bind.dialect.name) + for p in precursors: + values = dict(as_kv(p.as_insertable(existing_traces[p.trace_id]).row)) trace_annotation_id = await session.scalar( insert_on_conflict( values, diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 24ae70fc18..2a16e9ede0 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -229,6 +229,7 @@ async def lifespan(_: FastAPI) -> AsyncIterator[Dict[str, Any]]: global DB_MUTEX DB_MUTEX = asyncio.Lock() if dialect is SupportedSQLDialect.SQLITE else None async with bulk_inserter as ( + enqueue, queue_span, queue_evaluation, enqueue_operation, @@ -239,6 +240,7 @@ async def lifespan(_: FastAPI) -> AsyncIterator[Dict[str, Any]]: enable_prometheus=enable_prometheus, ): yield { + "enqueue": enqueue, "queue_span_for_bulk_insert": queue_span, "queue_evaluation_for_bulk_insert": queue_evaluation, "enqueue_operation": enqueue_operation, diff --git a/tests/conftest.py b/tests/conftest.py index e66ef2154d..1f2cfcdd33 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,19 @@ import asyncio import contextlib +import time from asyncio import AbstractEventLoop, get_running_loop from functools import partial from importlib.metadata import version -from time import sleep +from random import getrandbits from typing import ( Any, - AsyncContextManager, AsyncIterator, Awaitable, Callable, + Iterator, List, Literal, + Set, Tuple, ) @@ -20,6 +22,7 @@ from _pytest.config import Config, Parser from _pytest.fixtures import SubRequest from asgi_lifespan import LifespanManager +from faker import Faker from httpx import URL, Request, Response from phoenix.config import EXPORT_DIR from phoenix.core.model_schema_adapter import create_model_from_inferences @@ -43,7 +46,6 @@ def pytest_addoption(parser: Parser) -> None: parser.addoption( "--run-postgres", action="store_true", - default=False, help="Run tests that require Postgres", ) @@ -128,11 +130,11 @@ def _db_with_lock(engine: AsyncEngine) -> DbSessionFactory: lock, db = asyncio.Lock(), _db(engine) @contextlib.asynccontextmanager - async def _() -> AsyncIterator[AsyncSession]: + async def factory() -> AsyncIterator[AsyncSession]: async with lock, db() as session: yield session - return _ + return DbSessionFactory(db=factory, dialect=engine.dialect.name) @pytest.fixture @@ -144,15 +146,13 @@ async def project(db: DbSessionFactory) -> None: @pytest.fixture async def app( - dialect: str, - db: Callable[[], AsyncContextManager[AsyncSession]], + db: DbSessionFactory, ) -> AsyncIterator[ASGIApp]: - factory = DbSessionFactory(db=db, dialect=dialect) async with contextlib.AsyncExitStack() as stack: await stack.enter_async_context(patch_bulk_inserter()) await stack.enter_async_context(patch_grpc_server()) app = create_app( - db=factory, + db=db, model=create_model_from_inferences(EMPTY_INFERENCES, None), export_path=EXPORT_DIR, umap_params=get_umap_parameters(None), @@ -178,9 +178,12 @@ def __init__(self, transport: httpx.ASGITransport) -> None: def handle_request(self, request: Request) -> Response: fut = loop.create_task(self.handle_async_request(request)) - while not fut.done(): - sleep(0.01) - return fut.result() + time_cutoff = time.time() + 1 + while not fut.done() and time.time() < time_cutoff: + time.sleep(0.01) + if fut.done(): + return fut.result() + raise TimeoutError async def handle_async_request(self, request: Request) -> Response: response = await self.transport.handle_async_request(request) @@ -211,7 +214,7 @@ def px_client( httpx_clients: Tuple[httpx.Client, httpx.AsyncClient], ) -> Client: sync_client, _ = httpx_clients - client = Client() + client = Client(warn_if_server_not_running=False) client._client = sync_client client._base_url = str(sync_client.base_url) sync_client._base_url = URL("") @@ -239,7 +242,36 @@ async def patch_bulk_inserter() -> AsyncIterator[None]: cls = BulkInserter original = cls.__init__ name = original.__name__ - changes = {"sleep": 0.001} + changes = {"sleep": 0.001, "retry_delay_sec": 0.001, "retry_allowance": 1000} setattr(cls, name, lambda *_, **__: original(*_, **{**__, **changes})) yield setattr(cls, name, original) + + +@pytest.fixture +def fake() -> Faker: + return Faker() + + +@pytest.fixture +def rand_span_id() -> Iterator[str]: + def _(seen: Set[str]) -> Iterator[str]: + while True: + span_id = getrandbits(64).to_bytes(8, "big").hex() + if span_id not in seen: + seen.add(span_id) + yield span_id + + return _(set()) + + +@pytest.fixture +def rand_trace_id() -> Iterator[str]: + def _(seen: Set[str]) -> Iterator[str]: + while True: + span_id = getrandbits(128).to_bytes(16, "big").hex() + if span_id not in seen: + seen.add(span_id) + yield span_id + + return _(set()) diff --git a/tests/server/api/routers/v1/test_annotations.py b/tests/server/api/routers/v1/test_annotations.py new file mode 100644 index 0000000000..f62951ae84 --- /dev/null +++ b/tests/server/api/routers/v1/test_annotations.py @@ -0,0 +1,124 @@ +from asyncio import gather, sleep +from itertools import chain +from typing import Any, Awaitable, Callable, Iterator, List, Union, cast, get_args + +import pandas as pd +from faker import Faker +from phoenix import Client, TraceDataset +from phoenix.trace import DocumentEvaluations, SpanEvaluations, TraceEvaluations +from typing_extensions import TypeAlias, assert_never + +_Evals: TypeAlias = Union[SpanEvaluations, TraceEvaluations, DocumentEvaluations] + + +async def test_sending_evaluations_before_span( + px_client: Client, + span_data_with_documents: Any, + acall: Callable[..., Awaitable[Any]], + fake: Faker, + rand_span_id: Iterator[str], + rand_trace_id: Iterator[str], +) -> None: + size = 3 + eval_names = [fake.pystr() for _ in range(size)] + project_names = [fake.pystr() for _ in range(size)] + span = cast(pd.DataFrame, await acall(px_client.get_spans_dataframe)).iloc[:1] + span_ids, trace_ids, traces = {}, {}, {} + for project_name in project_names: + span_ids[project_name] = sorted(next(rand_span_id) for _ in range(size)) + trace_ids[project_name] = sorted(next(rand_trace_id) for _ in range(size)) + traces[project_name] = pd.concat( + [ + span.assign(**{"context.span_id": span_id, "context.trace_id": trace_id}) + for span_id, trace_id in zip(span_ids[project_name], trace_ids[project_name]) + ] + ).set_index("context.span_id", drop=False) + for i in range(size - 1, -1, -1): + s = i * fake.pyfloat() + await gather( + sleep(0.001), + *( + acall( + px_client.log_evaluations, + SpanEvaluations( + eval_name, + pd.DataFrame( + chain.from_iterable( + [ + dict(score=j + s, span_id=next(rand_span_id)), + dict(score=j + s, span_id=span_id), + ] + for j, span_id in enumerate(span_ids[project_name]) + ) + ).sample(frac=1), + ), + TraceEvaluations( + eval_name, + pd.DataFrame( + chain.from_iterable( + [ + dict(score=j + s, trace_id=next(rand_trace_id)), + dict(score=j + s, trace_id=trace_id), + ] + for j, trace_id in enumerate(trace_ids[project_name]) + ) + ).sample(frac=1), + ), + DocumentEvaluations( + eval_name, + pd.DataFrame( + chain.from_iterable( + [ + dict(score=j + s, span_id=next(rand_span_id), position=0), + dict(score=j + s, span_id=span_id, position=0), + dict(score=j + s, span_id=span_id, position=-1), + dict(score=j + s, span_id=span_id, position=999_999_999), + ] + for j, span_id in enumerate(span_ids[project_name]) + ) + ).sample(frac=1), + ), + ) + for eval_name in eval_names + for project_name in project_names + ), + ) + await gather( + sleep(1), + *( + acall( + px_client.log_traces, + TraceDataset(traces[project_name]), + project_name=project_name, + ) + for project_name in project_names + ), + ) + evals = dict( + zip( + project_names, + cast( + List[List[_Evals]], + await gather( + *( + acall(px_client.get_evaluations, project_name=project_name) + for project_name in project_names + ) + ), + ), + ) + ) + for project_name in project_names: + assert len(evals[project_name]) == len(eval_names) * len(get_args(_Evals)) + for e in evals[project_name]: + df = e.dataframe.sort_index() + assert len(df) == size + assert df.score.to_list() == list(range(size)) + if isinstance(e, SpanEvaluations): + assert df.index.to_list() == span_ids[project_name] + elif isinstance(e, TraceEvaluations): + assert df.index.to_list() == trace_ids[project_name] + elif isinstance(e, DocumentEvaluations): + assert df.index.to_list() == [(span_id, 0) for span_id in span_ids[project_name]] + else: + assert_never(e) diff --git a/tests/server/api/routers/v1/test_spans.py b/tests/server/api/routers/v1/test_spans.py index c25929799c..efd8bb1ca2 100644 --- a/tests/server/api/routers/v1/test_spans.py +++ b/tests/server/api/routers/v1/test_spans.py @@ -6,13 +6,12 @@ import httpx import pandas as pd import pytest +from faker import Faker from phoenix import Client, TraceDataset from phoenix.db import models -from phoenix.server.api.types.node import from_global_id_with_expected_type from phoenix.server.types import DbSessionFactory from phoenix.trace.dsl import SpanQuery from sqlalchemy import insert, select -from strawberry.relay import GlobalID async def test_span_round_tripping_with_docs( @@ -38,17 +37,20 @@ async def test_span_round_tripping_with_docs( assert new_count == orig_count * 2 +@pytest.mark.parametrize("sync", [False, True]) async def test_rest_span_annotation( db: DbSessionFactory, httpx_client: httpx.AsyncClient, project_with_a_single_trace_and_span: Any, + sync: bool, + fake: Faker, ) -> None: - span_gid = GlobalID("Span", "1") + name = fake.pystr() request_body = { "data": [ { - "span_id": str(span_gid), - "name": "Test Annotation", + "span_id": "7e2f08cb43bbf521", + "name": name, "annotator_kind": "HUMAN", "result": { "label": "True", @@ -60,19 +62,17 @@ async def test_rest_span_annotation( ] } - response = await httpx_client.post("/v1/span_annotations", json=request_body) + response = await httpx_client.post(f"v1/span_annotations?sync={sync}", json=request_body) assert response.status_code == 200 - - data = response.json()["data"] - annotation_gid = GlobalID.from_id(data[0]["id"]) - annotation_id = from_global_id_with_expected_type(annotation_gid, "SpanAnnotation") + if not sync: + await sleep(0.1) async with db() as session: orm_annotation = await session.scalar( - select(models.SpanAnnotation).where(models.SpanAnnotation.id == annotation_id) + select(models.SpanAnnotation).where(models.SpanAnnotation.name == name) ) assert orm_annotation is not None - assert orm_annotation.name == "Test Annotation" + assert orm_annotation.name == name assert orm_annotation.annotator_kind == "HUMAN" assert orm_annotation.label == "True" assert orm_annotation.score == 0.95 @@ -94,7 +94,7 @@ async def project_with_a_single_trace_and_span( trace_id = await session.scalar( insert(models.Trace) .values( - trace_id="1", + trace_id="649993371fa95c788177f739b7423818", project_rowid=project_row_id, start_time=datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"), end_time=datetime.fromisoformat("2021-01-01T00:01:00.000+00:00"), @@ -105,7 +105,7 @@ async def project_with_a_single_trace_and_span( insert(models.Span) .values( trace_rowid=trace_id, - span_id="1", + span_id="7e2f08cb43bbf521", parent_id=None, name="chain span", span_kind="CHAIN", diff --git a/tests/server/api/routers/v1/test_traces.py b/tests/server/api/routers/v1/test_traces.py index 51bdcb1f42..96464ac0f7 100644 --- a/tests/server/api/routers/v1/test_traces.py +++ b/tests/server/api/routers/v1/test_traces.py @@ -1,13 +1,13 @@ +from asyncio import sleep from datetime import datetime from typing import Any import httpx import pytest +from faker import Faker from phoenix.db import models -from phoenix.server.api.types.node import from_global_id_with_expected_type from phoenix.server.types import DbSessionFactory from sqlalchemy import insert, select -from strawberry.relay import GlobalID @pytest.fixture @@ -24,7 +24,7 @@ async def project_with_a_single_trace_and_span( trace_id = await session.scalar( insert(models.Trace) .values( - trace_id="1", + trace_id="82c6c9c33ccc586e0d3bdf46b20db309", project_rowid=project_row_id, start_time=datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"), end_time=datetime.fromisoformat("2021-01-01T00:01:00.000+00:00"), @@ -35,7 +35,7 @@ async def project_with_a_single_trace_and_span( insert(models.Span) .values( trace_rowid=trace_id, - span_id="1", + span_id="f0d808aedd5591b6", parent_id=None, name="chain span", span_kind="CHAIN", @@ -56,17 +56,20 @@ async def project_with_a_single_trace_and_span( ) +@pytest.mark.parametrize("sync", [False, True]) async def test_rest_trace_annotation( db: DbSessionFactory, httpx_client: httpx.AsyncClient, project_with_a_single_trace_and_span: Any, + sync: bool, + fake: Faker, ) -> None: - trace_gid = GlobalID("Trace", "1") + name = fake.pystr() request_body = { "data": [ { - "trace_id": str(trace_gid), - "name": "Test Annotation", + "trace_id": "82c6c9c33ccc586e0d3bdf46b20db309", + "name": name, "annotator_kind": "HUMAN", "result": { "label": "True", @@ -78,19 +81,17 @@ async def test_rest_trace_annotation( ] } - response = await httpx_client.post("/v1/trace_annotations", json=request_body) + response = await httpx_client.post(f"v1/trace_annotations?sync={sync}", json=request_body) assert response.status_code == 200 - - data = response.json()["data"] - annotation_gid = GlobalID.from_id(data[0]["id"]) - annotation_id = from_global_id_with_expected_type(annotation_gid, "TraceAnnotation") + if not sync: + await sleep(0.1) async with db() as session: orm_annotation = await session.scalar( - select(models.TraceAnnotation).where(models.TraceAnnotation.id == annotation_id) + select(models.TraceAnnotation).where(models.TraceAnnotation.name == name) ) assert orm_annotation is not None - assert orm_annotation.name == "Test Annotation" + assert orm_annotation.name == name assert orm_annotation.annotator_kind == "HUMAN" assert orm_annotation.label == "True" assert orm_annotation.score == 0.95