From 4709ddb292ee57a0554c1f14d21072f7c36dd547 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 30 Apr 2024 20:54:33 -0700 Subject: [PATCH 01/74] add test notebook --- integration-tests/pagination_queries.ipynb | 220 +++++++++++++++++++++ 1 file changed, 220 insertions(+) create mode 100644 integration-tests/pagination_queries.ipynb diff --git a/integration-tests/pagination_queries.ipynb b/integration-tests/pagination_queries.ipynb new file mode 100644 index 0000000000..20cad1ce70 --- /dev/null +++ b/integration-tests/pagination_queries.ipynb @@ -0,0 +1,220 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Example Queries for Cursor-Based Pagination" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from phoenix.db import models\n", + "from sqlalchemy import and_, create_engine, select\n", + "from sqlalchemy.orm import aliased, sessionmaker" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PostgresSession = sessionmaker(\n", + " create_engine(\n", + " \"postgresql+psycopg://localhost:5432/postgres?user=postgres&password=mysecretpassword\",\n", + " echo=True,\n", + " ),\n", + " expire_on_commit=False,\n", + ")\n", + "SqliteSession = sessionmaker(\n", + " create_engine(\"sqlite:////Users/xandersong/.phoenix/phoenix.db\", echo=True),\n", + " expire_on_commit=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- filter: None\n", + "- sort: None\n", + "- after: None\n", + "- before: None\n", + "- first: 10\n", + "- last: None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "page_size = 10\n", + "with SqliteSession() as session:\n", + " span_ids = session.scalars(select(models.Span.id).limit(page_size)).all()\n", + "span_ids" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- filter: None\n", + "- sort: None\n", + "- after: 5\n", + "- before: None\n", + "- first: 10\n", + "- last: None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cursor = 10\n", + "page_size = 10\n", + "with SqliteSession() as session:\n", + " span_ids = session.scalars(\n", + " select(models.Span.id).where(models.Span.id >= cursor).limit(page_size)\n", + " ).all()\n", + "span_ids" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- filter: \"\"\"span_kind == 'LLM'\"\"\"\n", + "- sort: None\n", + "- after: 10\n", + "- before: None\n", + "- first: 10\n", + "- last: None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cursor = 10\n", + "page_size = 10\n", + "with SqliteSession() as session:\n", + " span_ids = session.scalars(\n", + " select(models.Span.id)\n", + " .where(models.Span.span_kind == \"LLM\")\n", + " .where(models.Span.id >= cursor)\n", + " .limit(page_size)\n", + " ).all()\n", + "span_ids" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- filter: None\n", + "- sort: prompt token count\n", + "- after: 10\n", + "- before: None\n", + "- first: 10\n", + "- last: None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cursor = 100\n", + "page_size = 10\n", + "with SqliteSession() as session:\n", + " prompt_tokens = models.Span.attributes[[\"llm\", \"token_count\", \"prompt\"]]\n", + " for index, span in enumerate(\n", + " session.execute(\n", + " select(models.Span.id, prompt_tokens)\n", + " .where(models.Span.id >= cursor)\n", + " .order_by(prompt_tokens, models.Span.id)\n", + " .limit(page_size)\n", + " )\n", + " ):\n", + " print(f\"{index=} {span=}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- filter: attributes[\"llm.prompt.tokens] < 50\n", + "- sort: evals[\"Q&A Correctness\"].score DESC\n", + "- after: 10\n", + "- before: None\n", + "- first: 10\n", + "- last: None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "span" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "span.attributes[\"llm\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cursor = 100\n", + "page_size = 10\n", + "A = aliased(models.SpanAnnotation, name=\"A\")\n", + "with SqliteSession() as session:\n", + " for index, (span, score) in enumerate(\n", + " session.execute(\n", + " select(models.Span, A.score)\n", + " .join(\n", + " A,\n", + " onclause=and_(\n", + " A.span_rowid == models.Span.id,\n", + " A.name == \"Q&A Correctness\",\n", + " A.annotator_kind == \"LLM\",\n", + " ),\n", + " )\n", + " .where(models.Span.id >= cursor)\n", + " .order_by(A.score.desc(), models.Span.id)\n", + " )\n", + " ):\n", + " print(f\"{index=} {span.id=} {score=}\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 15bc0cb95dfbd9e9e56d60c5f072a9cad87e6abf Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Wed, 1 May 2024 09:38:01 -0700 Subject: [PATCH 02/74] refactor: improved summaries with dataloaders (#3039) --- app/schema.graphql | 7 +- .../graphql_query_performance.ipynb | 418 ++++++++++++++++++ pyproject.toml | 1 + src/phoenix/db/bulk_inserter.py | 19 +- src/phoenix/db/engines.py | 60 ++- src/phoenix/db/helpers.py | 47 ++ src/phoenix/db/models.py | 5 +- src/phoenix/server/api/context.py | 39 +- .../server/api/dataloaders/__init__.py | 14 +- .../document_evaluation_summaries.py | 124 ++++++ .../api/dataloaders/document_evaluations.py | 22 +- .../dataloaders/document_retrieval_metrics.py | 70 +-- .../api/dataloaders/evaluation_summaries.py | 123 ++++++ .../api/dataloaders/latency_ms_quantile.py | 200 ++++++--- .../dataloaders/min_start_or_max_end_times.py | 75 ++++ .../server/api/dataloaders/record_counts.py | 102 +++++ .../api/dataloaders/span_descendants.py | 18 +- .../api/dataloaders/span_evaluations.py | 22 +- .../server/api/dataloaders/token_counts.py | 115 +++++ .../api/dataloaders/trace_evaluations.py | 22 +- .../server/api/types/EvaluationSummary.py | 63 ++- src/phoenix/server/api/types/Project.py | 207 +++------ src/phoenix/server/app.py | 18 +- src/phoenix/trace/dsl/query.py | 52 ++- tests/conftest.py | 81 +++- tests/server/api/dataloaders/conftest.py | 89 ++++ .../dataloaders/test_evaluation_summaries.py | 75 ++++ .../dataloaders/test_latency_ms_quantiles.py | 63 +++ .../api/dataloaders/test_record_counts.py | 55 +++ .../api/dataloaders/test_token_counts.py | 55 +++ 30 files changed, 1839 insertions(+), 422 deletions(-) create mode 100644 integration-tests/graphql_query_performance.ipynb create mode 100644 src/phoenix/db/helpers.py create mode 100644 src/phoenix/server/api/dataloaders/document_evaluation_summaries.py create mode 100644 src/phoenix/server/api/dataloaders/evaluation_summaries.py create mode 100644 src/phoenix/server/api/dataloaders/min_start_or_max_end_times.py create mode 100644 src/phoenix/server/api/dataloaders/record_counts.py create mode 100644 src/phoenix/server/api/dataloaders/token_counts.py create mode 100644 tests/server/api/dataloaders/conftest.py create mode 100644 tests/server/api/dataloaders/test_evaluation_summaries.py create mode 100644 tests/server/api/dataloaders/test_latency_ms_quantiles.py create mode 100644 tests/server/api/dataloaders/test_record_counts.py create mode 100644 tests/server/api/dataloaders/test_token_counts.py diff --git a/app/schema.graphql b/app/schema.graphql index 10fd3f7875..e127cb7f9c 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -543,10 +543,13 @@ type Project implements Node { gradientEndColor: String! startTime: DateTime endTime: DateTime - recordCount(timeRange: TimeRange): Int! + recordCount(timeRange: TimeRange, filterCondition: String): Int! traceCount(timeRange: TimeRange): Int! - tokenCountTotal(timeRange: TimeRange): Int! + tokenCountTotal(timeRange: TimeRange, filterCondition: String): Int! + tokenCountPrompt(timeRange: TimeRange, filterCondition: String): Int! + tokenCountCompletion(timeRange: TimeRange, filterCondition: String): Int! latencyMsQuantile(probability: Float!, timeRange: TimeRange): Float + spanLatencyMsQuantile(probability: Float!, timeRange: TimeRange, filterCondition: String): Float trace(traceId: ID!): Trace spans(timeRange: TimeRange, first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection! diff --git a/integration-tests/graphql_query_performance.ipynb b/integration-tests/graphql_query_performance.ipynb new file mode 100644 index 0000000000..faf4ac780e --- /dev/null +++ b/integration-tests/graphql_query_performance.ipynb @@ -0,0 +1,418 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "from dictdiffer import diff\n", + "from gql import Client, gql\n", + "from gql.transport.requests import RequestsHTTPTransport\n", + "\n", + "new_url = \"http://127.0.0.1:6006/graphql\"\n", + "old_url = \"http://127.0.0.1:6005/graphql\"\n", + "\n", + "client_new_url = Client(\n", + " transport=RequestsHTTPTransport(url=new_url, timeout=60), fetch_schema_from_transport=True\n", + ")\n", + "client_old_url = Client(\n", + " transport=RequestsHTTPTransport(url=old_url, timeout=60), fetch_schema_from_transport=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Span Evaluation Summary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "span_eval_smry = gql(\n", + " \"\"\"\n", + " query MyQuery {\n", + " projects {\n", + " edges {\n", + " node {\n", + " hallucination: spanEvaluationSummary(evaluationName: \"Hallucination\") {\n", + " count\n", + " labelCount\n", + " labelFractions {\n", + " fraction\n", + " label\n", + " }\n", + " labels\n", + " meanScore\n", + " scoreCount\n", + " }\n", + " qa_correctness: spanEvaluationSummary(evaluationName: \"Q&A Correctness\") {\n", + " count\n", + " labelCount\n", + " labelFractions {\n", + " fraction\n", + " label\n", + " }\n", + " labels\n", + " meanScore\n", + " scoreCount\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + "\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Check Diff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "list(\n", + " diff(\n", + " client_old_url.execute(span_eval_smry),\n", + " client_new_url.execute(span_eval_smry),\n", + " tolerance=0.0001,\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(json.dumps(client_new_url.execute(span_eval_smry), indent=2))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Runtimes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "_ = client_new_url.execute(span_eval_smry)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "_ = client_old_url.execute(span_eval_smry)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Document Evaluation Summary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "doc_eval_smry = gql(\n", + " \"\"\"\n", + " query MyQuery {\n", + " projects {\n", + " edges {\n", + " node {\n", + " documentEvaluationSummary(evaluationName: \"Relevance\") {\n", + " averageNdcg\n", + " averagePrecision\n", + " countHit\n", + " countNdcg\n", + " countPrecision\n", + " countReciprocalRank\n", + " evaluationName\n", + " hitRate\n", + " meanReciprocalRank\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + "\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Check Diff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "list(\n", + " diff(\n", + " client_old_url.execute(doc_eval_smry),\n", + " client_new_url.execute(doc_eval_smry),\n", + " tolerance=0.0001,\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(json.dumps(client_new_url.execute(doc_eval_smry), indent=2))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Runtimes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "_ = client_new_url.execute(doc_eval_smry)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "_ = client_old_url.execute(doc_eval_smry)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Latency Ms Quantiles" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "latency_ms_qtl = gql(\n", + " \"\"\"\n", + " query MyQuery {\n", + " projects {\n", + " edges {\n", + " node {\n", + " _1: latencyMsQuantile(probability: 0.1)\n", + " _2: latencyMsQuantile(probability: 0.2)\n", + " _3: latencyMsQuantile(probability: 0.3)\n", + " _4: latencyMsQuantile(probability: 0.4)\n", + " _5: latencyMsQuantile(probability: 0.5)\n", + " _6: latencyMsQuantile(probability: 0.6)\n", + " _7: latencyMsQuantile(probability: 0.7)\n", + " _8: latencyMsQuantile(probability: 0.8)\n", + " _9: latencyMsQuantile(probability: 0.9)\n", + " }\n", + " }\n", + " }\n", + " }\n", + "\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Check Diff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "list(\n", + " diff(\n", + " client_old_url.execute(latency_ms_qtl),\n", + " client_new_url.execute(latency_ms_qtl),\n", + " tolerance=0.0001,\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(json.dumps(client_new_url.execute(latency_ms_qtl), indent=2))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Runtimes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "_ = client_new_url.execute(latency_ms_qtl)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "_ = client_old_url.execute(latency_ms_qtl)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Start and End Times" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "start_end_times = gql(\n", + " \"\"\"\n", + " query MyQuery {\n", + " projects {\n", + " edges {\n", + " node {\n", + " endTime\n", + " startTime\n", + " }\n", + " }\n", + " }\n", + " }\n", + "\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Check Diff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "list(\n", + " diff(\n", + " client_old_url.execute(start_end_times),\n", + " client_new_url.execute(start_end_times),\n", + " tolerance=0.0001,\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(json.dumps(client_new_url.execute(start_end_times), indent=2))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Runtimes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "_ = client_new_url.execute(start_end_times)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "_ = client_old_url.execute(start_end_times)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 3e4d54e8e7..1096063b45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ dependencies = [ "sqlalchemy[asyncio]>=2.0.4, <3", "alembic>=1.3.0, <2", "aiosqlite", + "aioitertools", "sqlean.py>=3.45.1", ] dynamic = ["version"] diff --git a/src/phoenix/db/bulk_inserter.py b/src/phoenix/db/bulk_inserter.py index 2f2364b0fd..8d4de9c868 100644 --- a/src/phoenix/db/bulk_inserter.py +++ b/src/phoenix/db/bulk_inserter.py @@ -23,6 +23,7 @@ import phoenix.trace.v1 as pb from phoenix.db import models +from phoenix.db.helpers import SupportedSQLDialect, num_docs_col from phoenix.exceptions import PhoenixException from phoenix.trace.attributes import get_attribute_value from phoenix.trace.schemas import Span, SpanStatusCode @@ -198,19 +199,23 @@ async def _insert_evaluation(session: AsyncSession, evaluation: pb.Evaluation) - ) elif evaluation_kind == "document_retrieval_id": span_id = evaluation.subject_id.document_retrieval_id.span_id - if not ( - span_rowid := await session.scalar( - select(models.Span.id).where(models.Span.span_id == span_id) - ) - ): + dialect = SupportedSQLDialect(session.bind.dialect.name) + stmt = select(models.Span.id, num_docs_col(dialect)).where(models.Span.span_id == span_id) + if not (row := (await session.execute(stmt)).first()): raise InsertEvaluationError( f"Cannot insert a document evaluation for a missing span: {span_id=}" ) + document_position = evaluation.subject_id.document_retrieval_id.document_position + if row.num_docs is None or row.num_docs <= document_position: + raise InsertEvaluationError( + f"Cannot insert a document evaluation for a non-existent " + f"document position: {span_id=}, {document_position=}" + ) await session.scalar( insert(models.DocumentAnnotation) .values( - span_rowid=span_rowid, - document_position=evaluation.subject_id.document_retrieval_id.document_position, + span_rowid=row.id, + document_position=document_position, name=evaluation_name, label=label, score=score, diff --git a/src/phoenix/db/engines.py b/src/phoenix/db/engines.py index ef93db7ba7..24843e94ec 100644 --- a/src/phoenix/db/engines.py +++ b/src/phoenix/db/engines.py @@ -10,15 +10,13 @@ import sqlean from sqlalchemy import URL, event, make_url from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from typing_extensions import assert_never +from phoenix.db.helpers import SupportedSQLDialect from phoenix.db.migrate import migrate_in_thread from phoenix.db.models import init_models -# supported backends -_SQLITE = "sqlite" -_POSTGRESQL = "postgresql" - -sqlean.extensions.enable("text") +sqlean.extensions.enable("text", "stats") def set_sqlite_pragma(connection: Connection, _: Any) -> None: @@ -42,12 +40,10 @@ def get_async_db_url(connection_str: str) -> URL: url = make_url(connection_str) if not url.database: raise ValueError("Failed to parse database from connection string") - backend = url.get_backend_name() - if backend == _SQLITE: - if url.database.startswith(":memory:"): - url = url.set(query={"cache": "shared"}) + backend = SupportedSQLDialect(url.get_backend_name()) + if backend is SupportedSQLDialect.SQLITE: return url.set(drivername="sqlite+aiosqlite") - if backend == _POSTGRESQL: + elif backend is SupportedSQLDialect.POSTGRESQL: url = url.set(drivername="postgresql+asyncpg") # For some reason username and password cannot be parsed from the typical slot # So we need to parse them out manually @@ -58,30 +54,43 @@ def get_async_db_url(connection_str: str) -> URL: username=None, ) return url - raise ValueError(f"Unsupported backend: {backend}") + else: + assert_never(backend) -def create_engine(connection_str: str, echo: bool = False) -> AsyncEngine: +def create_engine( + connection_str: str, + migrate: bool = True, + echo: bool = False, +) -> AsyncEngine: """ Factory to create a SQLAlchemy engine from a URL string. """ url = make_url(connection_str) if not url.database: raise ValueError("Failed to parse database from connection string") - backend = url.get_backend_name() - if backend == _SQLITE: - return aio_sqlite_engine(url=url, echo=echo) - if backend == _POSTGRESQL: - return aio_postgresql_engine(url=url, echo=echo) - raise ValueError(f"Unsupported backend: {backend}") + backend = SupportedSQLDialect(url.get_backend_name()) + url = get_async_db_url(url.render_as_string(hide_password=False)) + if backend is SupportedSQLDialect.SQLITE: + return aio_sqlite_engine(url=url, migrate=migrate, echo=echo) + elif backend is SupportedSQLDialect.POSTGRESQL: + return aio_postgresql_engine(url=url, migrate=migrate, echo=echo) + else: + assert_never(backend) def aio_sqlite_engine( url: URL, + migrate: bool = True, echo: bool = False, + shared_cache: bool = True, ) -> AsyncEngine: - async_url = get_async_db_url(url.render_as_string()) - database = async_url.render_as_string().partition("///")[-1] # includes query + database = url.database or ":memory:" + if database.startswith("file:"): + database = database[5:] + if database.startswith(":memory:") and shared_cache: + url = url.set(query={**url.query, "cache": "shared"}, database=":memory:") + database = url.render_as_string().partition("///")[-1] def async_creator() -> aiosqlite.Connection: conn = aiosqlite.Connection( @@ -92,12 +101,14 @@ def async_creator() -> aiosqlite.Connection: return conn engine = create_async_engine( - url=async_url, + url=url, echo=echo, json_serializer=_dumps, async_creator=async_creator, ) event.listen(engine.sync_engine, "connect", set_sqlite_pragma) + if not migrate: + return engine if database.startswith(":memory:"): try: asyncio.get_running_loop() @@ -112,13 +123,14 @@ def async_creator() -> aiosqlite.Connection: def aio_postgresql_engine( url: URL, + migrate: bool = True, echo: bool = False, ) -> AsyncEngine: - # Swap out the engine - async_url = get_async_db_url(url.render_as_string(hide_password=False)) - engine = create_async_engine(url=async_url, echo=echo, json_serializer=_dumps) + engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps) # TODO(persistence): figure out the postgres pragma # event.listen(engine.sync_engine, "connect", set_pragma) + if not migrate: + return engine migrate_in_thread(engine.url) return engine diff --git a/src/phoenix/db/helpers.py b/src/phoenix/db/helpers.py new file mode 100644 index 0000000000..fd5e46bc33 --- /dev/null +++ b/src/phoenix/db/helpers.py @@ -0,0 +1,47 @@ +from enum import Enum +from typing import Any + +from openinference.semconv.trace import ( + OpenInferenceSpanKindValues, + RerankerAttributes, + SpanAttributes, +) +from sqlalchemy import Integer, SQLColumnExpression, case, func +from typing_extensions import assert_never + +from phoenix.db import models + + +class SupportedSQLDialect(Enum): + SQLITE = "sqlite" + POSTGRESQL = "postgresql" + + @classmethod + def _missing_(cls, value: Any) -> "SupportedSQLDialect": + if isinstance(value, str) and value and value.isascii(): + return cls(value.lower()) + raise ValueError(f"`{value}` is not a supported SQL backend/dialect.") + + +def num_docs_col(dialect: SupportedSQLDialect) -> SQLColumnExpression[Integer]: + if dialect is SupportedSQLDialect.POSTGRESQL: + array_length = func.jsonb_array_length + elif dialect is SupportedSQLDialect.SQLITE: + array_length = func.json_array_length + else: + assert_never(dialect) + retrieval_docs = models.Span.attributes[_RETRIEVAL_DOCUMENTS] + num_retrieval_docs = array_length(retrieval_docs) + reranker_docs = models.Span.attributes[_RERANKER_OUTPUT_DOCUMENTS] + num_reranker_docs = array_length(reranker_docs) + return case( + ( + func.upper(models.Span.span_kind) == OpenInferenceSpanKindValues.RERANKER.value.upper(), + num_reranker_docs, + ), + else_=num_retrieval_docs, + ).label("num_docs") + + +_RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS.split(".") +_RERANKER_OUTPUT_DOCUMENTS = RerankerAttributes.RERANKER_OUTPUT_DOCUMENTS.split(".") diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index af12d7c04a..f78fffcf8f 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -247,7 +247,8 @@ def _(element: Any, compiler: Any, **kw: Any) -> Any: # See https://docs.sqlalchemy.org/en/20/core/compiler.html start_time, end_time = list(element.clauses) return compiler.process( - (func.extract("EPOCH", end_time) - func.extract("EPOCH", start_time)) * 1000, **kw + func.round((func.extract("EPOCH", end_time) - func.extract("EPOCH", start_time)) * 1000, 1), + **kw, ) @@ -259,7 +260,7 @@ def _(element: Any, compiler: Any, **kw: Any) -> Any: # FIXME: We don't know why sqlite returns a slightly different value. # postgresql is correct because it matches the value computed by Python. # unixepoch() gives the same results. - (func.julianday(end_time) - func.julianday(start_time)) * 86_400_000, + func.round((func.julianday(end_time) - func.julianday(start_time)) * 86_400_000, 1), **kw, ) diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index a1729f8d0a..a3ac1e52a0 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -1,31 +1,42 @@ from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import AsyncContextManager, Callable, List, Optional, Tuple, Union +from typing import AsyncContextManager, Callable, Optional, Union from sqlalchemy.ext.asyncio import AsyncSession from starlette.requests import Request from starlette.responses import Response from starlette.websockets import WebSocket -from strawberry.dataloader import DataLoader from phoenix.core.model_schema import Model -from phoenix.db import models -from phoenix.server.api.input_types.TimeRange import TimeRange -from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics -from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation, TraceEvaluation +from phoenix.server.api.dataloaders import ( + DocumentEvaluationsDataLoader, + DocumentEvaluationSummaryDataLoader, + DocumentRetrievalMetricsDataLoader, + EvaluationSummaryDataLoader, + LatencyMsQuantileDataLoader, + MinStartOrMaxEndTimeDataLoader, + RecordCountDataLoader, + SpanDescendantsDataLoader, + SpanEvaluationsDataLoader, + TokenCountDataLoader, + TraceEvaluationsDataLoader, +) @dataclass class DataLoaders: - latency_ms_quantile: DataLoader[Tuple[int, Optional[TimeRange], float], Optional[float]] - span_evaluations: DataLoader[int, List[SpanEvaluation]] - document_evaluations: DataLoader[int, List[DocumentEvaluation]] - trace_evaluations: DataLoader[int, List[TraceEvaluation]] - document_retrieval_metrics: DataLoader[ - Tuple[int, Optional[str], int], List[DocumentRetrievalMetrics] - ] - span_descendants: DataLoader[str, List[models.Span]] + document_evaluation_summaries: DocumentEvaluationSummaryDataLoader + document_evaluations: DocumentEvaluationsDataLoader + document_retrieval_metrics: DocumentRetrievalMetricsDataLoader + evaluation_summaries: EvaluationSummaryDataLoader + latency_ms_quantile: LatencyMsQuantileDataLoader + min_start_or_max_end_times: MinStartOrMaxEndTimeDataLoader + record_counts: RecordCountDataLoader + span_descendants: SpanDescendantsDataLoader + span_evaluations: SpanEvaluationsDataLoader + token_counts: TokenCountDataLoader + trace_evaluations: TraceEvaluationsDataLoader @dataclass diff --git a/src/phoenix/server/api/dataloaders/__init__.py b/src/phoenix/server/api/dataloaders/__init__.py index efa1acf78b..fa84039f1c 100644 --- a/src/phoenix/server/api/dataloaders/__init__.py +++ b/src/phoenix/server/api/dataloaders/__init__.py @@ -1,13 +1,25 @@ +from .document_evaluation_summaries import DocumentEvaluationSummaryDataLoader from .document_evaluations import DocumentEvaluationsDataLoader from .document_retrieval_metrics import DocumentRetrievalMetricsDataLoader +from .evaluation_summaries import EvaluationSummaryDataLoader from .latency_ms_quantile import LatencyMsQuantileDataLoader +from .min_start_or_max_end_times import MinStartOrMaxEndTimeDataLoader +from .record_counts import RecordCountDataLoader +from .span_descendants import SpanDescendantsDataLoader from .span_evaluations import SpanEvaluationsDataLoader +from .token_counts import TokenCountDataLoader from .trace_evaluations import TraceEvaluationsDataLoader __all__ = [ + "DocumentEvaluationSummaryDataLoader", "DocumentEvaluationsDataLoader", + "DocumentRetrievalMetricsDataLoader", + "EvaluationSummaryDataLoader", "LatencyMsQuantileDataLoader", + "MinStartOrMaxEndTimeDataLoader", + "RecordCountDataLoader", + "SpanDescendantsDataLoader", "SpanEvaluationsDataLoader", + "TokenCountDataLoader", "TraceEvaluationsDataLoader", - "DocumentRetrievalMetricsDataLoader", ] diff --git a/src/phoenix/server/api/dataloaders/document_evaluation_summaries.py b/src/phoenix/server/api/dataloaders/document_evaluation_summaries.py new file mode 100644 index 0000000000..89025a71d8 --- /dev/null +++ b/src/phoenix/server/api/dataloaders/document_evaluation_summaries.py @@ -0,0 +1,124 @@ +from collections import defaultdict +from datetime import datetime +from typing import ( + Any, + AsyncContextManager, + Callable, + DefaultDict, + List, + Optional, + Tuple, +) + +import numpy as np +from aioitertools.itertools import groupby +from sqlalchemy import Select, select +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry.dataloader import AbstractCache, DataLoader +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.db.helpers import SupportedSQLDialect, num_docs_col +from phoenix.metrics.retrieval_metrics import RetrievalMetrics +from phoenix.server.api.input_types.TimeRange import TimeRange +from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary +from phoenix.trace.dsl import SpanFilter + +ProjectRowId: TypeAlias = int +TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]] +FilterCondition: TypeAlias = Optional[str] +EvalName: TypeAlias = str + +Segment: TypeAlias = Tuple[ProjectRowId, TimeInterval, FilterCondition] +Param: TypeAlias = EvalName + +Key: TypeAlias = Tuple[ProjectRowId, Optional[TimeRange], FilterCondition, EvalName] +Result: TypeAlias = Optional[DocumentEvaluationSummary] +ResultPosition: TypeAlias = int +DEFAULT_VALUE: Result = None + + +def _cache_key_fn(key: Key) -> Tuple[Segment, Param]: + project_rowid, time_range, filter_condition, eval_name = key + interval = ( + (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None) + ) + return (project_rowid, interval, filter_condition), eval_name + + +class DocumentEvaluationSummaryDataLoader(DataLoader[Key, Result]): + def __init__( + self, + db: Callable[[], AsyncContextManager[AsyncSession]], + cache_map: Optional[AbstractCache[Key, Result]] = None, + ) -> None: + super().__init__( + load_fn=self._load_fn, + cache_key_fn=_cache_key_fn, + cache_map=cache_map, + ) + self._db = db + + async def _load_fn(self, keys: List[Key]) -> List[Result]: + results: List[Result] = [DEFAULT_VALUE] * len(keys) + arguments: DefaultDict[ + Segment, + DefaultDict[Param, List[ResultPosition]], + ] = defaultdict(lambda: defaultdict(list)) + for position, key in enumerate(keys): + segment, param = _cache_key_fn(key) + arguments[segment][param].append(position) + for segment, params in arguments.items(): + async with self._db() as session: + dialect = SupportedSQLDialect(session.bind.dialect.name) + stmt = _get_stmt(dialect, segment, *params.keys()) + data = await session.stream(stmt) + async for eval_name, group in groupby(data, lambda d: d.name): + metrics_collection = [] + async for (_, num_docs), subgroup in groupby( + group, lambda g: (g.id, g.num_docs) + ): + scores = [np.nan] * num_docs + for row in subgroup: + scores[row.document_position] = row.score + metrics_collection.append(RetrievalMetrics(scores)) + summary = DocumentEvaluationSummary( + evaluation_name=eval_name, + metrics_collection=metrics_collection, + ) + for position in params[eval_name]: + results[position] = summary + return results + + +def _get_stmt( + dialect: SupportedSQLDialect, + segment: Segment, + *eval_names: Param, +) -> Select[Any]: + project_rowid, (start_time, end_time), filter_condition = segment + mda = models.DocumentAnnotation + stmt = ( + select( + mda.name, + models.Span.id, + num_docs_col(dialect), + mda.score, + mda.document_position, + ) + .join(models.Trace) + .where(models.Trace.project_rowid == project_rowid) + .join(mda) + .where(mda.name.in_(eval_names)) + .where(mda.annotator_kind == "LLM") + .where(mda.score.is_not(None)) + .order_by(mda.name, models.Span.id) + ) + if start_time: + stmt = stmt.where(start_time <= models.Span.start_time) + if end_time: + stmt = stmt.where(models.Span.start_time < end_time) + if filter_condition: + span_filter = SpanFilter(condition=filter_condition) + stmt = span_filter(stmt) + return stmt diff --git a/src/phoenix/server/api/dataloaders/document_evaluations.py b/src/phoenix/server/api/dataloaders/document_evaluations.py index f45ea60a98..13c7582c6d 100644 --- a/src/phoenix/server/api/dataloaders/document_evaluations.py +++ b/src/phoenix/server/api/dataloaders/document_evaluations.py @@ -6,7 +6,7 @@ List, ) -from sqlalchemy import and_, select +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from strawberry.dataloader import DataLoader from typing_extensions import TypeAlias @@ -15,24 +15,22 @@ from phoenix.server.api.types.Evaluation import DocumentEvaluation Key: TypeAlias = int +Result: TypeAlias = List[DocumentEvaluation] -class DocumentEvaluationsDataLoader(DataLoader[Key, List[DocumentEvaluation]]): +class DocumentEvaluationsDataLoader(DataLoader[Key, Result]): def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None: super().__init__(load_fn=self._load_fn) self._db = db - async def _load_fn(self, keys: List[Key]) -> List[List[DocumentEvaluation]]: - document_evaluations_by_id: DefaultDict[Key, List[DocumentEvaluation]] = defaultdict(list) + async def _load_fn(self, keys: List[Key]) -> List[Result]: + document_evaluations_by_id: DefaultDict[Key, Result] = defaultdict(list) + mda = models.DocumentAnnotation async with self._db() as session: - for document_evaluation in await session.scalars( - select(models.DocumentAnnotation).where( - and_( - models.DocumentAnnotation.span_rowid.in_(keys), - models.DocumentAnnotation.annotator_kind == "LLM", - ) - ) - ): + data = await session.stream_scalars( + select(mda).where(mda.span_rowid.in_(keys)).where(mda.annotator_kind == "LLM") + ) + async for document_evaluation in data: document_evaluations_by_id[document_evaluation.span_rowid].append( DocumentEvaluation.from_sql_document_annotation(document_evaluation) ) diff --git a/src/phoenix/server/api/dataloaders/document_retrieval_metrics.py b/src/phoenix/server/api/dataloaders/document_retrieval_metrics.py index 5af058a12e..fde8b332e2 100644 --- a/src/phoenix/server/api/dataloaders/document_retrieval_metrics.py +++ b/src/phoenix/server/api/dataloaders/document_retrieval_metrics.py @@ -1,5 +1,4 @@ from collections import defaultdict -from itertools import groupby from typing import ( AsyncContextManager, Callable, @@ -12,6 +11,7 @@ ) import numpy as np +from aioitertools.itertools import groupby from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from strawberry.dataloader import DataLoader @@ -24,15 +24,17 @@ RowId: TypeAlias = int NumDocs: TypeAlias = int EvalName: TypeAlias = Optional[str] + Key: TypeAlias = Tuple[RowId, EvalName, NumDocs] +Result: TypeAlias = List[DocumentRetrievalMetrics] -class DocumentRetrievalMetricsDataLoader(DataLoader[Key, List[DocumentRetrievalMetrics]]): +class DocumentRetrievalMetricsDataLoader(DataLoader[Key, Result]): def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None: super().__init__(load_fn=self._load_fn) self._db = db - async def _load_fn(self, keys: List[Key]) -> List[List[DocumentRetrievalMetrics]]: + async def _load_fn(self, keys: List[Key]) -> List[Result]: mda = models.DocumentAnnotation stmt = ( select( @@ -57,40 +59,40 @@ async def _load_fn(self, keys: List[Key]) -> List[List[DocumentRetrievalMetrics] stmt = stmt.where(mda.name.in_(all_eval_names)) max_position = max(num_docs for _, _, num_docs in keys) stmt = stmt.where(mda.document_position < max_position) - async with self._db() as session: - data = await session.execute(stmt) - if not data: - return [[] for _ in keys] - results: Dict[Key, List[DocumentRetrievalMetrics]] = {key: [] for key in keys} + results: Dict[Key, Result] = {key: [] for key in keys} requested_num_docs: DefaultDict[Tuple[RowId, EvalName], Set[NumDocs]] = defaultdict(set) for row_id, eval_name, num_docs in results.keys(): requested_num_docs[(row_id, eval_name)].add(num_docs) - for (span_rowid, name), group in groupby(data, lambda r: (r.span_rowid, r.name)): - # We need to fulfill two types of potential requests: 1. when it - # specifies an evaluation name, and 2. when it doesn't care about - # the evaluation name by specifying None. - max_requested_num_docs = max( - ( - num_docs - for eval_name in (name, None) - for num_docs in (requested_num_docs.get((span_rowid, eval_name)) or ()) - ), - default=0, - ) - if max_requested_num_docs <= 0: - # We have over-fetched. Skip this group. - continue - scores = [np.nan] * max_requested_num_docs - for row in group: - # Length check is necessary due to over-fetching. - if row.document_position < len(scores): - scores[row.document_position] = row.score - for eval_name in (name, None): - for num_docs in requested_num_docs.get((span_rowid, eval_name)) or (): - metrics = RetrievalMetrics(scores[:num_docs]) - doc_metrics = DocumentRetrievalMetrics(evaluation_name=name, metrics=metrics) - key = (span_rowid, eval_name, num_docs) - results[key].append(doc_metrics) + async with self._db() as session: + data = await session.stream(stmt) + async for (span_rowid, name), group in groupby(data, lambda r: (r.span_rowid, r.name)): + # We need to fulfill two types of potential requests: 1. when it + # specifies an evaluation name, and 2. when it doesn't care about + # the evaluation name by specifying None. + max_requested_num_docs = max( + ( + num_docs + for eval_name in (name, None) + for num_docs in (requested_num_docs.get((span_rowid, eval_name)) or ()) + ), + default=0, + ) + if max_requested_num_docs <= 0: + # We have over-fetched. Skip this group. + continue + scores = [np.nan] * max_requested_num_docs + for row in group: + # Length check is necessary due to over-fetching. + if row.document_position < len(scores): + scores[row.document_position] = row.score + for eval_name in (name, None): + for num_docs in requested_num_docs.get((span_rowid, eval_name)) or (): + metrics = RetrievalMetrics(scores[:num_docs]) + doc_metrics = DocumentRetrievalMetrics( + evaluation_name=name, metrics=metrics + ) + key = (span_rowid, eval_name, num_docs) + results[key].append(doc_metrics) # Make sure to copy the result, so we don't return the same list # object to two different requesters. return [results[key].copy() for key in keys] diff --git a/src/phoenix/server/api/dataloaders/evaluation_summaries.py b/src/phoenix/server/api/dataloaders/evaluation_summaries.py new file mode 100644 index 0000000000..5b526fceda --- /dev/null +++ b/src/phoenix/server/api/dataloaders/evaluation_summaries.py @@ -0,0 +1,123 @@ +from collections import defaultdict +from datetime import datetime +from typing import ( + Any, + AsyncContextManager, + Callable, + DefaultDict, + List, + Literal, + Optional, + Tuple, +) + +import pandas as pd +from aioitertools.itertools import groupby +from sqlalchemy import Select, func, or_, select +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry.dataloader import AbstractCache, DataLoader +from typing_extensions import TypeAlias, assert_never + +from phoenix.db import models +from phoenix.server.api.input_types.TimeRange import TimeRange +from phoenix.server.api.types.EvaluationSummary import EvaluationSummary +from phoenix.trace.dsl import SpanFilter + +Kind: TypeAlias = Literal["span", "trace"] +ProjectRowId: TypeAlias = int +TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]] +FilterCondition: TypeAlias = Optional[str] +EvalName: TypeAlias = str + +Segment: TypeAlias = Tuple[Kind, ProjectRowId, TimeInterval, FilterCondition] +Param: TypeAlias = EvalName + +Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition, EvalName] +Result: TypeAlias = Optional[EvaluationSummary] +ResultPosition: TypeAlias = int +DEFAULT_VALUE: Result = None + + +def _cache_key_fn(key: Key) -> Tuple[Segment, Param]: + kind, project_rowid, time_range, filter_condition, eval_name = key + interval = ( + (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None) + ) + return (kind, project_rowid, interval, filter_condition), eval_name + + +class EvaluationSummaryDataLoader(DataLoader[Key, Result]): + def __init__( + self, + db: Callable[[], AsyncContextManager[AsyncSession]], + cache_map: Optional[AbstractCache[Key, Result]] = None, + ) -> None: + super().__init__( + load_fn=self._load_fn, + cache_key_fn=_cache_key_fn, + cache_map=cache_map, + ) + self._db = db + + async def _load_fn(self, keys: List[Key]) -> List[Result]: + results: List[Result] = [DEFAULT_VALUE] * len(keys) + arguments: DefaultDict[ + Segment, + DefaultDict[Param, List[ResultPosition]], + ] = defaultdict(lambda: defaultdict(list)) + for position, key in enumerate(keys): + segment, param = _cache_key_fn(key) + arguments[segment][param].append(position) + for segment, params in arguments.items(): + stmt = _get_stmt(segment, *params.keys()) + async with self._db() as session: + data = await session.stream(stmt) + async for eval_name, group in groupby(data, lambda row: row.name): + summary = EvaluationSummary(pd.DataFrame(group)) + for position in params[eval_name]: + results[position] = summary + return results + + +def _get_stmt( + segment: Segment, + *eval_names: Param, +) -> Select[Any]: + kind, project_rowid, (start_time, end_time), filter_condition = segment + stmt = select() + if kind == "span": + msa = models.SpanAnnotation + name_column, label_column, score_column = msa.name, msa.label, msa.score + annotator_kind_column = msa.annotator_kind + time_column = models.Span.start_time + stmt = stmt.join(models.Span).join_from(models.Span, models.Trace) + if filter_condition: + sf = SpanFilter(filter_condition) + stmt = sf(stmt) + elif kind == "trace": + mta = models.TraceAnnotation + name_column, label_column, score_column = mta.name, mta.label, mta.score + annotator_kind_column = mta.annotator_kind + time_column = models.Trace.start_time + stmt = stmt.join(models.Trace) + else: + assert_never(kind) + stmt = stmt.add_columns( + name_column, + label_column, + func.count().label("record_count"), + func.count(label_column).label("label_count"), + func.count(score_column).label("score_count"), + func.sum(score_column).label("score_sum"), + ) + stmt = stmt.group_by(name_column, label_column) + stmt = stmt.order_by(name_column, label_column) + stmt = stmt.where(models.Trace.project_rowid == project_rowid) + stmt = stmt.where(annotator_kind_column == "LLM") + stmt = stmt.where(or_(score_column.is_not(None), label_column.is_not(None))) + stmt = stmt.where(name_column.in_(eval_names)) + if start_time: + stmt = stmt.where(start_time <= time_column) + if end_time: + stmt = stmt.where(time_column < end_time) + return stmt diff --git a/src/phoenix/server/api/dataloaders/latency_ms_quantile.py b/src/phoenix/server/api/dataloaders/latency_ms_quantile.py index bd7b3cacd9..246982031a 100644 --- a/src/phoenix/server/api/dataloaders/latency_ms_quantile.py +++ b/src/phoenix/server/api/dataloaders/latency_ms_quantile.py @@ -3,87 +3,173 @@ from typing import ( Any, AsyncContextManager, + AsyncIterator, Callable, DefaultDict, List, + Literal, + Mapping, Optional, Tuple, + cast, ) -from ddsketch.ddsketch import DDSketch -from sqlalchemy import and_, select +from sqlalchemy import ( + ARRAY, + Float, + Integer, + Select, + SQLColumnExpression, + Values, + column, + func, + select, + values, +) from sqlalchemy.ext.asyncio import AsyncSession -from strawberry.dataloader import DataLoader -from typing_extensions import TypeAlias +from sqlalchemy.sql.functions import percentile_cont +from strawberry.dataloader import AbstractCache, DataLoader +from typing_extensions import TypeAlias, assert_never from phoenix.db import models +from phoenix.db.helpers import SupportedSQLDialect from phoenix.server.api.input_types.TimeRange import TimeRange +from phoenix.trace.dsl import SpanFilter -ProjectId: TypeAlias = int +Kind: TypeAlias = Literal["span", "trace"] +ProjectRowId: TypeAlias = int TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]] -Segment: TypeAlias = Tuple[ProjectId, TimeInterval] +FilterCondition: TypeAlias = Optional[str] Probability: TypeAlias = float -Key: TypeAlias = Tuple[ProjectId, Optional[TimeRange], Probability] -ResultPosition: TypeAlias = int QuantileValue: TypeAlias = float -OrmExpression: TypeAlias = Any + +Segment: TypeAlias = Tuple[Kind, TimeInterval, FilterCondition] +Param: TypeAlias = Tuple[ProjectRowId, Probability] + +Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition, Probability] +Result: TypeAlias = Optional[QuantileValue] +ResultPosition: TypeAlias = int +DEFAULT_VALUE: Result = None + +FloatCol: TypeAlias = SQLColumnExpression[Float[float]] + + +def _cache_key_fn(key: Key) -> Tuple[Segment, Param]: + kind, project_rowid, time_range, filter_condition, probability = key + interval = ( + (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None) + ) + return (kind, interval, filter_condition), (project_rowid, probability) -class LatencyMsQuantileDataLoader(DataLoader[Key, Optional[QuantileValue]]): - def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None: - super().__init__(load_fn=self._load_fn, cache_key_fn=self._cache_key_fn) +class LatencyMsQuantileDataLoader(DataLoader[Key, Result]): + def __init__( + self, + db: Callable[[], AsyncContextManager[AsyncSession]], + cache_map: Optional[AbstractCache[Key, Result]] = None, + ) -> None: + super().__init__( + load_fn=self._load_fn, + cache_key_fn=_cache_key_fn, + cache_map=cache_map, + ) self._db = db - @staticmethod - def _cache_key_fn(key: Key) -> Tuple[Segment, Probability]: - if isinstance(key[1], TimeRange): - return (key[0], (key[1].start, key[1].end)), key[2] - return (key[0], (None, None)), key[2] - - async def _load_fn(self, keys: List[Key]) -> List[Optional[QuantileValue]]: - # We use ddsketch here because sqlite doesn't have percentile functions - # unless we compile it with the percentile.c extension, like how it's - # done in the Python package https://github.com/nalgeon/sqlean.py - results: List[Optional[QuantileValue]] = [None] * len(keys) + async def _load_fn(self, keys: List[Key]) -> List[Result]: + results: List[Result] = [DEFAULT_VALUE] * len(keys) arguments: DefaultDict[ Segment, - List[Tuple[ResultPosition, Probability]], - ] = defaultdict(list) - sketches: DefaultDict[Segment, DDSketch] = defaultdict(DDSketch) - for i, key in enumerate(keys): - segment, probability = self._cache_key_fn(key) - arguments[segment].append((i, probability)) + DefaultDict[Param, List[ResultPosition]], + ] = defaultdict(lambda: defaultdict(list)) + for position, key in enumerate(keys): + segment, param = _cache_key_fn(key) + arguments[segment][param].append(position) async with self._db() as session: - for segment, probabilities in arguments.items(): - stmt = ( - select(models.Trace.latency_ms) - .join(models.Project) - .where(_get_filter_condition(segment)) - ) - sketch = sketches[segment] - async for val in await session.stream_scalars(stmt): - sketch.add(val) - for i, p in probabilities: - results[i] = sketch.get_quantile_value(p) + dialect = SupportedSQLDialect(session.bind.dialect.name) + for segment, params in arguments.items(): + async for position, quantile_value in _get_results( + dialect, session, segment, params + ): + results[position] = quantile_value return results -def _get_filter_condition(segment: Segment) -> OrmExpression: - id_, (start_time, end_time) = segment - if start_time and end_time: - return and_( - models.Project.id == id_, - start_time <= models.Trace.start_time, - models.Trace.start_time < end_time, - ) +async def _get_results( + dialect: SupportedSQLDialect, + session: AsyncSession, + segment: Segment, + params: Mapping[Param, List[ResultPosition]], +) -> AsyncIterator[Tuple[ResultPosition, QuantileValue]]: + kind, (start_time, end_time), filter_condition = segment + stmt = select(models.Trace.project_rowid) + if kind == "trace": + latency_column = cast(FloatCol, models.Trace.latency_ms) + time_column = models.Trace.start_time + elif kind == "span": + latency_column = cast(FloatCol, models.Span.latency_ms) + time_column = models.Span.start_time + stmt = stmt.join(models.Span) + if filter_condition: + sf = SpanFilter(filter_condition) + stmt = sf(stmt) + else: + assert_never(kind) if start_time: - return and_( - models.Project.id == id_, - start_time <= models.Trace.start_time, - ) + stmt = stmt.where(start_time <= time_column) if end_time: - return and_( - models.Project.id == id_, - models.Trace.start_time < end_time, - ) - return models.Project.id == id_ + stmt = stmt.where(time_column < end_time) + if dialect is SupportedSQLDialect.POSTGRESQL: + results = _get_results_postgresql(session, stmt, latency_column, params) + elif dialect is SupportedSQLDialect.SQLITE: + results = _get_results_sqlite(session, stmt, latency_column, params) + else: + assert_never(dialect) + async for position, quantile_value in results: + yield position, quantile_value + + +async def _get_results_sqlite( + session: AsyncSession, + base_stmt: Select[Any], + latency_column: FloatCol, + params: Mapping[Param, List[ResultPosition]], +) -> AsyncIterator[Tuple[ResultPosition, QuantileValue]]: + projects_per_prob: DefaultDict[Probability, List[ProjectRowId]] = defaultdict(list) + for project_rowid, probability in params.keys(): + projects_per_prob[probability].append(project_rowid) + pid = models.Trace.project_rowid + for probability, project_rowids in projects_per_prob.items(): + pctl: FloatCol = func.percentile(latency_column, probability * 100) + stmt = base_stmt.add_columns(pctl) + stmt = stmt.where(pid.in_(project_rowids)) + stmt = stmt.group_by(pid) + data = await session.stream(stmt) + async for project_rowid, quantile_value in data: + for position in params[(project_rowid, probability)]: + yield position, quantile_value + + +async def _get_results_postgresql( + session: AsyncSession, + base_stmt: Select[Any], + latency_column: FloatCol, + params: Mapping[Param, List[ResultPosition]], +) -> AsyncIterator[Tuple[ResultPosition, QuantileValue]]: + probs_per_project: DefaultDict[ProjectRowId, List[Probability]] = defaultdict(list) + for project_rowid, probability in params.keys(): + probs_per_project[project_rowid].append(probability) + pp: Values = values( + column("project_rowid", Integer), + column("probabilities", ARRAY(Float[float])), + name="project_probabilities", + ).data(probs_per_project.items()) # type: ignore + pid = models.Trace.project_rowid + pctl: FloatCol = percentile_cont(pp.c.probabilities).within_group(latency_column) + stmt = base_stmt.add_columns(pp.c.probabilities, pctl) + stmt = stmt.join(pp, pid == pp.c.project_rowid) + stmt = stmt.group_by(pid, pp.c.probabilities) + data = await session.stream(stmt) + async for project_rowid, probabilities, quantile_values in data: + for probability, quantile_value in zip(probabilities, quantile_values): + for position in params[(project_rowid, probability)]: + yield position, quantile_value diff --git a/src/phoenix/server/api/dataloaders/min_start_or_max_end_times.py b/src/phoenix/server/api/dataloaders/min_start_or_max_end_times.py new file mode 100644 index 0000000000..835cb40623 --- /dev/null +++ b/src/phoenix/server/api/dataloaders/min_start_or_max_end_times.py @@ -0,0 +1,75 @@ +from collections import defaultdict +from datetime import datetime +from typing import ( + AsyncContextManager, + Callable, + DefaultDict, + List, + Literal, + Optional, + Tuple, +) + +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry.dataloader import AbstractCache, DataLoader +from typing_extensions import TypeAlias, assert_never + +from phoenix.db import models + +Kind: TypeAlias = Literal["start", "end"] +ProjectRowId: TypeAlias = int + +Segment: TypeAlias = ProjectRowId +Param: TypeAlias = Kind + +Key: TypeAlias = Tuple[ProjectRowId, Kind] +Result: TypeAlias = Optional[datetime] +ResultPosition: TypeAlias = int +DEFAULT_VALUE: Result = None + + +class MinStartOrMaxEndTimeDataLoader(DataLoader[Key, Result]): + def __init__( + self, + db: Callable[[], AsyncContextManager[AsyncSession]], + cache_map: Optional[AbstractCache[Key, Result]] = None, + ) -> None: + super().__init__( + load_fn=self._load_fn, + cache_map=cache_map, + ) + self._db = db + + async def _load_fn(self, keys: List[Key]) -> List[Result]: + results: List[Result] = [DEFAULT_VALUE] * len(keys) + arguments: DefaultDict[ + Segment, + DefaultDict[Param, List[ResultPosition]], + ] = defaultdict(lambda: defaultdict(list)) + for position, key in enumerate(keys): + segment, param = key + arguments[segment][param].append(position) + pid = models.Trace.project_rowid + stmt = ( + select( + pid, + func.min(models.Trace.start_time).label("min_start"), + func.max(models.Trace.end_time).label("max_end"), + ) + .where(pid.in_(arguments.keys())) + .group_by(pid) + ) + async with self._db() as session: + data = await session.stream(stmt) + async for project_rowid, min_start, max_end in data: + for kind, positions in arguments[project_rowid].items(): + if kind == "start": + for position in positions: + results[position] = min_start + elif kind == "end": + for position in positions: + results[position] = max_end + else: + assert_never(kind) + return results diff --git a/src/phoenix/server/api/dataloaders/record_counts.py b/src/phoenix/server/api/dataloaders/record_counts.py new file mode 100644 index 0000000000..186c8d7534 --- /dev/null +++ b/src/phoenix/server/api/dataloaders/record_counts.py @@ -0,0 +1,102 @@ +from collections import defaultdict +from datetime import datetime +from typing import ( + Any, + AsyncContextManager, + Callable, + DefaultDict, + List, + Literal, + Optional, + Tuple, +) + +from sqlalchemy import Select, func, select +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry.dataloader import AbstractCache, DataLoader +from typing_extensions import TypeAlias, assert_never + +from phoenix.db import models +from phoenix.server.api.input_types.TimeRange import TimeRange +from phoenix.trace.dsl import SpanFilter + +Kind: TypeAlias = Literal["span", "trace"] +ProjectRowId: TypeAlias = int +TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]] +FilterCondition: TypeAlias = Optional[str] +SpanCount: TypeAlias = int + +Segment: TypeAlias = Tuple[Kind, TimeInterval, FilterCondition] +Param: TypeAlias = ProjectRowId + +Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition] +Result: TypeAlias = SpanCount +ResultPosition: TypeAlias = int +DEFAULT_VALUE: Result = 0 + + +def _cache_key_fn(key: Key) -> Tuple[Segment, Param]: + kind, project_rowid, time_range, filter_condition = key + interval = ( + (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None) + ) + return (kind, interval, filter_condition), project_rowid + + +class RecordCountDataLoader(DataLoader[Key, Result]): + def __init__( + self, + db: Callable[[], AsyncContextManager[AsyncSession]], + cache_map: Optional[AbstractCache[Key, Result]] = None, + ) -> None: + super().__init__( + load_fn=self._load_fn, + cache_key_fn=_cache_key_fn, + cache_map=cache_map, + ) + self._db = db + + async def _load_fn(self, keys: List[Key]) -> List[Result]: + results: List[Result] = [DEFAULT_VALUE] * len(keys) + arguments: DefaultDict[ + Segment, + DefaultDict[Param, List[ResultPosition]], + ] = defaultdict(lambda: defaultdict(list)) + for position, key in enumerate(keys): + segment, param = _cache_key_fn(key) + arguments[segment][param].append(position) + async with self._db() as session: + for segment, params in arguments.items(): + stmt = _get_stmt(segment, *params.keys()) + data = await session.stream(stmt) + async for project_rowid, count in data: + for position in params[project_rowid]: + results[position] = count + return results + + +def _get_stmt( + segment: Segment, + *project_rowids: Param, +) -> Select[Any]: + kind, (start_time, end_time), filter_condition = segment + pid = models.Trace.project_rowid + stmt = select(pid) + if kind == "span": + time_column = models.Span.start_time + stmt = stmt.join(models.Span) + if filter_condition: + sf = SpanFilter(filter_condition) + stmt = sf(stmt) + elif kind == "trace": + time_column = models.Trace.start_time + else: + assert_never(kind) + stmt = stmt.add_columns(func.count().label("count")) + stmt = stmt.where(pid.in_(project_rowids)) + stmt = stmt.group_by(pid) + if start_time: + stmt = stmt.where(start_time <= time_column) + if end_time: + stmt = stmt.where(time_column < end_time) + return stmt diff --git a/src/phoenix/server/api/dataloaders/span_descendants.py b/src/phoenix/server/api/dataloaders/span_descendants.py index 8ea3ab1d29..e98bd7cbe8 100644 --- a/src/phoenix/server/api/dataloaders/span_descendants.py +++ b/src/phoenix/server/api/dataloaders/span_descendants.py @@ -1,4 +1,3 @@ -from itertools import groupby from random import randint from typing import ( AsyncContextManager, @@ -7,6 +6,7 @@ List, ) +from aioitertools.itertools import groupby from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager @@ -16,15 +16,17 @@ from phoenix.db import models SpanId: TypeAlias = str + Key: TypeAlias = SpanId +Result: TypeAlias = List[models.Span] -class SpanDescendantsDataLoader(DataLoader[Key, List[models.Span]]): +class SpanDescendantsDataLoader(DataLoader[Key, Result]): def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None: super().__init__(load_fn=self._load_fn) self._db = db - async def _load_fn(self, keys: List[Key]) -> List[List[models.Span]]: + async def _load_fn(self, keys: List[Key]) -> List[Result]: root_ids = set(keys) root_id_label = f"root_id_{randint(0, 10**6):06}" descendant_ids = ( @@ -54,11 +56,9 @@ async def _load_fn(self, keys: List[Key]) -> List[List[models.Span]]: .options(contains_eager(models.Span.trace)) .order_by(descendant_ids.c[root_id_label]) ) + results: Dict[SpanId, Result] = {key: [] for key in keys} async with self._db() as session: - data = await session.execute(stmt) - if not data: - return [[] for _ in keys] - results: Dict[SpanId, List[models.Span]] = {key: [] for key in keys} - for root_id, group in groupby(data, key=lambda d: d[0]): - results[root_id].extend(span for _, span in group) + data = await session.stream(stmt) + async for root_id, group in groupby(data, key=lambda d: d[0]): + results[root_id].extend(span for _, span in group) return [results[key].copy() for key in keys] diff --git a/src/phoenix/server/api/dataloaders/span_evaluations.py b/src/phoenix/server/api/dataloaders/span_evaluations.py index 7bc124e052..015e71f1d8 100644 --- a/src/phoenix/server/api/dataloaders/span_evaluations.py +++ b/src/phoenix/server/api/dataloaders/span_evaluations.py @@ -6,7 +6,7 @@ List, ) -from sqlalchemy import and_, select +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from strawberry.dataloader import DataLoader from typing_extensions import TypeAlias @@ -15,24 +15,22 @@ from phoenix.server.api.types.Evaluation import SpanEvaluation Key: TypeAlias = int +Result: TypeAlias = List[SpanEvaluation] -class SpanEvaluationsDataLoader(DataLoader[Key, List[SpanEvaluation]]): +class SpanEvaluationsDataLoader(DataLoader[Key, Result]): def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None: super().__init__(load_fn=self._load_fn) self._db = db - async def _load_fn(self, keys: List[Key]) -> List[List[SpanEvaluation]]: - span_evaluations_by_id: DefaultDict[Key, List[SpanEvaluation]] = defaultdict(list) + async def _load_fn(self, keys: List[Key]) -> List[Result]: + span_evaluations_by_id: DefaultDict[Key, Result] = defaultdict(list) + msa = models.SpanAnnotation async with self._db() as session: - for span_evaluation in await session.scalars( - select(models.SpanAnnotation).where( - and_( - models.SpanAnnotation.span_rowid.in_(keys), - models.SpanAnnotation.annotator_kind == "LLM", - ) - ) - ): + data = await session.stream_scalars( + select(msa).where(msa.span_rowid.in_(keys)).where(msa.annotator_kind == "LLM") + ) + async for span_evaluation in data: span_evaluations_by_id[span_evaluation.span_rowid].append( SpanEvaluation.from_sql_span_annotation(span_evaluation) ) diff --git a/src/phoenix/server/api/dataloaders/token_counts.py b/src/phoenix/server/api/dataloaders/token_counts.py new file mode 100644 index 0000000000..ecbfa675a6 --- /dev/null +++ b/src/phoenix/server/api/dataloaders/token_counts.py @@ -0,0 +1,115 @@ +from collections import defaultdict +from datetime import datetime +from typing import ( + Any, + AsyncContextManager, + Callable, + DefaultDict, + List, + Literal, + Optional, + Tuple, +) + +from openinference.semconv.trace import SpanAttributes +from sqlalchemy import Select, func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.functions import coalesce +from strawberry.dataloader import AbstractCache, DataLoader +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.server.api.input_types.TimeRange import TimeRange +from phoenix.trace.dsl import SpanFilter + +Kind: TypeAlias = Literal["prompt", "completion", "total"] +ProjectRowId: TypeAlias = int +TimeInterval: TypeAlias = Tuple[Optional[datetime], Optional[datetime]] +FilterCondition: TypeAlias = Optional[str] +TokenCount: TypeAlias = int + +Segment: TypeAlias = Tuple[TimeInterval, FilterCondition] +Param: TypeAlias = Tuple[ProjectRowId, Kind] + +Key: TypeAlias = Tuple[Kind, ProjectRowId, Optional[TimeRange], FilterCondition] +Result: TypeAlias = TokenCount +ResultPosition: TypeAlias = int +DEFAULT_VALUE: Result = 0 + + +def _cache_key_fn(key: Key) -> Tuple[Segment, Param]: + kind, project_rowid, time_range, filter_condition = key + interval = ( + (time_range.start, time_range.end) if isinstance(time_range, TimeRange) else (None, None) + ) + return (interval, filter_condition), (project_rowid, kind) + + +class TokenCountDataLoader(DataLoader[Key, Result]): + def __init__( + self, + db: Callable[[], AsyncContextManager[AsyncSession]], + cache_map: Optional[AbstractCache[Key, Result]] = None, + ) -> None: + super().__init__( + load_fn=self._load_fn, + cache_key_fn=_cache_key_fn, + cache_map=cache_map, + ) + self._db = db + + async def _load_fn(self, keys: List[Key]) -> List[Result]: + results: List[Result] = [DEFAULT_VALUE] * len(keys) + arguments: DefaultDict[ + Segment, + DefaultDict[Param, List[ResultPosition]], + ] = defaultdict(lambda: defaultdict(list)) + for position, key in enumerate(keys): + segment, param = _cache_key_fn(key) + arguments[segment][param].append(position) + async with self._db() as session: + for segment, params in arguments.items(): + stmt = _get_stmt(segment, *params.keys()) + data = await session.stream(stmt) + async for project_rowid, prompt, completion, total in data: + for position in params[(project_rowid, "prompt")]: + results[position] = prompt + for position in params[(project_rowid, "completion")]: + results[position] = completion + for position in params[(project_rowid, "total")]: + results[position] = total + return results + + +def _get_stmt( + segment: Segment, + *params: Param, +) -> Select[Any]: + (start_time, end_time), filter_condition = segment + prompt = func.sum(models.Span.attributes[_LLM_TOKEN_COUNT_PROMPT].as_float()) + completion = func.sum(models.Span.attributes[_LLM_TOKEN_COUNT_COMPLETION].as_float()) + total = coalesce(prompt, 0) + coalesce(completion, 0) + pid = models.Trace.project_rowid + stmt: Select[Any] = ( + select( + pid, + prompt.label("prompt"), + completion.label("completion"), + total.label("total"), + ) + .join_from(models.Trace, models.Span) + .group_by(pid) + ) + if start_time: + stmt = stmt.where(start_time <= models.Span.start_time) + if end_time: + stmt = stmt.where(models.Span.start_time < end_time) + if filter_condition: + sf = SpanFilter(filter_condition) + stmt = sf(stmt) + stmt = stmt.where(pid.in_([rowid for rowid, _ in params])) + return stmt + + +_LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT.split(".") +_LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION.split(".") diff --git a/src/phoenix/server/api/dataloaders/trace_evaluations.py b/src/phoenix/server/api/dataloaders/trace_evaluations.py index 85db431041..babd7a9055 100644 --- a/src/phoenix/server/api/dataloaders/trace_evaluations.py +++ b/src/phoenix/server/api/dataloaders/trace_evaluations.py @@ -6,7 +6,7 @@ List, ) -from sqlalchemy import and_, select +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from strawberry.dataloader import DataLoader from typing_extensions import TypeAlias @@ -15,24 +15,22 @@ from phoenix.server.api.types.Evaluation import TraceEvaluation Key: TypeAlias = int +Result: TypeAlias = List[TraceEvaluation] -class TraceEvaluationsDataLoader(DataLoader[Key, List[TraceEvaluation]]): +class TraceEvaluationsDataLoader(DataLoader[Key, Result]): def __init__(self, db: Callable[[], AsyncContextManager[AsyncSession]]) -> None: super().__init__(load_fn=self._load_fn) self._db = db - async def _load_fn(self, keys: List[Key]) -> List[List[TraceEvaluation]]: - trace_evaluations_by_id: DefaultDict[Key, List[TraceEvaluation]] = defaultdict(list) + async def _load_fn(self, keys: List[Key]) -> List[Result]: + trace_evaluations_by_id: DefaultDict[Key, Result] = defaultdict(list) + mta = models.TraceAnnotation async with self._db() as session: - for trace_evaluation in await session.scalars( - select(models.TraceAnnotation).where( - and_( - models.TraceAnnotation.trace_rowid.in_(keys), - models.TraceAnnotation.annotator_kind == "LLM", - ) - ) - ): + data = await session.stream_scalars( + select(mta).where(mta.trace_rowid.in_(keys)).where(mta.annotator_kind == "LLM") + ) + async for trace_evaluation in data: trace_evaluations_by_id[trace_evaluation.trace_rowid].append( TraceEvaluation.from_sql_trace_annotation(trace_evaluation) ) diff --git a/src/phoenix/server/api/types/EvaluationSummary.py b/src/phoenix/server/api/types/EvaluationSummary.py index 5fcef3cd12..664d0f2094 100644 --- a/src/phoenix/server/api/types/EvaluationSummary.py +++ b/src/phoenix/server/api/types/EvaluationSummary.py @@ -1,10 +1,7 @@ -import math -from functools import cached_property -from typing import List, Optional, Sequence, Union, cast +from typing import List, Optional, Union, cast import pandas as pd import strawberry -from pandas.api.types import CategoricalDtype from strawberry import Private from phoenix.db import models @@ -20,52 +17,44 @@ class LabelFraction: @strawberry.type class EvaluationSummary: - count: int - labels: Sequence[str] - annotations: Private[Sequence[AnnotationType]] + df: Private[pd.DataFrame] - def __init__( - self, - annotations: Sequence[AnnotationType], - labels: Sequence[str], - ) -> None: - self.annotations = annotations - self.labels = labels - self.count = len(annotations) + def __init__(self, dataframe: pd.DataFrame) -> None: + self.df = dataframe + + @strawberry.field + def count(self) -> int: + return cast(int, self.df.record_count.sum()) + + @strawberry.field + def labels(self) -> List[str]: + return self.df.label.dropna().tolist() @strawberry.field def label_fractions(self) -> List[LabelFraction]: - if not self.labels or not (n := len(self._eval_labels)): + if not (n := self.df.label_count.sum()): return [] - counts = self._eval_labels.value_counts(dropna=True) return [ - LabelFraction(label=cast(str, label), fraction=count / n) - for label, count in counts.items() + LabelFraction( + label=cast(str, row.label), + fraction=row.label_count / n, + ) + for row in self.df.loc[ + self.df.label.notna(), + ["label", "label_count"], + ].itertuples() ] @strawberry.field def mean_score(self) -> Optional[float]: - value = self._eval_scores.mean() - return None if math.isnan(value) else value + if not (n := self.df.score_count.sum()): + return None + return cast(float, self.df.score_sum.sum() / n) @strawberry.field def score_count(self) -> int: - return self._eval_scores.count() + return cast(int, self.df.score_count.sum()) @strawberry.field def label_count(self) -> int: - return self._eval_labels.count() - - @cached_property - def _eval_scores(self) -> "pd.Series[float]": - return pd.Series( - (evaluation.score for evaluation in self.annotations), - dtype=float, - ) - - @cached_property - def _eval_labels(self) -> "pd.Series[CategoricalDtype]": - return pd.Series( - (evaluation.label for evaluation in self.annotations), - dtype=CategoricalDtype(categories=self.labels), # type: ignore - ) + return cast(int, self.df.label_count.sum()) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index b0dd3d15a2..628c34927a 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -1,18 +1,15 @@ from datetime import datetime -from typing import List, Optional, cast +from typing import List, Optional -import numpy as np import strawberry from openinference.semconv.trace import SpanAttributes -from sqlalchemy import ScalarResult, and_, distinct, func, select -from sqlalchemy.orm import contains_eager, selectinload -from sqlalchemy.sql.functions import coalesce +from sqlalchemy import and_, distinct, select +from sqlalchemy.orm import contains_eager from strawberry import ID, UNSET from strawberry.types import Info from phoenix.datetime_utils import right_open_time_range from phoenix.db import models -from phoenix.metrics.retrieval_metrics import RetrievalMetrics from phoenix.server.api.context import Context from phoenix.server.api.input_types.SpanSort import SpanSort from phoenix.server.api.input_types.TimeRange import TimeRange @@ -42,11 +39,9 @@ async def start_time( self, info: Info[Context, None], ) -> Optional[datetime]: - stmt = select(func.min(models.Trace.start_time)).where( - models.Trace.project_rowid == self.id_attr + start_time = await info.context.data_loaders.min_start_or_max_end_times.load( + (self.id_attr, "start"), ) - async with info.context.db() as session: - start_time = await session.scalar(stmt) start_time, _ = right_open_time_range(start_time, None) return start_time @@ -55,11 +50,9 @@ async def end_time( self, info: Info[Context, None], ) -> Optional[datetime]: - stmt = select(func.max(models.Trace.end_time)).where( - models.Trace.project_rowid == self.id_attr + end_time = await info.context.data_loaders.min_start_or_max_end_times.load( + (self.id_attr, "end"), ) - async with info.context.db() as session: - end_time = await session.scalar(stmt) _, end_time = right_open_time_range(None, end_time) return end_time @@ -68,21 +61,11 @@ async def record_count( self, info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, + filter_condition: Optional[str] = UNSET, ) -> int: - stmt = ( - select(func.count(models.Span.id)) - .join(models.Trace) - .where(models.Trace.project_rowid == self.id_attr) + return await info.context.data_loaders.record_counts.load( + ("span", self.id_attr, time_range, filter_condition), ) - if time_range: - stmt = stmt.where( - and_( - time_range.start <= models.Span.start_time, - models.Span.start_time < time_range.end, - ) - ) - async with info.context.db() as session: - return (await session.scalar(stmt)) or 0 @strawberry.field async def trace_count( @@ -90,39 +73,42 @@ async def trace_count( info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, ) -> int: - stmt = select(func.count(models.Trace.id)).where(models.Trace.project_rowid == self.id_attr) - if time_range: - stmt = stmt.where( - and_( - time_range.start <= models.Trace.start_time, - models.Trace.start_time < time_range.end, - ) - ) - async with info.context.db() as session: - return (await session.scalar(stmt)) or 0 + return await info.context.data_loaders.record_counts.load( + ("trace", self.id_attr, time_range, None), + ) @strawberry.field async def token_count_total( self, info: Info[Context, None], time_range: Optional[TimeRange] = UNSET, + filter_condition: Optional[str] = UNSET, ) -> int: - prompt = models.Span.attributes[LLM_TOKEN_COUNT_PROMPT].as_float() - completion = models.Span.attributes[LLM_TOKEN_COUNT_COMPLETION].as_float() - stmt = ( - select(coalesce(func.sum(prompt), 0) + coalesce(func.sum(completion), 0)) - .join(models.Trace) - .where(models.Trace.project_rowid == self.id_attr) + return await info.context.data_loaders.token_counts.load( + ("total", self.id_attr, time_range, filter_condition), + ) + + @strawberry.field + async def token_count_prompt( + self, + info: Info[Context, None], + time_range: Optional[TimeRange] = UNSET, + filter_condition: Optional[str] = UNSET, + ) -> int: + return await info.context.data_loaders.token_counts.load( + ("prompt", self.id_attr, time_range, filter_condition), + ) + + @strawberry.field + async def token_count_completion( + self, + info: Info[Context, None], + time_range: Optional[TimeRange] = UNSET, + filter_condition: Optional[str] = UNSET, + ) -> int: + return await info.context.data_loaders.token_counts.load( + ("completion", self.id_attr, time_range, filter_condition), ) - if time_range: - stmt = stmt.where( - and_( - time_range.start <= models.Span.start_time, - models.Span.start_time < time_range.end, - ) - ) - async with info.context.db() as session: - return (await session.scalar(stmt)) or 0 @strawberry.field async def latency_ms_quantile( @@ -132,7 +118,19 @@ async def latency_ms_quantile( time_range: Optional[TimeRange] = UNSET, ) -> Optional[float]: return await info.context.data_loaders.latency_ms_quantile.load( - (self.id_attr, time_range, probability) + ("trace", self.id_attr, time_range, None, probability), + ) + + @strawberry.field + async def span_latency_ms_quantile( + self, + info: Info[Context, None], + probability: float, + time_range: Optional[TimeRange] = UNSET, + filter_condition: Optional[str] = UNSET, + ) -> Optional[float]: + return await info.context.data_loaders.latency_ms_quantile.load( + ("span", self.id_attr, time_range, filter_condition, probability), ) @strawberry.field @@ -263,37 +261,9 @@ async def trace_evaluation_summary( evaluation_name: str, time_range: Optional[TimeRange] = UNSET, ) -> Optional[EvaluationSummary]: - base_query = ( - select(models.TraceAnnotation) - .join(models.Trace) - .where(models.Trace.project_rowid == self.id_attr) - .where(models.TraceAnnotation.annotator_kind == "LLM") - .where(models.TraceAnnotation.name == evaluation_name) + return await info.context.data_loaders.evaluation_summaries.load( + ("trace", self.id_attr, time_range, None, evaluation_name), ) - unfiltered = base_query - filtered = base_query - if time_range: - filtered = filtered.where( - and_( - time_range.start <= models.Span.start_time, - models.Span.start_time < time_range.end, - ) - ) - - # todo: implement filter condition - async with info.context.db() as session: - evaluations = list(await session.scalars(filtered)) - if not evaluations: - return None - labels = cast( - ScalarResult[str], - await session.scalars( - unfiltered.with_only_columns(distinct(models.TraceAnnotation.label)).where( - models.TraceAnnotation.label.is_not(None) - ) - ), - ) - return EvaluationSummary(evaluations, list(labels)) @strawberry.field async def span_evaluation_summary( @@ -303,38 +273,9 @@ async def span_evaluation_summary( time_range: Optional[TimeRange] = UNSET, filter_condition: Optional[str] = UNSET, ) -> Optional[EvaluationSummary]: - base_query = ( - select(models.SpanAnnotation) - .join(models.Span) - .join(models.Trace, models.Span.trace_rowid == models.Trace.id) - .where(models.Trace.project_rowid == self.id_attr) - .where(models.SpanAnnotation.annotator_kind == "LLM") - .where(models.SpanAnnotation.name == evaluation_name) + return await info.context.data_loaders.evaluation_summaries.load( + ("span", self.id_attr, time_range, filter_condition, evaluation_name), ) - unfiltered = base_query - filtered = base_query - if time_range: - filtered = filtered.where( - and_( - time_range.start <= models.Span.start_time, - models.Span.start_time < time_range.end, - ) - ) - - # todo: implement filter condition - async with info.context.db() as session: - evaluations = list(await session.scalars(filtered)) - if not evaluations: - return None - labels = cast( - ScalarResult[str], - await session.scalars( - unfiltered.with_only_columns(distinct(models.SpanAnnotation.label)).where( - models.SpanAnnotation.label.is_not(None) - ) - ), - ) - return EvaluationSummary(evaluations, list(labels)) @strawberry.field async def document_evaluation_summary( @@ -344,42 +285,8 @@ async def document_evaluation_summary( time_range: Optional[TimeRange] = UNSET, filter_condition: Optional[str] = UNSET, ) -> Optional[DocumentEvaluationSummary]: - stmt = ( - select(models.Span) - .join(models.Trace) - .where( - models.Trace.project_rowid == self.id_attr, - ) - .options(selectinload(models.Span.document_annotations)) - .options(contains_eager(models.Span.trace)) - ) - if time_range: - stmt = stmt.where( - and_( - time_range.start <= models.Span.start_time, - models.Span.start_time < time_range.end, - ) - ) - # todo: add filter_condition - async with info.context.db() as session: - sql_spans = await session.scalars(stmt) - metrics_collection = [] - for sql_span in sql_spans: - span = to_gql_span(sql_span) - if not (num_documents := span.num_documents): - continue - evaluation_scores: List[float] = [np.nan] * num_documents - for annotation in sql_span.document_annotations: - if (score := annotation.score) is not None and ( - document_position := annotation.document_position - ) < num_documents: - evaluation_scores[document_position] = score - metrics_collection.append(RetrievalMetrics(evaluation_scores)) - if not metrics_collection: - return None - return DocumentEvaluationSummary( - evaluation_name=evaluation_name, - metrics_collection=metrics_collection, + return await info.context.data_loaders.document_evaluation_summaries.load( + (self.id_attr, time_range, filter_condition, evaluation_name), ) @strawberry.field diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index ae18b0885b..6d4b549238 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -56,12 +56,17 @@ from phoenix.server.api.context import Context, DataLoaders from phoenix.server.api.dataloaders import ( DocumentEvaluationsDataLoader, + DocumentEvaluationSummaryDataLoader, DocumentRetrievalMetricsDataLoader, + EvaluationSummaryDataLoader, LatencyMsQuantileDataLoader, + MinStartOrMaxEndTimeDataLoader, + RecordCountDataLoader, + SpanDescendantsDataLoader, SpanEvaluationsDataLoader, + TokenCountDataLoader, TraceEvaluationsDataLoader, ) -from phoenix.server.api.dataloaders.span_descendants import SpanDescendantsDataLoader from phoenix.server.api.routers.v1 import V1_ROUTES from phoenix.server.api.schema import schema from phoenix.server.grpc_server import GrpcServer @@ -170,12 +175,17 @@ async def get_context( export_path=self.export_path, streaming_last_updated_at=self.streaming_last_updated_at, data_loaders=DataLoaders( - latency_ms_quantile=LatencyMsQuantileDataLoader(self.db), - span_evaluations=SpanEvaluationsDataLoader(self.db), + document_evaluation_summaries=DocumentEvaluationSummaryDataLoader(self.db), document_evaluations=DocumentEvaluationsDataLoader(self.db), - trace_evaluations=TraceEvaluationsDataLoader(self.db), document_retrieval_metrics=DocumentRetrievalMetricsDataLoader(self.db), + evaluation_summaries=EvaluationSummaryDataLoader(self.db), + latency_ms_quantile=LatencyMsQuantileDataLoader(self.db), + min_start_or_max_end_times=MinStartOrMaxEndTimeDataLoader(self.db), + record_counts=RecordCountDataLoader(self.db), span_descendants=SpanDescendantsDataLoader(self.db), + span_evaluations=SpanEvaluationsDataLoader(self.db), + token_counts=TokenCountDataLoader(self.db), + trace_evaluations=TraceEvaluationsDataLoader(self.db), ), ) diff --git a/src/phoenix/trace/dsl/query.py b/src/phoenix/trace/dsl/query.py index b0b62ad937..9d082f8c37 100644 --- a/src/phoenix/trace/dsl/query.py +++ b/src/phoenix/trace/dsl/query.py @@ -12,7 +12,6 @@ Dict, Iterable, List, - Literal, Mapping, Optional, Sequence, @@ -24,9 +23,11 @@ from sqlalchemy import JSON, Column, Label, Select, SQLColumnExpression, and_, func, select from sqlalchemy.dialects.postgresql import aggregate_order_by from sqlalchemy.orm import Session, aliased +from typing_extensions import assert_never from phoenix.config import DEFAULT_PROJECT_NAME from phoenix.db import models +from phoenix.db.helpers import SupportedSQLDialect from phoenix.trace.attributes import ( JSON_STRING_ATTRIBUTES, SEMANTIC_CONVENTIONS, @@ -41,10 +42,6 @@ DEFAULT_SPAN_LIMIT = 1000 -# supported SQL dialects -_SQLITE: Literal["sqlite"] = "sqlite" -_POSTGRESQL: Literal["postgresql"] = "postgresql" - RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS _SPAN_ID = "context.span_id" @@ -151,10 +148,10 @@ def with_primary_index_key(self, _: str) -> "Explosion": def update_sql( self, stmt: Select[Any], - dialect: Literal["sqlite", "postgresql"], + dialect: SupportedSQLDialect, ) -> Select[Any]: array = self() - if dialect == _SQLITE: + if dialect is SupportedSQLDialect.SQLITE: # Because sqlite doesn't support `WITH ORDINALITY`, the order of # the returned (table) values is not guaranteed. So we resort to # post hoc processing using pandas. @@ -164,7 +161,7 @@ def update_sql( array.label(self._array_tmp_col_label), ) return stmt - elif dialect == _POSTGRESQL: + elif dialect is SupportedSQLDialect.POSTGRESQL: element = ( func.jsonb_array_elements(array) .table_valued( @@ -190,12 +187,13 @@ def update_sql( .add_columns(position_label, *columns) ) return stmt - raise NotImplementedError(f"Unsupported dialect: {dialect}") + else: + assert_never(dialect) def update_df( self, df: pd.DataFrame, - dialect: Literal["sqlite", "postgresql"], + dialect: SupportedSQLDialect, ) -> pd.DataFrame: df = df.rename(self._remove_tmp_suffix, axis=1) if df.empty: @@ -210,10 +208,10 @@ def update_df( ) df = pd.DataFrame(columns=columns).set_index(self.index_keys) return df - if dialect != _SQLITE and self.kwargs: + if dialect != SupportedSQLDialect.SQLITE and self.kwargs: df = df.set_index(self.index_keys) return df - if dialect == _SQLITE: + if dialect is SupportedSQLDialect.SQLITE: # Because sqlite doesn't support `WITH ORDINALITY`, the order of # the returned (table) values is not guaranteed. So we resort to # post hoc processing using pandas. @@ -241,17 +239,21 @@ def _extract_values(array: List[Any]) -> List[Dict[str, Any]]: return res records = df.loc[:, self._array_tmp_col_label].dropna().map(_extract_values).explode() - else: + elif dialect is SupportedSQLDialect.POSTGRESQL: records = df.loc[:, self._array_tmp_col_label].dropna().map(flatten).map(dict) + else: + assert_never(dialect) df = df.drop(self._array_tmp_col_label, axis=1) if records.empty: df = df.set_index(self.index_keys[0]) return df df_explode = pd.DataFrame.from_records(records.to_list(), index=records.index) - if dialect == _SQLITE: + if dialect is SupportedSQLDialect.SQLITE: df = _outer_join(df, df_explode) - else: + elif dialect is SupportedSQLDialect.POSTGRESQL: df = pd.concat([df, df_explode], axis=1) + else: + assert_never(dialect) df = df.set_index(self.index_keys) return df @@ -301,10 +303,10 @@ def with_separator(self, separator: str = "\n\n") -> "Concatenation": def update_sql( self, stmt: Select[Any], - dialect: Literal["sqlite", "postgresql"], + dialect: SupportedSQLDialect, ) -> Select[Any]: array = self() - if dialect == _SQLITE: + if dialect is SupportedSQLDialect.SQLITE: # Because SQLite doesn't support `WITH ORDINALITY`, the order of # the returned table-values is not guaranteed. So we resort to # post hoc processing using pandas. @@ -314,7 +316,7 @@ def update_sql( array.label(self._array_tmp_col_label), ) return stmt - if dialect == _POSTGRESQL: + elif dialect is SupportedSQLDialect.POSTGRESQL: element = ( ( func.jsonb_array_elements(array) @@ -355,12 +357,13 @@ def update_sql( .group_by(*stmt.columns.keys()) ) return stmt - raise NotImplementedError(f"Unsupported dialect: {dialect}") + else: + assert_never(dialect) def update_df( self, df: pd.DataFrame, - dialect: Literal["sqlite", "postgresql"], + dialect: SupportedSQLDialect, ) -> pd.DataFrame: df = df.rename(self._remove_tmp_suffix, axis=1) if df.empty: @@ -373,7 +376,7 @@ def update_df( ) ) return pd.DataFrame(columns=columns, index=df.index) - if dialect == _SQLITE: + if dialect is SupportedSQLDialect.SQLITE: # Because SQLite doesn't support `WITH ORDINALITY`, the order of # the returned table-values is not guaranteed. So we resort to # post hoc processing using pandas. @@ -394,6 +397,10 @@ def _concat_values(array: List[Any]) -> Dict[str, Any]: records = df.loc[:, self._array_tmp_col_label].map(_concat_values) df_concat = pd.DataFrame.from_records(records.to_list(), index=records.index) return df.drop(self._array_tmp_col_label, axis=1).join(df_concat, how="outer") + elif dialect is SupportedSQLDialect.POSTGRESQL: + pass + else: + assert_never(dialect) return df def to_dict(self) -> Dict[str, Any]: @@ -527,8 +534,7 @@ def __call__( root_spans_only=root_spans_only, ) assert session.bind is not None - dialect = cast(Literal["sqlite", "postgresql"], session.bind.dialect.name) - assert dialect in ("sqlite", "postgresql") + dialect = SupportedSQLDialect(session.bind.dialect.name) row_id = models.Span.id.label(self._pk_tmp_col_label) stmt: Select[Any] = ( # We do not allow `group_by` anything other than `row_id` because otherwise diff --git a/tests/conftest.py b/tests/conftest.py index 11a513b45f..dc44f76bd6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,18 @@ -from typing import AsyncGenerator +import contextlib +from typing import AsyncContextManager, AsyncGenerator, AsyncIterator, Callable import pytest -import sqlean from phoenix.db import models +from phoenix.db.engines import aio_sqlite_engine from psycopg import Connection from pytest_postgresql import factories -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy import make_url +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) def pytest_addoption(parser): @@ -23,7 +29,10 @@ def pytest_collection_modifyitems(config, items): if not config.getoption("--run-postgres"): for item in items: if "session" in item.fixturenames: - if "postgres" in item.callspec.params.values(): + if "postgres_session" in item.callspec.params.values(): + item.add_marker(skip_postgres) + elif "db" in item.fixturenames: + if "postgres_db" in item.callspec.params.values(): item.add_marker(skip_postgres) @@ -37,7 +46,7 @@ def openai_api_key(monkeypatch: pytest.MonkeyPatch) -> str: phoenix_postgresql = factories.postgresql("postgresql_proc") -def create_async_postgres_engine(psycopg_connection: Connection) -> sessionmaker: +def create_async_postgres_engine(psycopg_connection: Connection) -> AsyncEngine: connection = psycopg_connection.cursor().connection user = connection.info.user password = connection.info.password @@ -48,12 +57,8 @@ def create_async_postgres_engine(psycopg_connection: Connection) -> sessionmaker return create_async_engine(async_database_url) -def create_async_sqlite_engine() -> sessionmaker: - return create_async_engine("sqlite+aiosqlite:///:memory:", module=sqlean) - - @pytest.fixture -async def postgres_engine(phoenix_postgresql: Connection) -> AsyncGenerator[sessionmaker, None]: +async def postgres_engine(phoenix_postgresql: Connection) -> AsyncGenerator[AsyncEngine, None]: engine = create_async_postgres_engine(phoenix_postgresql) async with engine.begin() as conn: await conn.run_sync(models.Base.metadata.create_all) @@ -62,30 +67,62 @@ async def postgres_engine(phoenix_postgresql: Connection) -> AsyncGenerator[sess @pytest.fixture -async def sqlite_engine() -> AsyncGenerator[sessionmaker, None]: - engine = create_async_sqlite_engine() +async def sqlite_engine() -> AsyncEngine: + engine = aio_sqlite_engine(make_url("sqlite+aiosqlite://"), migrate=False, shared_cache=False) async with engine.begin() as conn: await conn.run_sync(models.Base.metadata.create_all) - yield engine - await engine.dispose() + return engine -@pytest.fixture(params=["sqlite", "postgres"]) +@pytest.fixture(params=["sqlite_session", "postgres_session"]) def session(request) -> AsyncSession: return request.getfixturevalue(request.param) +@pytest.fixture(params=["sqlite_db", "postgres_db"]) +def db(request) -> async_sessionmaker: + return request.getfixturevalue(request.param) + + +@pytest.fixture +async def sqlite_db(sqlite_engine: AsyncEngine) -> Callable[[], AsyncContextManager[AsyncSession]]: + Session = async_sessionmaker(sqlite_engine, expire_on_commit=False) + + @contextlib.asynccontextmanager + async def factory() -> AsyncIterator[AsyncSession]: + async with Session.begin() as session: + yield session + + return factory + + +@pytest.fixture +async def postgres_db( + postgres_engine: AsyncEngine, +) -> Callable[[], AsyncContextManager[AsyncSession]]: + Session = async_sessionmaker(postgres_engine, expire_on_commit=False) + + @contextlib.asynccontextmanager + async def factory() -> AsyncIterator[AsyncSession]: + async with Session.begin() as session: + yield session + + return factory + + @pytest.fixture -async def sqlite(sqlite_engine: sessionmaker) -> AsyncGenerator[AsyncSession, None]: - async_session = sessionmaker(sqlite_engine, expire_on_commit=False, class_=AsyncSession) - async with async_session() as session: +async def sqlite_session( + sqlite_db: Callable[[], AsyncContextManager[AsyncSession]], +) -> AsyncGenerator[AsyncSession, None]: + async with sqlite_db() as session: yield session @pytest.fixture -async def postgres(postgres_engine: sessionmaker) -> AsyncGenerator[AsyncSession, None]: - async_session = sessionmaker(postgres_engine, expire_on_commit=False, class_=AsyncSession) - async with async_session() as session: +async def postgres_session( + postgres_db: Callable[[], AsyncContextManager[AsyncSession]], +) -> AsyncGenerator[AsyncSession, None]: + async with postgres_db() as session: yield session diff --git a/tests/server/api/dataloaders/conftest.py b/tests/server/api/dataloaders/conftest.py new file mode 100644 index 0000000000..53bf28bef4 --- /dev/null +++ b/tests/server/api/dataloaders/conftest.py @@ -0,0 +1,89 @@ +from datetime import datetime, timedelta +from random import randint, random, seed +from typing import AsyncContextManager, Callable + +import pytest +from phoenix.db import models +from sqlalchemy import insert +from sqlalchemy.ext.asyncio import AsyncSession + + +@pytest.fixture +async def data_for_testing_dataloaders( + db: Callable[[], AsyncContextManager[AsyncSession]], +) -> None: + seed(42) + orig_time = datetime.fromisoformat("2021-01-01T00:00:00.000+00:00") + I, J, K = 10, 10, 10 # noqa: E741 + async with db() as session: + for i in range(I): + project_row_id = await session.scalar( + insert(models.Project).values(name=f"{i}").returning(models.Project.id) + ) + for j in range(J): + seconds = randint(1, 1000) + start_time = orig_time + timedelta(seconds=seconds) + end_time = orig_time + timedelta(seconds=seconds * K * 2) + trace_row_id = await session.scalar( + insert(models.Trace) + .values( + trace_id=f"{i}_{j}", + project_rowid=project_row_id, + start_time=start_time, + end_time=end_time, + ) + .returning(models.Trace.id) + ) + for name in "ABCD": + await session.execute( + insert(models.TraceAnnotation).values( + name=name, + trace_rowid=trace_row_id, + label="XYZ"[randint(0, 2)], + score=random(), + metadata_={}, + annotator_kind="LLM", + ) + ) + for k in range(K): + seconds = randint(1, 1000) + start_time = orig_time + timedelta(seconds=seconds) + end_time = orig_time + timedelta(seconds=seconds * 2) + span_row_id = await session.scalar( + insert(models.Span) + .values( + trace_rowid=trace_row_id, + span_id=f"{i}_{j}_{k}", + parent_id=None, + name=f"{i}_{j}_{k}", + span_kind="UNKNOWN", + start_time=start_time, + end_time=end_time, + attributes={ + "llm": { + "token_count": { + "prompt": randint(1, 1000), + "completion": randint(1, 1000), + } + } + }, + events=[], + status_code="OK", + status_message="okay", + cumulative_error_count=0, + cumulative_llm_token_count_prompt=0, + cumulative_llm_token_count_completion=0, + ) + .returning(models.Span.id) + ) + for name in "ABCD": + await session.execute( + insert(models.SpanAnnotation).values( + name=name, + span_rowid=span_row_id, + label="XYZ"[randint(0, 2)], + score=random(), + metadata_={}, + annotator_kind="LLM", + ) + ) diff --git a/tests/server/api/dataloaders/test_evaluation_summaries.py b/tests/server/api/dataloaders/test_evaluation_summaries.py new file mode 100644 index 0000000000..a1039ccf0b --- /dev/null +++ b/tests/server/api/dataloaders/test_evaluation_summaries.py @@ -0,0 +1,75 @@ +from datetime import datetime +from typing import AsyncContextManager, Callable + +import pandas as pd +import pytest +from phoenix.db import models +from phoenix.server.api.dataloaders import EvaluationSummaryDataLoader +from phoenix.server.api.input_types.TimeRange import TimeRange +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + + +async def test_evaluation_summaries( + db: Callable[[], AsyncContextManager[AsyncSession]], + data_for_testing_dataloaders: None, +) -> None: + start_time = datetime.fromisoformat("2021-01-01T00:00:10.000+00:00") + end_time = datetime.fromisoformat("2021-01-01T00:10:00.000+00:00") + pid = models.Trace.project_rowid + async with db() as session: + span_df = await session.run_sync( + lambda s: pd.read_sql_query( + select( + pid, + models.SpanAnnotation.name, + func.avg(models.SpanAnnotation.score).label("mean_score"), + ) + .group_by(pid, models.SpanAnnotation.name) + .order_by(pid, models.SpanAnnotation.name) + .join_from(models.Trace, models.Span) + .join_from(models.Span, models.SpanAnnotation) + .where(models.Span.name.contains("_5_")) + .where(models.SpanAnnotation.name.in_(("A", "C"))) + .where(start_time <= models.Span.start_time) + .where(models.Span.start_time < end_time), + s.connection(), + ) + ) + trace_df = await session.run_sync( + lambda s: pd.read_sql_query( + select( + pid, + models.TraceAnnotation.name, + func.avg(models.TraceAnnotation.score).label("mean_score"), + ) + .group_by(pid, models.TraceAnnotation.name) + .order_by(pid, models.TraceAnnotation.name) + .join_from(models.Trace, models.TraceAnnotation) + .where(models.TraceAnnotation.name.in_(("B", "D"))) + .where(start_time <= models.Trace.start_time) + .where(models.Trace.start_time < end_time), + s.connection(), + ) + ) + expected = trace_df.loc[:, "mean_score"].to_list() + span_df.loc[:, "mean_score"].to_list() + actual = [ + smry.mean_score() + for smry in ( + await EvaluationSummaryDataLoader(db)._load_fn( + [ + ( + kind, + id_ + 1, + TimeRange(start=start_time, end=end_time), + "'_5_' in name" if kind == "span" else None, + eval_name, + ) + for kind in ("trace", "span") + for id_ in range(10) + for eval_name in (("B", "D") if kind == "trace" else ("A", "C")) + ] + ) + ) + ] + assert actual == pytest.approx(expected, 1e-7) diff --git a/tests/server/api/dataloaders/test_latency_ms_quantiles.py b/tests/server/api/dataloaders/test_latency_ms_quantiles.py new file mode 100644 index 0000000000..d4e75d3070 --- /dev/null +++ b/tests/server/api/dataloaders/test_latency_ms_quantiles.py @@ -0,0 +1,63 @@ +from datetime import datetime +from typing import AsyncContextManager, Callable + +import pandas as pd +import pytest +from phoenix.db import models +from phoenix.server.api.dataloaders import LatencyMsQuantileDataLoader +from phoenix.server.api.input_types.TimeRange import TimeRange +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + + +async def test_latency_ms_quantiles_p25_p50_p75( + db: Callable[[], AsyncContextManager[AsyncSession]], + data_for_testing_dataloaders: None, +) -> None: + start_time = datetime.fromisoformat("2021-01-01T00:00:10.000+00:00") + end_time = datetime.fromisoformat("2021-01-01T00:10:00.000+00:00") + pid = models.Trace.project_rowid + async with db() as session: + span_df = await session.run_sync( + lambda s: pd.read_sql_query( + select(pid, models.Span.latency_ms.label("latency_ms")) + .join_from(models.Trace, models.Span) + .where(models.Span.name.contains("_5_")) + .where(start_time <= models.Span.start_time) + .where(models.Span.start_time < end_time), + s.connection(), + ) + ) + trace_df = await session.run_sync( + lambda s: pd.read_sql_query( + select(pid, models.Trace.latency_ms.label("latency_ms")) + .where(start_time <= models.Trace.start_time) + .where(models.Trace.start_time < end_time), + s.connection(), + ) + ) + expected = ( + trace_df.groupby("project_rowid")["latency_ms"] + .quantile([0.25, 0.50, 0.75]) + .sort_index() + .to_list() + + span_df.groupby("project_rowid")["latency_ms"] + .quantile([0.25, 0.50, 0.75]) + .sort_index() + .to_list() + ) + actual = await LatencyMsQuantileDataLoader(db)._load_fn( + [ + ( + kind, + id_ + 1, + TimeRange(start=start_time, end=end_time), + "'_5_' in name" if kind == "span" else None, + probability, + ) + for kind in ("trace", "span") + for id_ in range(10) + for probability in (0.25, 0.50, 0.75) + ] + ) + assert actual == pytest.approx(expected, 1e-7) diff --git a/tests/server/api/dataloaders/test_record_counts.py b/tests/server/api/dataloaders/test_record_counts.py new file mode 100644 index 0000000000..3db60f622b --- /dev/null +++ b/tests/server/api/dataloaders/test_record_counts.py @@ -0,0 +1,55 @@ +from datetime import datetime +from typing import AsyncContextManager, Callable + +import pandas as pd +from phoenix.db import models +from phoenix.server.api.dataloaders import RecordCountDataLoader +from phoenix.server.api.input_types.TimeRange import TimeRange +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + + +async def test_record_counts( + db: Callable[[], AsyncContextManager[AsyncSession]], + data_for_testing_dataloaders: None, +) -> None: + start_time = datetime.fromisoformat("2021-01-01T00:00:10.000+00:00") + end_time = datetime.fromisoformat("2021-01-01T00:10:00.000+00:00") + pid = models.Trace.project_rowid + async with db() as session: + span_df = await session.run_sync( + lambda s: pd.read_sql_query( + select(pid, func.count().label("count")) + .join_from(models.Trace, models.Span) + .group_by(pid) + .order_by(pid) + .where(models.Span.name.contains("_5_")) + .where(start_time <= models.Span.start_time) + .where(models.Span.start_time < end_time), + s.connection(), + ) + ) + trace_df = await session.run_sync( + lambda s: pd.read_sql_query( + select(pid, func.count().label("count")) + .group_by(pid) + .order_by(pid) + .where(start_time <= models.Trace.start_time) + .where(models.Trace.start_time < end_time), + s.connection(), + ) + ) + expected = trace_df.loc[:, "count"].to_list() + span_df.loc[:, "count"].to_list() + actual = await RecordCountDataLoader(db)._load_fn( + [ + ( + kind, + id_ + 1, + TimeRange(start=start_time, end=end_time), + "'_5_' in name" if kind == "span" else None, + ) + for kind in ("trace", "span") + for id_ in range(10) + ] + ) + assert actual == expected diff --git a/tests/server/api/dataloaders/test_token_counts.py b/tests/server/api/dataloaders/test_token_counts.py new file mode 100644 index 0000000000..9453b5785f --- /dev/null +++ b/tests/server/api/dataloaders/test_token_counts.py @@ -0,0 +1,55 @@ +from datetime import datetime +from typing import AsyncContextManager, Callable + +import pandas as pd +from phoenix.db import models +from phoenix.server.api.dataloaders import TokenCountDataLoader +from phoenix.server.api.input_types.TimeRange import TimeRange +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + + +async def test_token_counts( + db: Callable[[], AsyncContextManager[AsyncSession]], + data_for_testing_dataloaders: None, +) -> None: + start_time = datetime.fromisoformat("2021-01-01T00:00:10.000+00:00") + end_time = datetime.fromisoformat("2021-01-01T00:10:00.000+00:00") + async with db() as session: + prompt = models.Span.attributes[["llm", "token_count", "prompt"]].as_float() + completion = models.Span.attributes[["llm", "token_count", "completion"]].as_float() + pid = models.Trace.project_rowid + span_df = await session.run_sync( + lambda s: pd.read_sql_query( + select( + pid, + func.sum(prompt).label("prompt"), + func.sum(completion).label("completion"), + ) + .join(models.Span) + .group_by(pid) + .order_by(pid) + .where(models.Span.name.contains("_5_")) + .where(start_time <= models.Span.start_time) + .where(models.Span.start_time < end_time), + s.connection(), + ) + ) + expected = ( + span_df.loc[:, "prompt"].to_list() + + span_df.loc[:, "completion"].to_list() + + (span_df.loc[:, "prompt"] + span_df.loc[:, "completion"]).to_list() + ) + actual = await TokenCountDataLoader(db)._load_fn( + [ + ( + kind, + id_ + 1, + TimeRange(start=start_time, end=end_time), + "'_5_' in name", + ) + for kind in ("prompt", "completion", "total") + for id_ in range(10) + ] + ) + assert actual == expected From 39926d01f25430d1526e69be4a6beaf8171aa5af Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 1 May 2024 14:25:14 -0700 Subject: [PATCH 03/74] add pagination it notebook --- .../pagination_query_testing.ipynb | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 integration-tests/pagination_query_testing.ipynb diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb new file mode 100644 index 0000000000..3be3c93c07 --- /dev/null +++ b/integration-tests/pagination_query_testing.ipynb @@ -0,0 +1,76 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dictdiffer import diff\n", + "from gql import Client, gql\n", + "from gql.transport.requests import RequestsHTTPTransport\n", + "\n", + "new_url = \"http://127.0.0.1:6006/graphql\"\n", + "old_url = \"http://127.0.0.1:6008/graphql\"\n", + "\n", + "client_new_url = Client(\n", + " transport=RequestsHTTPTransport(url=new_url, timeout=1),\n", + " fetch_schema_from_transport=True,\n", + ")\n", + "client_old_url = Client(\n", + " transport=RequestsHTTPTransport(url=old_url, timeout=1),\n", + " fetch_schema_from_transport=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "spans_query = gql(\n", + " \"\"\"query SpansQuery($projectId: GlobalID!, $after: String = null, $before: String = null, $first: Int = null, $last: Int = null) {\n", + " node(id: $projectId) {\n", + " ...on Project {\n", + " spans(after: $after, before: $before, first: $first, last: $last) {\n", + " edges {\n", + " node {\n", + " name\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + "}\"\"\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "list(\n", + " diff(\n", + " client_old_url.execute(\n", + " spans_query, variable_values={\"projectId\": \"UHJvamVjdDow\", \"first\": 10}\n", + " ),\n", + " client_new_url.execute(\n", + " spans_query, variable_values={\"projectId\": \"UHJvamVjdDox\", \"first\": 10}\n", + " ),\n", + " tolerance=0.0001,\n", + " )\n", + ")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 986a14081dc0bc586340f38a8d6b4d2bc744dab5 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 1 May 2024 14:53:02 -0700 Subject: [PATCH 04/74] clean test --- .../pagination_query_testing.ipynb | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 3be3c93c07..d2301e01bb 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -6,13 +6,19 @@ "metadata": {}, "outputs": [], "source": [ - "from dictdiffer import diff\n", "from gql import Client, gql\n", "from gql.transport.requests import RequestsHTTPTransport\n", + "from phoenix.server.api.types.pagination import (\n", + " cursor_to_id,\n", + " cursor_to_offset,\n", + ")\n", "\n", "new_url = \"http://127.0.0.1:6006/graphql\"\n", "old_url = \"http://127.0.0.1:6008/graphql\"\n", "\n", + "new_project_id = \"UHJvamVjdDox\"\n", + "old_project_id = \"UHJvamVjdDow\"\n", + "\n", "client_new_url = Client(\n", " transport=RequestsHTTPTransport(url=new_url, timeout=1),\n", " fetch_schema_from_transport=True,\n", @@ -35,9 +41,7 @@ " ...on Project {\n", " spans(after: $after, before: $before, first: $first, last: $last) {\n", " edges {\n", - " node {\n", - " name\n", - " }\n", + " cursor\n", " }\n", " }\n", " }\n", @@ -52,17 +56,18 @@ "metadata": {}, "outputs": [], "source": [ - "list(\n", - " diff(\n", - " client_old_url.execute(\n", - " spans_query, variable_values={\"projectId\": \"UHJvamVjdDow\", \"first\": 10}\n", - " ),\n", - " client_new_url.execute(\n", - " spans_query, variable_values={\"projectId\": \"UHJvamVjdDox\", \"first\": 10}\n", - " ),\n", - " tolerance=0.0001,\n", - " )\n", - ")" + "new_response = client_new_url.execute(\n", + " spans_query, variable_values={\"projectId\": new_project_id, \"first\": 10}\n", + ")\n", + "old_response = client_old_url.execute(\n", + " spans_query, variable_values={\"projectId\": old_project_id, \"first\": 10}\n", + ")\n", + "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "old_cursors = [edge[\"cursor\"] for edge in old_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_offsets = [cursor_to_offset(cursor) for cursor in old_cursors]\n", + "print(new_ids)\n", + "print(new_offsets)" ] } ], From 16935a4adcb3b0907a72c90ad15e3cdd37422527 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 1 May 2024 14:56:40 -0700 Subject: [PATCH 05/74] add helper functions for converting cursors to and from ids --- src/phoenix/server/api/types/pagination.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index a90602b483..ce1562e4ff 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -4,7 +4,9 @@ import strawberry from strawberry import UNSET +from typing_extensions import TypeAlias +ID: TypeAlias = int GenericType = TypeVar("GenericType") @@ -56,6 +58,21 @@ class Edge(Generic[GenericType]): CURSOR_PREFIX = "connection:" +def id_to_cursor(id: ID) -> Cursor: + """ + Creates a cursor string from an ID. + """ + return base64.b64encode(f"{CURSOR_PREFIX}{id}".encode("utf-8")).decode() + + +def cursor_to_id(cursor: Cursor) -> ID: + """ + Extracts the ID from the cursor string. + """ + _, id = base64.b64decode(cursor).decode().split(":") + return int(id) + + def offset_to_cursor(offset: int) -> Cursor: """ Creates the cursor string from an offset. From 56faea75414211230a2ebe8b612885ba3e983773 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 1 May 2024 15:30:56 -0700 Subject: [PATCH 06/74] add test case for after and first --- .../pagination_query_testing.ipynb | 40 +++++++++++++++++-- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index d2301e01bb..3e11ec3530 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -11,6 +11,8 @@ "from phoenix.server.api.types.pagination import (\n", " cursor_to_id,\n", " cursor_to_offset,\n", + " id_to_cursor,\n", + " offset_to_cursor,\n", ")\n", "\n", "new_url = \"http://127.0.0.1:6006/graphql\"\n", @@ -19,11 +21,11 @@ "new_project_id = \"UHJvamVjdDox\"\n", "old_project_id = \"UHJvamVjdDow\"\n", "\n", - "client_new_url = Client(\n", + "new_client = Client(\n", " transport=RequestsHTTPTransport(url=new_url, timeout=1),\n", " fetch_schema_from_transport=True,\n", ")\n", - "client_old_url = Client(\n", + "old_client = Client(\n", " transport=RequestsHTTPTransport(url=old_url, timeout=1),\n", " fetch_schema_from_transport=True,\n", ")" @@ -56,10 +58,10 @@ "metadata": {}, "outputs": [], "source": [ - "new_response = client_new_url.execute(\n", + "new_response = new_client.execute(\n", " spans_query, variable_values={\"projectId\": new_project_id, \"first\": 10}\n", ")\n", - "old_response = client_old_url.execute(\n", + "old_response = old_client.execute(\n", " spans_query, variable_values={\"projectId\": old_project_id, \"first\": 10}\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", @@ -69,6 +71,36 @@ "print(new_ids)\n", "print(new_offsets)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_response = new_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": new_project_id,\n", + " \"after\": id_to_cursor(10),\n", + " \"first\": 10,\n", + " },\n", + ")\n", + "old_response = old_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": old_project_id,\n", + " \"after\": offset_to_cursor(10),\n", + " \"first\": 10,\n", + " },\n", + ")\n", + "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "old_cursors = [edge[\"cursor\"] for edge in old_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_offsets = [cursor_to_offset(cursor) for cursor in old_cursors]\n", + "print(new_ids)\n", + "print(new_offsets)" + ] } ], "metadata": { From 0c7d4df1cd802dacfe94585e5f5e9e29be41ef23 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 15:25:46 -0700 Subject: [PATCH 07/74] add page info to tests --- .../pagination_query_testing.ipynb | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 3e11ec3530..d5acb91153 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -6,6 +6,7 @@ "metadata": {}, "outputs": [], "source": [ + "from dictdiffer import diff\n", "from gql import Client, gql\n", "from gql.transport.requests import RequestsHTTPTransport\n", "from phoenix.server.api.types.pagination import (\n", @@ -45,11 +46,29 @@ " edges {\n", " cursor\n", " }\n", + " pageInfo {\n", + " hasNextPage\n", + " hasPreviousPage\n", + " totalCount\n", + " }\n", " }\n", " }\n", " }\n", "}\"\"\"\n", - ")" + ")\n", + "\n", + "\n", + "def compare_responses(new_response, old_response):\n", + " new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + " new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + " old_cursors = [edge[\"cursor\"] for edge in old_response[\"node\"][\"spans\"][\"edges\"]]\n", + " new_offsets = [cursor_to_offset(cursor) for cursor in old_cursors]\n", + " new_page_info = new_response[\"node\"][\"spans\"][\"pageInfo\"]\n", + " old_page_info = old_response[\"node\"][\"spans\"][\"pageInfo\"]\n", + " page_info_diff = list(diff(new_page_info, old_page_info))\n", + " print(f\"{new_ids=}\")\n", + " print(f\"{new_offsets=}\")\n", + " print(f\"{page_info_diff=}\")" ] }, { @@ -64,12 +83,7 @@ "old_response = old_client.execute(\n", " spans_query, variable_values={\"projectId\": old_project_id, \"first\": 10}\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", - "old_cursors = [edge[\"cursor\"] for edge in old_response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_offsets = [cursor_to_offset(cursor) for cursor in old_cursors]\n", - "print(new_ids)\n", - "print(new_offsets)" + "compare_responses(new_response, old_response)" ] }, { @@ -94,12 +108,7 @@ " \"first\": 10,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", - "old_cursors = [edge[\"cursor\"] for edge in old_response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_offsets = [cursor_to_offset(cursor) for cursor in old_cursors]\n", - "print(new_ids)\n", - "print(new_offsets)" + "compare_responses(new_response, old_response)" ] } ], From f2a70f3e5b3568b83df022699539c6a1c82f9fef Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 15:39:48 -0700 Subject: [PATCH 08/74] remove totalCount from PageInfo gql type --- app/schema.graphql | 1 - src/phoenix/server/api/types/pagination.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/app/schema.graphql b/app/schema.graphql index e127cb7f9c..235d56a0ee 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -508,7 +508,6 @@ type PageInfo { hasPreviousPage: Boolean! startCursor: String endCursor: String - totalCount: Int! } enum PerformanceMetric { diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index ce1562e4ff..03ecdf1f50 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -37,7 +37,6 @@ class PageInfo: has_previous_page: bool start_cursor: Optional[str] end_cursor: Optional[str] - total_count: int # A type alias for the connection cursor implementation @@ -186,6 +185,5 @@ def connection_from_list_slice( end_cursor=last_edge.cursor if last_edge else None, has_previous_page=start_offset > lower_bound if isinstance(args.last, int) else False, has_next_page=end_offset < upper_bound if isinstance(args.first, int) else False, - total_count=list_length, ), ) From 6bce63bfa0332e0d8a18908ea3558123cb32ad62 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 17:17:35 -0700 Subject: [PATCH 09/74] nail down test case and understand offset cursors --- .../pagination_query_testing.ipynb | 40 +++++++++++++++---- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index d5acb91153..af6504d37f 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -45,11 +45,13 @@ " spans(after: $after, before: $before, first: $first, last: $last) {\n", " edges {\n", " cursor\n", + " node {\n", + " spanKind\n", + " }\n", " }\n", " pageInfo {\n", " hasNextPage\n", " hasPreviousPage\n", - " totalCount\n", " }\n", " }\n", " }\n", @@ -62,13 +64,15 @@ " new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", " new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", " old_cursors = [edge[\"cursor\"] for edge in old_response[\"node\"][\"spans\"][\"edges\"]]\n", - " new_offsets = [cursor_to_offset(cursor) for cursor in old_cursors]\n", + " old_offsets = [cursor_to_offset(cursor) for cursor in old_cursors]\n", " new_page_info = new_response[\"node\"][\"spans\"][\"pageInfo\"]\n", " old_page_info = old_response[\"node\"][\"spans\"][\"pageInfo\"]\n", " page_info_diff = list(diff(new_page_info, old_page_info))\n", " print(f\"{new_ids=}\")\n", - " print(f\"{new_offsets=}\")\n", - " print(f\"{page_info_diff=}\")" + " print(f\"{old_offsets=}\")\n", + " print(f\"{page_info_diff=}\")\n", + " assert new_ids == [offset + 1 for offset in old_offsets], \"mismatched spans\"\n", + " assert page_info_diff == [], \"mismatched page info\"" ] }, { @@ -78,14 +82,34 @@ "outputs": [], "source": [ "new_response = new_client.execute(\n", - " spans_query, variable_values={\"projectId\": new_project_id, \"first\": 10}\n", + " spans_query,\n", + " variable_values={\"projectId\": new_project_id, \"first\": 10},\n", ")\n", "old_response = old_client.execute(\n", - " spans_query, variable_values={\"projectId\": old_project_id, \"first\": 10}\n", + " spans_query,\n", + " variable_values={\"projectId\": old_project_id, \"first\": 10},\n", ")\n", "compare_responses(new_response, old_response)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_response" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "old_response" + ] + }, { "cell_type": "code", "execution_count": null, @@ -96,7 +120,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(10),\n", + " \"after\": id_to_cursor(2),\n", " \"first\": 10,\n", " },\n", ")\n", @@ -104,7 +128,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": old_project_id,\n", - " \"after\": offset_to_cursor(10),\n", + " \"after\": offset_to_cursor(1),\n", " \"first\": 10,\n", " },\n", ")\n", From 6e97aa77fe742ca59dcaa1694b75bd11e207cbef Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 17:30:33 -0700 Subject: [PATCH 10/74] implement forward cursor-based pagination with basic support (not tested on filtering or ordering) --- .../pagination_query_testing.ipynb | 18 --------------- src/phoenix/server/api/types/Project.py | 23 +++++++++++-------- src/phoenix/server/api/types/pagination.py | 22 +++++++++++++++++- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index af6504d37f..f78b85548e 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -92,24 +92,6 @@ "compare_responses(new_response, old_response)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "new_response" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "old_response" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 628c34927a..0192d0165f 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -18,9 +18,9 @@ from phoenix.server.api.types.node import Node from phoenix.server.api.types.pagination import ( Connection, - ConnectionArgs, Cursor, - connection_from_list, + connections, + cursor_to_id, ) from phoenix.server.api.types.Span import Span, to_gql_span from phoenix.server.api.types.Trace import Trace @@ -162,12 +162,6 @@ async def spans( root_spans_only: Optional[bool] = UNSET, filter_condition: Optional[str] = UNSET, ) -> Connection[Span]: - args = ConnectionArgs( - first=first, - after=after if isinstance(after, Cursor) else None, - last=last, - before=before if isinstance(before, Cursor) else None, - ) stmt = ( select(models.Span) .join(models.Trace) @@ -192,12 +186,21 @@ async def spans( if filter_condition: span_filter = SpanFilter(condition=filter_condition) stmt = span_filter(stmt) + if after: + span_rowid = cursor_to_id(after) + stmt = stmt.where(models.Span.id > span_rowid) + if first: + stmt = stmt.limit(first) if sort: stmt = sort.update_orm_expr(stmt) async with info.context.db() as session: spans = await session.scalars(stmt) - data = [to_gql_span(span) for span in spans] - return connection_from_list(data=data, args=args) + data = [(span.id, to_gql_span(span)) for span in spans] + return connections( + data, + has_previous_page=False, + has_next_page=True, # todo: overfetch to determine whether we have a next page + ) @strawberry.field( description="Names of all available evaluations for traces. " diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index 03ecdf1f50..430831579d 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -1,6 +1,6 @@ import base64 from dataclasses import dataclass -from typing import Generic, List, Optional, TypeVar +from typing import Generic, List, Optional, Tuple, TypeVar import strawberry from strawberry import UNSET @@ -187,3 +187,23 @@ def connection_from_list_slice( has_next_page=end_offset < upper_bound if isinstance(args.first, int) else False, ), ) + + +def connections( + data: List[Tuple[ID, GenericType]], + has_previous_page: bool, + has_next_page: bool, +) -> Connection[GenericType]: + edges = [Edge(node=node, cursor=id_to_cursor(id)) for id, node in data] + has_edges = len(edges) > 0 + first_edge = edges[0] if has_edges else None + last_edge = edges[-1] if has_edges else None + return Connection( + edges=edges, + page_info=PageInfo( + start_cursor=first_edge.cursor if first_edge else None, + end_cursor=last_edge.cursor if last_edge else None, + has_previous_page=has_previous_page, + has_next_page=has_next_page, + ), + ) From 8ae29bfac9d8079e5b54c2969f05f1091e3eec48 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 17:49:24 -0700 Subject: [PATCH 11/74] add failing test cases to test for the correctness of hasNextPage --- .../pagination_query_testing.ipynb | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index f78b85548e..ff4ceb144a 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -116,6 +116,84 @@ ")\n", "compare_responses(new_response, old_response)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# there are 765 spans in the llama-index rag fixture\n", + "new_response = new_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": new_project_id,\n", + " \"after\": id_to_cursor(754),\n", + " \"first\": 10,\n", + " },\n", + ")\n", + "old_response = old_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": old_project_id,\n", + " \"after\": offset_to_cursor(753),\n", + " \"first\": 10,\n", + " },\n", + ")\n", + "compare_responses(new_response, old_response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# there are 765 spans in the llama-index rag fixture\n", + "new_response = new_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": new_project_id,\n", + " \"after\": id_to_cursor(755),\n", + " \"first\": 10,\n", + " },\n", + ")\n", + "old_response = old_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": old_project_id,\n", + " \"after\": offset_to_cursor(754),\n", + " \"first\": 10,\n", + " },\n", + ")\n", + "compare_responses(new_response, old_response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# there are 765 spans in the llama-index rag fixture\n", + "new_response = new_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": new_project_id,\n", + " \"after\": id_to_cursor(756),\n", + " \"first\": 10,\n", + " },\n", + ")\n", + "old_response = old_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": old_project_id,\n", + " \"after\": offset_to_cursor(755),\n", + " \"first\": 10,\n", + " },\n", + ")\n", + "compare_responses(new_response, old_response)" + ] } ], "metadata": { From 8501f55800b91c6169b0494fa40dbd0d800f57bf Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 18:01:41 -0700 Subject: [PATCH 12/74] add overfetching to determine whether there's a next page and pass tests --- src/phoenix/server/api/types/Project.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 0192d0165f..cff5177543 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -1,4 +1,5 @@ from datetime import datetime +from itertools import islice from typing import List, Optional import strawberry @@ -190,16 +191,25 @@ async def spans( span_rowid = cursor_to_id(after) stmt = stmt.where(models.Span.id > span_rowid) if first: - stmt = stmt.limit(first) + stmt = stmt.limit( + first + 1 # overfetch by one to determine whether there's a next page + ) if sort: stmt = sort.update_orm_expr(stmt) async with info.context.db() as session: spans = await session.scalars(stmt) - data = [(span.id, to_gql_span(span)) for span in spans] + + data = [(span.id, to_gql_span(span)) for span in islice(spans, first)] + has_next_page = True + try: + next(spans) + except StopIteration: + has_next_page = False + return connections( data, has_previous_page=False, - has_next_page=True, # todo: overfetch to determine whether we have a next page + has_next_page=has_next_page, ) @strawberry.field( From 1aa6e2d95b4e1f369e86a0ab1c800a81ae64c879 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 18:26:44 -0700 Subject: [PATCH 13/74] add filter condition to span query and format notebook --- integration-tests/pagination_query_testing.ipynb | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index ff4ceb144a..e5c6c88015 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -39,10 +39,16 @@ "outputs": [], "source": [ "spans_query = gql(\n", - " \"\"\"query SpansQuery($projectId: GlobalID!, $after: String = null, $before: String = null, $first: Int = null, $last: Int = null) {\n", + " \"\"\"query SpansQuery($projectId: GlobalID!, $after: String = null, $before: String = null, $filterCondition: String = null, $first: Int = null, $last: Int = null) {\n", " node(id: $projectId) {\n", - " ...on Project {\n", - " spans(after: $after, before: $before, first: $first, last: $last) {\n", + " ... on Project {\n", + " spans(\n", + " after: $after\n", + " before: $before\n", + " filterCondition: $filterCondition\n", + " first: $first\n", + " last: $last\n", + " ) {\n", " edges {\n", " cursor\n", " node {\n", From 61177b822106adfcaa6b0797fc8574929ad8104d Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 18:50:21 -0700 Subject: [PATCH 14/74] add failing test case for filter condition --- .../pagination_query_testing.ipynb | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index e5c6c88015..61cb17085f 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -48,6 +48,7 @@ " filterCondition: $filterCondition\n", " first: $first\n", " last: $last\n", + " rootSpansOnly: false\n", " ) {\n", " edges {\n", " cursor\n", @@ -200,6 +201,33 @@ ")\n", "compare_responses(new_response, old_response)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_response = new_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": new_project_id,\n", + " \"after\": id_to_cursor(2),\n", + " \"first\": 10,\n", + " \"filterCondition\": \"span_kind == 'LLM'\",\n", + " },\n", + ")\n", + "old_response = old_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": old_project_id,\n", + " \"after\": offset_to_cursor(1),\n", + " \"first\": 10,\n", + " \"filterCondition\": \"span_kind == 'LLM'\",\n", + " },\n", + ")\n", + "compare_responses(new_response, old_response)" + ] } ], "metadata": { From 2cbc805457afea11f37a00a76efb2198a5f9457a Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 21:52:38 -0700 Subject: [PATCH 15/74] fix test case for filtering --- .../pagination_query_testing.ipynb | 49 +++++++++---------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 61cb17085f..ddd10f6621 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -10,9 +10,7 @@ "from gql import Client, gql\n", "from gql.transport.requests import RequestsHTTPTransport\n", "from phoenix.server.api.types.pagination import (\n", - " cursor_to_id,\n", " cursor_to_offset,\n", - " id_to_cursor,\n", " offset_to_cursor,\n", ")\n", "\n", @@ -69,7 +67,7 @@ "\n", "def compare_responses(new_response, old_response):\n", " new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", - " new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + " new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", " old_cursors = [edge[\"cursor\"] for edge in old_response[\"node\"][\"spans\"][\"edges\"]]\n", " old_offsets = [cursor_to_offset(cursor) for cursor in old_cursors]\n", " new_page_info = new_response[\"node\"][\"spans\"][\"pageInfo\"]\n", @@ -88,15 +86,15 @@ "metadata": {}, "outputs": [], "source": [ - "new_response = new_client.execute(\n", - " spans_query,\n", - " variable_values={\"projectId\": new_project_id, \"first\": 10},\n", - ")\n", - "old_response = old_client.execute(\n", - " spans_query,\n", - " variable_values={\"projectId\": old_project_id, \"first\": 10},\n", - ")\n", - "compare_responses(new_response, old_response)" + "# new_response = new_client.execute(\n", + "# spans_query,\n", + "# variable_values={\"projectId\": new_project_id, \"first\": 10},\n", + "# )\n", + "# old_response = old_client.execute(\n", + "# spans_query,\n", + "# variable_values={\"projectId\": old_project_id, \"first\": 10},\n", + "# )\n", + "# compare_responses(new_response, old_response)" ] }, { @@ -109,7 +107,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(2),\n", + " \"after\": offset_to_cursor(2),\n", " \"first\": 10,\n", " },\n", ")\n", @@ -135,7 +133,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(754),\n", + " \"after\": offset_to_cursor(754),\n", " \"first\": 10,\n", " },\n", ")\n", @@ -161,7 +159,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(755),\n", + " \"after\": offset_to_cursor(755),\n", " \"first\": 10,\n", " },\n", ")\n", @@ -187,7 +185,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(756),\n", + " \"after\": offset_to_cursor(756),\n", " \"first\": 10,\n", " },\n", ")\n", @@ -212,21 +210,18 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(2),\n", + " \"after\": offset_to_cursor(2),\n", " \"first\": 10,\n", " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", ")\n", - "old_response = old_client.execute(\n", - " spans_query,\n", - " variable_values={\n", - " \"projectId\": old_project_id,\n", - " \"after\": offset_to_cursor(1),\n", - " \"first\": 10,\n", - " \"filterCondition\": \"span_kind == 'LLM'\",\n", - " },\n", - ")\n", - "compare_responses(new_response, old_response)" + "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", + "print(new_ids)\n", + "span_kinds = [\n", + " edge[\"node\"][\"spanKind\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]\n", + "]\n", + "assert all(span_kind == \"llm\" for span_kind in span_kinds)" ] } ], From da66aa254cb3c0b7c30691ec455000875176d396 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 22:24:57 -0700 Subject: [PATCH 16/74] basic filtering and sorting test cases --- .../pagination_query_testing.ipynb | 50 +++++++++++++++---- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index ddd10f6621..b1baf8333d 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -37,7 +37,7 @@ "outputs": [], "source": [ "spans_query = gql(\n", - " \"\"\"query SpansQuery($projectId: GlobalID!, $after: String = null, $before: String = null, $filterCondition: String = null, $first: Int = null, $last: Int = null) {\n", + " \"\"\"query SpansQuery($projectId: GlobalID!, $after: String = null, $before: String = null, $filterCondition: String = null, $first: Int = null, $last: Int = null, $sort: SpanSort = null) {\n", " node(id: $projectId) {\n", " ... on Project {\n", " spans(\n", @@ -47,12 +47,10 @@ " first: $first\n", " last: $last\n", " rootSpansOnly: false\n", + " sort: $sort\n", " ) {\n", " edges {\n", " cursor\n", - " node {\n", - " spanKind\n", - " }\n", " }\n", " pageInfo {\n", " hasNextPage\n", @@ -217,11 +215,45 @@ ")\n", "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", - "print(new_ids)\n", - "span_kinds = [\n", - " edge[\"node\"][\"spanKind\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]\n", - "]\n", - "assert all(span_kind == \"llm\" for span_kind in span_kinds)" + "assert new_ids == [6, 11, 16, 21, 26, 31, 36, 41, 46, 51], new_ids" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_response = new_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": new_project_id,\n", + " \"sort\": {\"col\": \"startTime\", \"dir\": \"desc\"},\n", + " \"first\": 10,\n", + " },\n", + ")\n", + "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", + "assert new_ids == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], new_ids" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_response = new_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": new_project_id,\n", + " \"sort\": {\"col\": \"startTime\", \"dir\": \"asc\"},\n", + " \"first\": 10,\n", + " },\n", + ")\n", + "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", + "assert new_ids == [765, 764, 763, 762, 761, 760, 759, 758, 757, 756], new_ids" ] } ], From fcbe384eea5b97176065a5a5887510a89cd375c3 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 22:29:31 -0700 Subject: [PATCH 17/74] clean tests --- .../pagination_query_testing.ipynb | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index b1baf8333d..5827ec524e 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -10,7 +10,9 @@ "from gql import Client, gql\n", "from gql.transport.requests import RequestsHTTPTransport\n", "from phoenix.server.api.types.pagination import (\n", + " cursor_to_id,\n", " cursor_to_offset,\n", + " id_to_cursor,\n", " offset_to_cursor,\n", ")\n", "\n", @@ -65,7 +67,7 @@ "\n", "def compare_responses(new_response, old_response):\n", " new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", - " new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", + " new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", " old_cursors = [edge[\"cursor\"] for edge in old_response[\"node\"][\"spans\"][\"edges\"]]\n", " old_offsets = [cursor_to_offset(cursor) for cursor in old_cursors]\n", " new_page_info = new_response[\"node\"][\"spans\"][\"pageInfo\"]\n", @@ -84,15 +86,15 @@ "metadata": {}, "outputs": [], "source": [ - "# new_response = new_client.execute(\n", - "# spans_query,\n", - "# variable_values={\"projectId\": new_project_id, \"first\": 10},\n", - "# )\n", - "# old_response = old_client.execute(\n", - "# spans_query,\n", - "# variable_values={\"projectId\": old_project_id, \"first\": 10},\n", - "# )\n", - "# compare_responses(new_response, old_response)" + "new_response = new_client.execute(\n", + " spans_query,\n", + " variable_values={\"projectId\": new_project_id, \"first\": 10},\n", + ")\n", + "old_response = old_client.execute(\n", + " spans_query,\n", + " variable_values={\"projectId\": old_project_id, \"first\": 10},\n", + ")\n", + "compare_responses(new_response, old_response)" ] }, { @@ -105,7 +107,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": new_project_id,\n", - " \"after\": offset_to_cursor(2),\n", + " \"after\": id_to_cursor(2),\n", " \"first\": 10,\n", " },\n", ")\n", @@ -131,7 +133,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": new_project_id,\n", - " \"after\": offset_to_cursor(754),\n", + " \"after\": id_to_cursor(754),\n", " \"first\": 10,\n", " },\n", ")\n", @@ -157,7 +159,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": new_project_id,\n", - " \"after\": offset_to_cursor(755),\n", + " \"after\": id_to_cursor(755),\n", " \"first\": 10,\n", " },\n", ")\n", @@ -183,7 +185,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": new_project_id,\n", - " \"after\": offset_to_cursor(756),\n", + " \"after\": id_to_cursor(756),\n", " \"first\": 10,\n", " },\n", ")\n", @@ -208,7 +210,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": new_project_id,\n", - " \"after\": offset_to_cursor(2),\n", + " \"after\": id_to_cursor(2),\n", " \"first\": 10,\n", " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", From 0a18e1c1bdb5eed75de8ea006ae23fe8c6ffb0c6 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 23:03:07 -0700 Subject: [PATCH 18/74] add test case covering attribute filter with cursor --- .../pagination_query_testing.ipynb | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 5827ec524e..9c90a3933d 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -220,6 +220,26 @@ "assert new_ids == [6, 11, 16, 21, 26, 31, 36, 41, 46, 51], new_ids" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_response = new_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": new_project_id,\n", + " \"after\": id_to_cursor(1), # skip the first span satisfying the filter condition\n", + " \"first\": 10,\n", + " \"filterCondition\": \"span_kind == 'LLM' and cumulative_llm_token_count_prompt > 200\",\n", + " },\n", + ")\n", + "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", + "assert new_ids == [21, 26, 31, 36, 41, 46, 51, 56, 61, 66]" + ] + }, { "cell_type": "code", "execution_count": null, From a617752fafc66003f49e4e3d4c9e7eeb3ea9ff63 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 23:21:47 -0700 Subject: [PATCH 19/74] descending ordering test with cursor --- integration-tests/pagination_query_testing.ipynb | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 9c90a3933d..231ff7788d 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -237,7 +237,7 @@ ")\n", "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", - "assert new_ids == [21, 26, 31, 36, 41, 46, 51, 56, 61, 66]" + "assert new_ids == [21, 26, 31, 36, 41, 46, 51, 56, 61, 66], new_ids" ] }, { @@ -269,13 +269,14 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": new_project_id,\n", - " \"sort\": {\"col\": \"startTime\", \"dir\": \"asc\"},\n", + " \"after\": id_to_cursor(1),\n", + " \"sort\": {\"col\": \"startTime\", \"dir\": \"desc\"},\n", " \"first\": 10,\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", - "assert new_ids == [765, 764, 763, 762, 761, 760, 759, 758, 757, 756], new_ids" + "assert new_ids == [2, 3, 4, 5, 6, 7, 8, 9, 10, 11], new_ids" ] } ], From 3f7a0a0848eb7cc5110e3224e8e556e9349976ca Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 2 May 2024 23:59:09 -0700 Subject: [PATCH 20/74] add failing test to order by ascending start time with cursor --- .../pagination_query_testing.ipynb | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 231ff7788d..7a9de6b0d2 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -278,6 +278,37 @@ "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", "assert new_ids == [2, 3, 4, 5, 6, 7, 8, 9, 10, 11], new_ids" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "new_response = new_client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": new_project_id,\n", + " \"after\": id_to_cursor(765),\n", + " \"sort\": {\"col\": \"startTime\", \"dir\": \"asc\"},\n", + " \"first\": 10,\n", + " },\n", + ")\n", + "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", + "assert new_ids == [\n", + " 764,\n", + " 763,\n", + " 762,\n", + " 761,\n", + " 760,\n", + " 759,\n", + " 758,\n", + " 757,\n", + " 756,\n", + " 755,\n", + "], new_ids" + ] } ], "metadata": { From 0bf342e7dce3e665e06de2720af29580be9911b7 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 10:13:45 -0700 Subject: [PATCH 21/74] add failing test to order by ascending start time with cursor --- src/phoenix/server/api/types/Project.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 91edccbf1b..27cb7b95bd 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -3,7 +3,7 @@ import strawberry from aioitertools.itertools import islice -from sqlalchemy import and_, desc, distinct, select +from sqlalchemy import and_, distinct, select from sqlalchemy.orm import contains_eager from strawberry import ID, UNSET from strawberry.types import Info @@ -198,7 +198,9 @@ async def spans( if sort: stmt = sort.update_orm_expr(stmt) else: - stmt = stmt.order_by(desc(models.Span.id)) + stmt = stmt.order_by( + models.Span.id + ) # todo: i changed this to conform to the previous behavior of the api stmt = stmt.limit( SPANS_LIMIT ) # todo: remove this after adding pagination https://github.com/Arize-ai/phoenix/issues/3003 From 7808effc590da4c985f22930799cb1c1a29fbb5b Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 11:33:33 -0700 Subject: [PATCH 22/74] adapt notebook test cases since we are now ordering by descending row id by default --- .../pagination_query_testing.ipynb | 245 ++++++++---------- src/phoenix/server/api/types/Project.py | 8 +- 2 files changed, 114 insertions(+), 139 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 7a9de6b0d2..4d00f3862f 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -6,28 +6,16 @@ "metadata": {}, "outputs": [], "source": [ - "from dictdiffer import diff\n", "from gql import Client, gql\n", "from gql.transport.requests import RequestsHTTPTransport\n", "from phoenix.server.api.types.pagination import (\n", " cursor_to_id,\n", - " cursor_to_offset,\n", " id_to_cursor,\n", - " offset_to_cursor,\n", ")\n", "\n", - "new_url = \"http://127.0.0.1:6006/graphql\"\n", - "old_url = \"http://127.0.0.1:6008/graphql\"\n", - "\n", - "new_project_id = \"UHJvamVjdDox\"\n", - "old_project_id = \"UHJvamVjdDow\"\n", - "\n", - "new_client = Client(\n", - " transport=RequestsHTTPTransport(url=new_url, timeout=1),\n", - " fetch_schema_from_transport=True,\n", - ")\n", - "old_client = Client(\n", - " transport=RequestsHTTPTransport(url=old_url, timeout=1),\n", + "project_id = \"UHJvamVjdDox\"\n", + "client = Client(\n", + " transport=RequestsHTTPTransport(url=\"http://127.0.0.1:6006/graphql\", timeout=1),\n", " fetch_schema_from_transport=True,\n", ")" ] @@ -62,22 +50,7 @@ " }\n", " }\n", "}\"\"\"\n", - ")\n", - "\n", - "\n", - "def compare_responses(new_response, old_response):\n", - " new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", - " new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", - " old_cursors = [edge[\"cursor\"] for edge in old_response[\"node\"][\"spans\"][\"edges\"]]\n", - " old_offsets = [cursor_to_offset(cursor) for cursor in old_cursors]\n", - " new_page_info = new_response[\"node\"][\"spans\"][\"pageInfo\"]\n", - " old_page_info = old_response[\"node\"][\"spans\"][\"pageInfo\"]\n", - " page_info_diff = list(diff(new_page_info, old_page_info))\n", - " print(f\"{new_ids=}\")\n", - " print(f\"{old_offsets=}\")\n", - " print(f\"{page_info_diff=}\")\n", - " assert new_ids == [offset + 1 for offset in old_offsets], \"mismatched spans\"\n", - " assert page_info_diff == [], \"mismatched page info\"" + ")" ] }, { @@ -86,15 +59,14 @@ "metadata": {}, "outputs": [], "source": [ - "new_response = new_client.execute(\n", - " spans_query,\n", - " variable_values={\"projectId\": new_project_id, \"first\": 10},\n", - ")\n", - "old_response = old_client.execute(\n", + "# basic query\n", + "new_response = client.execute(\n", " spans_query,\n", - " variable_values={\"projectId\": old_project_id, \"first\": 10},\n", + " variable_values={\"projectId\": project_id, \"first\": 5},\n", ")\n", - "compare_responses(new_response, old_response)" + "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "assert new_ids == [765, 764, 763, 762, 761], new_ids" ] }, { @@ -103,23 +75,18 @@ "metadata": {}, "outputs": [], "source": [ - "new_response = new_client.execute(\n", + "# query with cursor\n", + "new_response = client.execute(\n", " spans_query,\n", " variable_values={\n", - " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(2),\n", - " \"first\": 10,\n", + " \"projectId\": project_id,\n", + " \"after\": id_to_cursor(761),\n", + " \"first\": 5,\n", " },\n", ")\n", - "old_response = old_client.execute(\n", - " spans_query,\n", - " variable_values={\n", - " \"projectId\": old_project_id,\n", - " \"after\": offset_to_cursor(1),\n", - " \"first\": 10,\n", - " },\n", - ")\n", - "compare_responses(new_response, old_response)" + "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "assert new_ids == [760, 759, 758, 757, 756], new_ids" ] }, { @@ -128,24 +95,22 @@ "metadata": {}, "outputs": [], "source": [ - "# there are 765 spans in the llama-index rag fixture\n", - "new_response = new_client.execute(\n", + "# page ends on the penultimate record and excludees last record\n", + "new_response = client.execute(\n", " spans_query,\n", " variable_values={\n", - " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(754),\n", - " \"first\": 10,\n", + " \"projectId\": project_id,\n", + " \"after\": id_to_cursor(7),\n", + " \"first\": 5,\n", " },\n", ")\n", - "old_response = old_client.execute(\n", - " spans_query,\n", - " variable_values={\n", - " \"projectId\": old_project_id,\n", - " \"after\": offset_to_cursor(753),\n", - " \"first\": 10,\n", - " },\n", - ")\n", - "compare_responses(new_response, old_response)" + "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "has_next_page = new_response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", + "has_previous_page = new_response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", + "assert new_ids == [6, 5, 4, 3, 2], new_ids\n", + "assert has_next_page is True\n", + "assert has_previous_page is False" ] }, { @@ -154,24 +119,22 @@ "metadata": {}, "outputs": [], "source": [ - "# there are 765 spans in the llama-index rag fixture\n", - "new_response = new_client.execute(\n", + "# page ends on the last record exactly\n", + "new_response = client.execute(\n", " spans_query,\n", " variable_values={\n", - " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(755),\n", - " \"first\": 10,\n", + " \"projectId\": project_id,\n", + " \"after\": id_to_cursor(6),\n", + " \"first\": 5,\n", " },\n", ")\n", - "old_response = old_client.execute(\n", - " spans_query,\n", - " variable_values={\n", - " \"projectId\": old_project_id,\n", - " \"after\": offset_to_cursor(754),\n", - " \"first\": 10,\n", - " },\n", - ")\n", - "compare_responses(new_response, old_response)" + "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "has_next_page = new_response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", + "has_previous_page = new_response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", + "assert new_ids == [5, 4, 3, 2, 1], new_ids\n", + "assert has_next_page is False\n", + "assert has_previous_page is False" ] }, { @@ -180,24 +143,22 @@ "metadata": {}, "outputs": [], "source": [ - "# there are 765 spans in the llama-index rag fixture\n", - "new_response = new_client.execute(\n", + "# page ends before it reaches the limit\n", + "new_response = client.execute(\n", " spans_query,\n", " variable_values={\n", - " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(756),\n", - " \"first\": 10,\n", + " \"projectId\": project_id,\n", + " \"after\": id_to_cursor(5),\n", + " \"first\": 5,\n", " },\n", ")\n", - "old_response = old_client.execute(\n", - " spans_query,\n", - " variable_values={\n", - " \"projectId\": old_project_id,\n", - " \"after\": offset_to_cursor(755),\n", - " \"first\": 10,\n", - " },\n", - ")\n", - "compare_responses(new_response, old_response)" + "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "has_next_page = new_response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", + "has_previous_page = new_response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", + "assert new_ids == [4, 3, 2, 1], new_ids\n", + "assert has_next_page is False\n", + "assert has_previous_page is False" ] }, { @@ -206,18 +167,24 @@ "metadata": {}, "outputs": [], "source": [ - "new_response = new_client.execute(\n", + "# basic filter condition\n", + "new_response = client.execute(\n", " spans_query,\n", " variable_values={\n", - " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(2),\n", - " \"first\": 10,\n", + " \"projectId\": project_id,\n", + " \"first\": 5,\n", " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", - "assert new_ids == [6, 11, 16, 21, 26, 31, 36, 41, 46, 51], new_ids" + "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "assert new_ids == [\n", + " 761,\n", + " 756,\n", + " 751,\n", + " 746,\n", + " 741,\n", + "], new_ids" ] }, { @@ -226,18 +193,25 @@ "metadata": {}, "outputs": [], "source": [ - "new_response = new_client.execute(\n", + "# basic filter condition with cursor\n", + "new_response = client.execute(\n", " spans_query,\n", " variable_values={\n", - " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(1), # skip the first span satisfying the filter condition\n", - " \"first\": 10,\n", - " \"filterCondition\": \"span_kind == 'LLM' and cumulative_llm_token_count_prompt > 200\",\n", + " \"projectId\": project_id,\n", + " \"first\": 5,\n", + " \"after\": id_to_cursor(761), # skip the first span satisfying the filter condition\n", + " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", - "assert new_ids == [21, 26, 31, 36, 41, 46, 51, 56, 61, 66], new_ids" + "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "assert new_ids == [\n", + " 756,\n", + " 751,\n", + " 746,\n", + " 741,\n", + " 736,\n", + "], new_ids" ] }, { @@ -246,17 +220,25 @@ "metadata": {}, "outputs": [], "source": [ - "new_response = new_client.execute(\n", + "# compound filter condition with cursor\n", + "new_response = client.execute(\n", " spans_query,\n", " variable_values={\n", - " \"projectId\": new_project_id,\n", - " \"sort\": {\"col\": \"startTime\", \"dir\": \"desc\"},\n", - " \"first\": 10,\n", + " \"projectId\": project_id,\n", + " \"after\": id_to_cursor(761), # skip the first span satisfying the filter condition\n", + " \"first\": 5,\n", + " \"filterCondition\": \"span_kind == 'LLM' and cumulative_llm_token_count_prompt > 300\",\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", - "assert new_ids == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], new_ids" + "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "assert new_ids == [\n", + " 756,\n", + " 751,\n", + " 736,\n", + " 731,\n", + " 721,\n", + "], new_ids" ] }, { @@ -265,18 +247,18 @@ "metadata": {}, "outputs": [], "source": [ - "new_response = new_client.execute(\n", + "# order by start time\n", + "new_response = client.execute(\n", " spans_query,\n", " variable_values={\n", - " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(1),\n", + " \"projectId\": project_id,\n", " \"sort\": {\"col\": \"startTime\", \"dir\": \"desc\"},\n", - " \"first\": 10,\n", + " \"first\": 5,\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", - "assert new_ids == [2, 3, 4, 5, 6, 7, 8, 9, 10, 11], new_ids" + "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "assert new_ids == [1, 2, 3, 4, 5], new_ids" ] }, { @@ -285,28 +267,23 @@ "metadata": {}, "outputs": [], "source": [ - "new_response = new_client.execute(\n", + "# order by cumulative prompt token count in descending order\n", + "new_response = client.execute(\n", " spans_query,\n", " variable_values={\n", - " \"projectId\": new_project_id,\n", - " \"after\": id_to_cursor(765),\n", - " \"sort\": {\"col\": \"startTime\", \"dir\": \"asc\"},\n", - " \"first\": 10,\n", + " \"projectId\": project_id,\n", + " \"sort\": {\"col\": \"cumulativeTokenCountPrompt\", \"dir\": \"desc\"},\n", + " \"first\": 5,\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_offset(cursor) for cursor in new_cursors]\n", + "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", "assert new_ids == [\n", - " 764,\n", - " 763,\n", - " 762,\n", - " 761,\n", - " 760,\n", - " 759,\n", - " 758,\n", - " 757,\n", - " 756,\n", - " 755,\n", + " 60,\n", + " 57,\n", + " 56,\n", + " 125,\n", + " 122,\n", "], new_ids" ] } diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 27cb7b95bd..339a355758 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -3,7 +3,7 @@ import strawberry from aioitertools.itertools import islice -from sqlalchemy import and_, distinct, select +from sqlalchemy import and_, desc, distinct, select from sqlalchemy.orm import contains_eager from strawberry import ID, UNSET from strawberry.types import Info @@ -190,7 +190,7 @@ async def spans( stmt = span_filter(stmt) if after: span_rowid = cursor_to_id(after) - stmt = stmt.where(models.Span.id > span_rowid) + stmt = stmt.where(models.Span.id < span_rowid) if first: stmt = stmt.limit( first + 1 # overfetch by one to determine whether there's a next page @@ -198,9 +198,7 @@ async def spans( if sort: stmt = sort.update_orm_expr(stmt) else: - stmt = stmt.order_by( - models.Span.id - ) # todo: i changed this to conform to the previous behavior of the api + stmt = stmt.order_by(desc(models.Span.id)) stmt = stmt.limit( SPANS_LIMIT ) # todo: remove this after adding pagination https://github.com/Arize-ai/phoenix/issues/3003 From 92ede3daad116d7d8e3b207f880019ced513e57c Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 11:48:46 -0700 Subject: [PATCH 23/74] pass order by cumulative prompt token count test case --- src/phoenix/server/api/types/Project.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 339a355758..9c4aed8084 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -197,8 +197,7 @@ async def spans( ) if sort: stmt = sort.update_orm_expr(stmt) - else: - stmt = stmt.order_by(desc(models.Span.id)) + stmt = stmt.order_by(desc(models.Span.id)) stmt = stmt.limit( SPANS_LIMIT ) # todo: remove this after adding pagination https://github.com/Arize-ai/phoenix/issues/3003 From 7a5942347be85906ce9237bd6ae569b22c3702d7 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 11:50:26 -0700 Subject: [PATCH 24/74] rename variable in test notebook --- .../pagination_query_testing.ipynb | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 4d00f3862f..a1e1bc4d38 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -60,11 +60,11 @@ "outputs": [], "source": [ "# basic query\n", - "new_response = client.execute(\n", + "response = client.execute(\n", " spans_query,\n", " variable_values={\"projectId\": project_id, \"first\": 5},\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", "assert new_ids == [765, 764, 763, 762, 761], new_ids" ] @@ -76,7 +76,7 @@ "outputs": [], "source": [ "# query with cursor\n", - "new_response = client.execute(\n", + "response = client.execute(\n", " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", @@ -84,7 +84,7 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", "assert new_ids == [760, 759, 758, 757, 756], new_ids" ] @@ -96,7 +96,7 @@ "outputs": [], "source": [ "# page ends on the penultimate record and excludees last record\n", - "new_response = client.execute(\n", + "response = client.execute(\n", " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", @@ -104,10 +104,10 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", - "has_next_page = new_response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", - "has_previous_page = new_response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", + "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", + "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", "assert new_ids == [6, 5, 4, 3, 2], new_ids\n", "assert has_next_page is True\n", "assert has_previous_page is False" @@ -120,7 +120,7 @@ "outputs": [], "source": [ "# page ends on the last record exactly\n", - "new_response = client.execute(\n", + "response = client.execute(\n", " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", @@ -128,10 +128,10 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", - "has_next_page = new_response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", - "has_previous_page = new_response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", + "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", + "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", "assert new_ids == [5, 4, 3, 2, 1], new_ids\n", "assert has_next_page is False\n", "assert has_previous_page is False" @@ -144,7 +144,7 @@ "outputs": [], "source": [ "# page ends before it reaches the limit\n", - "new_response = client.execute(\n", + "response = client.execute(\n", " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", @@ -152,10 +152,10 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", - "has_next_page = new_response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", - "has_previous_page = new_response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", + "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", + "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", "assert new_ids == [4, 3, 2, 1], new_ids\n", "assert has_next_page is False\n", "assert has_previous_page is False" @@ -168,7 +168,7 @@ "outputs": [], "source": [ "# basic filter condition\n", - "new_response = client.execute(\n", + "response = client.execute(\n", " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", @@ -176,7 +176,7 @@ " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", "assert new_ids == [\n", " 761,\n", @@ -194,7 +194,7 @@ "outputs": [], "source": [ "# basic filter condition with cursor\n", - "new_response = client.execute(\n", + "response = client.execute(\n", " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", @@ -203,7 +203,7 @@ " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", "assert new_ids == [\n", " 756,\n", @@ -221,7 +221,7 @@ "outputs": [], "source": [ "# compound filter condition with cursor\n", - "new_response = client.execute(\n", + "response = client.execute(\n", " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", @@ -230,7 +230,7 @@ " \"filterCondition\": \"span_kind == 'LLM' and cumulative_llm_token_count_prompt > 300\",\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", "assert new_ids == [\n", " 756,\n", @@ -248,7 +248,7 @@ "outputs": [], "source": [ "# order by start time\n", - "new_response = client.execute(\n", + "response = client.execute(\n", " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", @@ -256,7 +256,7 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", "assert new_ids == [1, 2, 3, 4, 5], new_ids" ] @@ -268,7 +268,7 @@ "outputs": [], "source": [ "# order by cumulative prompt token count in descending order\n", - "new_response = client.execute(\n", + "response = client.execute(\n", " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", @@ -276,7 +276,7 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in new_response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", "assert new_ids == [\n", " 60,\n", From 5b88cf87de6ebd502c7b609f0a21928a84431a80 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 14:09:02 -0700 Subject: [PATCH 25/74] add cursor serialization and deserialization for floating point sortable fields --- src/phoenix/server/api/types/pagination.py | 41 +++++++++++++++++++++- tests/server/api/types/test_pagination.py | 30 +++++++++++++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index 430831579d..05513946db 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -1,6 +1,7 @@ import base64 from dataclasses import dataclass -from typing import Generic, List, Optional, Tuple, TypeVar +from enum import Enum +from typing import ClassVar, Generic, List, Optional, Tuple, TypeVar import strawberry from strawberry import UNSET @@ -57,6 +58,44 @@ class Edge(Generic[GenericType]): CURSOR_PREFIX = "connection:" +class SortableFieldType(Enum): + float = "float" + + +@dataclass +class SortableField: + type: SortableFieldType + value: float + + +@dataclass +class TupleIdentifier: + rowid: int + sortable_field: Optional[SortableField] = None + + _DELIMITER: ClassVar[str] = ":" + + def to_cursor(self) -> Cursor: + cursor_components = [str(self.rowid)] + if (sortable_field := self.sortable_field) is not None: + cursor_components.extend([sortable_field.type.value, str(sortable_field.value)]) + return base64.b64encode(self._DELIMITER.join(cursor_components).encode()).decode() + + @classmethod + def from_cursor(cls, cursor: Cursor) -> "TupleIdentifier": + decoded = base64.b64decode(cursor).decode() + rowid_string = decoded + sortable_field = None + if (first_delimiter_index := decoded.find(cls._DELIMITER)) > -1: + rowid_string = decoded[:first_delimiter_index] + second_delimiter_index = decoded.index(cls._DELIMITER, first_delimiter_index + 1) + sortable_field = SortableField( + type=SortableFieldType(decoded[first_delimiter_index + 1 : second_delimiter_index]), + value=float(decoded[second_delimiter_index + 1 :]), + ) + return cls(rowid=int(rowid_string), sortable_field=sortable_field) + + def id_to_cursor(id: ID) -> Cursor: """ Creates a cursor string from an ID. diff --git a/tests/server/api/types/test_pagination.py b/tests/server/api/types/test_pagination.py index 61e050a03b..b7915f6ddd 100644 --- a/tests/server/api/types/test_pagination.py +++ b/tests/server/api/types/test_pagination.py @@ -1,7 +1,13 @@ import phoenix.core.model_schema as ms from phoenix.core.model_schema import FEATURE from phoenix.server.api.types.Dimension import Dimension -from phoenix.server.api.types.pagination import ConnectionArgs, connection_from_list +from phoenix.server.api.types.pagination import ( + ConnectionArgs, + SortableField, + SortableFieldType, + TupleIdentifier, + connection_from_list, +) def test_connection_from_list(): @@ -92,3 +98,25 @@ def test_connection_from_empty_list(): assert len(connection.edges) == 0 assert connection.page_info.has_next_page is False + + +class TestTupleIdentifier: + def test_to_and_from_cursor_with_rowid_deserializes_original(self) -> None: + original = TupleIdentifier(rowid=10) + cursor = original.to_cursor() + deserialized = TupleIdentifier.from_cursor(cursor) + assert deserialized.rowid == 10 + assert deserialized.sortable_field is None + + def test_to_and_from_cursor_with_rowid_and_float_sortable_field_deserializes_original( + self, + ) -> None: + original = TupleIdentifier( + rowid=10, sortable_field=SortableField(type=SortableFieldType.float, value=11.5) + ) + cursor = original.to_cursor() + deserialized = TupleIdentifier.from_cursor(cursor) + assert deserialized.rowid == 10 + assert (sortable_field := deserialized.sortable_field) is not None + assert sortable_field.type == SortableFieldType.float + assert abs(sortable_field.value - 11.5) < 1e-6 From 2126500bc31e465db192c472e004a041c9666b09 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 14:41:17 -0700 Subject: [PATCH 26/74] refactor to be open to datetime cursor serialization --- src/phoenix/server/api/types/pagination.py | 15 ++++++++++++--- tests/server/api/types/test_pagination.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index 05513946db..7ad357b7b2 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -67,6 +67,15 @@ class SortableField: type: SortableFieldType value: float + def stringify_value(self) -> str: + return str(self.value) + + @classmethod + def from_stringified_value( + cls, type: SortableFieldType, stringified_value: str + ) -> "SortableField": + return cls(type=type, value=float(stringified_value)) + @dataclass class TupleIdentifier: @@ -78,7 +87,7 @@ class TupleIdentifier: def to_cursor(self) -> Cursor: cursor_components = [str(self.rowid)] if (sortable_field := self.sortable_field) is not None: - cursor_components.extend([sortable_field.type.value, str(sortable_field.value)]) + cursor_components.extend([sortable_field.type.value, sortable_field.stringify_value()]) return base64.b64encode(self._DELIMITER.join(cursor_components).encode()).decode() @classmethod @@ -89,9 +98,9 @@ def from_cursor(cls, cursor: Cursor) -> "TupleIdentifier": if (first_delimiter_index := decoded.find(cls._DELIMITER)) > -1: rowid_string = decoded[:first_delimiter_index] second_delimiter_index = decoded.index(cls._DELIMITER, first_delimiter_index + 1) - sortable_field = SortableField( + sortable_field = SortableField.from_stringified_value( type=SortableFieldType(decoded[first_delimiter_index + 1 : second_delimiter_index]), - value=float(decoded[second_delimiter_index + 1 :]), + stringified_value=decoded[second_delimiter_index + 1 :], ) return cls(rowid=int(rowid_string), sortable_field=sortable_field) diff --git a/tests/server/api/types/test_pagination.py b/tests/server/api/types/test_pagination.py index b7915f6ddd..77fcb8f6b7 100644 --- a/tests/server/api/types/test_pagination.py +++ b/tests/server/api/types/test_pagination.py @@ -119,4 +119,4 @@ def test_to_and_from_cursor_with_rowid_and_float_sortable_field_deserializes_ori assert deserialized.rowid == 10 assert (sortable_field := deserialized.sortable_field) is not None assert sortable_field.type == SortableFieldType.float - assert abs(sortable_field.value - 11.5) < 1e-6 + assert abs(sortable_field.value - 11.5) < 1e-8 From 30a783b1b56794cc5dd28da4217c28d1633ddbd1 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 15:13:34 -0700 Subject: [PATCH 27/74] add serialization for datetimes to cursors --- src/phoenix/server/api/types/pagination.py | 30 +++++++++++----- tests/server/api/types/test_pagination.py | 40 ++++++++++++++++++++-- 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index 7ad357b7b2..ddf92f6978 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -1,7 +1,8 @@ import base64 from dataclasses import dataclass -from enum import Enum -from typing import ClassVar, Generic, List, Optional, Tuple, TypeVar +from datetime import datetime +from enum import Enum, auto +from typing import ClassVar, Generic, List, Optional, Tuple, TypeVar, Union, assert_never import strawberry from strawberry import UNSET @@ -9,6 +10,7 @@ ID: TypeAlias = int GenericType = TypeVar("GenericType") +SortableFieldValue: TypeAlias = Union[float, datetime] @strawberry.type @@ -59,22 +61,34 @@ class Edge(Generic[GenericType]): class SortableFieldType(Enum): - float = "float" + FLOAT = auto() + DATETIME = auto() @dataclass class SortableField: type: SortableFieldType - value: float + value: SortableFieldValue def stringify_value(self) -> str: - return str(self.value) + if isinstance(self.value, float): + return str(self.value) + if isinstance(self.value, datetime): + return self.value.isoformat() + assert_never(self.type) @classmethod def from_stringified_value( cls, type: SortableFieldType, stringified_value: str ) -> "SortableField": - return cls(type=type, value=float(stringified_value)) + value: SortableFieldValue + if type == SortableFieldType.FLOAT: + value = float(stringified_value) + elif type == SortableFieldType.DATETIME: + value = datetime.fromisoformat(stringified_value) + else: + assert_never(type) + return cls(type=type, value=value) @dataclass @@ -87,7 +101,7 @@ class TupleIdentifier: def to_cursor(self) -> Cursor: cursor_components = [str(self.rowid)] if (sortable_field := self.sortable_field) is not None: - cursor_components.extend([sortable_field.type.value, sortable_field.stringify_value()]) + cursor_components.extend([sortable_field.type.name, sortable_field.stringify_value()]) return base64.b64encode(self._DELIMITER.join(cursor_components).encode()).decode() @classmethod @@ -99,7 +113,7 @@ def from_cursor(cls, cursor: Cursor) -> "TupleIdentifier": rowid_string = decoded[:first_delimiter_index] second_delimiter_index = decoded.index(cls._DELIMITER, first_delimiter_index + 1) sortable_field = SortableField.from_stringified_value( - type=SortableFieldType(decoded[first_delimiter_index + 1 : second_delimiter_index]), + type=SortableFieldType[decoded[first_delimiter_index + 1 : second_delimiter_index]], stringified_value=decoded[second_delimiter_index + 1 :], ) return cls(rowid=int(rowid_string), sortable_field=sortable_field) diff --git a/tests/server/api/types/test_pagination.py b/tests/server/api/types/test_pagination.py index 77fcb8f6b7..c7fee32286 100644 --- a/tests/server/api/types/test_pagination.py +++ b/tests/server/api/types/test_pagination.py @@ -1,3 +1,5 @@ +from datetime import datetime + import phoenix.core.model_schema as ms from phoenix.core.model_schema import FEATURE from phoenix.server.api.types.Dimension import Dimension @@ -108,15 +110,47 @@ def test_to_and_from_cursor_with_rowid_deserializes_original(self) -> None: assert deserialized.rowid == 10 assert deserialized.sortable_field is None - def test_to_and_from_cursor_with_rowid_and_float_sortable_field_deserializes_original( + def test_to_and_from_cursor_with_rowid_and_float_deserializes_original( self, ) -> None: original = TupleIdentifier( - rowid=10, sortable_field=SortableField(type=SortableFieldType.float, value=11.5) + rowid=10, sortable_field=SortableField(type=SortableFieldType.FLOAT, value=11.5) ) cursor = original.to_cursor() deserialized = TupleIdentifier.from_cursor(cursor) assert deserialized.rowid == 10 assert (sortable_field := deserialized.sortable_field) is not None - assert sortable_field.type == SortableFieldType.float + assert sortable_field.type == SortableFieldType.FLOAT assert abs(sortable_field.value - 11.5) < 1e-8 + + def test_to_and_from_cursor_with_rowid_and_tz_naive_datetime_deserializes_original( + self, + ) -> None: + timestamp = datetime.fromisoformat("2021-01-01T00:00:00") + original = TupleIdentifier( + rowid=10, + sortable_field=SortableField(type=SortableFieldType.DATETIME, value=timestamp), + ) + cursor = original.to_cursor() + deserialized = TupleIdentifier.from_cursor(cursor) + assert deserialized.rowid == 10 + assert (sortable_field := deserialized.sortable_field) is not None + assert sortable_field.type == SortableFieldType.DATETIME + assert sortable_field.value == timestamp + assert sortable_field.value.tzinfo is None + + def test_to_and_from_cursor_with_rowid_and_tz_aware_datetime_deserializes_original( + self, + ) -> None: + timestamp = datetime.fromisoformat("2021-01-01T00:00:00+00:00") + original = TupleIdentifier( + rowid=10, + sortable_field=SortableField(type=SortableFieldType.DATETIME, value=timestamp), + ) + cursor = original.to_cursor() + deserialized = TupleIdentifier.from_cursor(cursor) + assert deserialized.rowid == 10 + assert (sortable_field := deserialized.sortable_field) is not None + assert sortable_field.type == SortableFieldType.DATETIME + assert sortable_field.value == timestamp + assert sortable_field.value.tzinfo is not None From c31532c2c61be9c2e64b78fcd99e2a6b653dec9d Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 15:44:09 -0700 Subject: [PATCH 28/74] refactor spans resolver to use new identifier abstraction --- .../pagination_query_testing.ipynb | 39 ++++++++++--------- src/phoenix/server/api/types/Project.py | 9 +++-- src/phoenix/server/api/types/pagination.py | 21 ++-------- 3 files changed, 31 insertions(+), 38 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index a1e1bc4d38..6407c74baf 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -9,8 +9,7 @@ "from gql import Client, gql\n", "from gql.transport.requests import RequestsHTTPTransport\n", "from phoenix.server.api.types.pagination import (\n", - " cursor_to_id,\n", - " id_to_cursor,\n", + " TupleIdentifier,\n", ")\n", "\n", "project_id = \"UHJvamVjdDox\"\n", @@ -65,7 +64,7 @@ " variable_values={\"projectId\": project_id, \"first\": 5},\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [765, 764, 763, 762, 761], new_ids" ] }, @@ -80,12 +79,12 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": id_to_cursor(761),\n", + " \"after\": TupleIdentifier(rowid=761).to_cursor(),\n", " \"first\": 5,\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [760, 759, 758, 757, 756], new_ids" ] }, @@ -100,12 +99,12 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": id_to_cursor(7),\n", + " \"after\": TupleIdentifier(7).to_cursor(),\n", " \"first\": 5,\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", "assert new_ids == [6, 5, 4, 3, 2], new_ids\n", @@ -124,12 +123,12 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": id_to_cursor(6),\n", + " \"after\": TupleIdentifier(6).to_cursor(),\n", " \"first\": 5,\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", "assert new_ids == [5, 4, 3, 2, 1], new_ids\n", @@ -148,12 +147,12 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": id_to_cursor(5),\n", + " \"after\": TupleIdentifier(5).to_cursor(),\n", " \"first\": 5,\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", "assert new_ids == [4, 3, 2, 1], new_ids\n", @@ -177,7 +176,7 @@ " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [\n", " 761,\n", " 756,\n", @@ -199,12 +198,14 @@ " variable_values={\n", " \"projectId\": project_id,\n", " \"first\": 5,\n", - " \"after\": id_to_cursor(761), # skip the first span satisfying the filter condition\n", + " \"after\": TupleIdentifier(\n", + " 761\n", + " ).to_cursor(), # skip the first span satisfying the filter condition\n", " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [\n", " 756,\n", " 751,\n", @@ -225,13 +226,15 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": id_to_cursor(761), # skip the first span satisfying the filter condition\n", + " \"after\": TupleIdentifier(\n", + " 761\n", + " ).to_cursor(), # skip the first span satisfying the filter condition\n", " \"first\": 5,\n", " \"filterCondition\": \"span_kind == 'LLM' and cumulative_llm_token_count_prompt > 300\",\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [\n", " 756,\n", " 751,\n", @@ -257,7 +260,7 @@ " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [1, 2, 3, 4, 5], new_ids" ] }, @@ -277,7 +280,7 @@ " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [cursor_to_id(cursor) for cursor in new_cursors]\n", + "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [\n", " 60,\n", " 57,\n", diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 9c4aed8084..bee55316f3 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -19,8 +19,8 @@ from phoenix.server.api.types.pagination import ( Connection, Cursor, + TupleIdentifier, connections, - cursor_to_id, ) from phoenix.server.api.types.Span import Span, to_gql_span from phoenix.server.api.types.Trace import Trace @@ -189,7 +189,7 @@ async def spans( span_filter = SpanFilter(condition=filter_condition) stmt = span_filter(stmt) if after: - span_rowid = cursor_to_id(after) + span_rowid = TupleIdentifier.from_cursor(after).rowid stmt = stmt.where(models.Span.id < span_rowid) if first: stmt = stmt.limit( @@ -203,7 +203,10 @@ async def spans( ) # todo: remove this after adding pagination https://github.com/Arize-ai/phoenix/issues/3003 async with info.context.db() as session: spans = await session.stream_scalars(stmt) - data = [(span.id, to_gql_span(span)) async for span in islice(spans, first)] + data = [ + (TupleIdentifier(rowid=span.id), to_gql_span(span)) + async for span in islice(spans, first) + ] has_next_page = True try: await spans.__anext__() diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index ddf92f6978..259a65d85f 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -119,21 +119,6 @@ def from_cursor(cls, cursor: Cursor) -> "TupleIdentifier": return cls(rowid=int(rowid_string), sortable_field=sortable_field) -def id_to_cursor(id: ID) -> Cursor: - """ - Creates a cursor string from an ID. - """ - return base64.b64encode(f"{CURSOR_PREFIX}{id}".encode("utf-8")).decode() - - -def cursor_to_id(cursor: Cursor) -> ID: - """ - Extracts the ID from the cursor string. - """ - _, id = base64.b64decode(cursor).decode().split(":") - return int(id) - - def offset_to_cursor(offset: int) -> Cursor: """ Creates the cursor string from an offset. @@ -252,11 +237,13 @@ def connection_from_list_slice( def connections( - data: List[Tuple[ID, GenericType]], + data: List[Tuple[TupleIdentifier, GenericType]], has_previous_page: bool, has_next_page: bool, ) -> Connection[GenericType]: - edges = [Edge(node=node, cursor=id_to_cursor(id)) for id, node in data] + edges = [ + Edge(node=node, cursor=tuple_identifier.to_cursor()) for tuple_identifier, node in data + ] has_edges = len(edges) > 0 first_edge = edges[0] if has_edges else None last_edge = edges[-1] if has_edges else None From a0cbe23622019f12e7a6bf450985d69d834aba96 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 15:52:54 -0700 Subject: [PATCH 29/74] refactor TupleIdentifier to NodeIdentifier --- .../pagination_query_testing.ipynb | 34 +++++++++---------- src/phoenix/server/api/types/Project.py | 6 ++-- src/phoenix/server/api/types/pagination.py | 6 ++-- tests/server/api/types/test_pagination.py | 20 +++++------ 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 6407c74baf..f502ccf4a4 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -9,7 +9,7 @@ "from gql import Client, gql\n", "from gql.transport.requests import RequestsHTTPTransport\n", "from phoenix.server.api.types.pagination import (\n", - " TupleIdentifier,\n", + " NodeIdentifier,\n", ")\n", "\n", "project_id = \"UHJvamVjdDox\"\n", @@ -64,7 +64,7 @@ " variable_values={\"projectId\": project_id, \"first\": 5},\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [765, 764, 763, 762, 761], new_ids" ] }, @@ -79,12 +79,12 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": TupleIdentifier(rowid=761).to_cursor(),\n", + " \"after\": NodeIdentifier(rowid=761).to_cursor(),\n", " \"first\": 5,\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [760, 759, 758, 757, 756], new_ids" ] }, @@ -99,12 +99,12 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": TupleIdentifier(7).to_cursor(),\n", + " \"after\": NodeIdentifier(7).to_cursor(),\n", " \"first\": 5,\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", "assert new_ids == [6, 5, 4, 3, 2], new_ids\n", @@ -123,12 +123,12 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": TupleIdentifier(6).to_cursor(),\n", + " \"after\": NodeIdentifier(6).to_cursor(),\n", " \"first\": 5,\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", "assert new_ids == [5, 4, 3, 2, 1], new_ids\n", @@ -147,12 +147,12 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": TupleIdentifier(5).to_cursor(),\n", + " \"after\": NodeIdentifier(5).to_cursor(),\n", " \"first\": 5,\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", "assert new_ids == [4, 3, 2, 1], new_ids\n", @@ -176,7 +176,7 @@ " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [\n", " 761,\n", " 756,\n", @@ -198,14 +198,14 @@ " variable_values={\n", " \"projectId\": project_id,\n", " \"first\": 5,\n", - " \"after\": TupleIdentifier(\n", + " \"after\": NodeIdentifier(\n", " 761\n", " ).to_cursor(), # skip the first span satisfying the filter condition\n", " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [\n", " 756,\n", " 751,\n", @@ -226,7 +226,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": TupleIdentifier(\n", + " \"after\": NodeIdentifier(\n", " 761\n", " ).to_cursor(), # skip the first span satisfying the filter condition\n", " \"first\": 5,\n", @@ -234,7 +234,7 @@ " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [\n", " 756,\n", " 751,\n", @@ -260,7 +260,7 @@ " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [1, 2, 3, 4, 5], new_ids" ] }, @@ -280,7 +280,7 @@ " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [TupleIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [\n", " 60,\n", " 57,\n", diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index bee55316f3..9423d3ea82 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -19,7 +19,7 @@ from phoenix.server.api.types.pagination import ( Connection, Cursor, - TupleIdentifier, + NodeIdentifier, connections, ) from phoenix.server.api.types.Span import Span, to_gql_span @@ -189,7 +189,7 @@ async def spans( span_filter = SpanFilter(condition=filter_condition) stmt = span_filter(stmt) if after: - span_rowid = TupleIdentifier.from_cursor(after).rowid + span_rowid = NodeIdentifier.from_cursor(after).rowid stmt = stmt.where(models.Span.id < span_rowid) if first: stmt = stmt.limit( @@ -204,7 +204,7 @@ async def spans( async with info.context.db() as session: spans = await session.stream_scalars(stmt) data = [ - (TupleIdentifier(rowid=span.id), to_gql_span(span)) + (NodeIdentifier(rowid=span.id), to_gql_span(span)) async for span in islice(spans, first) ] has_next_page = True diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index 259a65d85f..04c006ea41 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -92,7 +92,7 @@ def from_stringified_value( @dataclass -class TupleIdentifier: +class NodeIdentifier: rowid: int sortable_field: Optional[SortableField] = None @@ -105,7 +105,7 @@ def to_cursor(self) -> Cursor: return base64.b64encode(self._DELIMITER.join(cursor_components).encode()).decode() @classmethod - def from_cursor(cls, cursor: Cursor) -> "TupleIdentifier": + def from_cursor(cls, cursor: Cursor) -> "NodeIdentifier": decoded = base64.b64decode(cursor).decode() rowid_string = decoded sortable_field = None @@ -237,7 +237,7 @@ def connection_from_list_slice( def connections( - data: List[Tuple[TupleIdentifier, GenericType]], + data: List[Tuple[NodeIdentifier, GenericType]], has_previous_page: bool, has_next_page: bool, ) -> Connection[GenericType]: diff --git a/tests/server/api/types/test_pagination.py b/tests/server/api/types/test_pagination.py index c7fee32286..0c9dd065a9 100644 --- a/tests/server/api/types/test_pagination.py +++ b/tests/server/api/types/test_pagination.py @@ -5,9 +5,9 @@ from phoenix.server.api.types.Dimension import Dimension from phoenix.server.api.types.pagination import ( ConnectionArgs, + NodeIdentifier, SortableField, SortableFieldType, - TupleIdentifier, connection_from_list, ) @@ -102,22 +102,22 @@ def test_connection_from_empty_list(): assert connection.page_info.has_next_page is False -class TestTupleIdentifier: +class TestNodeIdentifier: def test_to_and_from_cursor_with_rowid_deserializes_original(self) -> None: - original = TupleIdentifier(rowid=10) + original = NodeIdentifier(rowid=10) cursor = original.to_cursor() - deserialized = TupleIdentifier.from_cursor(cursor) + deserialized = NodeIdentifier.from_cursor(cursor) assert deserialized.rowid == 10 assert deserialized.sortable_field is None def test_to_and_from_cursor_with_rowid_and_float_deserializes_original( self, ) -> None: - original = TupleIdentifier( + original = NodeIdentifier( rowid=10, sortable_field=SortableField(type=SortableFieldType.FLOAT, value=11.5) ) cursor = original.to_cursor() - deserialized = TupleIdentifier.from_cursor(cursor) + deserialized = NodeIdentifier.from_cursor(cursor) assert deserialized.rowid == 10 assert (sortable_field := deserialized.sortable_field) is not None assert sortable_field.type == SortableFieldType.FLOAT @@ -127,12 +127,12 @@ def test_to_and_from_cursor_with_rowid_and_tz_naive_datetime_deserializes_origin self, ) -> None: timestamp = datetime.fromisoformat("2021-01-01T00:00:00") - original = TupleIdentifier( + original = NodeIdentifier( rowid=10, sortable_field=SortableField(type=SortableFieldType.DATETIME, value=timestamp), ) cursor = original.to_cursor() - deserialized = TupleIdentifier.from_cursor(cursor) + deserialized = NodeIdentifier.from_cursor(cursor) assert deserialized.rowid == 10 assert (sortable_field := deserialized.sortable_field) is not None assert sortable_field.type == SortableFieldType.DATETIME @@ -143,12 +143,12 @@ def test_to_and_from_cursor_with_rowid_and_tz_aware_datetime_deserializes_origin self, ) -> None: timestamp = datetime.fromisoformat("2021-01-01T00:00:00+00:00") - original = TupleIdentifier( + original = NodeIdentifier( rowid=10, sortable_field=SortableField(type=SortableFieldType.DATETIME, value=timestamp), ) cursor = original.to_cursor() - deserialized = TupleIdentifier.from_cursor(cursor) + deserialized = NodeIdentifier.from_cursor(cursor) assert deserialized.rowid == 10 assert (sortable_field := deserialized.sortable_field) is not None assert sortable_field.type == SortableFieldType.DATETIME From 837f30a74ee8cc134e8a1a999bfe420a2a303e8f Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 16:27:19 -0700 Subject: [PATCH 30/74] fix type error --- src/phoenix/server/api/types/pagination.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index 04c006ea41..b2a9c690e6 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -2,11 +2,11 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum, auto -from typing import ClassVar, Generic, List, Optional, Tuple, TypeVar, Union, assert_never +from typing import ClassVar, Generic, List, Optional, Tuple, TypeVar, Union import strawberry from strawberry import UNSET -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, assert_never ID: TypeAlias = int GenericType = TypeVar("GenericType") From 470c0499dd5d770440cf6dfa09dbb33674bef66a Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 16:43:35 -0700 Subject: [PATCH 31/74] add failing test case for order by datetime with cursor --- .../pagination_query_testing.ipynb | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index f502ccf4a4..c8d3653f78 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -10,6 +10,8 @@ "from gql.transport.requests import RequestsHTTPTransport\n", "from phoenix.server.api.types.pagination import (\n", " NodeIdentifier,\n", + " SortableField,\n", + " SortableFieldType,\n", ")\n", "\n", "project_id = \"UHJvamVjdDox\"\n", @@ -289,6 +291,33 @@ " 122,\n", "], new_ids" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# order by start time with cursor\n", + "response = client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": project_id,\n", + " \"sort\": {\"col\": \"startTime\", \"dir\": \"desc\"},\n", + " \"first\": 5,\n", + " \"after\": NodeIdentifier(\n", + " 5,\n", + " sortable_field=SortableField.from_stringified_value(\n", + " type=SortableFieldType.DATETIME,\n", + " stringified_value=\"2024-05-03 00:37:38.223570\",\n", + " ),\n", + " ).to_cursor(),\n", + " },\n", + ")\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "assert new_ids == [6, 7, 8, 9, 10], new_ids" + ] } ], "metadata": { From 789bcd9e377f42e06207aacabdec5d96918d4907 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 19:01:24 -0700 Subject: [PATCH 32/74] refactor troublesome test case --- tests/trace/dsl/test_filter.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/tests/trace/dsl/test_filter.py b/tests/trace/dsl/test_filter.py index e1ffbb2324..e2f5052878 100644 --- a/tests/trace/dsl/test_filter.py +++ b/tests/trace/dsl/test_filter.py @@ -1,6 +1,6 @@ import ast import sys -from typing import Any, List, Optional +from typing import List, Optional from unittest.mock import patch import phoenix.trace.dsl.filter @@ -97,7 +97,7 @@ def test_get_attribute_keys_list(expression: str, expected: Optional[List[str]]) "llm.token_count.total - llm.token_count.prompt > 1000", "attributes[['llm', 'token_count', 'total']].as_float() - attributes[['llm', 'token_count', 'prompt']].as_float() > 1000" # noqa E501 if sys.version_info >= (3, 9) - else "(attributes[['llm', 'token_count', 'total']].as_float() - attributes[['llm', 'token_count', 'prompt']].as_float()) > 1000", # noqa E501 + else "((attributes[['llm', 'token_count', 'total']].as_float() - attributes[['llm', 'token_count', 'prompt']].as_float()) > 1000)", # noqa E501 ), ( "first.value in (1,) and second.value in ('2',) and '3' in third.value", @@ -113,19 +113,19 @@ def test_get_attribute_keys_list(expression: str, expected: Optional[List[str]]) "first.value + 1 < second.value", "attributes[['first', 'value']].as_float() + 1 < attributes[['second', 'value']].as_float()" # noqa E501 if sys.version_info >= (3, 9) - else "(attributes[['first', 'value']].as_float() + 1) < attributes[['second', 'value']].as_float()", # noqa E501 + else "((attributes[['first', 'value']].as_float() + 1) < attributes[['second', 'value']].as_float())", # noqa E501 ), ( "first.value * second.value > third.value", "attributes[['first', 'value']].as_float() * attributes[['second', 'value']].as_float() > attributes[['third', 'value']].as_float()" # noqa E501 if sys.version_info >= (3, 9) - else "(attributes[['first', 'value']].as_float() * attributes[['second', 'value']].as_float()) > attributes[['third', 'value']].as_float()", # noqa E501 + else "((attributes[['first', 'value']].as_float() * attributes[['second', 'value']].as_float()) > attributes[['third', 'value']].as_float())", # noqa E501 ), ( "first.value + second.value > third.value", "cast(attributes[['first', 'value']].as_string() + attributes[['second', 'value']].as_string(), String) > attributes[['third', 'value']].as_string()" # noqa E501 if sys.version_info >= (3, 9) - else "cast((attributes[['first', 'value']].as_string() + attributes[['second', 'value']].as_string()), String) > attributes[['third', 'value']].as_string()", # noqa E501 + else "(cast((attributes[['first', 'value']].as_string() + attributes[['second', 'value']].as_string()), String) > attributes[['third', 'value']].as_string())", # noqa E501 ), ( "my.value == '1.0' or float(my.value) < 2.0", @@ -156,7 +156,7 @@ async def test_filter_translated( return_value=0, ): f = SpanFilter(expression) - assert _unparse(f.translated) == expected + assert unparse(f.translated).strip() == expected # next line is only to test that the syntax is accepted await session.execute(f(select(models.Span.id))) @@ -230,12 +230,3 @@ def test_apply_eval_aliasing(filter_condition: str, expected: str) -> None: with patch.object(phoenix.trace.dsl.filter, "randint", return_value=0): aliased, _ = _apply_eval_aliasing(filter_condition) assert aliased == expected - - -def _unparse(exp: Any) -> str: - # `unparse` for python 3.8 outputs differently, - # otherwise this function is unnecessary. - s = unparse(exp).strip() - if s[0] == "(" and s[-1] == ")": - return s[1:-1] - return s From 329c2ad8d55445c36a82fd573b9a913747737832 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 21:55:13 -0700 Subject: [PATCH 33/74] pass test with order by start time with cursor --- integration-tests/pagination_query_testing.ipynb | 2 +- src/phoenix/server/api/types/Project.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index c8d3653f78..4993e8466b 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -309,7 +309,7 @@ " 5,\n", " sortable_field=SortableField.from_stringified_value(\n", " type=SortableFieldType.DATETIME,\n", - " stringified_value=\"2024-05-03 00:37:38.223570\",\n", + " stringified_value=\"2024-05-05T04:25:29.911245+00:00\",\n", " ),\n", " ).to_cursor(),\n", " },\n", diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 9423d3ea82..e8052036ab 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -5,6 +5,7 @@ from aioitertools.itertools import islice from sqlalchemy import and_, desc, distinct, select from sqlalchemy.orm import contains_eager +from sqlalchemy.sql.expression import tuple_ from strawberry import ID, UNSET from strawberry.types import Info @@ -189,8 +190,14 @@ async def spans( span_filter = SpanFilter(condition=filter_condition) stmt = span_filter(stmt) if after: - span_rowid = NodeIdentifier.from_cursor(after).rowid - stmt = stmt.where(models.Span.id < span_rowid) + node_identifier = NodeIdentifier.from_cursor(after) + if (sortable_field := node_identifier.sortable_field) is not None: + stmt = stmt.where( + tuple_(models.Span.start_time, models.Span.id) + < (sortable_field.value, node_identifier.rowid) + ) + else: + stmt = stmt.where(models.Span.id < node_identifier.rowid) if first: stmt = stmt.limit( first + 1 # overfetch by one to determine whether there's a next page @@ -198,9 +205,6 @@ async def spans( if sort: stmt = sort.update_orm_expr(stmt) stmt = stmt.order_by(desc(models.Span.id)) - stmt = stmt.limit( - SPANS_LIMIT - ) # todo: remove this after adding pagination https://github.com/Arize-ai/phoenix/issues/3003 async with info.context.db() as session: spans = await session.stream_scalars(stmt) data = [ From 3c22691e82d634cb6cb81c7504bd083390c63c6b Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 22:06:09 -0700 Subject: [PATCH 34/74] update node identifier tests with a more granular timestamp --- tests/server/api/types/test_pagination.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/server/api/types/test_pagination.py b/tests/server/api/types/test_pagination.py index 0c9dd065a9..8f7c2b5fbb 100644 --- a/tests/server/api/types/test_pagination.py +++ b/tests/server/api/types/test_pagination.py @@ -126,7 +126,7 @@ def test_to_and_from_cursor_with_rowid_and_float_deserializes_original( def test_to_and_from_cursor_with_rowid_and_tz_naive_datetime_deserializes_original( self, ) -> None: - timestamp = datetime.fromisoformat("2021-01-01T00:00:00") + timestamp = datetime.fromisoformat("2024-05-05T04:25:29.911245") original = NodeIdentifier( rowid=10, sortable_field=SortableField(type=SortableFieldType.DATETIME, value=timestamp), @@ -142,7 +142,7 @@ def test_to_and_from_cursor_with_rowid_and_tz_naive_datetime_deserializes_origin def test_to_and_from_cursor_with_rowid_and_tz_aware_datetime_deserializes_original( self, ) -> None: - timestamp = datetime.fromisoformat("2021-01-01T00:00:00+00:00") + timestamp = datetime.fromisoformat("2024-05-05T04:25:29.911245+00:00") original = NodeIdentifier( rowid=10, sortable_field=SortableField(type=SortableFieldType.DATETIME, value=timestamp), From e404fe42999bd09d9676667dae62113bb2ae5fab Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 22:33:42 -0700 Subject: [PATCH 35/74] pass test case for order by ascending timestamp with cursor --- .../pagination_query_testing.ipynb | 29 ++++++++++++++++++- src/phoenix/server/api/types/Project.py | 10 +++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 4993e8466b..c9597baf39 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -298,7 +298,7 @@ "metadata": {}, "outputs": [], "source": [ - "# order by start time with cursor\n", + "# order by descending start time with cursor\n", "response = client.execute(\n", " spans_query,\n", " variable_values={\n", @@ -318,6 +318,33 @@ "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [6, 7, 8, 9, 10], new_ids" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# order by ascending start time with cursor\n", + "response = client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": project_id,\n", + " \"sort\": {\"col\": \"startTime\", \"dir\": \"asc\"},\n", + " \"first\": 5,\n", + " \"after\": NodeIdentifier(\n", + " 10,\n", + " sortable_field=SortableField.from_stringified_value(\n", + " type=SortableFieldType.DATETIME,\n", + " stringified_value=\"2024-05-05T04:25:29.053197+00:00\",\n", + " ),\n", + " ).to_cursor(),\n", + " },\n", + ")\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "assert new_ids == [9, 8, 7, 6, 5], new_ids" + ] } ], "metadata": { diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index e8052036ab..ac031fb15e 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -1,3 +1,4 @@ +import operator from datetime import datetime from typing import List, Optional @@ -23,6 +24,7 @@ NodeIdentifier, connections, ) +from phoenix.server.api.types.SortDir import SortDir from phoenix.server.api.types.Span import Span, to_gql_span from phoenix.server.api.types.Trace import Trace from phoenix.server.api.types.ValidationResult import ValidationResult @@ -192,9 +194,13 @@ async def spans( if after: node_identifier = NodeIdentifier.from_cursor(after) if (sortable_field := node_identifier.sortable_field) is not None: + assert sort is not None # todo: refactor this into a validation check + compare = operator.lt if sort.dir is SortDir.desc else operator.gt stmt = stmt.where( - tuple_(models.Span.start_time, models.Span.id) - < (sortable_field.value, node_identifier.rowid) + compare( + tuple_(models.Span.start_time, models.Span.id), + (sortable_field.value, node_identifier.rowid), + ) ) else: stmt = stmt.where(models.Span.id < node_identifier.rowid) From 0cd5ee4a5803375b7bf0c51403eaa7aac1b5f7df Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 22:48:14 -0700 Subject: [PATCH 36/74] add test case to sort by ascending start time without cursor --- .../pagination_query_testing.ipynb | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index c9597baf39..8c893b3613 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -252,7 +252,7 @@ "metadata": {}, "outputs": [], "source": [ - "# order by start time\n", + "# order by descending start time\n", "response = client.execute(\n", " spans_query,\n", " variable_values={\n", @@ -266,6 +266,26 @@ "assert new_ids == [1, 2, 3, 4, 5], new_ids" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# order by ascending start time\n", + "response = client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": project_id,\n", + " \"sort\": {\"col\": \"startTime\", \"dir\": \"asc\"},\n", + " \"first\": 5,\n", + " },\n", + ")\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "assert new_ids == [765, 764, 763, 762, 761], new_ids" + ] + }, { "cell_type": "code", "execution_count": null, From 1eba507e687fb5a8d713adbe55a082bb4451a257 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 4 May 2024 23:43:01 -0700 Subject: [PATCH 37/74] pass test for pageInfo cursor from order by start time with cursor --- .../pagination_query_testing.ipynb | 40 +++++++++++++++++-- src/phoenix/server/api/types/Project.py | 23 ++++++++--- 2 files changed, 55 insertions(+), 8 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 8c893b3613..ccac750045 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -16,7 +16,7 @@ "\n", "project_id = \"UHJvamVjdDox\"\n", "client = Client(\n", - " transport=RequestsHTTPTransport(url=\"http://127.0.0.1:6006/graphql\", timeout=1),\n", + " transport=RequestsHTTPTransport(url=\"http://127.0.0.1:6006/graphql\", timeout=100),\n", " fetch_schema_from_transport=True,\n", ")" ] @@ -46,6 +46,8 @@ " pageInfo {\n", " hasNextPage\n", " hasPreviousPage\n", + " startCursor\n", + " endCursor\n", " }\n", " }\n", " }\n", @@ -336,7 +338,23 @@ ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [6, 7, 8, 9, 10], new_ids" + "startCursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"startCursor\"]\n", + "endCursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", + "start_node_identifier = NodeIdentifier.from_cursor(startCursor)\n", + "end_node_identifier = NodeIdentifier.from_cursor(endCursor)\n", + "assert new_ids == [6, 7, 8, 9, 10], new_ids\n", + "assert (start_sortable_field := start_node_identifier.sortable_field) is not None\n", + "assert (\n", + " start_timestamp := start_sortable_field.stringify_value()\n", + ") == \"2024-05-05T04:25:29.258516+00:00\", start_timestamp\n", + "assert (\n", + " start_field_type := start_sortable_field.type\n", + ") == SortableFieldType.DATETIME, start_field_type\n", + "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", + "assert (\n", + " end_timestamp := end_sortable_field.stringify_value()\n", + ") == \"2024-05-05T04:25:29.053197+00:00\", end_timestamp\n", + "assert (end_field_type := end_sortable_field.type) == SortableFieldType.DATETIME, end_field_type" ] }, { @@ -363,7 +381,23 @@ ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [9, 8, 7, 6, 5], new_ids" + "startCursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"startCursor\"]\n", + "endCursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", + "start_node_identifier = NodeIdentifier.from_cursor(startCursor)\n", + "end_node_identifier = NodeIdentifier.from_cursor(endCursor)\n", + "assert new_ids == [9, 8, 7, 6, 5], new_ids\n", + "assert (start_sortable_field := start_node_identifier.sortable_field) is not None\n", + "assert (\n", + " start_timestamp := start_sortable_field.stringify_value()\n", + ") == \"2024-05-05T04:25:29.053273+00:00\", start_timestamp\n", + "assert (\n", + " start_field_type := start_sortable_field.type\n", + ") == SortableFieldType.DATETIME, start_field_type\n", + "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", + "assert (\n", + " end_timestamp := end_sortable_field.stringify_value()\n", + ") == \"2024-05-05T04:25:29.911245+00:00\", end_timestamp\n", + "assert (end_field_type := end_sortable_field.type) == SortableFieldType.DATETIME, end_field_type" ] } ], diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index ac031fb15e..df5cf0e517 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -22,6 +22,8 @@ Connection, Cursor, NodeIdentifier, + SortableField, + SortableFieldType, connections, ) from phoenix.server.api.types.SortDir import SortDir @@ -191,9 +193,11 @@ async def spans( if filter_condition: span_filter = SpanFilter(condition=filter_condition) stmt = span_filter(stmt) + sortable_field: Optional[SortableField] = None if after: node_identifier = NodeIdentifier.from_cursor(after) - if (sortable_field := node_identifier.sortable_field) is not None: + if node_identifier.sortable_field is not None: + sortable_field = node_identifier.sortable_field assert sort is not None # todo: refactor this into a validation check compare = operator.lt if sort.dir is SortDir.desc else operator.gt stmt = stmt.where( @@ -211,12 +215,21 @@ async def spans( if sort: stmt = sort.update_orm_expr(stmt) stmt = stmt.order_by(desc(models.Span.id)) + data = [] async with info.context.db() as session: spans = await session.stream_scalars(stmt) - data = [ - (NodeIdentifier(rowid=span.id), to_gql_span(span)) - async for span in islice(spans, first) - ] + async for span in islice(spans, first): + sf = ( + SortableField(type=SortableFieldType.DATETIME, value=span.start_time) + if sortable_field is not None + else None + ) + node_identifier = NodeIdentifier( + rowid=span.id, + sortable_field=sf, + ) + data.append((node_identifier, to_gql_span(span))) + # todo: does this need to be inside the async with block? has_next_page = True try: await spans.__anext__() From 389bc5b722e47ac270c987526938522f4b52ed9d Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 12:17:15 -0700 Subject: [PATCH 38/74] ensure fixture spans are inserted in chronological order and don't reset timestamps and span ids. update tests accordingly --- .../pagination_query_testing.ipynb | 152 +++++++++++++----- src/phoenix/server/main.py | 23 ++- 2 files changed, 121 insertions(+), 54 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index ccac750045..a49553d811 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -21,6 +21,54 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# test query for doing sanity checks\n", + "response = client.execute(\n", + " gql(\n", + " \"\"\"query SpansTableSpansQuery($after: String = null, $filterCondition: String = null, $first: Int = 100, $sort: SpanSort = {col: startTime, dir: desc}, $timeRange: TimeRange, $id: GlobalID!) {\n", + " node(id: $id) {\n", + " ... on Project {\n", + " spans(\n", + " first: $first\n", + " after: $after\n", + " sort: $sort\n", + " filterCondition: $filterCondition\n", + " timeRange: $timeRange\n", + " ) {\n", + " edges {\n", + " cursor\n", + " }\n", + " pageInfo {\n", + " endCursor\n", + " hasNextPage\n", + " }\n", + " }\n", + " }\n", + " }\n", + "}\"\"\"\n", + " ),\n", + " variable_values={\n", + " \"after\": None,\n", + " \"filterCondition\": \"\",\n", + " \"first\": 100,\n", + " \"sort\": {\"col\": \"startTime\", \"dir\": \"desc\"},\n", + " \"timeRange\": {\n", + " \"start\": \"2024-04-28T17:00:00.000Z\",\n", + " \"end\": \"2025-05-05T17:00:00.000Z\",\n", + " },\n", + " \"id\": \"UHJvamVjdDox\",\n", + " },\n", + ")\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "new_ids" + ] + }, { "cell_type": "code", "execution_count": null, @@ -182,11 +230,11 @@ "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [\n", - " 761,\n", - " 756,\n", - " 751,\n", - " 746,\n", - " 741,\n", + " 765,\n", + " 760,\n", + " 755,\n", + " 750,\n", + " 745,\n", "], new_ids" ] }, @@ -203,7 +251,7 @@ " \"projectId\": project_id,\n", " \"first\": 5,\n", " \"after\": NodeIdentifier(\n", - " 761\n", + " 765\n", " ).to_cursor(), # skip the first span satisfying the filter condition\n", " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", @@ -211,11 +259,11 @@ "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [\n", - " 756,\n", - " 751,\n", - " 746,\n", - " 741,\n", - " 736,\n", + " 760,\n", + " 755,\n", + " 750,\n", + " 745,\n", + " 740,\n", "], new_ids" ] }, @@ -231,7 +279,7 @@ " variable_values={\n", " \"projectId\": project_id,\n", " \"after\": NodeIdentifier(\n", - " 761\n", + " 745\n", " ).to_cursor(), # skip the first span satisfying the filter condition\n", " \"first\": 5,\n", " \"filterCondition\": \"span_kind == 'LLM' and cumulative_llm_token_count_prompt > 300\",\n", @@ -240,11 +288,11 @@ "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [\n", - " 756,\n", - " 751,\n", - " 736,\n", - " 731,\n", - " 721,\n", + " 740,\n", + " 730,\n", + " 720,\n", + " 710,\n", + " 690,\n", "], new_ids" ] }, @@ -265,7 +313,7 @@ ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [1, 2, 3, 4, 5], new_ids" + "assert new_ids == [765, 764, 763, 762, 761], new_ids" ] }, { @@ -285,7 +333,7 @@ ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [765, 764, 763, 762, 761], new_ids" + "assert new_ids == [1, 2, 3, 4, 5], new_ids" ] }, { @@ -306,14 +354,34 @@ "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", "assert new_ids == [\n", - " 60,\n", - " 57,\n", - " 56,\n", - " 125,\n", - " 122,\n", + " 710,\n", + " 709,\n", + " 706,\n", + " 645,\n", + " 644,\n", "], new_ids" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# order by cumulative prompt token count in ascending order\n", + "response = client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": project_id,\n", + " \"sort\": {\"col\": \"cumulativeTokenCountPrompt\", \"dir\": \"asc\"},\n", + " \"first\": 5,\n", + " },\n", + ")\n", + "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "assert new_ids == [763, 762, 758, 757, 753], new_ids" + ] + }, { "cell_type": "code", "execution_count": null, @@ -328,32 +396,32 @@ " \"sort\": {\"col\": \"startTime\", \"dir\": \"desc\"},\n", " \"first\": 5,\n", " \"after\": NodeIdentifier(\n", - " 5,\n", + " 760,\n", " sortable_field=SortableField.from_stringified_value(\n", " type=SortableFieldType.DATETIME,\n", - " stringified_value=\"2024-05-05T04:25:29.911245+00:00\",\n", + " stringified_value=\"2023-12-11T17:48:40.154938+00:00\",\n", " ),\n", " ).to_cursor(),\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "startCursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"startCursor\"]\n", - "endCursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", - "start_node_identifier = NodeIdentifier.from_cursor(startCursor)\n", - "end_node_identifier = NodeIdentifier.from_cursor(endCursor)\n", - "assert new_ids == [6, 7, 8, 9, 10], new_ids\n", + "start_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"startCursor\"]\n", + "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", + "start_node_identifier = NodeIdentifier.from_cursor(start_cursor)\n", + "end_node_identifier = NodeIdentifier.from_cursor(end_cursor)\n", + "assert new_ids == [759, 758, 757, 756, 755], new_ids\n", "assert (start_sortable_field := start_node_identifier.sortable_field) is not None\n", "assert (\n", " start_timestamp := start_sortable_field.stringify_value()\n", - ") == \"2024-05-05T04:25:29.258516+00:00\", start_timestamp\n", + ") == \"2023-12-11T17:48:40.154139+00:00\", start_timestamp\n", "assert (\n", " start_field_type := start_sortable_field.type\n", ") == SortableFieldType.DATETIME, start_field_type\n", "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", "assert (\n", " end_timestamp := end_sortable_field.stringify_value()\n", - ") == \"2024-05-05T04:25:29.053197+00:00\", end_timestamp\n", + ") == \"2023-12-11T17:48:38.803725+00:00\", end_timestamp\n", "assert (end_field_type := end_sortable_field.type) == SortableFieldType.DATETIME, end_field_type" ] }, @@ -371,32 +439,32 @@ " \"sort\": {\"col\": \"startTime\", \"dir\": \"asc\"},\n", " \"first\": 5,\n", " \"after\": NodeIdentifier(\n", - " 10,\n", + " 8,\n", " sortable_field=SortableField.from_stringified_value(\n", " type=SortableFieldType.DATETIME,\n", - " stringified_value=\"2024-05-05T04:25:29.053197+00:00\",\n", + " stringified_value=\"2023-12-11T17:43:25.540677+00:00\",\n", " ),\n", " ).to_cursor(),\n", " },\n", ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "startCursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"startCursor\"]\n", - "endCursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", - "start_node_identifier = NodeIdentifier.from_cursor(startCursor)\n", - "end_node_identifier = NodeIdentifier.from_cursor(endCursor)\n", - "assert new_ids == [9, 8, 7, 6, 5], new_ids\n", + "start_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"startCursor\"]\n", + "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", + "start_node_identifier = NodeIdentifier.from_cursor(start_cursor)\n", + "end_node_identifier = NodeIdentifier.from_cursor(end_cursor)\n", + "assert new_ids == [9, 10, 11, 12, 13], new_ids\n", "assert (start_sortable_field := start_node_identifier.sortable_field) is not None\n", "assert (\n", " start_timestamp := start_sortable_field.stringify_value()\n", - ") == \"2024-05-05T04:25:29.053273+00:00\", start_timestamp\n", + ") == \"2023-12-11T17:43:25.842986+00:00\", start_timestamp\n", "assert (\n", " start_field_type := start_sortable_field.type\n", ") == SortableFieldType.DATETIME, start_field_type\n", "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", "assert (\n", " end_timestamp := end_sortable_field.stringify_value()\n", - ") == \"2024-05-05T04:25:29.911245+00:00\", end_timestamp\n", + ") == \"2023-12-11T17:43:26.496177+00:00\", end_timestamp\n", "assert (end_field_type := end_sortable_field.type) == SortableFieldType.DATETIME, end_field_type" ] } diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 0238fec057..0acd1b44aa 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -39,7 +39,6 @@ download_traces_fixture, get_evals_from_fixture, get_trace_fixture_by_name, - reset_fixture_span_ids_and_timestamps, ) from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp from phoenix.trace.schemas import Span @@ -206,17 +205,17 @@ def _get_pid_file() -> Path: fixture_spans: List[Span] = [] fixture_evals: List[pb.Evaluation] = [] if trace_dataset_name is not None: - fixture_spans, fixture_evals = reset_fixture_span_ids_and_timestamps( - ( - # Apply `encode` here because legacy jsonl files contains UUIDs as strings. - # `encode` removes the hyphens in the UUIDs. - decode_otlp_span(encode_span_to_otlp(json_string_to_span(json_span))) - for json_span in download_traces_fixture( - get_trace_fixture_by_name(trace_dataset_name) - ) - ), - get_evals_from_fixture(trace_dataset_name), - ) + # todo: add boolean flag for --reset-span-ids-and-timestamps + # todo: ensure that fixture tuples are inserted in chronological order + fixture_spans = [ + # Apply `encode` here because legacy jsonl files contains UUIDs as strings. + # `encode` removes the hyphens in the UUIDs. + decode_otlp_span(encode_span_to_otlp(json_string_to_span(json_span))) + for json_span in reversed( + download_traces_fixture(get_trace_fixture_by_name(trace_dataset_name)) + ) + ] + fixture_evals = list(get_evals_from_fixture(trace_dataset_name)) umap_params_list = args.umap_params.split(",") umap_params = UMAPParameters( min_dist=float(umap_params_list[0]), From 459fb3c5f98d51e5699ee5f44a4bb5ba88f6160a Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 13:50:26 -0700 Subject: [PATCH 39/74] ensure that queries with sorts but no cursor still return the sorted field as part of the cursor --- .../pagination_query_testing.ipynb | 31 +++++++++++++++++-- src/phoenix/server/api/types/Project.py | 2 +- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index a49553d811..d7747d0505 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -21,6 +21,17 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cursor = \"MTAwOkRBVEVUSU1FOjIwMjMtMTItMTFUMTc6NDQ6MDIuNTM0MTI5KzAwOjAw\"\n", + "node_identifier = NodeIdentifier.from_cursor(cursor)\n", + "print(node_identifier)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -146,7 +157,7 @@ "metadata": {}, "outputs": [], "source": [ - "# page ends on the penultimate record and excludees last record\n", + "# page ends on the penultimate record and excludes last record\n", "response = client.execute(\n", " spans_query,\n", " variable_values={\n", @@ -313,7 +324,14 @@ ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [765, 764, 763, 762, 761], new_ids" + "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", + "end_node_identifier = NodeIdentifier.from_cursor(end_cursor)\n", + "assert new_ids == [765, 764, 763, 762, 761], new_ids\n", + "assert end_node_identifier.rowid == 761\n", + "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", + "assert (\n", + " end_node_start_timestamp := end_sortable_field.value.isoformat()\n", + ") == \"2023-12-11T17:48:40.807667+00:00\", end_node_start_timestamp" ] }, { @@ -333,7 +351,14 @@ ")\n", "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [1, 2, 3, 4, 5], new_ids" + "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", + "end_node_identifier = NodeIdentifier.from_cursor(end_cursor)\n", + "assert new_ids == [1, 2, 3, 4, 5], new_ids\n", + "assert end_node_identifier.rowid == 5\n", + "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", + "assert (\n", + " end_node_start_timestamp := end_sortable_field.value.isoformat()\n", + ") == \"2023-12-11T17:43:23.712144+00:00\", end_node_start_timestamp" ] }, { diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index df5cf0e517..41b626826e 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -221,7 +221,7 @@ async def spans( async for span in islice(spans, first): sf = ( SortableField(type=SortableFieldType.DATETIME, value=span.start_time) - if sortable_field is not None + if sort else None ) node_identifier = NodeIdentifier( From 1a04f7705b571c1fc3940d364c508d0c65ac3bd2 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 14:11:37 -0700 Subject: [PATCH 40/74] rename variables in tests --- .../pagination_query_testing.ipynb | 98 +++++++++---------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index d7747d0505..1fb4b6c9f2 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -75,9 +75,9 @@ " \"id\": \"UHJvamVjdDox\",\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "new_ids" + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids" ] }, { @@ -126,9 +126,9 @@ " spans_query,\n", " variable_values={\"projectId\": project_id, \"first\": 5},\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [765, 764, 763, 762, 761], new_ids" + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [765, 764, 763, 762, 761], ids" ] }, { @@ -146,9 +146,9 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [760, 759, 758, 757, 756], new_ids" + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [760, 759, 758, 757, 756], ids" ] }, { @@ -166,11 +166,11 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", - "assert new_ids == [6, 5, 4, 3, 2], new_ids\n", + "assert ids == [6, 5, 4, 3, 2], ids\n", "assert has_next_page is True\n", "assert has_previous_page is False" ] @@ -190,11 +190,11 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", - "assert new_ids == [5, 4, 3, 2, 1], new_ids\n", + "assert ids == [5, 4, 3, 2, 1], ids\n", "assert has_next_page is False\n", "assert has_previous_page is False" ] @@ -214,11 +214,11 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", - "assert new_ids == [4, 3, 2, 1], new_ids\n", + "assert ids == [4, 3, 2, 1], ids\n", "assert has_next_page is False\n", "assert has_previous_page is False" ] @@ -238,15 +238,15 @@ " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [\n", " 765,\n", " 760,\n", " 755,\n", " 750,\n", " 745,\n", - "], new_ids" + "], ids" ] }, { @@ -267,15 +267,15 @@ " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [\n", " 760,\n", " 755,\n", " 750,\n", " 745,\n", " 740,\n", - "], new_ids" + "], ids" ] }, { @@ -296,15 +296,15 @@ " \"filterCondition\": \"span_kind == 'LLM' and cumulative_llm_token_count_prompt > 300\",\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [\n", " 740,\n", " 730,\n", " 720,\n", " 710,\n", " 690,\n", - "], new_ids" + "], ids" ] }, { @@ -322,11 +322,11 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", "end_node_identifier = NodeIdentifier.from_cursor(end_cursor)\n", - "assert new_ids == [765, 764, 763, 762, 761], new_ids\n", + "assert ids == [765, 764, 763, 762, 761], ids\n", "assert end_node_identifier.rowid == 761\n", "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", "assert (\n", @@ -349,11 +349,11 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", "end_node_identifier = NodeIdentifier.from_cursor(end_cursor)\n", - "assert new_ids == [1, 2, 3, 4, 5], new_ids\n", + "assert ids == [1, 2, 3, 4, 5], ids\n", "assert end_node_identifier.rowid == 5\n", "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", "assert (\n", @@ -376,15 +376,15 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [\n", " 710,\n", " 709,\n", " 706,\n", " 645,\n", " 644,\n", - "], new_ids" + "], ids" ] }, { @@ -402,9 +402,9 @@ " \"first\": 5,\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", - "assert new_ids == [763, 762, 758, 757, 753], new_ids" + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [763, 762, 758, 757, 753], ids" ] }, { @@ -429,13 +429,13 @@ " ).to_cursor(),\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", "start_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"startCursor\"]\n", "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", "start_node_identifier = NodeIdentifier.from_cursor(start_cursor)\n", "end_node_identifier = NodeIdentifier.from_cursor(end_cursor)\n", - "assert new_ids == [759, 758, 757, 756, 755], new_ids\n", + "assert ids == [759, 758, 757, 756, 755], ids\n", "assert (start_sortable_field := start_node_identifier.sortable_field) is not None\n", "assert (\n", " start_timestamp := start_sortable_field.stringify_value()\n", @@ -472,13 +472,13 @@ " ).to_cursor(),\n", " },\n", ")\n", - "new_cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "new_ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in new_cursors]\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", "start_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"startCursor\"]\n", "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", "start_node_identifier = NodeIdentifier.from_cursor(start_cursor)\n", "end_node_identifier = NodeIdentifier.from_cursor(end_cursor)\n", - "assert new_ids == [9, 10, 11, 12, 13], new_ids\n", + "assert ids == [9, 10, 11, 12, 13], ids\n", "assert (start_sortable_field := start_node_identifier.sortable_field) is not None\n", "assert (\n", " start_timestamp := start_sortable_field.stringify_value()\n", From 4797f2f824cb7baaeb5d9fed5ff73fef8615fcf4 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 14:25:17 -0700 Subject: [PATCH 41/74] allow integer types to be passed into a node identifiers sortable field value --- src/phoenix/server/api/types/pagination.py | 2 +- tests/server/api/types/test_pagination.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index b2a9c690e6..309a0002cc 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -71,7 +71,7 @@ class SortableField: value: SortableFieldValue def stringify_value(self) -> str: - if isinstance(self.value, float): + if isinstance(self.value, (int, float)): return str(self.value) if isinstance(self.value, datetime): return self.value.isoformat() diff --git a/tests/server/api/types/test_pagination.py b/tests/server/api/types/test_pagination.py index 8f7c2b5fbb..a2452ec0d9 100644 --- a/tests/server/api/types/test_pagination.py +++ b/tests/server/api/types/test_pagination.py @@ -123,6 +123,24 @@ def test_to_and_from_cursor_with_rowid_and_float_deserializes_original( assert sortable_field.type == SortableFieldType.FLOAT assert abs(sortable_field.value - 11.5) < 1e-8 + def test_to_and_from_cursor_with_rowid_and_int_deserializes_original_as_float( + self, + ) -> None: + original = NodeIdentifier( + rowid=10, + sortable_field=SortableField( + type=SortableFieldType.FLOAT, + value=11, # an integer value + ), + ) + cursor = original.to_cursor() + deserialized = NodeIdentifier.from_cursor(cursor) + assert deserialized.rowid == 10 + assert (sortable_field := deserialized.sortable_field) is not None + assert sortable_field.type == SortableFieldType.FLOAT + assert isinstance((value := sortable_field.value), float) + assert abs(value - 11.0) < 1e-8 + def test_to_and_from_cursor_with_rowid_and_tz_naive_datetime_deserializes_original( self, ) -> None: From d2f35365f7a1011016fe69618f21e329bfa7b591 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 14:59:41 -0700 Subject: [PATCH 42/74] refactor SpanColumn type to advertise an `orm_expression` property --- src/phoenix/server/api/input_types/SpanSort.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index 6120af615d..be26b1759e 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -29,6 +29,10 @@ class SpanColumn(Enum): cumulativeTokenCountPrompt = auto() cumulativeTokenCountCompletion = auto() + @property + def orm_expression(self) -> Any: + return _SPAN_COLUMN_TO_ORM_EXPR_MAP[self] + @strawberry.enum class EvalAttr(Enum): @@ -76,7 +80,7 @@ class SpanSort: def update_orm_expr(self, stmt: Select[Any]) -> Select[Any]: if self.col and not self.eval_result_key: - expr = _SPAN_COLUMN_TO_ORM_EXPR_MAP[self.col] + expr = self.col.orm_expression if self.dir == SortDir.desc: expr = desc(expr) return stmt.order_by(nulls_last(expr)) From 27577df9f13d1da45bea3f86778e5fe5261034f3 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 15:24:28 -0700 Subject: [PATCH 43/74] add and pass test for floating point order by with cursor --- .../pagination_query_testing.ipynb | 30 +++++++++++++++++++ .../server/api/input_types/SpanSort.py | 18 +++++++++++ src/phoenix/server/api/types/Project.py | 21 ++++++++----- 3 files changed, 61 insertions(+), 8 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 1fb4b6c9f2..70f36ff63c 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -492,6 +492,36 @@ ") == \"2023-12-11T17:43:26.496177+00:00\", end_timestamp\n", "assert (end_field_type := end_sortable_field.type) == SortableFieldType.DATETIME, end_field_type" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# order by cumulative prompt token count in descending order with cursor\n", + "response = client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": project_id,\n", + " \"sort\": {\"col\": \"cumulativeTokenCountPrompt\", \"dir\": \"desc\"},\n", + " \"first\": 5,\n", + " \"after\": NodeIdentifier(\n", + " rowid=644, # row 644 is in between rows 645 and 641, which also have 1054 cumulative prompt tokens\n", + " sortable_field=SortableField(type=SortableFieldType.FLOAT, value=1054),\n", + " ).to_cursor(),\n", + " },\n", + ")\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [\n", + " 641,\n", + " 550,\n", + " 549,\n", + " 546,\n", + " 60,\n", + "], ids" + ] } ], "metadata": { diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index be26b1759e..89ff00841b 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -6,9 +6,11 @@ from sqlalchemy import and_, desc, nulls_last from sqlalchemy.sql.expression import Select from strawberry import UNSET +from typing_extensions import assert_never import phoenix.trace.v1 as pb from phoenix.db import models +from phoenix.server.api.types.pagination import SortableFieldType from phoenix.server.api.types.SortDir import SortDir from phoenix.trace.schemas import SpanID @@ -33,6 +35,22 @@ class SpanColumn(Enum): def orm_expression(self) -> Any: return _SPAN_COLUMN_TO_ORM_EXPR_MAP[self] + @property + def data_type(self) -> SortableFieldType: + if self is SpanColumn.startTime or self is SpanColumn.endTime: + return SortableFieldType.DATETIME + if ( + self is SpanColumn.latencyMs + or self is SpanColumn.tokenCountTotal + or self is SpanColumn.tokenCountPrompt + or self is SpanColumn.tokenCountCompletion + or self is SpanColumn.cumulativeTokenCountTotal + or self is SpanColumn.cumulativeTokenCountPrompt + or self is SpanColumn.cumulativeTokenCountCompletion + ): + return SortableFieldType.FLOAT + assert_never(self) + @strawberry.enum class EvalAttr(Enum): diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 41b626826e..654a4284d9 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -23,7 +23,6 @@ Cursor, NodeIdentifier, SortableField, - SortableFieldType, connections, ) from phoenix.server.api.types.SortDir import SortDir @@ -200,12 +199,13 @@ async def spans( sortable_field = node_identifier.sortable_field assert sort is not None # todo: refactor this into a validation check compare = operator.lt if sort.dir is SortDir.desc else operator.gt - stmt = stmt.where( - compare( - tuple_(models.Span.start_time, models.Span.id), - (sortable_field.value, node_identifier.rowid), + if sort_column := sort.col: + stmt = stmt.where( + compare( + tuple_(sort_column.orm_expression, models.Span.id), + (sortable_field.value, node_identifier.rowid), + ) ) - ) else: stmt = stmt.where(models.Span.id < node_identifier.rowid) if first: @@ -220,8 +220,13 @@ async def spans( spans = await session.stream_scalars(stmt) async for span in islice(spans, first): sf = ( - SortableField(type=SortableFieldType.DATETIME, value=span.start_time) - if sort + SortableField( + type=sort_col.data_type, + value=getattr( + span, sort_col.orm_expression.key + ), # todo: find a cleaner way to get this value + ) + if sort and (sort_col := sort.col) else None ) node_identifier = NodeIdentifier( From c1ac2f99915ccf93442ec0f4900a879562ae6c0c Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 15:34:44 -0700 Subject: [PATCH 44/74] refactor map from graphql span column type to orm expressions and remove unnecessary test --- .../server/api/input_types/SpanSort.py | 36 +++++++++++-------- tests/server/api/input_types/test_SpanSort.py | 8 ----- 2 files changed, 22 insertions(+), 22 deletions(-) delete mode 100644 tests/server/api/input_types/test_SpanSort.py diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index 89ff00841b..70961d431a 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -33,7 +33,28 @@ class SpanColumn(Enum): @property def orm_expression(self) -> Any: - return _SPAN_COLUMN_TO_ORM_EXPR_MAP[self] + if self is SpanColumn.startTime: + return models.Span.start_time + if self is SpanColumn.endTime: + return models.Span.end_time + if self is SpanColumn.latencyMs: + return models.Span.latency_ms + if self is SpanColumn.tokenCountTotal: + return models.Span.attributes[LLM_TOKEN_COUNT_TOTAL].as_float() + if self is SpanColumn.tokenCountPrompt: + return models.Span.attributes[LLM_TOKEN_COUNT_PROMPT].as_float() + if self is SpanColumn.tokenCountCompletion: + return models.Span.attributes[LLM_TOKEN_COUNT_COMPLETION].as_float() + if self is SpanColumn.cumulativeTokenCountTotal: + return ( + models.Span.cumulative_llm_token_count_prompt + + models.Span.cumulative_llm_token_count_completion + ) + if self is SpanColumn.cumulativeTokenCountPrompt: + return models.Span.cumulative_llm_token_count_prompt + if self is SpanColumn.cumulativeTokenCountCompletion: + return models.Span.cumulative_llm_token_count_completion + assert_never(self) @property def data_type(self) -> SortableFieldType: @@ -58,19 +79,6 @@ class EvalAttr(Enum): label = "label" -_SPAN_COLUMN_TO_ORM_EXPR_MAP = { - SpanColumn.startTime: models.Span.start_time, - SpanColumn.endTime: models.Span.end_time, - SpanColumn.latencyMs: models.Span.latency_ms, - SpanColumn.tokenCountTotal: models.Span.attributes[LLM_TOKEN_COUNT_TOTAL].as_float(), - SpanColumn.tokenCountPrompt: models.Span.attributes[LLM_TOKEN_COUNT_PROMPT].as_float(), - SpanColumn.tokenCountCompletion: models.Span.attributes[LLM_TOKEN_COUNT_COMPLETION].as_float(), - SpanColumn.cumulativeTokenCountTotal: models.Span.cumulative_llm_token_count_prompt - + models.Span.cumulative_llm_token_count_completion, - SpanColumn.cumulativeTokenCountPrompt: models.Span.cumulative_llm_token_count_prompt, - SpanColumn.cumulativeTokenCountCompletion: models.Span.cumulative_llm_token_count_completion, -} - _EVAL_ATTR_TO_ORM_EXPR_MAP = { EvalAttr.score: models.SpanAnnotation.score, EvalAttr.label: models.SpanAnnotation.label, diff --git a/tests/server/api/input_types/test_SpanSort.py b/tests/server/api/input_types/test_SpanSort.py deleted file mode 100644 index 895444f5f4..0000000000 --- a/tests/server/api/input_types/test_SpanSort.py +++ /dev/null @@ -1,8 +0,0 @@ -from phoenix.server.api.input_types.SpanSort import ( - _SPAN_COLUMN_TO_ORM_EXPR_MAP, - SpanColumn, -) - - -def test_span_column_has_orm_expr(): - assert set(SpanColumn) == set(_SPAN_COLUMN_TO_ORM_EXPR_MAP) From 423e354945790215db16b174ecbb4471fef391ca Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 16:00:57 -0700 Subject: [PATCH 45/74] add support for integer types to node identifier --- src/phoenix/server/api/types/pagination.py | 9 ++++++--- tests/server/api/types/test_pagination.py | 16 +++++++++++++++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index 309a0002cc..0d1ef5ae46 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -10,7 +10,7 @@ ID: TypeAlias = int GenericType = TypeVar("GenericType") -SortableFieldValue: TypeAlias = Union[float, datetime] +SortableFieldValue: TypeAlias = Union[int, float, datetime] @strawberry.type @@ -61,6 +61,7 @@ class Edge(Generic[GenericType]): class SortableFieldType(Enum): + INT = auto() FLOAT = auto() DATETIME = auto() @@ -82,9 +83,11 @@ def from_stringified_value( cls, type: SortableFieldType, stringified_value: str ) -> "SortableField": value: SortableFieldValue - if type == SortableFieldType.FLOAT: + if type is SortableFieldType.INT: + value = int(stringified_value) + elif type is SortableFieldType.FLOAT: value = float(stringified_value) - elif type == SortableFieldType.DATETIME: + elif type is SortableFieldType.DATETIME: value = datetime.fromisoformat(stringified_value) else: assert_never(type) diff --git a/tests/server/api/types/test_pagination.py b/tests/server/api/types/test_pagination.py index a2452ec0d9..89de0ffa4d 100644 --- a/tests/server/api/types/test_pagination.py +++ b/tests/server/api/types/test_pagination.py @@ -110,6 +110,20 @@ def test_to_and_from_cursor_with_rowid_deserializes_original(self) -> None: assert deserialized.rowid == 10 assert deserialized.sortable_field is None + def test_to_and_from_cursor_with_rowid_and_int_deserializes_original( + self, + ) -> None: + original = NodeIdentifier( + rowid=10, sortable_field=SortableField(type=SortableFieldType.INT, value=11) + ) + cursor = original.to_cursor() + deserialized = NodeIdentifier.from_cursor(cursor) + assert deserialized.rowid == 10 + assert (sortable_field := deserialized.sortable_field) is not None + assert sortable_field.type == SortableFieldType.INT + assert isinstance((value := sortable_field.value), int) + assert value == 11 + def test_to_and_from_cursor_with_rowid_and_float_deserializes_original( self, ) -> None: @@ -123,7 +137,7 @@ def test_to_and_from_cursor_with_rowid_and_float_deserializes_original( assert sortable_field.type == SortableFieldType.FLOAT assert abs(sortable_field.value - 11.5) < 1e-8 - def test_to_and_from_cursor_with_rowid_and_int_deserializes_original_as_float( + def test_to_and_from_cursor_with_rowid_and_float_passed_as_int_deserializes_original_as_float( self, ) -> None: original = NodeIdentifier( From 77ab626a35c6b05a0b9d0384d88b263ed3e3b49d Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 18:02:12 -0700 Subject: [PATCH 46/74] pass tests for order by integer with cursor --- .../pagination_query_testing.ipynb | 24 +++++++++++++++++++ .../server/api/input_types/SpanSort.py | 13 ++++++---- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 70f36ff63c..a6d1f3d17b 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -522,6 +522,30 @@ " 60,\n", "], ids" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# order by cumulative prompt token count in ascending order with cursor\n", + "response = client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": project_id,\n", + " \"sort\": {\"col\": \"cumulativeTokenCountPrompt\", \"dir\": \"asc\"},\n", + " \"first\": 5,\n", + " \"after\": NodeIdentifier(\n", + " rowid=294, # row 294 is in between rows 295 and 291, which also have 276 cumulative prompt tokens\n", + " sortable_field=SortableField(type=SortableFieldType.INT, value=276),\n", + " ).to_cursor(),\n", + " },\n", + ")\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [295, 115, 114, 111, 25], ids" + ] } ], "metadata": { diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index 70961d431a..e5d317af4b 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -58,18 +58,21 @@ def orm_expression(self) -> Any: @property def data_type(self) -> SortableFieldType: - if self is SpanColumn.startTime or self is SpanColumn.endTime: - return SortableFieldType.DATETIME + if ( + self is SpanColumn.cumulativeTokenCountTotal + or self is SpanColumn.cumulativeTokenCountPrompt + or self is SpanColumn.cumulativeTokenCountCompletion + ): + return SortableFieldType.INT if ( self is SpanColumn.latencyMs or self is SpanColumn.tokenCountTotal or self is SpanColumn.tokenCountPrompt or self is SpanColumn.tokenCountCompletion - or self is SpanColumn.cumulativeTokenCountTotal - or self is SpanColumn.cumulativeTokenCountPrompt - or self is SpanColumn.cumulativeTokenCountCompletion ): return SortableFieldType.FLOAT + if self is SpanColumn.startTime or self is SpanColumn.endTime: + return SortableFieldType.DATETIME assert_never(self) From 0dcc245596ba45cadbc5ab97dbaedf9ecd11b63f Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 18:37:24 -0700 Subject: [PATCH 47/74] paginating by latency working on column parameters --- src/phoenix/server/api/types/Project.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 654a4284d9..0079639780 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -223,7 +223,7 @@ async def spans( SortableField( type=sort_col.data_type, value=getattr( - span, sort_col.orm_expression.key + span, sort_col.orm_expression.name ), # todo: find a cleaner way to get this value ) if sort and (sort_col := sort.col) From 85b08b8f84d503ed02368364e23d5bbe504c5976 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 18:52:22 -0700 Subject: [PATCH 48/74] add support for string-based cursors --- src/phoenix/server/api/types/pagination.py | 9 +++++++-- tests/server/api/types/test_pagination.py | 13 +++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index 0d1ef5ae46..24c465f880 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -10,7 +10,7 @@ ID: TypeAlias = int GenericType = TypeVar("GenericType") -SortableFieldValue: TypeAlias = Union[int, float, datetime] +SortableFieldValue: TypeAlias = Union[str, int, float, datetime] @strawberry.type @@ -61,6 +61,7 @@ class Edge(Generic[GenericType]): class SortableFieldType(Enum): + STRING = auto() INT = auto() FLOAT = auto() DATETIME = auto() @@ -72,6 +73,8 @@ class SortableField: value: SortableFieldValue def stringify_value(self) -> str: + if isinstance(self.value, str): + return self.value if isinstance(self.value, (int, float)): return str(self.value) if isinstance(self.value, datetime): @@ -83,7 +86,9 @@ def from_stringified_value( cls, type: SortableFieldType, stringified_value: str ) -> "SortableField": value: SortableFieldValue - if type is SortableFieldType.INT: + if type is SortableFieldType.STRING: + value = stringified_value + elif type is SortableFieldType.INT: value = int(stringified_value) elif type is SortableFieldType.FLOAT: value = float(stringified_value) diff --git a/tests/server/api/types/test_pagination.py b/tests/server/api/types/test_pagination.py index 89de0ffa4d..a44b5329bc 100644 --- a/tests/server/api/types/test_pagination.py +++ b/tests/server/api/types/test_pagination.py @@ -110,6 +110,19 @@ def test_to_and_from_cursor_with_rowid_deserializes_original(self) -> None: assert deserialized.rowid == 10 assert deserialized.sortable_field is None + def test_to_and_from_cursor_with_rowid_and_string_deserializes_original( + self, + ) -> None: + original = NodeIdentifier( + rowid=10, sortable_field=SortableField(type=SortableFieldType.STRING, value="abc") + ) + cursor = original.to_cursor() + deserialized = NodeIdentifier.from_cursor(cursor) + assert deserialized.rowid == 10 + assert (sortable_field := deserialized.sortable_field) is not None + assert sortable_field.type == SortableFieldType.STRING + assert sortable_field.value == "abc" + def test_to_and_from_cursor_with_rowid_and_int_deserializes_original( self, ) -> None: From 0427c72693baed89b2b4c10ab33ab53033e6fa67 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 19:16:30 -0700 Subject: [PATCH 49/74] add failing test for order by eval label with cursor --- .../pagination_query_testing.ipynb | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index a6d1f3d17b..99997cff56 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -546,6 +546,39 @@ "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", "assert ids == [295, 115, 114, 111, 25], ids" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# order by cumulative prompt token count in descending order with cursor\n", + "response = client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": project_id,\n", + " \"sort\": {\n", + " \"evalResultKey\": {\"name\": \"Hallucination\", \"attr\": \"label\"},\n", + " \"dir\": \"desc\",\n", + " },\n", + " \"first\": 5,\n", + " \"after\": NodeIdentifier(\n", + " rowid=141, # row 141 is surrounded by many other hallucinations\n", + " sortable_field=SortableField(type=SortableFieldType.STRING, value=\"hallucinated\"),\n", + " ).to_cursor(),\n", + " },\n", + ")\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [\n", + " 121,\n", + " 116,\n", + " 106,\n", + " 76,\n", + " 66,\n", + "], ids" + ] } ], "metadata": { From 05745cae19ebdc949bb812cbe43969df12d3502f Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 19:21:12 -0700 Subject: [PATCH 50/74] add passing tests for order by eval labels with no cursor, both ascending and descending --- .../pagination_query_testing.ipynb | 70 +++++++++++++++++-- 1 file changed, 63 insertions(+), 7 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 99997cff56..29f830a9b2 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -447,7 +447,9 @@ "assert (\n", " end_timestamp := end_sortable_field.stringify_value()\n", ") == \"2023-12-11T17:48:38.803725+00:00\", end_timestamp\n", - "assert (end_field_type := end_sortable_field.type) == SortableFieldType.DATETIME, end_field_type" + "assert (\n", + " end_field_type := end_sortable_field.type\n", + ") == SortableFieldType.DATETIME, end_field_type" ] }, { @@ -490,7 +492,9 @@ "assert (\n", " end_timestamp := end_sortable_field.stringify_value()\n", ") == \"2023-12-11T17:43:26.496177+00:00\", end_timestamp\n", - "assert (end_field_type := end_sortable_field.type) == SortableFieldType.DATETIME, end_field_type" + "assert (\n", + " end_field_type := end_sortable_field.type\n", + ") == SortableFieldType.DATETIME, end_field_type" ] }, { @@ -553,7 +557,7 @@ "metadata": {}, "outputs": [], "source": [ - "# order by cumulative prompt token count in descending order with cursor\n", + "# order by hallucination eval label in descending order\n", "response = client.execute(\n", " spans_query,\n", " variable_values={\n", @@ -563,10 +567,62 @@ " \"dir\": \"desc\",\n", " },\n", " \"first\": 5,\n", - " \"after\": NodeIdentifier(\n", - " rowid=141, # row 141 is surrounded by many other hallucinations\n", - " sortable_field=SortableField(type=SortableFieldType.STRING, value=\"hallucinated\"),\n", - " ).to_cursor(),\n", + " },\n", + ")\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [761, 756, 746, 741, 721], ids" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# order by hallucination eval label in ascending order\n", + "response = client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": project_id,\n", + " \"sort\": {\n", + " \"evalResultKey\": {\"name\": \"Hallucination\", \"attr\": \"label\"},\n", + " \"dir\": \"asc\",\n", + " },\n", + " \"first\": 5,\n", + " },\n", + ")\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [\n", + " 751,\n", + " 736,\n", + " 731,\n", + " 726,\n", + " 716,\n", + "], ids" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# order by hallucination eval label in descending order with cursor\n", + "response = client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": project_id,\n", + " \"sort\": {\n", + " \"evalResultKey\": {\"name\": \"Hallucination\", \"attr\": \"label\"},\n", + " \"dir\": \"desc\",\n", + " },\n", + " \"first\": 5,\n", + " # \"after\": NodeIdentifier(\n", + " # rowid=141, # row 141 is surrounded by many other hallucinations\n", + " # sortable_field=SortableField(type=SortableFieldType.STRING, value=\"hallucinated\"),\n", + " # ).to_cursor(),\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", From 2ecc76a0a10cc6f969808213579c23c7209d5c3a Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 20:51:15 -0700 Subject: [PATCH 51/74] refactor SpanFilter to return result object containing annotation aliases --- src/phoenix/server/api/types/Project.py | 3 ++- src/phoenix/trace/dsl/filter.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 0079639780..330a9cd293 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -191,7 +191,8 @@ async def spans( ).where(parent.c.span_id.is_(None)) if filter_condition: span_filter = SpanFilter(condition=filter_condition) - stmt = span_filter(stmt) + filter_result = span_filter.result(stmt) + stmt = filter_result.stmt sortable_field: Optional[SortableField] = None if after: node_identifier = NodeIdentifier.from_cursor(after) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index ea08af968f..b0c02c6a1a 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -129,6 +129,12 @@ def attribute_alias(self, attribute: EvalAttribute) -> str: ) +@dataclass(frozen=True) +class SpanFilterResult: + stmt: Select[typing.Any] + aliased_annotation_relations: typing.Tuple[AliasedAnnotationRelation, ...] + + @dataclass(frozen=True) class SpanFilter: condition: str = "" @@ -172,6 +178,12 @@ def __post_init__(self) -> None: object.__setattr__(self, "_aliased_annotation_relations", aliased_annotation_relations) object.__setattr__(self, "_aliased_annotation_attributes", aliased_annotation_attributes) + def result(self, select: Select[typing.Any]) -> SpanFilterResult: + stmt = self(select) + return SpanFilterResult( + stmt=stmt, aliased_annotation_relations=self._aliased_annotation_relations + ) + def __call__(self, select: Select[typing.Any]) -> Select[typing.Any]: if not self.condition: return select From d8b59de6042f4236fcc548a963ea7d1f448c9dcc Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 21:50:51 -0700 Subject: [PATCH 52/74] refactor SpanSort to return result object in preparation for returning aliases --- .../server/api/input_types/SpanSort.py | 27 ++++++++++++------- src/phoenix/server/api/types/Project.py | 3 ++- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index e5d317af4b..9ab56720cf 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from enum import Enum, auto from typing import Any, Optional, Protocol @@ -98,6 +99,12 @@ class SupportsGetSpanEvaluation(Protocol): def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluation]: ... +@dataclass(frozen=True) +class SpanSortResult: + stmt: Select[Any] + eval_alias: Optional[str] = None + + @strawberry.input( description="The sort key and direction for span connections. Must " "specify one and only one of either `col` or `evalResultKey`." @@ -107,22 +114,24 @@ class SpanSort: eval_result_key: Optional[EvalResultKey] = UNSET dir: SortDir - def update_orm_expr(self, stmt: Select[Any]) -> Select[Any]: + def update_orm_expr(self, stmt: Select[Any]) -> SpanSortResult: if self.col and not self.eval_result_key: expr = self.col.orm_expression if self.dir == SortDir.desc: expr = desc(expr) - return stmt.order_by(nulls_last(expr)) + return SpanSortResult(stmt=stmt.order_by(nulls_last(expr))) if self.eval_result_key and not self.col: eval_name = self.eval_result_key.name expr = _EVAL_ATTR_TO_ORM_EXPR_MAP[self.eval_result_key.attr] if self.dir == SortDir.desc: expr = desc(expr) - return stmt.join( - models.SpanAnnotation, - onclause=and_( - models.SpanAnnotation.span_rowid == models.Span.id, - models.SpanAnnotation.name == eval_name, - ), - ).order_by(expr) + return SpanSortResult( + stmt=stmt.join( + models.SpanAnnotation, + onclause=and_( + models.SpanAnnotation.span_rowid == models.Span.id, + models.SpanAnnotation.name == eval_name, + ), + ).order_by(expr) + ) raise ValueError("Exactly one of `col` or `evalResultKey` must be specified on `SpanSort`.") diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 330a9cd293..8c558af411 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -214,7 +214,8 @@ async def spans( first + 1 # overfetch by one to determine whether there's a next page ) if sort: - stmt = sort.update_orm_expr(stmt) + sort_result = sort.update_orm_expr(stmt) + stmt = sort_result.stmt stmt = stmt.order_by(desc(models.Span.id)) data = [] async with info.context.db() as session: From 81c8f55b6b7ae86356222d4a5e05c87c8e1aecfd Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 21:51:31 -0700 Subject: [PATCH 53/74] fix style --- integration-tests/pagination_query_testing.ipynb | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 29f830a9b2..ae746ef6cb 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -447,9 +447,7 @@ "assert (\n", " end_timestamp := end_sortable_field.stringify_value()\n", ") == \"2023-12-11T17:48:38.803725+00:00\", end_timestamp\n", - "assert (\n", - " end_field_type := end_sortable_field.type\n", - ") == SortableFieldType.DATETIME, end_field_type" + "assert (end_field_type := end_sortable_field.type) == SortableFieldType.DATETIME, end_field_type" ] }, { @@ -492,9 +490,7 @@ "assert (\n", " end_timestamp := end_sortable_field.stringify_value()\n", ") == \"2023-12-11T17:43:26.496177+00:00\", end_timestamp\n", - "assert (\n", - " end_field_type := end_sortable_field.type\n", - ") == SortableFieldType.DATETIME, end_field_type" + "assert (end_field_type := end_sortable_field.type) == SortableFieldType.DATETIME, end_field_type" ] }, { From eb06174828d0fca5e8bca97d2c4e583a5ef68910 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 5 May 2024 22:41:58 -0700 Subject: [PATCH 54/74] Revert "refactor SpanFilter to return result object containing annotation aliases" This reverts commit 2ecc76a0a10cc6f969808213579c23c7209d5c3a. --- src/phoenix/server/api/types/Project.py | 3 +-- src/phoenix/trace/dsl/filter.py | 12 ------------ 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 8c558af411..c218f928b9 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -191,8 +191,7 @@ async def spans( ).where(parent.c.span_id.is_(None)) if filter_condition: span_filter = SpanFilter(condition=filter_condition) - filter_result = span_filter.result(stmt) - stmt = filter_result.stmt + stmt = span_filter(stmt) sortable_field: Optional[SortableField] = None if after: node_identifier = NodeIdentifier.from_cursor(after) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index b0c02c6a1a..ea08af968f 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -129,12 +129,6 @@ def attribute_alias(self, attribute: EvalAttribute) -> str: ) -@dataclass(frozen=True) -class SpanFilterResult: - stmt: Select[typing.Any] - aliased_annotation_relations: typing.Tuple[AliasedAnnotationRelation, ...] - - @dataclass(frozen=True) class SpanFilter: condition: str = "" @@ -178,12 +172,6 @@ def __post_init__(self) -> None: object.__setattr__(self, "_aliased_annotation_relations", aliased_annotation_relations) object.__setattr__(self, "_aliased_annotation_attributes", aliased_annotation_attributes) - def result(self, select: Select[typing.Any]) -> SpanFilterResult: - stmt = self(select) - return SpanFilterResult( - stmt=stmt, aliased_annotation_relations=self._aliased_annotation_relations - ) - def __call__(self, select: Select[typing.Any]) -> Select[typing.Any]: if not self.condition: return select From ced6e20865e133a330c37ccff30469f9c841f3b8 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 6 May 2024 00:45:15 -0700 Subject: [PATCH 55/74] passing order by descending eval labels with cursor --- .../pagination_query_testing.ipynb | 8 +- .../server/api/input_types/SpanSort.py | 73 ++++++++++++++----- src/phoenix/server/api/types/Project.py | 44 +++++------ 3 files changed, 81 insertions(+), 44 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index ae746ef6cb..0d057400b6 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -615,10 +615,10 @@ " \"dir\": \"desc\",\n", " },\n", " \"first\": 5,\n", - " # \"after\": NodeIdentifier(\n", - " # rowid=141, # row 141 is surrounded by many other hallucinations\n", - " # sortable_field=SortableField(type=SortableFieldType.STRING, value=\"hallucinated\"),\n", - " # ).to_cursor(),\n", + " \"after\": NodeIdentifier(\n", + " rowid=141, # row 141 is surrounded by many other hallucinations\n", + " sortable_field=SortableField(type=SortableFieldType.STRING, value=\"hallucinated\"),\n", + " ).to_cursor(),\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index 9ab56720cf..0e9e7b2277 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -1,10 +1,11 @@ from dataclasses import dataclass from enum import Enum, auto -from typing import Any, Optional, Protocol +from typing import Any, Optional, Protocol, cast import strawberry from openinference.semconv.trace import SpanAttributes from sqlalchemy import and_, desc, nulls_last +from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.sql.expression import Select from strawberry import UNSET from typing_extensions import assert_never @@ -32,6 +33,10 @@ class SpanColumn(Enum): cumulativeTokenCountPrompt = auto() cumulativeTokenCountCompletion = auto() + @property + def orm_key(self) -> str: + return cast(str, self.orm_expression.name) + @property def orm_expression(self) -> Any: if self is SpanColumn.startTime: @@ -82,11 +87,28 @@ class EvalAttr(Enum): score = "score" label = "label" + @property + def orm_key(self) -> str: + return f"span_annotations_{self.value}" + + @property + def orm_expression(self) -> Any: + expr: InstrumentedAttribute[Any] + if self is EvalAttr.score: + expr = models.SpanAnnotation.score + elif self is EvalAttr.label: + expr = models.SpanAnnotation.label + else: + assert_never(self) + return expr.label(self.orm_key) -_EVAL_ATTR_TO_ORM_EXPR_MAP = { - EvalAttr.score: models.SpanAnnotation.score, - EvalAttr.label: models.SpanAnnotation.label, -} + @property + def data_type(self) -> SortableFieldType: + if self is EvalAttr.label: + return SortableFieldType.STRING + if self is EvalAttr.score: + return SortableFieldType.FLOAT + assert_never(self) @strawberry.input @@ -102,7 +124,9 @@ def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluat @dataclass(frozen=True) class SpanSortResult: stmt: Select[Any] - eval_alias: Optional[str] = None + orm_key: str + orm_expression: Any + data_type: SortableFieldType @strawberry.input( @@ -115,23 +139,34 @@ class SpanSort: dir: SortDir def update_orm_expr(self, stmt: Select[Any]) -> SpanSortResult: - if self.col and not self.eval_result_key: - expr = self.col.orm_expression + if (col := self.col) and not self.eval_result_key: + expr = col.orm_expression if self.dir == SortDir.desc: expr = desc(expr) - return SpanSortResult(stmt=stmt.order_by(nulls_last(expr))) - if self.eval_result_key and not self.col: - eval_name = self.eval_result_key.name - expr = _EVAL_ATTR_TO_ORM_EXPR_MAP[self.eval_result_key.attr] + return SpanSortResult( + stmt=stmt.order_by(nulls_last(expr)), + orm_key=col.orm_key, + orm_expression=col.orm_expression, + data_type=col.data_type, + ) + if (eval_result_key := self.eval_result_key) and not col: + eval_name = eval_result_key.name + eval_attr = eval_result_key.attr + expr = eval_result_key.attr.orm_expression + stmt = stmt.add_columns(expr) if self.dir == SortDir.desc: expr = desc(expr) + stmt = stmt.join( + models.SpanAnnotation, + onclause=and_( + models.SpanAnnotation.span_rowid == models.Span.id, + models.SpanAnnotation.name == eval_name, + ), + ).order_by(expr) return SpanSortResult( - stmt=stmt.join( - models.SpanAnnotation, - onclause=and_( - models.SpanAnnotation.span_rowid == models.Span.id, - models.SpanAnnotation.name == eval_name, - ), - ).order_by(expr) + stmt=stmt, + orm_key=eval_attr.orm_key, + orm_expression=eval_result_key.attr.orm_expression, + data_type=eval_attr.data_type, ) raise ValueError("Exactly one of `col` or `evalResultKey` must be specified on `SpanSort`.") diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index c218f928b9..74213b01af 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -13,7 +13,7 @@ from phoenix.datetime_utils import right_open_time_range from phoenix.db import models from phoenix.server.api.context import Context -from phoenix.server.api.input_types.SpanSort import SpanSort +from phoenix.server.api.input_types.SpanSort import SpanSort, SpanSortResult from phoenix.server.api.input_types.TimeRange import TimeRange from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary from phoenix.server.api.types.EvaluationSummary import EvaluationSummary @@ -193,16 +193,20 @@ async def spans( span_filter = SpanFilter(condition=filter_condition) stmt = span_filter(stmt) sortable_field: Optional[SortableField] = None + sort_result: Optional[SpanSortResult] = None + if sort: + sort_result = sort.update_orm_expr(stmt) + stmt = sort_result.stmt if after: node_identifier = NodeIdentifier.from_cursor(after) if node_identifier.sortable_field is not None: sortable_field = node_identifier.sortable_field assert sort is not None # todo: refactor this into a validation check compare = operator.lt if sort.dir is SortDir.desc else operator.gt - if sort_column := sort.col: + if sort_result: stmt = stmt.where( compare( - tuple_(sort_column.orm_expression, models.Span.id), + tuple_(sort_result.orm_expression, models.Span.id), (sortable_field.value, node_identifier.rowid), ) ) @@ -212,34 +216,32 @@ async def spans( stmt = stmt.limit( first + 1 # overfetch by one to determine whether there's a next page ) - if sort: - sort_result = sort.update_orm_expr(stmt) - stmt = sort_result.stmt stmt = stmt.order_by(desc(models.Span.id)) data = [] async with info.context.db() as session: - spans = await session.stream_scalars(stmt) - async for span in islice(spans, first): - sf = ( - SortableField( - type=sort_col.data_type, - value=getattr( - span, sort_col.orm_expression.name - ), # todo: find a cleaner way to get this value - ) - if sort and (sort_col := sort.col) - else None - ) + rows = await session.execute(stmt) + async for row in islice(rows, first): + span = row[0] + eval_value = row[1] if len(row) > 1 else None node_identifier = NodeIdentifier( rowid=span.id, - sortable_field=sf, + sortable_field=( + SortableField( + type=sort_result.data_type, + value=eval_value + if eval_value is not None + else getattr(span, sort_result.orm_key), + ) + if sort_result + else None + ), ) data.append((node_identifier, to_gql_span(span))) # todo: does this need to be inside the async with block? has_next_page = True try: - await spans.__anext__() - except StopAsyncIteration: + next(rows) + except StopIteration: has_next_page = False return connections( From 795d7a72c6b9f227642dd32ec5fe8af5d752af90 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 6 May 2024 00:57:45 -0700 Subject: [PATCH 56/74] add test for order by ascending eval labels with cursor --- .../pagination_query_testing.ipynb | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 0d057400b6..d612273ffd 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -631,6 +631,33 @@ " 66,\n", "], ids" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# order by hallucination eval label in ascending order with cursor\n", + "response = client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": project_id,\n", + " \"sort\": {\n", + " \"evalResultKey\": {\"name\": \"Hallucination\", \"attr\": \"label\"},\n", + " \"dir\": \"asc\",\n", + " },\n", + " \"first\": 5,\n", + " \"after\": NodeIdentifier(\n", + " rowid=731, # row 746 is surrounded by many other hallucinations\n", + " sortable_field=SortableField(type=SortableFieldType.STRING, value=\"factual\"),\n", + " ).to_cursor(),\n", + " },\n", + ")\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "assert ids == [751, 736, 761, 756, 746], ids" + ] } ], "metadata": { From 5a1a97590a7d4e52d787b1e89107fcb3829e8552 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 6 May 2024 22:55:51 -0700 Subject: [PATCH 57/74] rename NodeIdentifier to Cursor --- .../pagination_query_testing.ipynb | 146 +++++++++--------- src/phoenix/server/api/schema.py | 10 +- src/phoenix/server/api/types/Model.py | 18 +-- src/phoenix/server/api/types/Project.py | 10 +- src/phoenix/server/api/types/Trace.py | 10 +- src/phoenix/server/api/types/pagination.py | 26 ++-- tests/server/api/types/test_pagination.py | 60 +++---- 7 files changed, 143 insertions(+), 137 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index d612273ffd..342e597e77 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -9,7 +9,7 @@ "from gql import Client, gql\n", "from gql.transport.requests import RequestsHTTPTransport\n", "from phoenix.server.api.types.pagination import (\n", - " NodeIdentifier,\n", + " Cursor,\n", " SortableField,\n", " SortableFieldType,\n", ")\n", @@ -28,7 +28,7 @@ "outputs": [], "source": [ "cursor = \"MTAwOkRBVEVUSU1FOjIwMjMtMTItMTFUMTc6NDQ6MDIuNTM0MTI5KzAwOjAw\"\n", - "node_identifier = NodeIdentifier.from_cursor(cursor)\n", + "node_identifier = Cursor.from_string(cursor)\n", "print(node_identifier)" ] }, @@ -76,7 +76,7 @@ " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "ids" ] }, @@ -127,7 +127,7 @@ " variable_values={\"projectId\": project_id, \"first\": 5},\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [765, 764, 763, 762, 761], ids" ] }, @@ -142,12 +142,12 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": NodeIdentifier(rowid=761).to_cursor(),\n", + " \"after\": str(Cursor(rowid=761)),\n", " \"first\": 5,\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [760, 759, 758, 757, 756], ids" ] }, @@ -162,12 +162,12 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": NodeIdentifier(7).to_cursor(),\n", + " \"after\": str(Cursor(7)),\n", " \"first\": 5,\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", "assert ids == [6, 5, 4, 3, 2], ids\n", @@ -186,12 +186,12 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": NodeIdentifier(6).to_cursor(),\n", + " \"after\": str(Cursor(6)),\n", " \"first\": 5,\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", "assert ids == [5, 4, 3, 2, 1], ids\n", @@ -210,12 +210,12 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": NodeIdentifier(5).to_cursor(),\n", + " \"after\": str(Cursor(5)),\n", " \"first\": 5,\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", "assert ids == [4, 3, 2, 1], ids\n", @@ -239,7 +239,7 @@ " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [\n", " 765,\n", " 760,\n", @@ -261,14 +261,12 @@ " variable_values={\n", " \"projectId\": project_id,\n", " \"first\": 5,\n", - " \"after\": NodeIdentifier(\n", - " 765\n", - " ).to_cursor(), # skip the first span satisfying the filter condition\n", + " \"after\": str(Cursor(765)), # skip the first span satisfying the filter condition\n", " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [\n", " 760,\n", " 755,\n", @@ -289,15 +287,13 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": NodeIdentifier(\n", - " 745\n", - " ).to_cursor(), # skip the first span satisfying the filter condition\n", + " \"after\": str(Cursor(745)), # skip the first span satisfying the filter condition\n", " \"first\": 5,\n", " \"filterCondition\": \"span_kind == 'LLM' and cumulative_llm_token_count_prompt > 300\",\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [\n", " 740,\n", " 730,\n", @@ -323,9 +319,9 @@ " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", - "end_node_identifier = NodeIdentifier.from_cursor(end_cursor)\n", + "end_node_identifier = Cursor.from_string(end_cursor)\n", "assert ids == [765, 764, 763, 762, 761], ids\n", "assert end_node_identifier.rowid == 761\n", "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", @@ -350,9 +346,9 @@ " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", - "end_node_identifier = NodeIdentifier.from_cursor(end_cursor)\n", + "end_node_identifier = Cursor.from_string(end_cursor)\n", "assert ids == [1, 2, 3, 4, 5], ids\n", "assert end_node_identifier.rowid == 5\n", "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", @@ -377,7 +373,7 @@ " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [\n", " 710,\n", " 709,\n", @@ -403,7 +399,7 @@ " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [763, 762, 758, 757, 753], ids" ] }, @@ -420,21 +416,23 @@ " \"projectId\": project_id,\n", " \"sort\": {\"col\": \"startTime\", \"dir\": \"desc\"},\n", " \"first\": 5,\n", - " \"after\": NodeIdentifier(\n", - " 760,\n", - " sortable_field=SortableField.from_stringified_value(\n", - " type=SortableFieldType.DATETIME,\n", - " stringified_value=\"2023-12-11T17:48:40.154938+00:00\",\n", - " ),\n", - " ).to_cursor(),\n", + " \"after\": str(\n", + " Cursor(\n", + " 760,\n", + " sortable_field=SortableField.from_stringified_value(\n", + " type=SortableFieldType.DATETIME,\n", + " stringified_value=\"2023-12-11T17:48:40.154938+00:00\",\n", + " ),\n", + " )\n", + " ),\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "start_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"startCursor\"]\n", "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", - "start_node_identifier = NodeIdentifier.from_cursor(start_cursor)\n", - "end_node_identifier = NodeIdentifier.from_cursor(end_cursor)\n", + "start_node_identifier = Cursor.from_string(start_cursor)\n", + "end_node_identifier = Cursor.from_string(end_cursor)\n", "assert ids == [759, 758, 757, 756, 755], ids\n", "assert (start_sortable_field := start_node_identifier.sortable_field) is not None\n", "assert (\n", @@ -463,21 +461,23 @@ " \"projectId\": project_id,\n", " \"sort\": {\"col\": \"startTime\", \"dir\": \"asc\"},\n", " \"first\": 5,\n", - " \"after\": NodeIdentifier(\n", - " 8,\n", - " sortable_field=SortableField.from_stringified_value(\n", - " type=SortableFieldType.DATETIME,\n", - " stringified_value=\"2023-12-11T17:43:25.540677+00:00\",\n", - " ),\n", - " ).to_cursor(),\n", + " \"after\": str(\n", + " Cursor(\n", + " 8,\n", + " sortable_field=SortableField.from_stringified_value(\n", + " type=SortableFieldType.DATETIME,\n", + " stringified_value=\"2023-12-11T17:43:25.540677+00:00\",\n", + " ),\n", + " )\n", + " ),\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "start_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"startCursor\"]\n", "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", - "start_node_identifier = NodeIdentifier.from_cursor(start_cursor)\n", - "end_node_identifier = NodeIdentifier.from_cursor(end_cursor)\n", + "start_node_identifier = Cursor.from_string(start_cursor)\n", + "end_node_identifier = Cursor.from_string(end_cursor)\n", "assert ids == [9, 10, 11, 12, 13], ids\n", "assert (start_sortable_field := start_node_identifier.sortable_field) is not None\n", "assert (\n", @@ -506,14 +506,16 @@ " \"projectId\": project_id,\n", " \"sort\": {\"col\": \"cumulativeTokenCountPrompt\", \"dir\": \"desc\"},\n", " \"first\": 5,\n", - " \"after\": NodeIdentifier(\n", - " rowid=644, # row 644 is in between rows 645 and 641, which also have 1054 cumulative prompt tokens\n", - " sortable_field=SortableField(type=SortableFieldType.FLOAT, value=1054),\n", - " ).to_cursor(),\n", + " \"after\": str(\n", + " Cursor(\n", + " rowid=644, # row 644 is in between rows 645 and 641, which also have 1054 cumulative prompt tokens\n", + " sortable_field=SortableField(type=SortableFieldType.FLOAT, value=1054),\n", + " )\n", + " ),\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [\n", " 641,\n", " 550,\n", @@ -536,14 +538,16 @@ " \"projectId\": project_id,\n", " \"sort\": {\"col\": \"cumulativeTokenCountPrompt\", \"dir\": \"asc\"},\n", " \"first\": 5,\n", - " \"after\": NodeIdentifier(\n", - " rowid=294, # row 294 is in between rows 295 and 291, which also have 276 cumulative prompt tokens\n", - " sortable_field=SortableField(type=SortableFieldType.INT, value=276),\n", - " ).to_cursor(),\n", + " \"after\": str(\n", + " Cursor(\n", + " rowid=294, # row 294 is in between rows 295 and 291, which also have 276 cumulative prompt tokens\n", + " sortable_field=SortableField(type=SortableFieldType.INT, value=276),\n", + " )\n", + " ),\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [295, 115, 114, 111, 25], ids" ] }, @@ -566,7 +570,7 @@ " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [761, 756, 746, 741, 721], ids" ] }, @@ -589,7 +593,7 @@ " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [\n", " 751,\n", " 736,\n", @@ -615,14 +619,16 @@ " \"dir\": \"desc\",\n", " },\n", " \"first\": 5,\n", - " \"after\": NodeIdentifier(\n", - " rowid=141, # row 141 is surrounded by many other hallucinations\n", - " sortable_field=SortableField(type=SortableFieldType.STRING, value=\"hallucinated\"),\n", - " ).to_cursor(),\n", + " \"after\": str(\n", + " Cursor(\n", + " rowid=141, # row 141 is surrounded by many other hallucinations\n", + " sortable_field=SortableField(type=SortableFieldType.STRING, value=\"hallucinated\"),\n", + " )\n", + " ),\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [\n", " 121,\n", " 116,\n", @@ -648,14 +654,16 @@ " \"dir\": \"asc\",\n", " },\n", " \"first\": 5,\n", - " \"after\": NodeIdentifier(\n", - " rowid=731, # row 746 is surrounded by many other hallucinations\n", - " sortable_field=SortableField(type=SortableFieldType.STRING, value=\"factual\"),\n", - " ).to_cursor(),\n", + " \"after\": str(\n", + " Cursor(\n", + " rowid=731, # row 746 is surrounded by many other hallucinations\n", + " sortable_field=SortableField(type=SortableFieldType.STRING, value=\"factual\"),\n", + " )\n", + " ),\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", - "ids = [NodeIdentifier.from_cursor(cursor).rowid for cursor in cursors]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [751, 736, 761, 756, 746], ids" ] } diff --git a/src/phoenix/server/api/schema.py b/src/phoenix/server/api/schema.py index 6c7e995034..514362ef23 100644 --- a/src/phoenix/server/api/schema.py +++ b/src/phoenix/server/api/schema.py @@ -42,7 +42,7 @@ from phoenix.server.api.types.pagination import ( Connection, ConnectionArgs, - Cursor, + CursorString, connection_from_list, ) from phoenix.server.api.types.Project import Project @@ -58,14 +58,14 @@ async def projects( info: Info[Context, None], first: Optional[int] = 50, last: Optional[int] = UNSET, - after: Optional[Cursor] = UNSET, - before: Optional[Cursor] = UNSET, + after: Optional[CursorString] = UNSET, + before: Optional[CursorString] = UNSET, ) -> Connection[Project]: args = ConnectionArgs( first=first, - after=after if isinstance(after, Cursor) else None, + after=after if isinstance(after, CursorString) else None, last=last, - before=before if isinstance(before, Cursor) else None, + before=before if isinstance(before, CursorString) else None, ) async with info.context.db() as session: projects = await session.scalars(select(models.Project)) diff --git a/src/phoenix/server/api/types/Model.py b/src/phoenix/server/api/types/Model.py index 81729482b4..155c772ab2 100644 --- a/src/phoenix/server/api/types/Model.py +++ b/src/phoenix/server/api/types/Model.py @@ -19,7 +19,7 @@ from .Dimension import Dimension, to_gql_dimension from .EmbeddingDimension import EmbeddingDimension, to_gql_embedding_dimension from .ExportedFile import ExportedFile -from .pagination import Connection, ConnectionArgs, Cursor, connection_from_list +from .pagination import Connection, ConnectionArgs, CursorString, connection_from_list from .TimeSeries import ( PerformanceTimeSeries, ensure_timeseries_parameters, @@ -35,8 +35,8 @@ def dimensions( info: Info[Context, None], first: Optional[int] = 50, last: Optional[int] = UNSET, - after: Optional[Cursor] = UNSET, - before: Optional[Cursor] = UNSET, + after: Optional[CursorString] = UNSET, + before: Optional[CursorString] = UNSET, include: Optional[DimensionFilter] = UNSET, exclude: Optional[DimensionFilter] = UNSET, ) -> Connection[Dimension]: @@ -50,9 +50,9 @@ def dimensions( ], args=ConnectionArgs( first=first, - after=after if isinstance(after, Cursor) else None, + after=after if isinstance(after, CursorString) else None, last=last, - before=before if isinstance(before, Cursor) else None, + before=before if isinstance(before, CursorString) else None, ), ) @@ -105,8 +105,8 @@ def embedding_dimensions( info: Info[Context, None], first: Optional[int] = 50, last: Optional[int] = UNSET, - after: Optional[Cursor] = UNSET, - before: Optional[Cursor] = UNSET, + after: Optional[CursorString] = UNSET, + before: Optional[CursorString] = UNSET, ) -> Connection[EmbeddingDimension]: """ A non-trivial implementation should efficiently fetch only @@ -123,9 +123,9 @@ def embedding_dimensions( ], args=ConnectionArgs( first=first, - after=after if isinstance(after, Cursor) else None, + after=after if isinstance(after, CursorString) else None, last=last, - before=before if isinstance(before, Cursor) else None, + before=before if isinstance(before, CursorString) else None, ), ) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 829337603c..a3ad54649a 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -21,7 +21,7 @@ from phoenix.server.api.types.pagination import ( Connection, Cursor, - NodeIdentifier, + CursorString, SortableField, connections, ) @@ -158,8 +158,8 @@ async def spans( time_range: Optional[TimeRange] = UNSET, first: Optional[int] = 50, last: Optional[int] = UNSET, - after: Optional[Cursor] = UNSET, - before: Optional[Cursor] = UNSET, + after: Optional[CursorString] = UNSET, + before: Optional[CursorString] = UNSET, sort: Optional[SpanSort] = UNSET, root_spans_only: Optional[bool] = UNSET, filter_condition: Optional[str] = UNSET, @@ -194,7 +194,7 @@ async def spans( sort_result = sort.update_orm_expr(stmt) stmt = sort_result.stmt if after: - node_identifier = NodeIdentifier.from_cursor(after) + node_identifier = Cursor.from_string(after) if node_identifier.sortable_field is not None: sortable_field = node_identifier.sortable_field assert sort is not None # todo: refactor this into a validation check @@ -219,7 +219,7 @@ async def spans( async for row in islice(rows, first): span = row[0] eval_value = row[1] if len(row) > 1 else None - node_identifier = NodeIdentifier( + node_identifier = Cursor( rowid=span.id, sortable_field=( SortableField( diff --git a/src/phoenix/server/api/types/Trace.py b/src/phoenix/server/api/types/Trace.py index 762adf9f67..d538fe66f0 100644 --- a/src/phoenix/server/api/types/Trace.py +++ b/src/phoenix/server/api/types/Trace.py @@ -13,7 +13,7 @@ from phoenix.server.api.types.pagination import ( Connection, ConnectionArgs, - Cursor, + CursorString, connection_from_list, ) from phoenix.server.api.types.Span import Span, to_gql_span @@ -27,14 +27,14 @@ async def spans( info: Info[Context, None], first: Optional[int] = 50, last: Optional[int] = UNSET, - after: Optional[Cursor] = UNSET, - before: Optional[Cursor] = UNSET, + after: Optional[CursorString] = UNSET, + before: Optional[CursorString] = UNSET, ) -> Connection[Span]: args = ConnectionArgs( first=first, - after=after if isinstance(after, Cursor) else None, + after=after if isinstance(after, CursorString) else None, last=last, - before=before if isinstance(before, Cursor) else None, + before=before if isinstance(before, CursorString) else None, ) stmt = ( select(models.Span) diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index 24c465f880..37458a57a2 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -43,7 +43,7 @@ class PageInfo: # A type alias for the connection cursor implementation -Cursor = str +CursorString = str @strawberry.type @@ -100,20 +100,20 @@ def from_stringified_value( @dataclass -class NodeIdentifier: +class Cursor: rowid: int sortable_field: Optional[SortableField] = None _DELIMITER: ClassVar[str] = ":" - def to_cursor(self) -> Cursor: + def __str__(self) -> CursorString: cursor_components = [str(self.rowid)] if (sortable_field := self.sortable_field) is not None: cursor_components.extend([sortable_field.type.name, sortable_field.stringify_value()]) return base64.b64encode(self._DELIMITER.join(cursor_components).encode()).decode() @classmethod - def from_cursor(cls, cursor: Cursor) -> "NodeIdentifier": + def from_string(cls, cursor: CursorString) -> "Cursor": decoded = base64.b64decode(cursor).decode() rowid_string = decoded sortable_field = None @@ -127,14 +127,14 @@ def from_cursor(cls, cursor: Cursor) -> "NodeIdentifier": return cls(rowid=int(rowid_string), sortable_field=sortable_field) -def offset_to_cursor(offset: int) -> Cursor: +def offset_to_cursor(offset: int) -> CursorString: """ Creates the cursor string from an offset. """ return base64.b64encode(f"{CURSOR_PREFIX}{offset}".encode("utf-8")).decode() -def cursor_to_offset(cursor: Cursor) -> int: +def cursor_to_offset(cursor: CursorString) -> int: """ Extracts the offset from the cursor string. """ @@ -142,13 +142,13 @@ def cursor_to_offset(cursor: Cursor) -> int: return int(offset) -def get_offset_with_default(cursor: Optional[Cursor], default_offset: int) -> int: +def get_offset_with_default(cursor: Optional[CursorString], default_offset: int) -> int: """ Given an optional cursor and a default offset, returns the offset to use; if the cursor contains a valid offset, that will be used, otherwise it will be the default. """ - if not isinstance(cursor, Cursor): + if not isinstance(cursor, CursorString): return default_offset offset = cursor_to_offset(cursor) return offset if isinstance(offset, int) else default_offset @@ -161,9 +161,9 @@ class ConnectionArgs: """ first: Optional[int] = UNSET - after: Optional[Cursor] = UNSET + after: Optional[CursorString] = UNSET last: Optional[int] = UNSET - before: Optional[Cursor] = UNSET + before: Optional[CursorString] = UNSET def connection_from_list( @@ -245,13 +245,11 @@ def connection_from_list_slice( def connections( - data: List[Tuple[NodeIdentifier, GenericType]], + data: List[Tuple[Cursor, GenericType]], has_previous_page: bool, has_next_page: bool, ) -> Connection[GenericType]: - edges = [ - Edge(node=node, cursor=tuple_identifier.to_cursor()) for tuple_identifier, node in data - ] + edges = [Edge(node=node, cursor=str(cursor)) for cursor, node in data] has_edges = len(edges) > 0 first_edge = edges[0] if has_edges else None last_edge = edges[-1] if has_edges else None diff --git a/tests/server/api/types/test_pagination.py b/tests/server/api/types/test_pagination.py index a44b5329bc..8e2e91c0b7 100644 --- a/tests/server/api/types/test_pagination.py +++ b/tests/server/api/types/test_pagination.py @@ -5,7 +5,7 @@ from phoenix.server.api.types.Dimension import Dimension from phoenix.server.api.types.pagination import ( ConnectionArgs, - NodeIdentifier, + Cursor, SortableField, SortableFieldType, connection_from_list, @@ -102,98 +102,98 @@ def test_connection_from_empty_list(): assert connection.page_info.has_next_page is False -class TestNodeIdentifier: - def test_to_and_from_cursor_with_rowid_deserializes_original(self) -> None: - original = NodeIdentifier(rowid=10) - cursor = original.to_cursor() - deserialized = NodeIdentifier.from_cursor(cursor) +class TestCursor: + def test_to_and_from_string_with_rowid_deserializes_original(self) -> None: + original = Cursor(rowid=10) + cursor = str(original) + deserialized = Cursor.from_string(cursor) assert deserialized.rowid == 10 assert deserialized.sortable_field is None - def test_to_and_from_cursor_with_rowid_and_string_deserializes_original( + def test_to_and_from_string_with_rowid_and_string_deserializes_original( self, ) -> None: - original = NodeIdentifier( + original = Cursor( rowid=10, sortable_field=SortableField(type=SortableFieldType.STRING, value="abc") ) - cursor = original.to_cursor() - deserialized = NodeIdentifier.from_cursor(cursor) + cursor_string = str(original) + deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 assert (sortable_field := deserialized.sortable_field) is not None assert sortable_field.type == SortableFieldType.STRING assert sortable_field.value == "abc" - def test_to_and_from_cursor_with_rowid_and_int_deserializes_original( + def test_to_and_from_string_with_rowid_and_int_deserializes_original( self, ) -> None: - original = NodeIdentifier( + original = Cursor( rowid=10, sortable_field=SortableField(type=SortableFieldType.INT, value=11) ) - cursor = original.to_cursor() - deserialized = NodeIdentifier.from_cursor(cursor) + cursor_string = str(original) + deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 assert (sortable_field := deserialized.sortable_field) is not None assert sortable_field.type == SortableFieldType.INT assert isinstance((value := sortable_field.value), int) assert value == 11 - def test_to_and_from_cursor_with_rowid_and_float_deserializes_original( + def test_to_and_from_string_with_rowid_and_float_deserializes_original( self, ) -> None: - original = NodeIdentifier( + original = Cursor( rowid=10, sortable_field=SortableField(type=SortableFieldType.FLOAT, value=11.5) ) - cursor = original.to_cursor() - deserialized = NodeIdentifier.from_cursor(cursor) + cursor_string = str(original) + deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 assert (sortable_field := deserialized.sortable_field) is not None assert sortable_field.type == SortableFieldType.FLOAT assert abs(sortable_field.value - 11.5) < 1e-8 - def test_to_and_from_cursor_with_rowid_and_float_passed_as_int_deserializes_original_as_float( + def test_to_and_from_string_with_rowid_and_float_passed_as_int_deserializes_original_as_float( self, ) -> None: - original = NodeIdentifier( + original = Cursor( rowid=10, sortable_field=SortableField( type=SortableFieldType.FLOAT, value=11, # an integer value ), ) - cursor = original.to_cursor() - deserialized = NodeIdentifier.from_cursor(cursor) + cursor_string = str(original) + deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 assert (sortable_field := deserialized.sortable_field) is not None assert sortable_field.type == SortableFieldType.FLOAT assert isinstance((value := sortable_field.value), float) assert abs(value - 11.0) < 1e-8 - def test_to_and_from_cursor_with_rowid_and_tz_naive_datetime_deserializes_original( + def test_to_and_from_string_with_rowid_and_tz_naive_datetime_deserializes_original( self, ) -> None: timestamp = datetime.fromisoformat("2024-05-05T04:25:29.911245") - original = NodeIdentifier( + original = Cursor( rowid=10, sortable_field=SortableField(type=SortableFieldType.DATETIME, value=timestamp), ) - cursor = original.to_cursor() - deserialized = NodeIdentifier.from_cursor(cursor) + cursor_string = str(original) + deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 assert (sortable_field := deserialized.sortable_field) is not None assert sortable_field.type == SortableFieldType.DATETIME assert sortable_field.value == timestamp assert sortable_field.value.tzinfo is None - def test_to_and_from_cursor_with_rowid_and_tz_aware_datetime_deserializes_original( + def test_to_and_from_string_with_rowid_and_tz_aware_datetime_deserializes_original( self, ) -> None: timestamp = datetime.fromisoformat("2024-05-05T04:25:29.911245+00:00") - original = NodeIdentifier( + original = Cursor( rowid=10, sortable_field=SortableField(type=SortableFieldType.DATETIME, value=timestamp), ) - cursor = original.to_cursor() - deserialized = NodeIdentifier.from_cursor(cursor) + cursor_string = str(original) + deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 assert (sortable_field := deserialized.sortable_field) is not None assert sortable_field.type == SortableFieldType.DATETIME From 4a20bfc61045137e5fc5f0dda2a7e33555bee348 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 6 May 2024 23:10:10 -0700 Subject: [PATCH 58/74] more refactoring of pagination types --- .../pagination_query_testing.ipynb | 60 ++++++++---------- .../server/api/input_types/SpanSort.py | 18 +++--- src/phoenix/server/api/types/Project.py | 14 ++--- src/phoenix/server/api/types/pagination.py | 46 +++++++------- tests/server/api/types/test_pagination.py | 62 +++++++++---------- 5 files changed, 95 insertions(+), 105 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 342e597e77..c67830d458 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -10,8 +10,8 @@ "from gql.transport.requests import RequestsHTTPTransport\n", "from phoenix.server.api.types.pagination import (\n", " Cursor,\n", - " SortableField,\n", - " SortableFieldType,\n", + " SortColumn,\n", + " SortColumnDataType,\n", ")\n", "\n", "project_id = \"UHJvamVjdDox\"\n", @@ -324,9 +324,9 @@ "end_node_identifier = Cursor.from_string(end_cursor)\n", "assert ids == [765, 764, 763, 762, 761], ids\n", "assert end_node_identifier.rowid == 761\n", - "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", + "assert (end_sort_column := end_node_identifier.sort_column) is not None\n", "assert (\n", - " end_node_start_timestamp := end_sortable_field.value.isoformat()\n", + " end_node_start_timestamp := end_sort_column.value.isoformat()\n", ") == \"2023-12-11T17:48:40.807667+00:00\", end_node_start_timestamp" ] }, @@ -351,9 +351,9 @@ "end_node_identifier = Cursor.from_string(end_cursor)\n", "assert ids == [1, 2, 3, 4, 5], ids\n", "assert end_node_identifier.rowid == 5\n", - "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", + "assert (end_sort_column := end_node_identifier.sort_column) is not None\n", "assert (\n", - " end_node_start_timestamp := end_sortable_field.value.isoformat()\n", + " end_node_start_timestamp := end_sort_column.value.isoformat()\n", ") == \"2023-12-11T17:43:23.712144+00:00\", end_node_start_timestamp" ] }, @@ -419,8 +419,8 @@ " \"after\": str(\n", " Cursor(\n", " 760,\n", - " sortable_field=SortableField.from_stringified_value(\n", - " type=SortableFieldType.DATETIME,\n", + " sort_column=SortColumn.from_string(\n", + " type=SortColumnDataType.DATETIME,\n", " stringified_value=\"2023-12-11T17:48:40.154938+00:00\",\n", " ),\n", " )\n", @@ -434,18 +434,14 @@ "start_node_identifier = Cursor.from_string(start_cursor)\n", "end_node_identifier = Cursor.from_string(end_cursor)\n", "assert ids == [759, 758, 757, 756, 755], ids\n", - "assert (start_sortable_field := start_node_identifier.sortable_field) is not None\n", + "assert (start_sort_column := start_node_identifier.sort_column) is not None\n", "assert (\n", - " start_timestamp := start_sortable_field.stringify_value()\n", + " start_timestamp := str(start_sort_column)\n", ") == \"2023-12-11T17:48:40.154139+00:00\", start_timestamp\n", - "assert (\n", - " start_field_type := start_sortable_field.type\n", - ") == SortableFieldType.DATETIME, start_field_type\n", - "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", - "assert (\n", - " end_timestamp := end_sortable_field.stringify_value()\n", - ") == \"2023-12-11T17:48:38.803725+00:00\", end_timestamp\n", - "assert (end_field_type := end_sortable_field.type) == SortableFieldType.DATETIME, end_field_type" + "assert (start_field_type := start_sort_column.type) == SortColumnDataType.DATETIME, start_field_type\n", + "assert (end_sort_column := end_node_identifier.sort_column) is not None\n", + "assert (end_timestamp := str(end_sort_column)) == \"2023-12-11T17:48:38.803725+00:00\", end_timestamp\n", + "assert (end_field_type := end_sort_column.type) == SortColumnDataType.DATETIME, end_field_type" ] }, { @@ -464,8 +460,8 @@ " \"after\": str(\n", " Cursor(\n", " 8,\n", - " sortable_field=SortableField.from_stringified_value(\n", - " type=SortableFieldType.DATETIME,\n", + " sort_column=SortColumn.from_string(\n", + " type=SortColumnDataType.DATETIME,\n", " stringified_value=\"2023-12-11T17:43:25.540677+00:00\",\n", " ),\n", " )\n", @@ -479,18 +475,14 @@ "start_node_identifier = Cursor.from_string(start_cursor)\n", "end_node_identifier = Cursor.from_string(end_cursor)\n", "assert ids == [9, 10, 11, 12, 13], ids\n", - "assert (start_sortable_field := start_node_identifier.sortable_field) is not None\n", + "assert (start_sort_column := start_node_identifier.sort_column) is not None\n", "assert (\n", - " start_timestamp := start_sortable_field.stringify_value()\n", + " start_timestamp := str(start_sort_column)\n", ") == \"2023-12-11T17:43:25.842986+00:00\", start_timestamp\n", - "assert (\n", - " start_field_type := start_sortable_field.type\n", - ") == SortableFieldType.DATETIME, start_field_type\n", - "assert (end_sortable_field := end_node_identifier.sortable_field) is not None\n", - "assert (\n", - " end_timestamp := end_sortable_field.stringify_value()\n", - ") == \"2023-12-11T17:43:26.496177+00:00\", end_timestamp\n", - "assert (end_field_type := end_sortable_field.type) == SortableFieldType.DATETIME, end_field_type" + "assert (start_field_type := start_sort_column.type) == SortColumnDataType.DATETIME, start_field_type\n", + "assert (end_sort_column := end_node_identifier.sort_column) is not None\n", + "assert (end_timestamp := str(end_sort_column)) == \"2023-12-11T17:43:26.496177+00:00\", end_timestamp\n", + "assert (end_field_type := end_sort_column.type) == SortColumnDataType.DATETIME, end_field_type" ] }, { @@ -509,7 +501,7 @@ " \"after\": str(\n", " Cursor(\n", " rowid=644, # row 644 is in between rows 645 and 641, which also have 1054 cumulative prompt tokens\n", - " sortable_field=SortableField(type=SortableFieldType.FLOAT, value=1054),\n", + " sort_column=SortColumn(type=SortColumnDataType.FLOAT, value=1054),\n", " )\n", " ),\n", " },\n", @@ -541,7 +533,7 @@ " \"after\": str(\n", " Cursor(\n", " rowid=294, # row 294 is in between rows 295 and 291, which also have 276 cumulative prompt tokens\n", - " sortable_field=SortableField(type=SortableFieldType.INT, value=276),\n", + " sort_column=SortColumn(type=SortColumnDataType.INT, value=276),\n", " )\n", " ),\n", " },\n", @@ -622,7 +614,7 @@ " \"after\": str(\n", " Cursor(\n", " rowid=141, # row 141 is surrounded by many other hallucinations\n", - " sortable_field=SortableField(type=SortableFieldType.STRING, value=\"hallucinated\"),\n", + " sort_column=SortColumn(type=SortColumnDataType.STRING, value=\"hallucinated\"),\n", " )\n", " ),\n", " },\n", @@ -657,7 +649,7 @@ " \"after\": str(\n", " Cursor(\n", " rowid=731, # row 746 is surrounded by many other hallucinations\n", - " sortable_field=SortableField(type=SortableFieldType.STRING, value=\"factual\"),\n", + " sort_column=SortColumn(type=SortColumnDataType.STRING, value=\"factual\"),\n", " )\n", " ),\n", " },\n", diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index 0e9e7b2277..84ca50016b 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -12,7 +12,7 @@ import phoenix.trace.v1 as pb from phoenix.db import models -from phoenix.server.api.types.pagination import SortableFieldType +from phoenix.server.api.types.pagination import SortColumnDataType from phoenix.server.api.types.SortDir import SortDir from phoenix.trace.schemas import SpanID @@ -63,22 +63,22 @@ def orm_expression(self) -> Any: assert_never(self) @property - def data_type(self) -> SortableFieldType: + def data_type(self) -> SortColumnDataType: if ( self is SpanColumn.cumulativeTokenCountTotal or self is SpanColumn.cumulativeTokenCountPrompt or self is SpanColumn.cumulativeTokenCountCompletion ): - return SortableFieldType.INT + return SortColumnDataType.INT if ( self is SpanColumn.latencyMs or self is SpanColumn.tokenCountTotal or self is SpanColumn.tokenCountPrompt or self is SpanColumn.tokenCountCompletion ): - return SortableFieldType.FLOAT + return SortColumnDataType.FLOAT if self is SpanColumn.startTime or self is SpanColumn.endTime: - return SortableFieldType.DATETIME + return SortColumnDataType.DATETIME assert_never(self) @@ -103,11 +103,11 @@ def orm_expression(self) -> Any: return expr.label(self.orm_key) @property - def data_type(self) -> SortableFieldType: + def data_type(self) -> SortColumnDataType: if self is EvalAttr.label: - return SortableFieldType.STRING + return SortColumnDataType.STRING if self is EvalAttr.score: - return SortableFieldType.FLOAT + return SortColumnDataType.FLOAT assert_never(self) @@ -126,7 +126,7 @@ class SpanSortResult: stmt: Select[Any] orm_key: str orm_expression: Any - data_type: SortableFieldType + data_type: SortColumnDataType @strawberry.input( diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index a3ad54649a..66988819b2 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -22,7 +22,7 @@ Connection, Cursor, CursorString, - SortableField, + SortColumn, connections, ) from phoenix.server.api.types.SortDir import SortDir @@ -188,22 +188,22 @@ async def spans( if filter_condition: span_filter = SpanFilter(condition=filter_condition) stmt = span_filter(stmt) - sortable_field: Optional[SortableField] = None + sort_column: Optional[SortColumn] = None sort_result: Optional[SpanSortResult] = None if sort: sort_result = sort.update_orm_expr(stmt) stmt = sort_result.stmt if after: node_identifier = Cursor.from_string(after) - if node_identifier.sortable_field is not None: - sortable_field = node_identifier.sortable_field + if node_identifier.sort_column is not None: + sort_column = node_identifier.sort_column assert sort is not None # todo: refactor this into a validation check compare = operator.lt if sort.dir is SortDir.desc else operator.gt if sort_result: stmt = stmt.where( compare( tuple_(sort_result.orm_expression, models.Span.id), - (sortable_field.value, node_identifier.rowid), + (sort_column.value, node_identifier.rowid), ) ) else: @@ -221,8 +221,8 @@ async def spans( eval_value = row[1] if len(row) > 1 else None node_identifier = Cursor( rowid=span.id, - sortable_field=( - SortableField( + sort_column=( + SortColumn( type=sort_result.data_type, value=eval_value if eval_value is not None diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index 37458a57a2..0aff6bc066 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -10,7 +10,7 @@ ID: TypeAlias = int GenericType = TypeVar("GenericType") -SortableFieldValue: TypeAlias = Union[str, int, float, datetime] +SortColumnValue: TypeAlias = Union[str, int, float, datetime] @strawberry.type @@ -60,7 +60,7 @@ class Edge(Generic[GenericType]): CURSOR_PREFIX = "connection:" -class SortableFieldType(Enum): +class SortColumnDataType(Enum): STRING = auto() INT = auto() FLOAT = auto() @@ -68,11 +68,11 @@ class SortableFieldType(Enum): @dataclass -class SortableField: - type: SortableFieldType - value: SortableFieldValue +class SortColumn: + type: SortColumnDataType + value: SortColumnValue - def stringify_value(self) -> str: + def __str__(self) -> str: if isinstance(self.value, str): return self.value if isinstance(self.value, (int, float)): @@ -82,17 +82,15 @@ def stringify_value(self) -> str: assert_never(self.type) @classmethod - def from_stringified_value( - cls, type: SortableFieldType, stringified_value: str - ) -> "SortableField": - value: SortableFieldValue - if type is SortableFieldType.STRING: + def from_string(cls, type: SortColumnDataType, stringified_value: str) -> "SortColumn": + value: SortColumnValue + if type is SortColumnDataType.STRING: value = stringified_value - elif type is SortableFieldType.INT: + elif type is SortColumnDataType.INT: value = int(stringified_value) - elif type is SortableFieldType.FLOAT: + elif type is SortColumnDataType.FLOAT: value = float(stringified_value) - elif type is SortableFieldType.DATETIME: + elif type is SortColumnDataType.DATETIME: value = datetime.fromisoformat(stringified_value) else: assert_never(type) @@ -102,29 +100,31 @@ def from_stringified_value( @dataclass class Cursor: rowid: int - sortable_field: Optional[SortableField] = None + sort_column: Optional[SortColumn] = None _DELIMITER: ClassVar[str] = ":" def __str__(self) -> CursorString: - cursor_components = [str(self.rowid)] - if (sortable_field := self.sortable_field) is not None: - cursor_components.extend([sortable_field.type.name, sortable_field.stringify_value()]) - return base64.b64encode(self._DELIMITER.join(cursor_components).encode()).decode() + cursor_parts = [str(self.rowid)] + if (sort_column := self.sort_column) is not None: + cursor_parts.extend([sort_column.type.name, str(sort_column)]) + return base64.b64encode(self._DELIMITER.join(cursor_parts).encode()).decode() @classmethod def from_string(cls, cursor: CursorString) -> "Cursor": decoded = base64.b64decode(cursor).decode() rowid_string = decoded - sortable_field = None + sort_column = None if (first_delimiter_index := decoded.find(cls._DELIMITER)) > -1: rowid_string = decoded[:first_delimiter_index] second_delimiter_index = decoded.index(cls._DELIMITER, first_delimiter_index + 1) - sortable_field = SortableField.from_stringified_value( - type=SortableFieldType[decoded[first_delimiter_index + 1 : second_delimiter_index]], + sort_column = SortColumn.from_string( + type=SortColumnDataType[ + decoded[first_delimiter_index + 1 : second_delimiter_index] + ], stringified_value=decoded[second_delimiter_index + 1 :], ) - return cls(rowid=int(rowid_string), sortable_field=sortable_field) + return cls(rowid=int(rowid_string), sort_column=sort_column) def offset_to_cursor(offset: int) -> CursorString: diff --git a/tests/server/api/types/test_pagination.py b/tests/server/api/types/test_pagination.py index 8e2e91c0b7..f76f1e8b53 100644 --- a/tests/server/api/types/test_pagination.py +++ b/tests/server/api/types/test_pagination.py @@ -6,8 +6,8 @@ from phoenix.server.api.types.pagination import ( ConnectionArgs, Cursor, - SortableField, - SortableFieldType, + SortColumn, + SortColumnDataType, connection_from_list, ) @@ -108,64 +108,62 @@ def test_to_and_from_string_with_rowid_deserializes_original(self) -> None: cursor = str(original) deserialized = Cursor.from_string(cursor) assert deserialized.rowid == 10 - assert deserialized.sortable_field is None + assert deserialized.sort_column is None def test_to_and_from_string_with_rowid_and_string_deserializes_original( self, ) -> None: original = Cursor( - rowid=10, sortable_field=SortableField(type=SortableFieldType.STRING, value="abc") + rowid=10, sort_column=SortColumn(type=SortColumnDataType.STRING, value="abc") ) cursor_string = str(original) deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 - assert (sortable_field := deserialized.sortable_field) is not None - assert sortable_field.type == SortableFieldType.STRING - assert sortable_field.value == "abc" + assert (sort_column := deserialized.sort_column) is not None + assert sort_column.type == SortColumnDataType.STRING + assert sort_column.value == "abc" def test_to_and_from_string_with_rowid_and_int_deserializes_original( self, ) -> None: - original = Cursor( - rowid=10, sortable_field=SortableField(type=SortableFieldType.INT, value=11) - ) + original = Cursor(rowid=10, sort_column=SortColumn(type=SortColumnDataType.INT, value=11)) cursor_string = str(original) deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 - assert (sortable_field := deserialized.sortable_field) is not None - assert sortable_field.type == SortableFieldType.INT - assert isinstance((value := sortable_field.value), int) + assert (sort_column := deserialized.sort_column) is not None + assert sort_column.type == SortColumnDataType.INT + assert isinstance((value := sort_column.value), int) assert value == 11 def test_to_and_from_string_with_rowid_and_float_deserializes_original( self, ) -> None: original = Cursor( - rowid=10, sortable_field=SortableField(type=SortableFieldType.FLOAT, value=11.5) + rowid=10, sort_column=SortColumn(type=SortColumnDataType.FLOAT, value=11.5) ) cursor_string = str(original) deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 - assert (sortable_field := deserialized.sortable_field) is not None - assert sortable_field.type == SortableFieldType.FLOAT - assert abs(sortable_field.value - 11.5) < 1e-8 + assert (sort_column := deserialized.sort_column) is not None + assert sort_column.type == SortColumnDataType.FLOAT + assert abs(sort_column.value - 11.5) < 1e-8 def test_to_and_from_string_with_rowid_and_float_passed_as_int_deserializes_original_as_float( self, ) -> None: original = Cursor( rowid=10, - sortable_field=SortableField( - type=SortableFieldType.FLOAT, + sort_column=SortColumn( + type=SortColumnDataType.FLOAT, value=11, # an integer value ), ) cursor_string = str(original) deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 - assert (sortable_field := deserialized.sortable_field) is not None - assert sortable_field.type == SortableFieldType.FLOAT - assert isinstance((value := sortable_field.value), float) + assert (sort_column := deserialized.sort_column) is not None + assert sort_column.type == SortColumnDataType.FLOAT + assert isinstance((value := sort_column.value), float) assert abs(value - 11.0) < 1e-8 def test_to_and_from_string_with_rowid_and_tz_naive_datetime_deserializes_original( @@ -174,15 +172,15 @@ def test_to_and_from_string_with_rowid_and_tz_naive_datetime_deserializes_origin timestamp = datetime.fromisoformat("2024-05-05T04:25:29.911245") original = Cursor( rowid=10, - sortable_field=SortableField(type=SortableFieldType.DATETIME, value=timestamp), + sort_column=SortColumn(type=SortColumnDataType.DATETIME, value=timestamp), ) cursor_string = str(original) deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 - assert (sortable_field := deserialized.sortable_field) is not None - assert sortable_field.type == SortableFieldType.DATETIME - assert sortable_field.value == timestamp - assert sortable_field.value.tzinfo is None + assert (sort_column := deserialized.sort_column) is not None + assert sort_column.type == SortColumnDataType.DATETIME + assert sort_column.value == timestamp + assert sort_column.value.tzinfo is None def test_to_and_from_string_with_rowid_and_tz_aware_datetime_deserializes_original( self, @@ -190,12 +188,12 @@ def test_to_and_from_string_with_rowid_and_tz_aware_datetime_deserializes_origin timestamp = datetime.fromisoformat("2024-05-05T04:25:29.911245+00:00") original = Cursor( rowid=10, - sortable_field=SortableField(type=SortableFieldType.DATETIME, value=timestamp), + sort_column=SortColumn(type=SortColumnDataType.DATETIME, value=timestamp), ) cursor_string = str(original) deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 - assert (sortable_field := deserialized.sortable_field) is not None - assert sortable_field.type == SortableFieldType.DATETIME - assert sortable_field.value == timestamp - assert sortable_field.value.tzinfo is not None + assert (sort_column := deserialized.sort_column) is not None + assert sort_column.type == SortColumnDataType.DATETIME + assert sort_column.value == timestamp + assert sort_column.value.tzinfo is not None From 711ed68bfd460b7bf197a88eb7a1f1f0b982c04f Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 6 May 2024 23:11:28 -0700 Subject: [PATCH 59/74] more renaming of pagination types --- .../pagination_query_testing.ipynb | 32 +++++++++---------- src/phoenix/server/api/types/Project.py | 14 ++++---- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index c67830d458..cebb51e9fd 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -28,8 +28,8 @@ "outputs": [], "source": [ "cursor = \"MTAwOkRBVEVUSU1FOjIwMjMtMTItMTFUMTc6NDQ6MDIuNTM0MTI5KzAwOjAw\"\n", - "node_identifier = Cursor.from_string(cursor)\n", - "print(node_identifier)" + "cursor = Cursor.from_string(cursor)\n", + "print(cursor)" ] }, { @@ -321,10 +321,10 @@ "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", - "end_node_identifier = Cursor.from_string(end_cursor)\n", + "end_cursor = Cursor.from_string(end_cursor)\n", "assert ids == [765, 764, 763, 762, 761], ids\n", - "assert end_node_identifier.rowid == 761\n", - "assert (end_sort_column := end_node_identifier.sort_column) is not None\n", + "assert end_cursor.rowid == 761\n", + "assert (end_sort_column := end_cursor.sort_column) is not None\n", "assert (\n", " end_node_start_timestamp := end_sort_column.value.isoformat()\n", ") == \"2023-12-11T17:48:40.807667+00:00\", end_node_start_timestamp" @@ -348,10 +348,10 @@ "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", - "end_node_identifier = Cursor.from_string(end_cursor)\n", + "end_cursor = Cursor.from_string(end_cursor)\n", "assert ids == [1, 2, 3, 4, 5], ids\n", - "assert end_node_identifier.rowid == 5\n", - "assert (end_sort_column := end_node_identifier.sort_column) is not None\n", + "assert end_cursor.rowid == 5\n", + "assert (end_sort_column := end_cursor.sort_column) is not None\n", "assert (\n", " end_node_start_timestamp := end_sort_column.value.isoformat()\n", ") == \"2023-12-11T17:43:23.712144+00:00\", end_node_start_timestamp" @@ -431,15 +431,15 @@ "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "start_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"startCursor\"]\n", "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", - "start_node_identifier = Cursor.from_string(start_cursor)\n", - "end_node_identifier = Cursor.from_string(end_cursor)\n", + "start_cursor = Cursor.from_string(start_cursor)\n", + "end_cursor = Cursor.from_string(end_cursor)\n", "assert ids == [759, 758, 757, 756, 755], ids\n", - "assert (start_sort_column := start_node_identifier.sort_column) is not None\n", + "assert (start_sort_column := start_cursor.sort_column) is not None\n", "assert (\n", " start_timestamp := str(start_sort_column)\n", ") == \"2023-12-11T17:48:40.154139+00:00\", start_timestamp\n", "assert (start_field_type := start_sort_column.type) == SortColumnDataType.DATETIME, start_field_type\n", - "assert (end_sort_column := end_node_identifier.sort_column) is not None\n", + "assert (end_sort_column := end_cursor.sort_column) is not None\n", "assert (end_timestamp := str(end_sort_column)) == \"2023-12-11T17:48:38.803725+00:00\", end_timestamp\n", "assert (end_field_type := end_sort_column.type) == SortColumnDataType.DATETIME, end_field_type" ] @@ -472,15 +472,15 @@ "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "start_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"startCursor\"]\n", "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", - "start_node_identifier = Cursor.from_string(start_cursor)\n", - "end_node_identifier = Cursor.from_string(end_cursor)\n", + "start_cursor = Cursor.from_string(start_cursor)\n", + "end_cursor = Cursor.from_string(end_cursor)\n", "assert ids == [9, 10, 11, 12, 13], ids\n", - "assert (start_sort_column := start_node_identifier.sort_column) is not None\n", + "assert (start_sort_column := start_cursor.sort_column) is not None\n", "assert (\n", " start_timestamp := str(start_sort_column)\n", ") == \"2023-12-11T17:43:25.842986+00:00\", start_timestamp\n", "assert (start_field_type := start_sort_column.type) == SortColumnDataType.DATETIME, start_field_type\n", - "assert (end_sort_column := end_node_identifier.sort_column) is not None\n", + "assert (end_sort_column := end_cursor.sort_column) is not None\n", "assert (end_timestamp := str(end_sort_column)) == \"2023-12-11T17:43:26.496177+00:00\", end_timestamp\n", "assert (end_field_type := end_sort_column.type) == SortColumnDataType.DATETIME, end_field_type" ] diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 66988819b2..f465f83384 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -194,20 +194,20 @@ async def spans( sort_result = sort.update_orm_expr(stmt) stmt = sort_result.stmt if after: - node_identifier = Cursor.from_string(after) - if node_identifier.sort_column is not None: - sort_column = node_identifier.sort_column + cursor = Cursor.from_string(after) + if cursor.sort_column is not None: + sort_column = cursor.sort_column assert sort is not None # todo: refactor this into a validation check compare = operator.lt if sort.dir is SortDir.desc else operator.gt if sort_result: stmt = stmt.where( compare( tuple_(sort_result.orm_expression, models.Span.id), - (sort_column.value, node_identifier.rowid), + (sort_column.value, cursor.rowid), ) ) else: - stmt = stmt.where(models.Span.id < node_identifier.rowid) + stmt = stmt.where(models.Span.id < cursor.rowid) if first: stmt = stmt.limit( first + 1 # overfetch by one to determine whether there's a next page @@ -219,7 +219,7 @@ async def spans( async for row in islice(rows, first): span = row[0] eval_value = row[1] if len(row) > 1 else None - node_identifier = Cursor( + cursor = Cursor( rowid=span.id, sort_column=( SortColumn( @@ -232,7 +232,7 @@ async def spans( else None ), ) - data.append((node_identifier, to_gql_span(span))) + data.append((cursor, to_gql_span(span))) # todo: does this need to be inside the async with block? has_next_page = True try: From e4e00aa5d0fe8ca9dba4f88dd8c82a03dedd4094 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 6 May 2024 23:34:07 -0700 Subject: [PATCH 60/74] rename more variables for clarity --- .../server/api/input_types/SpanSort.py | 20 ++++++------- src/phoenix/server/api/types/Project.py | 28 +++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index 84ca50016b..b057439330 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -34,7 +34,7 @@ class SpanColumn(Enum): cumulativeTokenCountCompletion = auto() @property - def orm_key(self) -> str: + def column_name(self) -> str: return cast(str, self.orm_expression.name) @property @@ -88,7 +88,7 @@ class EvalAttr(Enum): label = "label" @property - def orm_key(self) -> str: + def column_name(self) -> str: return f"span_annotations_{self.value}" @property @@ -100,7 +100,7 @@ def orm_expression(self) -> Any: expr = models.SpanAnnotation.label else: assert_never(self) - return expr.label(self.orm_key) + return expr.label(self.column_name) @property def data_type(self) -> SortColumnDataType: @@ -122,9 +122,9 @@ def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluat @dataclass(frozen=True) -class SpanSortResult: +class SpanSortConfig: stmt: Select[Any] - orm_key: str + column_name: str orm_expression: Any data_type: SortColumnDataType @@ -138,14 +138,14 @@ class SpanSort: eval_result_key: Optional[EvalResultKey] = UNSET dir: SortDir - def update_orm_expr(self, stmt: Select[Any]) -> SpanSortResult: + def update_orm_expr(self, stmt: Select[Any]) -> SpanSortConfig: if (col := self.col) and not self.eval_result_key: expr = col.orm_expression if self.dir == SortDir.desc: expr = desc(expr) - return SpanSortResult( + return SpanSortConfig( stmt=stmt.order_by(nulls_last(expr)), - orm_key=col.orm_key, + column_name=col.column_name, orm_expression=col.orm_expression, data_type=col.data_type, ) @@ -163,9 +163,9 @@ def update_orm_expr(self, stmt: Select[Any]) -> SpanSortResult: models.SpanAnnotation.name == eval_name, ), ).order_by(expr) - return SpanSortResult( + return SpanSortConfig( stmt=stmt, - orm_key=eval_attr.orm_key, + column_name=eval_attr.column_name, orm_expression=eval_result_key.attr.orm_expression, data_type=eval_attr.data_type, ) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index f465f83384..8eddb4717d 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -13,7 +13,7 @@ from phoenix.datetime_utils import right_open_time_range from phoenix.db import models from phoenix.server.api.context import Context -from phoenix.server.api.input_types.SpanSort import SpanSort, SpanSortResult +from phoenix.server.api.input_types.SpanSort import SpanSort, SpanSortConfig from phoenix.server.api.input_types.TimeRange import TimeRange from phoenix.server.api.types.DocumentEvaluationSummary import DocumentEvaluationSummary from phoenix.server.api.types.EvaluationSummary import EvaluationSummary @@ -189,20 +189,20 @@ async def spans( span_filter = SpanFilter(condition=filter_condition) stmt = span_filter(stmt) sort_column: Optional[SortColumn] = None - sort_result: Optional[SpanSortResult] = None + sort_config: Optional[SpanSortConfig] = None if sort: - sort_result = sort.update_orm_expr(stmt) - stmt = sort_result.stmt + sort_config = sort.update_orm_expr(stmt) + stmt = sort_config.stmt if after: cursor = Cursor.from_string(after) if cursor.sort_column is not None: sort_column = cursor.sort_column assert sort is not None # todo: refactor this into a validation check compare = operator.lt if sort.dir is SortDir.desc else operator.gt - if sort_result: + if sort_config: stmt = stmt.where( compare( - tuple_(sort_result.orm_expression, models.Span.id), + tuple_(sort_config.orm_expression, models.Span.id), (sort_column.value, cursor.rowid), ) ) @@ -215,20 +215,20 @@ async def spans( stmt = stmt.order_by(desc(models.Span.id)) data = [] async with info.context.db() as session: - rows = await session.execute(stmt) - async for row in islice(rows, first): - span = row[0] - eval_value = row[1] if len(row) > 1 else None + span_records = await session.execute(stmt) + async for span_record in islice(span_records, first): + span = span_record[0] + eval_value = span_record[1] if len(span_record) > 1 else None cursor = Cursor( rowid=span.id, sort_column=( SortColumn( - type=sort_result.data_type, + type=sort_config.data_type, value=eval_value if eval_value is not None - else getattr(span, sort_result.orm_key), + else getattr(span, sort_config.column_name), ) - if sort_result + if sort_config else None ), ) @@ -236,7 +236,7 @@ async def spans( # todo: does this need to be inside the async with block? has_next_page = True try: - next(rows) + next(span_records) except StopIteration: has_next_page = False From 3c462b025248c525f98a0c804165d0c4d52d5b78 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 6 May 2024 23:42:02 -0700 Subject: [PATCH 61/74] rename pagination variables again --- .../server/api/input_types/SpanSort.py | 18 +++++----- src/phoenix/server/api/types/Project.py | 7 ++-- src/phoenix/server/api/types/pagination.py | 30 +++++++++-------- tests/server/api/types/test_pagination.py | 33 ++++++++++--------- 4 files changed, 46 insertions(+), 42 deletions(-) diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index b057439330..69ad65ee73 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -12,7 +12,7 @@ import phoenix.trace.v1 as pb from phoenix.db import models -from phoenix.server.api.types.pagination import SortColumnDataType +from phoenix.server.api.types.pagination import CursorSortColumnDataType from phoenix.server.api.types.SortDir import SortDir from phoenix.trace.schemas import SpanID @@ -63,22 +63,22 @@ def orm_expression(self) -> Any: assert_never(self) @property - def data_type(self) -> SortColumnDataType: + def data_type(self) -> CursorSortColumnDataType: if ( self is SpanColumn.cumulativeTokenCountTotal or self is SpanColumn.cumulativeTokenCountPrompt or self is SpanColumn.cumulativeTokenCountCompletion ): - return SortColumnDataType.INT + return CursorSortColumnDataType.INT if ( self is SpanColumn.latencyMs or self is SpanColumn.tokenCountTotal or self is SpanColumn.tokenCountPrompt or self is SpanColumn.tokenCountCompletion ): - return SortColumnDataType.FLOAT + return CursorSortColumnDataType.FLOAT if self is SpanColumn.startTime or self is SpanColumn.endTime: - return SortColumnDataType.DATETIME + return CursorSortColumnDataType.DATETIME assert_never(self) @@ -103,11 +103,11 @@ def orm_expression(self) -> Any: return expr.label(self.column_name) @property - def data_type(self) -> SortColumnDataType: + def data_type(self) -> CursorSortColumnDataType: if self is EvalAttr.label: - return SortColumnDataType.STRING + return CursorSortColumnDataType.STRING if self is EvalAttr.score: - return SortColumnDataType.FLOAT + return CursorSortColumnDataType.FLOAT assert_never(self) @@ -126,7 +126,7 @@ class SpanSortConfig: stmt: Select[Any] column_name: str orm_expression: Any - data_type: SortColumnDataType + data_type: CursorSortColumnDataType @strawberry.input( diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 8eddb4717d..3bd8b5d2b3 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -21,8 +21,8 @@ from phoenix.server.api.types.pagination import ( Connection, Cursor, + CursorSortColumn, CursorString, - SortColumn, connections, ) from phoenix.server.api.types.SortDir import SortDir @@ -188,7 +188,7 @@ async def spans( if filter_condition: span_filter = SpanFilter(condition=filter_condition) stmt = span_filter(stmt) - sort_column: Optional[SortColumn] = None + sort_column: Optional[CursorSortColumn] = None sort_config: Optional[SpanSortConfig] = None if sort: sort_config = sort.update_orm_expr(stmt) @@ -222,7 +222,7 @@ async def spans( cursor = Cursor( rowid=span.id, sort_column=( - SortColumn( + CursorSortColumn( type=sort_config.data_type, value=eval_value if eval_value is not None @@ -233,7 +233,6 @@ async def spans( ), ) data.append((cursor, to_gql_span(span))) - # todo: does this need to be inside the async with block? has_next_page = True try: next(span_records) diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index 0aff6bc066..a4de67cad5 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -10,7 +10,7 @@ ID: TypeAlias = int GenericType = TypeVar("GenericType") -SortColumnValue: TypeAlias = Union[str, int, float, datetime] +CursorSortColumnValue: TypeAlias = Union[str, int, float, datetime] @strawberry.type @@ -60,7 +60,7 @@ class Edge(Generic[GenericType]): CURSOR_PREFIX = "connection:" -class SortColumnDataType(Enum): +class CursorSortColumnDataType(Enum): STRING = auto() INT = auto() FLOAT = auto() @@ -68,9 +68,9 @@ class SortColumnDataType(Enum): @dataclass -class SortColumn: - type: SortColumnDataType - value: SortColumnValue +class CursorSortColumn: + type: CursorSortColumnDataType + value: CursorSortColumnValue def __str__(self) -> str: if isinstance(self.value, str): @@ -82,15 +82,17 @@ def __str__(self) -> str: assert_never(self.type) @classmethod - def from_string(cls, type: SortColumnDataType, stringified_value: str) -> "SortColumn": - value: SortColumnValue - if type is SortColumnDataType.STRING: + def from_string( + cls, type: CursorSortColumnDataType, stringified_value: str + ) -> "CursorSortColumn": + value: CursorSortColumnValue + if type is CursorSortColumnDataType.STRING: value = stringified_value - elif type is SortColumnDataType.INT: + elif type is CursorSortColumnDataType.INT: value = int(stringified_value) - elif type is SortColumnDataType.FLOAT: + elif type is CursorSortColumnDataType.FLOAT: value = float(stringified_value) - elif type is SortColumnDataType.DATETIME: + elif type is CursorSortColumnDataType.DATETIME: value = datetime.fromisoformat(stringified_value) else: assert_never(type) @@ -100,7 +102,7 @@ def from_string(cls, type: SortColumnDataType, stringified_value: str) -> "SortC @dataclass class Cursor: rowid: int - sort_column: Optional[SortColumn] = None + sort_column: Optional[CursorSortColumn] = None _DELIMITER: ClassVar[str] = ":" @@ -118,8 +120,8 @@ def from_string(cls, cursor: CursorString) -> "Cursor": if (first_delimiter_index := decoded.find(cls._DELIMITER)) > -1: rowid_string = decoded[:first_delimiter_index] second_delimiter_index = decoded.index(cls._DELIMITER, first_delimiter_index + 1) - sort_column = SortColumn.from_string( - type=SortColumnDataType[ + sort_column = CursorSortColumn.from_string( + type=CursorSortColumnDataType[ decoded[first_delimiter_index + 1 : second_delimiter_index] ], stringified_value=decoded[second_delimiter_index + 1 :], diff --git a/tests/server/api/types/test_pagination.py b/tests/server/api/types/test_pagination.py index f76f1e8b53..edf2bd9b3e 100644 --- a/tests/server/api/types/test_pagination.py +++ b/tests/server/api/types/test_pagination.py @@ -6,8 +6,8 @@ from phoenix.server.api.types.pagination import ( ConnectionArgs, Cursor, - SortColumn, - SortColumnDataType, + CursorSortColumn, + CursorSortColumnDataType, connection_from_list, ) @@ -114,24 +114,27 @@ def test_to_and_from_string_with_rowid_and_string_deserializes_original( self, ) -> None: original = Cursor( - rowid=10, sort_column=SortColumn(type=SortColumnDataType.STRING, value="abc") + rowid=10, + sort_column=CursorSortColumn(type=CursorSortColumnDataType.STRING, value="abc"), ) cursor_string = str(original) deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 assert (sort_column := deserialized.sort_column) is not None - assert sort_column.type == SortColumnDataType.STRING + assert sort_column.type == CursorSortColumnDataType.STRING assert sort_column.value == "abc" def test_to_and_from_string_with_rowid_and_int_deserializes_original( self, ) -> None: - original = Cursor(rowid=10, sort_column=SortColumn(type=SortColumnDataType.INT, value=11)) + original = Cursor( + rowid=10, sort_column=CursorSortColumn(type=CursorSortColumnDataType.INT, value=11) + ) cursor_string = str(original) deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 assert (sort_column := deserialized.sort_column) is not None - assert sort_column.type == SortColumnDataType.INT + assert sort_column.type == CursorSortColumnDataType.INT assert isinstance((value := sort_column.value), int) assert value == 11 @@ -139,13 +142,13 @@ def test_to_and_from_string_with_rowid_and_float_deserializes_original( self, ) -> None: original = Cursor( - rowid=10, sort_column=SortColumn(type=SortColumnDataType.FLOAT, value=11.5) + rowid=10, sort_column=CursorSortColumn(type=CursorSortColumnDataType.FLOAT, value=11.5) ) cursor_string = str(original) deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 assert (sort_column := deserialized.sort_column) is not None - assert sort_column.type == SortColumnDataType.FLOAT + assert sort_column.type == CursorSortColumnDataType.FLOAT assert abs(sort_column.value - 11.5) < 1e-8 def test_to_and_from_string_with_rowid_and_float_passed_as_int_deserializes_original_as_float( @@ -153,8 +156,8 @@ def test_to_and_from_string_with_rowid_and_float_passed_as_int_deserializes_orig ) -> None: original = Cursor( rowid=10, - sort_column=SortColumn( - type=SortColumnDataType.FLOAT, + sort_column=CursorSortColumn( + type=CursorSortColumnDataType.FLOAT, value=11, # an integer value ), ) @@ -162,7 +165,7 @@ def test_to_and_from_string_with_rowid_and_float_passed_as_int_deserializes_orig deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 assert (sort_column := deserialized.sort_column) is not None - assert sort_column.type == SortColumnDataType.FLOAT + assert sort_column.type == CursorSortColumnDataType.FLOAT assert isinstance((value := sort_column.value), float) assert abs(value - 11.0) < 1e-8 @@ -172,13 +175,13 @@ def test_to_and_from_string_with_rowid_and_tz_naive_datetime_deserializes_origin timestamp = datetime.fromisoformat("2024-05-05T04:25:29.911245") original = Cursor( rowid=10, - sort_column=SortColumn(type=SortColumnDataType.DATETIME, value=timestamp), + sort_column=CursorSortColumn(type=CursorSortColumnDataType.DATETIME, value=timestamp), ) cursor_string = str(original) deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 assert (sort_column := deserialized.sort_column) is not None - assert sort_column.type == SortColumnDataType.DATETIME + assert sort_column.type == CursorSortColumnDataType.DATETIME assert sort_column.value == timestamp assert sort_column.value.tzinfo is None @@ -188,12 +191,12 @@ def test_to_and_from_string_with_rowid_and_tz_aware_datetime_deserializes_origin timestamp = datetime.fromisoformat("2024-05-05T04:25:29.911245+00:00") original = Cursor( rowid=10, - sort_column=SortColumn(type=SortColumnDataType.DATETIME, value=timestamp), + sort_column=CursorSortColumn(type=CursorSortColumnDataType.DATETIME, value=timestamp), ) cursor_string = str(original) deserialized = Cursor.from_string(cursor_string) assert deserialized.rowid == 10 assert (sort_column := deserialized.sort_column) is not None - assert sort_column.type == SortColumnDataType.DATETIME + assert sort_column.type == CursorSortColumnDataType.DATETIME assert sort_column.value == timestamp assert sort_column.value.tzinfo is not None From e3203284a36232bc04ccd6593302bc525c676220 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 6 May 2024 23:54:04 -0700 Subject: [PATCH 62/74] add docstring --- src/phoenix/server/api/types/pagination.py | 51 ++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index a4de67cad5..c8935ada86 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -101,6 +101,57 @@ def from_string( @dataclass class Cursor: + """ + Serializes and deserializes cursor strings for ID-based pagination. + + In the simplest case, a cursor encodes the rowid of a record. In the case + that a sort has been applied, the cursor additionally encodes the data type + and value of the column indexed for sorting so that the sort position can be + efficiently found. The encoding ensures that the cursor string is opaque to + the client and discourages the client from making use of the encoded + content. + + Examples: + # encodes "10" + Cursor(rowid=10) + + # encodes "11:STRING:abc" + Cursor( + rowid=11, + sort_column=CursorSortColumn( + type=CursorSortColumnDataType.STRING, + value="abc" + ) + ) + + # encodes "10:INT:5" + Cursor( + rowid=10, + sort_column=CursorSortColumn( + type=CursorSortColumnDataType.INT, + value=5 + ) + ) + + # encodes "17:FLOAT:5.7" + Cursor( + rowid=17, + sort_column=CursorSortColumn( + type=CursorSortColumnDataType.FLOAT, + value=5.7 + ) + ) + + # encodes "20:DATETIME:2024-05-05T04:25:29.911245+00:00" + Cursor( + rowid=20, + sort_column=CursorSortColumn( + type=CursorSortColumnDataType.DATETIME, + value=datetime.fromisoformat("2024-05-05T04:25:29.911245+00:00") + ) + ) + """ + rowid: int sort_column: Optional[CursorSortColumn] = None From f7f7786eeb851cd56dde2f03dd4ef0d6f42758d7 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 May 2024 00:36:04 -0700 Subject: [PATCH 63/74] fix bug preventing sorts on token counts --- .../pagination_query_testing.ipynb | 34 +++++++++++-------- .../server/api/input_types/SpanSort.py | 7 ++-- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index cebb51e9fd..973ae2ebe5 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -10,8 +10,8 @@ "from gql.transport.requests import RequestsHTTPTransport\n", "from phoenix.server.api.types.pagination import (\n", " Cursor,\n", - " SortColumn,\n", - " SortColumnDataType,\n", + " CursorSortColumn,\n", + " CursorSortColumnDataType,\n", ")\n", "\n", "project_id = \"UHJvamVjdDox\"\n", @@ -419,8 +419,8 @@ " \"after\": str(\n", " Cursor(\n", " 760,\n", - " sort_column=SortColumn.from_string(\n", - " type=SortColumnDataType.DATETIME,\n", + " sort_column=CursorSortColumn.from_string(\n", + " type=CursorSortColumnDataType.DATETIME,\n", " stringified_value=\"2023-12-11T17:48:40.154938+00:00\",\n", " ),\n", " )\n", @@ -438,10 +438,12 @@ "assert (\n", " start_timestamp := str(start_sort_column)\n", ") == \"2023-12-11T17:48:40.154139+00:00\", start_timestamp\n", - "assert (start_field_type := start_sort_column.type) == SortColumnDataType.DATETIME, start_field_type\n", + "assert (\n", + " start_field_type := start_sort_column.type\n", + ") == CursorSortColumnDataType.DATETIME, start_field_type\n", "assert (end_sort_column := end_cursor.sort_column) is not None\n", "assert (end_timestamp := str(end_sort_column)) == \"2023-12-11T17:48:38.803725+00:00\", end_timestamp\n", - "assert (end_field_type := end_sort_column.type) == SortColumnDataType.DATETIME, end_field_type" + "assert (end_field_type := end_sort_column.type) == CursorSortColumnDataType.DATETIME, end_field_type" ] }, { @@ -460,8 +462,8 @@ " \"after\": str(\n", " Cursor(\n", " 8,\n", - " sort_column=SortColumn.from_string(\n", - " type=SortColumnDataType.DATETIME,\n", + " sort_column=CursorSortColumn.from_string(\n", + " type=CursorSortColumnDataType.DATETIME,\n", " stringified_value=\"2023-12-11T17:43:25.540677+00:00\",\n", " ),\n", " )\n", @@ -479,10 +481,12 @@ "assert (\n", " start_timestamp := str(start_sort_column)\n", ") == \"2023-12-11T17:43:25.842986+00:00\", start_timestamp\n", - "assert (start_field_type := start_sort_column.type) == SortColumnDataType.DATETIME, start_field_type\n", + "assert (\n", + " start_field_type := start_sort_column.type\n", + ") == CursorSortColumnDataType.DATETIME, start_field_type\n", "assert (end_sort_column := end_cursor.sort_column) is not None\n", "assert (end_timestamp := str(end_sort_column)) == \"2023-12-11T17:43:26.496177+00:00\", end_timestamp\n", - "assert (end_field_type := end_sort_column.type) == SortColumnDataType.DATETIME, end_field_type" + "assert (end_field_type := end_sort_column.type) == CursorSortColumnDataType.DATETIME, end_field_type" ] }, { @@ -501,7 +505,7 @@ " \"after\": str(\n", " Cursor(\n", " rowid=644, # row 644 is in between rows 645 and 641, which also have 1054 cumulative prompt tokens\n", - " sort_column=SortColumn(type=SortColumnDataType.FLOAT, value=1054),\n", + " sort_column=CursorSortColumn(type=CursorSortColumnDataType.FLOAT, value=1054),\n", " )\n", " ),\n", " },\n", @@ -533,7 +537,7 @@ " \"after\": str(\n", " Cursor(\n", " rowid=294, # row 294 is in between rows 295 and 291, which also have 276 cumulative prompt tokens\n", - " sort_column=SortColumn(type=SortColumnDataType.INT, value=276),\n", + " sort_column=CursorSortColumn(type=CursorSortColumnDataType.INT, value=276),\n", " )\n", " ),\n", " },\n", @@ -614,7 +618,9 @@ " \"after\": str(\n", " Cursor(\n", " rowid=141, # row 141 is surrounded by many other hallucinations\n", - " sort_column=SortColumn(type=SortColumnDataType.STRING, value=\"hallucinated\"),\n", + " sort_column=CursorSortColumn(\n", + " type=CursorSortColumnDataType.STRING, value=\"hallucinated\"\n", + " ),\n", " )\n", " ),\n", " },\n", @@ -649,7 +655,7 @@ " \"after\": str(\n", " Cursor(\n", " rowid=731, # row 746 is surrounded by many other hallucinations\n", - " sort_column=SortColumn(type=SortColumnDataType.STRING, value=\"factual\"),\n", + " sort_column=CursorSortColumn(type=CursorSortColumnDataType.STRING, value=\"factual\"),\n", " )\n", " ),\n", " },\n", diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index 69ad65ee73..d83dcfbbb8 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum, auto -from typing import Any, Optional, Protocol, cast +from typing import Any, Optional, Protocol import strawberry from openinference.semconv.trace import SpanAttributes @@ -35,7 +35,10 @@ class SpanColumn(Enum): @property def column_name(self) -> str: - return cast(str, self.orm_expression.name) + for attribute_name in ("name", "key"): + if attribute_value := getattr(self.orm_expression, attribute_name, None): + return str(attribute_value) + raise ValueError(f"Could not determine column name for {self}") @property def orm_expression(self) -> Any: From d5a96a573be49e275e884fe63d53b331d913f8a5 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 May 2024 00:53:58 -0700 Subject: [PATCH 64/74] add test for order by hallucination eval score in descending order with cursor --- .../pagination_query_testing.ipynb | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 973ae2ebe5..26f1cdefbb 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -664,6 +664,41 @@ "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [751, 736, 761, 756, 746], ids" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# order by hallucination eval score in descending order with cursor\n", + "response = client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": project_id,\n", + " \"sort\": {\n", + " \"evalResultKey\": {\"name\": \"Hallucination\", \"attr\": \"score\"},\n", + " \"dir\": \"desc\",\n", + " },\n", + " \"first\": 5,\n", + " \"after\": str(\n", + " Cursor(\n", + " rowid=21,\n", + " sort_column=CursorSortColumn(type=CursorSortColumnDataType.FLOAT, value=1),\n", + " )\n", + " ),\n", + " },\n", + ")\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", + "assert ids == [\n", + " 6,\n", + " 761,\n", + " 756,\n", + " 746,\n", + " 741,\n", + "], ids" + ] } ], "metadata": { From d57001ed1c407d5d965c5baab94b9d7d8785890d Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 May 2024 01:04:34 -0700 Subject: [PATCH 65/74] add test for ascending order by hallucination score --- .../pagination_query_testing.ipynb | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 26f1cdefbb..737ebac89d 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -699,6 +699,35 @@ " 741,\n", "], ids" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# order by hallucination eval score in ascending order with cursor\n", + "response = client.execute(\n", + " spans_query,\n", + " variable_values={\n", + " \"projectId\": project_id,\n", + " \"sort\": {\n", + " \"evalResultKey\": {\"name\": \"Hallucination\", \"attr\": \"score\"},\n", + " \"dir\": \"desc\",\n", + " },\n", + " \"first\": 5,\n", + " \"after\": str(\n", + " Cursor(\n", + " rowid=26,\n", + " sort_column=CursorSortColumn(type=CursorSortColumnDataType.FLOAT, value=1),\n", + " )\n", + " ),\n", + " },\n", + ")\n", + "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", + "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", + "assert ids == [21, 6, 761, 756, 746], ids" + ] } ], "metadata": { From b63426aca0d7f96ee3fa58ee15a883951721fe5f Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 May 2024 01:05:53 -0700 Subject: [PATCH 66/74] revert changes to main --- src/phoenix/server/main.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 0acd1b44aa..0238fec057 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -39,6 +39,7 @@ download_traces_fixture, get_evals_from_fixture, get_trace_fixture_by_name, + reset_fixture_span_ids_and_timestamps, ) from phoenix.trace.otel import decode_otlp_span, encode_span_to_otlp from phoenix.trace.schemas import Span @@ -205,17 +206,17 @@ def _get_pid_file() -> Path: fixture_spans: List[Span] = [] fixture_evals: List[pb.Evaluation] = [] if trace_dataset_name is not None: - # todo: add boolean flag for --reset-span-ids-and-timestamps - # todo: ensure that fixture tuples are inserted in chronological order - fixture_spans = [ - # Apply `encode` here because legacy jsonl files contains UUIDs as strings. - # `encode` removes the hyphens in the UUIDs. - decode_otlp_span(encode_span_to_otlp(json_string_to_span(json_span))) - for json_span in reversed( - download_traces_fixture(get_trace_fixture_by_name(trace_dataset_name)) - ) - ] - fixture_evals = list(get_evals_from_fixture(trace_dataset_name)) + fixture_spans, fixture_evals = reset_fixture_span_ids_and_timestamps( + ( + # Apply `encode` here because legacy jsonl files contains UUIDs as strings. + # `encode` removes the hyphens in the UUIDs. + decode_otlp_span(encode_span_to_otlp(json_string_to_span(json_span))) + for json_span in download_traces_fixture( + get_trace_fixture_by_name(trace_dataset_name) + ) + ), + get_evals_from_fixture(trace_dataset_name), + ) umap_params_list = args.umap_params.split(",") umap_params = UMAPParameters( min_dist=float(umap_params_list[0]), From 380aceec8022383ad59965b4a2a7924c4224c970 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 May 2024 01:19:28 -0700 Subject: [PATCH 67/74] clean --- src/phoenix/server/api/input_types/SpanSort.py | 15 +++++++++------ src/phoenix/server/api/types/Project.py | 7 +++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index d83dcfbbb8..e28df66681 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -127,9 +127,10 @@ def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluat @dataclass(frozen=True) class SpanSortConfig: stmt: Select[Any] - column_name: str orm_expression: Any - data_type: CursorSortColumnDataType + dir: SortDir + column_name: str + column_data_type: CursorSortColumnDataType @strawberry.input( @@ -148,9 +149,10 @@ def update_orm_expr(self, stmt: Select[Any]) -> SpanSortConfig: expr = desc(expr) return SpanSortConfig( stmt=stmt.order_by(nulls_last(expr)), - column_name=col.column_name, orm_expression=col.orm_expression, - data_type=col.data_type, + dir=self.dir, + column_name=col.column_name, + column_data_type=col.data_type, ) if (eval_result_key := self.eval_result_key) and not col: eval_name = eval_result_key.name @@ -168,8 +170,9 @@ def update_orm_expr(self, stmt: Select[Any]) -> SpanSortConfig: ).order_by(expr) return SpanSortConfig( stmt=stmt, - column_name=eval_attr.column_name, orm_expression=eval_result_key.attr.orm_expression, - data_type=eval_attr.data_type, + dir=self.dir, + column_name=eval_attr.column_name, + column_data_type=eval_attr.data_type, ) raise ValueError("Exactly one of `col` or `evalResultKey` must be specified on `SpanSort`.") diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 3bd8b5d2b3..0eb984df69 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -195,10 +195,9 @@ async def spans( stmt = sort_config.stmt if after: cursor = Cursor.from_string(after) - if cursor.sort_column is not None: + if sort_config and cursor.sort_column: sort_column = cursor.sort_column - assert sort is not None # todo: refactor this into a validation check - compare = operator.lt if sort.dir is SortDir.desc else operator.gt + compare = operator.lt if sort_config.dir is SortDir.desc else operator.gt if sort_config: stmt = stmt.where( compare( @@ -223,7 +222,7 @@ async def spans( rowid=span.id, sort_column=( CursorSortColumn( - type=sort_config.data_type, + type=sort_config.column_data_type, value=eval_value if eval_value is not None else getattr(span, sort_config.column_name), From c5f0cad507efbfbf1434bd4ef39175c58681efbe Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 May 2024 09:46:40 -0700 Subject: [PATCH 68/74] rename variable --- integration-tests/pagination_query_testing.ipynb | 4 ++-- src/phoenix/server/api/types/pagination.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 737ebac89d..23333e4878 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -421,7 +421,7 @@ " 760,\n", " sort_column=CursorSortColumn.from_string(\n", " type=CursorSortColumnDataType.DATETIME,\n", - " stringified_value=\"2023-12-11T17:48:40.154938+00:00\",\n", + " cursor_string=\"2023-12-11T17:48:40.154938+00:00\",\n", " ),\n", " )\n", " ),\n", @@ -464,7 +464,7 @@ " 8,\n", " sort_column=CursorSortColumn.from_string(\n", " type=CursorSortColumnDataType.DATETIME,\n", - " stringified_value=\"2023-12-11T17:43:25.540677+00:00\",\n", + " cursor_string=\"2023-12-11T17:43:25.540677+00:00\",\n", " ),\n", " )\n", " ),\n", diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index c8935ada86..551dfe213e 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -82,18 +82,16 @@ def __str__(self) -> str: assert_never(self.type) @classmethod - def from_string( - cls, type: CursorSortColumnDataType, stringified_value: str - ) -> "CursorSortColumn": + def from_string(cls, type: CursorSortColumnDataType, cursor_string: str) -> "CursorSortColumn": value: CursorSortColumnValue if type is CursorSortColumnDataType.STRING: - value = stringified_value + value = cursor_string elif type is CursorSortColumnDataType.INT: - value = int(stringified_value) + value = int(cursor_string) elif type is CursorSortColumnDataType.FLOAT: - value = float(stringified_value) + value = float(cursor_string) elif type is CursorSortColumnDataType.DATETIME: - value = datetime.fromisoformat(stringified_value) + value = datetime.fromisoformat(cursor_string) else: assert_never(type) return cls(type=type, value=value) @@ -175,7 +173,7 @@ def from_string(cls, cursor: CursorString) -> "Cursor": type=CursorSortColumnDataType[ decoded[first_delimiter_index + 1 : second_delimiter_index] ], - stringified_value=decoded[second_delimiter_index + 1 :], + cursor_string=decoded[second_delimiter_index + 1 :], ) return cls(rowid=int(rowid_string), sort_column=sort_column) From 55fe46e409eee2044fb28c05d3cc6c909a48c12c Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 May 2024 09:50:13 -0700 Subject: [PATCH 69/74] change type hint --- src/phoenix/server/api/types/pagination.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/phoenix/server/api/types/pagination.py b/src/phoenix/server/api/types/pagination.py index 551dfe213e..00174980a0 100644 --- a/src/phoenix/server/api/types/pagination.py +++ b/src/phoenix/server/api/types/pagination.py @@ -155,14 +155,14 @@ class Cursor: _DELIMITER: ClassVar[str] = ":" - def __str__(self) -> CursorString: + def __str__(self) -> str: cursor_parts = [str(self.rowid)] if (sort_column := self.sort_column) is not None: cursor_parts.extend([sort_column.type.name, str(sort_column)]) return base64.b64encode(self._DELIMITER.join(cursor_parts).encode()).decode() @classmethod - def from_string(cls, cursor: CursorString) -> "Cursor": + def from_string(cls, cursor: str) -> "Cursor": decoded = base64.b64decode(cursor).decode() rowid_string = decoded sort_column = None From 7f4b772565f444cf3b34b68a4b5a9c0a1d686eaf Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 May 2024 10:00:18 -0700 Subject: [PATCH 70/74] remove unnecessary variable declaration --- src/phoenix/server/api/types/Project.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 0eb984df69..c02b5bfeff 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -188,7 +188,6 @@ async def spans( if filter_condition: span_filter = SpanFilter(condition=filter_condition) stmt = span_filter(stmt) - sort_column: Optional[CursorSortColumn] = None sort_config: Optional[SpanSortConfig] = None if sort: sort_config = sort.update_orm_expr(stmt) From 20ce9e64112b6d50fe81e5397b37fc736f0cbc13 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 May 2024 10:03:07 -0700 Subject: [PATCH 71/74] remove unnecessary nesting --- src/phoenix/server/api/types/Project.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index c02b5bfeff..1c687d5320 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -197,13 +197,12 @@ async def spans( if sort_config and cursor.sort_column: sort_column = cursor.sort_column compare = operator.lt if sort_config.dir is SortDir.desc else operator.gt - if sort_config: - stmt = stmt.where( - compare( - tuple_(sort_config.orm_expression, models.Span.id), - (sort_column.value, cursor.rowid), - ) + stmt = stmt.where( + compare( + tuple_(sort_config.orm_expression, models.Span.id), + (sort_column.value, cursor.rowid), ) + ) else: stmt = stmt.where(models.Span.id < cursor.rowid) if first: From a8c7d589239a6e327a11bbddded00c1c26342134 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 May 2024 10:08:20 -0700 Subject: [PATCH 72/74] nest iter inside context manager, it works outside the context manager but putting it here for safety --- src/phoenix/server/api/types/Project.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 1c687d5320..1113d182b8 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -230,11 +230,11 @@ async def spans( ), ) data.append((cursor, to_gql_span(span))) - has_next_page = True - try: - next(span_records) - except StopIteration: - has_next_page = False + has_next_page = True + try: + next(span_records) + except StopIteration: + has_next_page = False return connections( data, From ae0ac67b99aae73da06d5dfb60c786ce90e8e296 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 May 2024 15:25:25 -0700 Subject: [PATCH 73/74] ensure that direction of rowid column matches direction of sort column --- .../pagination_query_testing.ipynb | 142 +++++++++++------- src/phoenix/server/api/types/Project.py | 9 +- 2 files changed, 92 insertions(+), 59 deletions(-) diff --git a/integration-tests/pagination_query_testing.ipynb b/integration-tests/pagination_query_testing.ipynb index 23333e4878..e251ac5bd2 100644 --- a/integration-tests/pagination_query_testing.ipynb +++ b/integration-tests/pagination_query_testing.ipynb @@ -128,7 +128,7 @@ ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", - "assert ids == [765, 764, 763, 762, 761], ids" + "assert ids == [1, 2, 3, 4, 5], ids" ] }, { @@ -142,13 +142,13 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": str(Cursor(rowid=761)),\n", + " \"after\": str(Cursor(rowid=755)),\n", " \"first\": 5,\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", - "assert ids == [760, 759, 758, 757, 756], ids" + "assert ids == [756, 757, 758, 759, 760], ids" ] }, { @@ -162,7 +162,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": str(Cursor(7)),\n", + " \"after\": str(Cursor(759)),\n", " \"first\": 5,\n", " },\n", ")\n", @@ -170,7 +170,7 @@ "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", - "assert ids == [6, 5, 4, 3, 2], ids\n", + "assert ids == [760, 761, 762, 763, 764], ids\n", "assert has_next_page is True\n", "assert has_previous_page is False" ] @@ -186,7 +186,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": str(Cursor(6)),\n", + " \"after\": str(Cursor(760)),\n", " \"first\": 5,\n", " },\n", ")\n", @@ -194,7 +194,7 @@ "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", - "assert ids == [5, 4, 3, 2, 1], ids\n", + "assert ids == [761, 762, 763, 764, 765], ids\n", "assert has_next_page is False\n", "assert has_previous_page is False" ] @@ -210,7 +210,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": str(Cursor(5)),\n", + " \"after\": str(Cursor(761)),\n", " \"first\": 5,\n", " },\n", ")\n", @@ -218,7 +218,7 @@ "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "has_next_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasNextPage\"]\n", "has_previous_page = response[\"node\"][\"spans\"][\"pageInfo\"][\"hasPreviousPage\"]\n", - "assert ids == [4, 3, 2, 1], ids\n", + "assert ids == [762, 763, 764, 765], ids\n", "assert has_next_page is False\n", "assert has_previous_page is False" ] @@ -241,11 +241,11 @@ "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [\n", - " 765,\n", - " 760,\n", - " 755,\n", - " 750,\n", - " 745,\n", + " 5,\n", + " 10,\n", + " 15,\n", + " 20,\n", + " 25,\n", "], ids" ] }, @@ -261,18 +261,18 @@ " variable_values={\n", " \"projectId\": project_id,\n", " \"first\": 5,\n", - " \"after\": str(Cursor(765)), # skip the first span satisfying the filter condition\n", + " \"after\": str(Cursor(5)), # skip the first span satisfying the filter condition\n", " \"filterCondition\": \"span_kind == 'LLM'\",\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [\n", - " 760,\n", - " 755,\n", - " 750,\n", - " 745,\n", - " 740,\n", + " 10,\n", + " 15,\n", + " 20,\n", + " 25,\n", + " 30,\n", "], ids" ] }, @@ -287,7 +287,7 @@ " spans_query,\n", " variable_values={\n", " \"projectId\": project_id,\n", - " \"after\": str(Cursor(745)), # skip the first span satisfying the filter condition\n", + " \"after\": str(Cursor(10)), # skip the first span satisfying the filter condition\n", " \"first\": 5,\n", " \"filterCondition\": \"span_kind == 'LLM' and cumulative_llm_token_count_prompt > 300\",\n", " },\n", @@ -295,11 +295,11 @@ "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [\n", - " 740,\n", - " 730,\n", - " 720,\n", - " 710,\n", - " 690,\n", + " 15,\n", + " 30,\n", + " 35,\n", + " 45,\n", + " 60,\n", "], ids" ] }, @@ -400,7 +400,13 @@ ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", - "assert ids == [763, 762, 758, 757, 753], ids" + "assert ids == [\n", + " 2,\n", + " 3,\n", + " 7,\n", + " 8,\n", + " 12,\n", + "], ids" ] }, { @@ -418,10 +424,10 @@ " \"first\": 5,\n", " \"after\": str(\n", " Cursor(\n", - " 760,\n", + " 758,\n", " sort_column=CursorSortColumn.from_string(\n", " type=CursorSortColumnDataType.DATETIME,\n", - " cursor_string=\"2023-12-11T17:48:40.154938+00:00\",\n", + " cursor_string=\"2023-12-11T17:48:39.949837+00:00\",\n", " ),\n", " )\n", " ),\n", @@ -433,16 +439,22 @@ "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", "start_cursor = Cursor.from_string(start_cursor)\n", "end_cursor = Cursor.from_string(end_cursor)\n", - "assert ids == [759, 758, 757, 756, 755], ids\n", + "assert ids == [\n", + " 757,\n", + " 756,\n", + " 755,\n", + " 754,\n", + " 753,\n", + "], ids\n", "assert (start_sort_column := start_cursor.sort_column) is not None\n", "assert (\n", " start_timestamp := str(start_sort_column)\n", - ") == \"2023-12-11T17:48:40.154139+00:00\", start_timestamp\n", + ") == \"2023-12-11T17:48:39.949695+00:00\", start_timestamp\n", "assert (\n", " start_field_type := start_sort_column.type\n", ") == CursorSortColumnDataType.DATETIME, start_field_type\n", "assert (end_sort_column := end_cursor.sort_column) is not None\n", - "assert (end_timestamp := str(end_sort_column)) == \"2023-12-11T17:48:38.803725+00:00\", end_timestamp\n", + "assert (end_timestamp := str(end_sort_column)) == \"2023-12-11T17:48:38.603846+00:00\", end_timestamp\n", "assert (end_field_type := end_sort_column.type) == CursorSortColumnDataType.DATETIME, end_field_type" ] }, @@ -461,10 +473,10 @@ " \"first\": 5,\n", " \"after\": str(\n", " Cursor(\n", - " 8,\n", + " 9,\n", " sort_column=CursorSortColumn.from_string(\n", " type=CursorSortColumnDataType.DATETIME,\n", - " cursor_string=\"2023-12-11T17:43:25.540677+00:00\",\n", + " cursor_string=\"2023-12-11T17:43:25.842986+00:00\",\n", " ),\n", " )\n", " ),\n", @@ -476,16 +488,22 @@ "end_cursor = response[\"node\"][\"spans\"][\"pageInfo\"][\"endCursor\"]\n", "start_cursor = Cursor.from_string(start_cursor)\n", "end_cursor = Cursor.from_string(end_cursor)\n", - "assert ids == [9, 10, 11, 12, 13], ids\n", + "assert ids == [\n", + " 10,\n", + " 11,\n", + " 12,\n", + " 13,\n", + " 14,\n", + "], ids\n", "assert (start_sort_column := start_cursor.sort_column) is not None\n", "assert (\n", " start_timestamp := str(start_sort_column)\n", - ") == \"2023-12-11T17:43:25.842986+00:00\", start_timestamp\n", + ") == \"2023-12-11T17:43:25.844758+00:00\", start_timestamp\n", "assert (\n", " start_field_type := start_sort_column.type\n", ") == CursorSortColumnDataType.DATETIME, start_field_type\n", "assert (end_sort_column := end_cursor.sort_column) is not None\n", - "assert (end_timestamp := str(end_sort_column)) == \"2023-12-11T17:43:26.496177+00:00\", end_timestamp\n", + "assert (end_timestamp := str(end_sort_column)) == \"2023-12-11T17:43:26.704532+00:00\", end_timestamp\n", "assert (end_field_type := end_sort_column.type) == CursorSortColumnDataType.DATETIME, end_field_type" ] }, @@ -544,7 +562,13 @@ ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", - "assert ids == [295, 115, 114, 111, 25], ids" + "assert ids == [\n", + " 295,\n", + " 111,\n", + " 114,\n", + " 115,\n", + " 21,\n", + "], ids" ] }, { @@ -590,13 +614,7 @@ ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", - "assert ids == [\n", - " 751,\n", - " 736,\n", - " 731,\n", - " 726,\n", - " 716,\n", - "], ids" + "assert ids == [6, 21, 26, 31, 41], ids" ] }, { @@ -617,7 +635,7 @@ " \"first\": 5,\n", " \"after\": str(\n", " Cursor(\n", - " rowid=141, # row 141 is surrounded by many other hallucinations\n", + " rowid=16, # row 141 is surrounded by many other hallucinations\n", " sort_column=CursorSortColumn(\n", " type=CursorSortColumnDataType.STRING, value=\"hallucinated\"\n", " ),\n", @@ -628,11 +646,11 @@ "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", "assert ids == [\n", - " 121,\n", - " 116,\n", - " 106,\n", - " 76,\n", - " 66,\n", + " 11,\n", + " 1,\n", + " 751,\n", + " 736,\n", + " 731,\n", "], ids" ] }, @@ -662,7 +680,13 @@ ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", - "assert ids == [751, 736, 761, 756, 746], ids" + "assert ids == [\n", + " 736,\n", + " 751,\n", + " 1,\n", + " 11,\n", + " 16,\n", + "], ids" ] }, { @@ -713,20 +737,26 @@ " \"projectId\": project_id,\n", " \"sort\": {\n", " \"evalResultKey\": {\"name\": \"Hallucination\", \"attr\": \"score\"},\n", - " \"dir\": \"desc\",\n", + " \"dir\": \"asc\",\n", " },\n", " \"first\": 5,\n", " \"after\": str(\n", " Cursor(\n", - " rowid=26,\n", - " sort_column=CursorSortColumn(type=CursorSortColumnDataType.FLOAT, value=1),\n", + " rowid=746,\n", + " sort_column=CursorSortColumn(type=CursorSortColumnDataType.FLOAT, value=0),\n", " )\n", " ),\n", " },\n", ")\n", "cursors = [edge[\"cursor\"] for edge in response[\"node\"][\"spans\"][\"edges\"]]\n", "ids = [Cursor.from_string(cursor).rowid for cursor in cursors]\n", - "assert ids == [21, 6, 761, 756, 746], ids" + "assert ids == [\n", + " 756,\n", + " 761,\n", + " 6,\n", + " 21,\n", + " 26,\n", + "], ids" ] } ], diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 1113d182b8..7fe69aacb0 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -1,6 +1,6 @@ import operator from datetime import datetime -from typing import List, Optional +from typing import Any, List, Optional import strawberry from aioitertools.itertools import islice @@ -189,9 +189,12 @@ async def spans( span_filter = SpanFilter(condition=filter_condition) stmt = span_filter(stmt) sort_config: Optional[SpanSortConfig] = None + cursor_rowid_column: Any = models.Span.id if sort: sort_config = sort.update_orm_expr(stmt) stmt = sort_config.stmt + if sort_config.dir is SortDir.desc: + cursor_rowid_column = desc(cursor_rowid_column) if after: cursor = Cursor.from_string(after) if sort_config and cursor.sort_column: @@ -204,12 +207,12 @@ async def spans( ) ) else: - stmt = stmt.where(models.Span.id < cursor.rowid) + stmt = stmt.where(models.Span.id > cursor.rowid) if first: stmt = stmt.limit( first + 1 # overfetch by one to determine whether there's a next page ) - stmt = stmt.order_by(desc(models.Span.id)) + stmt = stmt.order_by(cursor_rowid_column) data = [] async with info.context.db() as session: span_records = await session.execute(stmt) From 30f46daec3a2197d23e0f3ba76dbf53618853308 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 7 May 2024 17:16:43 -0700 Subject: [PATCH 74/74] fix bug with token counts --- .../server/api/input_types/SpanSort.py | 47 ++++++++++--------- src/phoenix/server/api/types/Project.py | 6 +-- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index e28df66681..80b1a0e619 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -35,35 +35,35 @@ class SpanColumn(Enum): @property def column_name(self) -> str: - for attribute_name in ("name", "key"): - if attribute_value := getattr(self.orm_expression, attribute_name, None): - return str(attribute_value) - raise ValueError(f"Could not determine column name for {self}") + return "f{self.name}_span_sort_column" @property def orm_expression(self) -> Any: + expr: Any if self is SpanColumn.startTime: - return models.Span.start_time - if self is SpanColumn.endTime: - return models.Span.end_time - if self is SpanColumn.latencyMs: - return models.Span.latency_ms - if self is SpanColumn.tokenCountTotal: - return models.Span.attributes[LLM_TOKEN_COUNT_TOTAL].as_float() - if self is SpanColumn.tokenCountPrompt: - return models.Span.attributes[LLM_TOKEN_COUNT_PROMPT].as_float() - if self is SpanColumn.tokenCountCompletion: - return models.Span.attributes[LLM_TOKEN_COUNT_COMPLETION].as_float() - if self is SpanColumn.cumulativeTokenCountTotal: - return ( + expr = models.Span.start_time + elif self is SpanColumn.endTime: + expr = models.Span.end_time + elif self is SpanColumn.latencyMs: + expr = models.Span.latency_ms + elif self is SpanColumn.tokenCountTotal: + expr = models.Span.attributes[LLM_TOKEN_COUNT_TOTAL].as_float() + elif self is SpanColumn.tokenCountPrompt: + expr = models.Span.attributes[LLM_TOKEN_COUNT_PROMPT].as_float() + elif self is SpanColumn.tokenCountCompletion: + expr = models.Span.attributes[LLM_TOKEN_COUNT_COMPLETION].as_float() + elif self is SpanColumn.cumulativeTokenCountTotal: + expr = ( models.Span.cumulative_llm_token_count_prompt + models.Span.cumulative_llm_token_count_completion ) - if self is SpanColumn.cumulativeTokenCountPrompt: - return models.Span.cumulative_llm_token_count_prompt - if self is SpanColumn.cumulativeTokenCountCompletion: - return models.Span.cumulative_llm_token_count_completion - assert_never(self) + elif self is SpanColumn.cumulativeTokenCountPrompt: + expr = models.Span.cumulative_llm_token_count_prompt + elif self is SpanColumn.cumulativeTokenCountCompletion: + expr = models.Span.cumulative_llm_token_count_completion + else: + assert_never(self) + return expr.label(self.column_name) @property def data_type(self) -> CursorSortColumnDataType: @@ -92,7 +92,7 @@ class EvalAttr(Enum): @property def column_name(self) -> str: - return f"span_annotations_{self.value}" + return f"{self.value}_eval_sort_column" @property def orm_expression(self) -> Any: @@ -145,6 +145,7 @@ class SpanSort: def update_orm_expr(self, stmt: Select[Any]) -> SpanSortConfig: if (col := self.col) and not self.eval_result_key: expr = col.orm_expression + stmt = stmt.add_columns(expr) if self.dir == SortDir.desc: expr = desc(expr) return SpanSortConfig( diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 7fe69aacb0..a5aeb16dc8 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -218,15 +218,13 @@ async def spans( span_records = await session.execute(stmt) async for span_record in islice(span_records, first): span = span_record[0] - eval_value = span_record[1] if len(span_record) > 1 else None + sort_column_value = span_record[1] if len(span_record) > 1 else None cursor = Cursor( rowid=span.id, sort_column=( CursorSortColumn( type=sort_config.column_data_type, - value=eval_value - if eval_value is not None - else getattr(span, sort_config.column_name), + value=sort_column_value, ) if sort_config else None