From c202fae17ec3e4bcf36d52bdab35e3fe91ab69fa Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Tue, 16 Apr 2024 18:04:47 -0700 Subject: [PATCH 01/46] wip --- app/src/openInference/tracing/types.ts | 92 +- app/src/pages/trace/TracePage.tsx | 204 ++- integration-tests/span_query_testing.ipynb | 1207 +++++++++++++++++ pyproject.toml | 3 +- src/phoenix/core/project.py | 13 +- src/phoenix/db/bulk_inserter.py | 25 +- .../migrations/versions/cf03bd6bae1d_init.py | 10 +- src/phoenix/db/models.py | 62 +- .../server/api/input_types/SpanSort.py | 6 +- src/phoenix/server/api/routers/v1/spans.py | 100 +- src/phoenix/server/api/types/Project.py | 10 +- src/phoenix/server/api/types/Span.py | 80 +- src/phoenix/server/app.py | 3 +- src/phoenix/session/client.py | 28 +- src/phoenix/session/session.py | 186 +-- src/phoenix/trace/dsl/filter.py | 553 +++++--- src/phoenix/trace/dsl/missing.py | 60 - src/phoenix/trace/dsl/query.py | 630 ++++++--- src/phoenix/trace/otel.py | 227 +--- src/phoenix/trace/schemas.py | 6 +- src/phoenix/utilities/__init__.py | 26 - src/phoenix/utilities/attributes.py | 278 ++++ tests/core/test_project.py | 14 +- tests/server/api/types/test_span.py | 17 - tests/trace/conftest.py | 19 + tests/trace/dsl/test_filter.py | 215 ++- tests/trace/dsl/test_query.py | 302 ----- tests/trace/test_otel.py | 262 ++-- tests/utilities/test_attributes.py | 90 ++ 29 files changed, 3005 insertions(+), 1723 deletions(-) create mode 100644 integration-tests/span_query_testing.ipynb delete mode 100644 src/phoenix/trace/dsl/missing.py create mode 100644 src/phoenix/utilities/attributes.py delete mode 100644 tests/server/api/types/test_span.py create mode 100644 tests/trace/conftest.py delete mode 100644 tests/trace/dsl/test_query.py create mode 100644 tests/utilities/test_attributes.py diff --git a/app/src/openInference/tracing/types.ts b/app/src/openInference/tracing/types.ts index e5ded49192..510e00d303 100644 --- a/app/src/openInference/tracing/types.ts +++ b/app/src/openInference/tracing/types.ts @@ -1,46 +1,86 @@ import { - DOCUMENT_CONTENT, - DOCUMENT_ID, - DOCUMENT_METADATA, - DOCUMENT_SCORE, - EMBEDDING_TEXT, + EmbeddingAttributePostfixes, + LLMAttributePostfixes, LLMPromptTemplateAttributePostfixes, - MESSAGE_CONTENT, - MESSAGE_NAME, - MESSAGE_ROLE, - MESSAGE_TOOL_CALLS, - TOOL_CALL_FUNCTION_ARGUMENTS_JSON, - TOOL_CALL_FUNCTION_NAME, + MessageAttributePostfixes, + RerankerAttributePostfixes, + RetrievalAttributePostfixes, + ToolAttributePostfixes, } from "@arizeai/openinference-semantic-conventions"; +import { + DocumentAttributePostfixes, + SemanticAttributePrefixes, +} from "@arizeai/openinference-semantic-conventions/src/trace/SemanticConventions"; +export type AttributeTool = { + [ToolAttributePostfixes.name]?: string; + [ToolAttributePostfixes.description]?: string; + [ToolAttributePostfixes.parameters]?: string; +}; export type AttributeToolCall = { - [TOOL_CALL_FUNCTION_NAME]: string; - [TOOL_CALL_FUNCTION_ARGUMENTS_JSON]: string; + function?: { + name?: string; + arguments?: string; + }; }; +export type AttributeMessages = { + [SemanticAttributePrefixes.message]?: AttributeMessage; +}[]; export type AttributeMessage = { - [MESSAGE_ROLE]: string; - [MESSAGE_CONTENT]: string; - [MESSAGE_NAME]?: string; - [MESSAGE_TOOL_CALLS]?: AttributeToolCall[]; - [key: string]: unknown; + [MessageAttributePostfixes.role]?: string; + [MessageAttributePostfixes.content]?: string; + [MessageAttributePostfixes.name]?: string; + [MessageAttributePostfixes.function_call_name]?: string; + [MessageAttributePostfixes.function_call_arguments_json]?: string; + [MessageAttributePostfixes.tool_calls]?: { + [SemanticAttributePrefixes.tool_call]?: AttributeToolCall; + }[]; }; +export type AttributeRetrieval = { + [RetrievalAttributePostfixes.documents]?: { + [SemanticAttributePrefixes.document]?: AttributeDocument; + }[]; +}; export type AttributeDocument = { - [DOCUMENT_ID]?: string; - [DOCUMENT_CONTENT]: string; - [DOCUMENT_SCORE]?: number; - [DOCUMENT_METADATA]?: string; - [key: string]: unknown; + [DocumentAttributePostfixes.id]?: string; + [DocumentAttributePostfixes.content]?: string; + [DocumentAttributePostfixes.score]?: number; + [DocumentAttributePostfixes.metadata]?: string; }; export type AttributeEmbedding = { - [EMBEDDING_TEXT]?: string; - [key: string]: unknown; + [EmbeddingAttributePostfixes.model_name]?: string; + [EmbeddingAttributePostfixes.embeddings]?: { + [SemanticAttributePrefixes.embedding]?: AttributeEmbeddingEmbedding; + }[]; +}; +export type AttributeEmbeddingEmbedding = { + [EmbeddingAttributePostfixes.text]?: string; +}; + +export type AttributeReranker = { + [RerankerAttributePostfixes.query]?: string; + [RerankerAttributePostfixes.input_documents]?: { + [SemanticAttributePrefixes.document]?: AttributeDocument; + }[]; + [RerankerAttributePostfixes.output_documents]?: { + [SemanticAttributePrefixes.document]?: AttributeDocument; + }[]; +}; + +export type AttributeLlm = { + [LLMAttributePostfixes.model_name]?: string; + [LLMAttributePostfixes.token_count]?: number; + [LLMAttributePostfixes.input_messages]?: AttributeMessages; + [LLMAttributePostfixes.output_messages]?: AttributeMessages; + [LLMAttributePostfixes.invocation_parameters]?: string; + [LLMAttributePostfixes.prompts]?: string[]; + [LLMAttributePostfixes.prompt_template]?: AttributePromptTemplate; }; export type AttributePromptTemplate = { [LLMPromptTemplateAttributePostfixes.template]: string; [LLMPromptTemplateAttributePostfixes.variables]: Record; - [key: string]: unknown; }; diff --git a/app/src/pages/trace/TracePage.tsx b/app/src/pages/trace/TracePage.tsx index 187c489654..739e79f8b5 100644 --- a/app/src/pages/trace/TracePage.tsx +++ b/app/src/pages/trace/TracePage.tsx @@ -36,27 +36,17 @@ import { ViewStyleProps, } from "@arizeai/components"; import { - DOCUMENT_CONTENT, - DOCUMENT_ID, - DOCUMENT_METADATA, - DOCUMENT_SCORE, - EMBEDDING_TEXT, EmbeddingAttributePostfixes, LLMAttributePostfixes, - LLMPromptTemplateAttributePostfixes, - MESSAGE_CONTENT, - MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, - MESSAGE_FUNCTION_CALL_NAME, - MESSAGE_NAME, - MESSAGE_ROLE, - MESSAGE_TOOL_CALLS, + MessageAttributePostfixes, RerankerAttributePostfixes, RetrievalAttributePostfixes, - SemanticAttributePrefixes, - TOOL_CALL_FUNCTION_ARGUMENTS_JSON, - TOOL_CALL_FUNCTION_NAME, ToolAttributePostfixes, } from "@arizeai/openinference-semantic-conventions"; +import { + DocumentAttributePostfixes, + SemanticAttributePrefixes, +} from "@arizeai/openinference-semantic-conventions/src/trace/SemanticConventions"; import { CopyToClipboardButton, ExternalLink } from "@phoenix/components"; import { resizeHandleCSS } from "@phoenix/components/resize"; @@ -75,8 +65,13 @@ import { ConnectedMarkdownModeRadioGroup } from "@phoenix/markdown/MarkdownModeR import { AttributeDocument, AttributeEmbedding, + AttributeEmbeddingEmbedding, + AttributeLlm, AttributeMessage, AttributePromptTemplate, + AttributeReranker, + AttributeRetrieval, + AttributeTool, } from "@phoenix/openInference/tracing/types"; import { assertUnreachable, isStringArray } from "@phoenix/typeUtils"; import { formatFloat, numberFormatter } from "@phoenix/utils/numberFormatUtils"; @@ -98,7 +93,13 @@ type DocumentEvaluation = Span["documentEvaluations"][number]; /** * A span attribute object that is a map of string to an unknown value */ -type AttributeObject = Record; +type AttributeObject = { + [SemanticAttributePrefixes.retrieval]?: AttributeRetrieval; + [SemanticAttributePrefixes.embedding]?: AttributeEmbedding; + [SemanticAttributePrefixes.tool]?: AttributeTool; + [SemanticAttributePrefixes.reranker]?: AttributeReranker; + [SemanticAttributePrefixes.llm]?: AttributeLlm; +}; /** * Hook that safely parses a JSON string. @@ -115,30 +116,6 @@ const useSafelyParsedJSON = ( }, [jsonStr]); }; -function isAttributeObject(value: unknown): value is AttributeObject { - if ( - value != null && - typeof value === "object" && - !Object.keys(value).find((key) => typeof key != "string") - ) { - return true; - } - return false; -} - -export function isAttributePromptTemplate( - value: unknown -): value is AttributePromptTemplate { - if ( - isAttributeObject(value) && - typeof value[LLMPromptTemplateAttributePostfixes.template] === "string" && - typeof value[LLMPromptTemplateAttributePostfixes.variables] === "object" - ) { - return true; - } - return false; -} - const spanHasException = (span: Span) => { return span.events.some((event) => event.name === "exception"); }; @@ -571,13 +548,10 @@ function SpanInfo({ span }: { span: Span }) { function LLMSpanInfo(props: { span: Span; spanAttributes: AttributeObject }) { const { spanAttributes, span } = props; const { input, output } = span; - const llmAttributes = useMemo(() => { - const llmAttrs = spanAttributes[SemanticAttributePrefixes.llm]; - if (typeof llmAttrs === "object") { - return llmAttrs as AttributeObject; - } - return null; - }, [spanAttributes]); + const llmAttributes = useMemo( + () => spanAttributes[SemanticAttributePrefixes.llm] || null, + [spanAttributes] + ); const modelName = useMemo(() => { if (llmAttributes == null) { @@ -594,16 +568,18 @@ function LLMSpanInfo(props: { span: Span; spanAttributes: AttributeObject }) { if (llmAttributes == null) { return []; } - return (llmAttributes[LLMAttributePostfixes.input_messages] || - []) as AttributeMessage[]; + return (llmAttributes[LLMAttributePostfixes.input_messages]?.map( + (obj) => obj[SemanticAttributePrefixes.message] + ) || []) as AttributeMessage[]; }, [llmAttributes]); const outputMessages = useMemo(() => { if (llmAttributes == null) { return []; } - return (llmAttributes[LLMAttributePostfixes.output_messages] || - []) as AttributeMessage[]; + return (llmAttributes[LLMAttributePostfixes.output_messages]?.map( + (obj) => obj[SemanticAttributePrefixes.message] + ) || []) as AttributeMessage[]; }, [llmAttributes]); const prompts = useMemo(() => { @@ -621,10 +597,9 @@ function LLMSpanInfo(props: { span: Span; spanAttributes: AttributeObject }) { if (llmAttributes == null) { return null; } - const maybePromptTemplate = llmAttributes[LLMAttributePostfixes.prompt_template]; - if (!isAttributePromptTemplate(maybePromptTemplate)) { + if (maybePromptTemplate == null) { return null; } return maybePromptTemplate; @@ -794,19 +769,17 @@ function RetrieverSpanInfo(props: { }) { const { spanAttributes, span } = props; const { input } = span; - const retrieverAttributes = useMemo(() => { - const retrieverAttrs = spanAttributes[SemanticAttributePrefixes.retrieval]; - if (typeof retrieverAttrs === "object") { - return retrieverAttrs as AttributeObject; - } - return null; - }, [spanAttributes]); + const retrieverAttributes = useMemo( + () => spanAttributes[SemanticAttributePrefixes.retrieval] || null, + [spanAttributes] + ); const documents = useMemo(() => { if (retrieverAttributes == null) { return []; } - return (retrieverAttributes[RetrievalAttributePostfixes.documents] || - []) as AttributeDocument[]; + return (retrieverAttributes[RetrievalAttributePostfixes.documents]?.map( + (obj) => obj[SemanticAttributePrefixes.document] + ) || []) as AttributeDocument[]; }, [retrieverAttributes]); // Construct a map of document position to document evaluations @@ -918,13 +891,10 @@ function RerankerSpanInfo(props: { spanAttributes: AttributeObject; }) { const { spanAttributes } = props; - const rerankerAttributes = useMemo(() => { - const rerankerAttrs = spanAttributes[SemanticAttributePrefixes.reranker]; - if (typeof rerankerAttrs === "object") { - return rerankerAttrs as AttributeObject; - } - return null; - }, [spanAttributes]); + const rerankerAttributes = useMemo( + () => spanAttributes[SemanticAttributePrefixes.reranker] || null, + [spanAttributes] + ); const query = useMemo(() => { if (rerankerAttributes == null) { return ""; @@ -936,14 +906,17 @@ function RerankerSpanInfo(props: { if (rerankerAttributes == null) { return []; } - return (rerankerAttributes[RerankerAttributePostfixes.input_documents] || - []) as AttributeDocument[]; + return (rerankerAttributes[RerankerAttributePostfixes.input_documents]?.map( + (obj) => obj[SemanticAttributePrefixes.document] + ) || []) as AttributeDocument[]; }, [rerankerAttributes]); const output_documents = useMemo(() => { if (rerankerAttributes == null) { return []; } - return (rerankerAttributes[RerankerAttributePostfixes.output_documents] || + return (rerankerAttributes[ + RerankerAttributePostfixes.output_documents + ]?.map((obj) => obj[SemanticAttributePrefixes.document]) || []) as AttributeDocument[]; }, [rerankerAttributes]); @@ -1024,19 +997,17 @@ function EmbeddingSpanInfo(props: { spanAttributes: AttributeObject; }) { const { spanAttributes } = props; - const embeddingAttributes = useMemo(() => { - const embeddingAttrs = spanAttributes[SemanticAttributePrefixes.embedding]; - if (typeof embeddingAttrs === "object") { - return embeddingAttrs as AttributeObject; - } - return null; - }, [spanAttributes]); - const embeddings = useMemo(() => { + const embeddingAttributes = useMemo( + () => spanAttributes[SemanticAttributePrefixes.embedding] || null, + [spanAttributes] + ); + const embeddings = useMemo(() => { if (embeddingAttributes == null) { return []; } - return (embeddingAttributes[EmbeddingAttributePostfixes.embeddings] || - []) as AttributeDocument[]; + return (embeddingAttributes[EmbeddingAttributePostfixes.embeddings]?.map( + (obj) => obj[SemanticAttributePrefixes.embedding] + ) || []) as AttributeEmbeddingEmbedding[]; }, [embeddingAttributes]); const hasEmbeddings = embeddings.length > 0; @@ -1071,7 +1042,7 @@ function EmbeddingSpanInfo(props: { title="Embedded Text" > - {embedding[EMBEDDING_TEXT] || ""} + {embedding[EmbeddingAttributePostfixes.text] || ""} @@ -1088,14 +1059,10 @@ function EmbeddingSpanInfo(props: { function ToolSpanInfo(props: { span: Span; spanAttributes: AttributeObject }) { const { spanAttributes } = props; - const toolAttributes = useMemo(() => { - const toolAttrs = spanAttributes[SemanticAttributePrefixes.tool]; - if (typeof toolAttrs === "object") { - return toolAttrs as AttributeObject; - } - return {}; - }, [spanAttributes]); - + const toolAttributes = useMemo( + () => spanAttributes[SemanticAttributePrefixes.tool] || {}, + [spanAttributes] + ); const hasToolAttributes = Object.keys(toolAttributes).length > 0; if (!hasToolAttributes) { return null; @@ -1168,8 +1135,9 @@ function DocumentItem({ borderColor: ViewProps["borderColor"]; labelColor: LabelProps["color"]; }) { - const metadata = document[DOCUMENT_METADATA]; + const metadata = document[DocumentAttributePostfixes.metadata]; const hasEvaluations = documentEvaluations && documentEvaluations.length; + const documentContent = document[DocumentAttributePostfixes.content]; return ( } /> - document {document[DOCUMENT_ID]} + + document {document[DocumentAttributePostfixes.id]} + } extra={ - typeof document[DOCUMENT_SCORE] === "number" && ( + typeof document[DocumentAttributePostfixes.score] === "number" && ( ) } > - - - {document[DOCUMENT_CONTENT]} - - + {documentContent && ( + + {documentContent} + + )} {metadata && ( <> @@ -1288,12 +1258,15 @@ function DocumentItem({ } function LLMMessage({ message }: { message: AttributeMessage }) { - const messageContent = message[MESSAGE_CONTENT]; - const toolCalls = message[MESSAGE_TOOL_CALLS] || []; + const messageContent = message[MessageAttributePostfixes.content]; + const toolCalls = + message[MessageAttributePostfixes.tool_calls]?.map( + (obj) => obj[SemanticAttributePrefixes.tool_call] + ) || []; const hasFunctionCall = - message[MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON] && - message[MESSAGE_FUNCTION_CALL_NAME]; - const role = message[MESSAGE_ROLE]; + message[MessageAttributePostfixes.function_call_arguments_json] && + message[MessageAttributePostfixes.function_call_name]; + const role = message[MessageAttributePostfixes.role] || "unknown"; const messageStyles = useMemo(() => { if (role === "user") { return { @@ -1328,7 +1301,10 @@ function LLMMessage({ message }: { message: AttributeMessage }) { {...defaultCardProps} {...messageStyles} title={ - role + (message[MESSAGE_NAME] ? `: ${message[MESSAGE_NAME]}` : "") + role + + (message[MessageAttributePostfixes.name] + ? `: ${message[MessageAttributePostfixes.name]}` + : "") } extra={ @@ -1341,9 +1317,7 @@ function LLMMessage({ message }: { message: AttributeMessage }) { > {messageContent ? ( - - {message[MESSAGE_CONTENT]} - + {messageContent} ) : null} {toolCalls.length > 0 ? toolCalls.map((toolCall, idx) => { @@ -1355,11 +1329,9 @@ function LLMMessage({ message }: { message: AttributeMessage }) { margin: var(--ac-global-dimension-static-size-100) 0; `} > - {toolCall[TOOL_CALL_FUNCTION_NAME] as string}( + {toolCall?.function?.name as string}( {JSON.stringify( - JSON.parse( - toolCall[TOOL_CALL_FUNCTION_ARGUMENTS_JSON] as string - ), + JSON.parse(toolCall?.function?.arguments as string), null, 2 )} @@ -1376,10 +1348,12 @@ function LLMMessage({ message }: { message: AttributeMessage }) { margin: var(--ac-global-dimension-static-size-100) 0; `} > - {message[MESSAGE_FUNCTION_CALL_NAME] as string}( + {message[MessageAttributePostfixes.function_call_name] as string}( {JSON.stringify( JSON.parse( - message[MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON] as string + message[ + MessageAttributePostfixes.function_call_arguments_json + ] as string ), null, 2 diff --git a/integration-tests/span_query_testing.ipynb b/integration-tests/span_query_testing.ipynb new file mode 100644 index 0000000000..41b88ad032 --- /dev/null +++ b/integration-tests/span_query_testing.ipynb @@ -0,0 +1,1207 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "287646a0", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "pd.set_option(\"display.max_colwidth\", None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "408327c6-1ca1-4052-bcf9-ef7a82e21282", + "metadata": {}, + "outputs": [], + "source": [ + "import phoenix as px\n", + "from phoenix.db import models\n", + "from phoenix.trace.dsl.helpers import get_qa_with_reference, get_retrieved_documents\n", + "from phoenix.trace.dsl.query import SpanQuery\n", + "from sqlalchemy import create_engine, select\n", + "from sqlalchemy.orm import sessionmaker" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b29d83f-ea8b-4bb8-bdf6-b0efa5638647", + "metadata": {}, + "outputs": [], + "source": [ + "PostgresSession = sessionmaker(\n", + " create_engine(\n", + " \"postgresql+psycopg://localhost:5432/postgres?user=postgres&password=mysecretpassword\",\n", + " echo=True,\n", + " ),\n", + " expire_on_commit=False,\n", + ")\n", + "SqliteSession = sessionmaker(\n", + " create_engine(\"sqlite:////Users/rogeryang/.phoenix/phoenix.db\", echo=True),\n", + " expire_on_commit=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "d8dcce0f", + "metadata": {}, + "source": [ + "# latency ms" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "775c6a49-b7fb-4156-a0ae-33712c408a2d", + "metadata": {}, + "outputs": [], + "source": [ + "stmt = select(models.Span.latency_ms)\n", + "with SqliteSession.begin() as session:\n", + " print(session.scalar(stmt))\n", + "with PostgresSession.begin() as session:\n", + " print(session.scalar(stmt))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50889f12-e619-4cbe-8fcc-abb0d5da2573", + "metadata": {}, + "outputs": [], + "source": [ + "orig_endpoint = \"http://127.0.0.1:6007\"\n", + "postgres_endpoint = \"http://127.0.0.1:6006\"\n", + "sqlite_endpoint = \"http://127.0.0.1:6005\"" + ] + }, + { + "cell_type": "markdown", + "id": "c458d469", + "metadata": {}, + "source": [ + "# get spans dataframe with filter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de77e3e0", + "metadata": {}, + "outputs": [], + "source": [ + "filter_condition = \"latency_ms > 1000 and 'service' in output.value\"\n", + "df_orig_root_spans = (\n", + " px.Client(endpoint=orig_endpoint)\n", + " .get_spans_dataframe(filter_condition, root_spans_only=True)\n", + " .sort_index()\n", + " .sort_index(axis=1)\n", + ")\n", + "print(f\"{df_orig_root_spans.shape=}\")\n", + "df_postgres_root_spans = (\n", + " px.Client(endpoint=postgres_endpoint)\n", + " .get_spans_dataframe(filter_condition, root_spans_only=True)\n", + " .sort_index()\n", + " .sort_index(axis=1)\n", + ")\n", + "print(f\"{df_postgres_root_spans.shape=}\")\n", + "df_sqlite_root_spans = (\n", + " px.Client(endpoint=sqlite_endpoint)\n", + " .get_spans_dataframe(filter_condition, root_spans_only=True)\n", + " .sort_index()\n", + " .sort_index(axis=1)\n", + ")\n", + "print(f\"{df_sqlite_root_spans.shape=}\")\n", + "print(df_orig_root_spans.columns)\n", + "print(df_postgres_root_spans.columns)\n", + "print(df_sqlite_root_spans.columns)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37c86673", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(len(df_orig_root_spans.columns)):\n", + " print(\n", + " f\"{df_orig_root_spans.iloc[:,i].equals(df_postgres_root_spans.iloc[:,i])}, {df_orig_root_spans.iloc[:,i].equals(df_sqlite_root_spans.iloc[:,i])}, {df_postgres_root_spans.iloc[:,i].equals(df_sqlite_root_spans.iloc[:,i])},, {df_orig_root_spans.columns[i]} {i=}\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "ebf0e5e3", + "metadata": {}, + "source": [ + "# get spans dataframe no filter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df02c19f", + "metadata": {}, + "outputs": [], + "source": [ + "df_orig = (\n", + " px.Client(endpoint=orig_endpoint)\n", + " .get_spans_dataframe()\n", + " .sort_index()\n", + " .sort_index(axis=1)\n", + " .drop(\"conversation\", axis=1)\n", + ")\n", + "print(f\"{df_orig.shape=}\")\n", + "df_postgres = (\n", + " px.Client(endpoint=postgres_endpoint)\n", + " .get_spans_dataframe()\n", + " .sort_index()\n", + " .sort_index(axis=1)\n", + " .drop(\"attributes.openinference.span.kind\", axis=1)\n", + ")\n", + "print(f\"{df_postgres.shape=}\")\n", + "df_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint)\n", + " .get_spans_dataframe()\n", + " .sort_index()\n", + " .sort_index(axis=1)\n", + " .drop(\"attributes.openinference.span.kind\", axis=1)\n", + ")\n", + "print(f\"{df_sqlite.shape=}\")\n", + "print(df_orig.columns)\n", + "print(df_postgres.columns)\n", + "print(df_sqlite.columns)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2964dea", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(len(df_orig.columns)):\n", + " print(\n", + " f\"{df_orig.iloc[:,i].equals(df_postgres.iloc[:,i])}\",\n", + " f\"{df_orig.iloc[:,i].equals(df_sqlite.iloc[:,i])}\",\n", + " f\"{df_postgres.iloc[:,i].equals(df_sqlite.iloc[:,i])}\",\n", + " f\"{df_orig.columns[i]}\",\n", + " f\"{i=}\",\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "4415316b", + "metadata": {}, + "source": [ + "# qa with reference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2523c3a2", + "metadata": {}, + "outputs": [], + "source": [ + "qa_orig = get_qa_with_reference(px.Client(endpoint=orig_endpoint)).sort_index().sort_index(axis=1)\n", + "qa_postgres = (\n", + " get_qa_with_reference(px.Client(endpoint=postgres_endpoint)).sort_index().sort_index(axis=1)\n", + ")\n", + "qa_sqlite = (\n", + " get_qa_with_reference(px.Client(endpoint=sqlite_endpoint)).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{qa_orig.shape=}\")\n", + "print(f\"{qa_postgres.shape=}\")\n", + "print(f\"{qa_sqlite.shape=}\")\n", + "print(f\"{qa_orig.equals(qa_postgres)=}\")\n", + "print(f\"{qa_orig.equals(qa_sqlite)=}\")\n", + "print(f\"{qa_postgres.equals(qa_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " qa_orig.sample(5, random_state=42),\n", + " qa_postgres.sample(5, random_state=42),\n", + " qa_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "2c555dcd", + "metadata": {}, + "source": [ + "# get retrieved documents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8524aed", + "metadata": {}, + "outputs": [], + "source": [ + "docs_orig = (\n", + " get_retrieved_documents(px.Client(endpoint=orig_endpoint)).sort_index().sort_index(axis=1)\n", + ")\n", + "docs_postgres = (\n", + " get_retrieved_documents(px.Client(endpoint=postgres_endpoint)).sort_index().sort_index(axis=1)\n", + ")\n", + "docs_sqlite = (\n", + " get_retrieved_documents(px.Client(endpoint=sqlite_endpoint)).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{docs_orig.shape=}\")\n", + "print(f\"{docs_postgres.shape=}\")\n", + "print(f\"{docs_sqlite.shape=}\")\n", + "print(f\"{docs_orig.equals(docs_postgres)=}\")\n", + "print(f\"{docs_orig.equals(docs_sqlite)=}\")\n", + "print(f\"{docs_postgres.equals(docs_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " docs_orig.sample(5, random_state=42),\n", + " docs_postgres.sample(5, random_state=42),\n", + " docs_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "550bd82f", + "metadata": {}, + "source": [ + "# select" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "538e0570-c8aa-40c6-a322-ef09a68dffbd", + "metadata": {}, + "outputs": [], + "source": [ + "select_query = (\n", + " SpanQuery()\n", + " .select(\"trace_id\", \"input.value\")\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + ")\n", + "df_select_orig = px.Client(endpoint=orig_endpoint).query_spans(select_query).sort_index()\n", + "print(f\"{df_select_orig.shape=}\")\n", + "df_select_postgres = px.Client(endpoint=postgres_endpoint).query_spans(select_query).sort_index()\n", + "print(f\"{df_select_postgres.shape=}\")\n", + "df_select_sqlite = px.Client(endpoint=sqlite_endpoint).query_spans(select_query).sort_index()\n", + "print(f\"{df_select_sqlite.shape=}\")\n", + "print(f\"{df_select_orig.equals(df_select_postgres)=}\")\n", + "print(f\"{df_select_orig.equals(df_select_sqlite)=}\")\n", + "print(f\"{df_select_postgres.equals(df_select_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_select_orig.sample(5, random_state=42),\n", + " df_select_postgres.sample(5, random_state=42),\n", + " df_select_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "733cbdc9", + "metadata": {}, + "source": [ + "# explode (no select or concat)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbf947d0-3cea-4d28-9e35-b389c2adf91e", + "metadata": {}, + "outputs": [], + "source": [ + "explode_query = (\n", + " SpanQuery()\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .explode(\"retrieval.documents\", content=\"document.content\", score=\"document.score\")\n", + ")\n", + "df_explode_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_orig.shape=}\")\n", + "df_explode_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_postgres.shape=}\")\n", + "df_explode_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_sqlite.shape=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_postgres)=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_sqlite)=}\")\n", + "print(f\"{df_explode_postgres.equals(df_explode_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_explode_orig.sample(5, random_state=42),\n", + " df_explode_postgres.sample(5, random_state=42),\n", + " df_explode_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "455a66f0", + "metadata": {}, + "source": [ + "# explode with select (no concat)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed7aa469", + "metadata": {}, + "outputs": [], + "source": [ + "explode_query = (\n", + " SpanQuery()\n", + " .select(\"trace_id\", \"input.value\")\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .explode(\"retrieval.documents\", content=\"document.content\", score=\"document.score\")\n", + ")\n", + "df_explode_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_orig.shape=}\")\n", + "df_explode_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_postgres.shape=}\")\n", + "df_explode_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_sqlite.shape=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_postgres)=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_sqlite)=}\")\n", + "print(f\"{df_explode_postgres.equals(df_explode_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_explode_orig.sample(5, random_state=42),\n", + " df_explode_postgres.sample(5, random_state=42),\n", + " df_explode_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "aae901dc", + "metadata": {}, + "source": [ + "# explode with concat (no select)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b842f02", + "metadata": {}, + "outputs": [], + "source": [ + "explode_query = (\n", + " SpanQuery()\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .explode(\"retrieval.documents\", content=\"document.content\")\n", + " .concat(\"retrieval.documents\", score=\"document.score\")\n", + ")\n", + "df_explode_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_orig.shape=}\")\n", + "df_explode_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_postgres.shape=}\")\n", + "df_explode_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_sqlite.shape=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_postgres)=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_sqlite)=}\")\n", + "print(f\"{df_explode_postgres.equals(df_explode_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_explode_orig.sample(5, random_state=42),\n", + " df_explode_postgres.sample(5, random_state=42),\n", + " df_explode_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "163828d2", + "metadata": {}, + "source": [ + "# explode with concat and select" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30fe7d7d", + "metadata": {}, + "outputs": [], + "source": [ + "explode_query = (\n", + " SpanQuery()\n", + " .select(\"trace_id\", \"input.value\")\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .explode(\"retrieval.documents\", content=\"document.content\")\n", + " .concat(\"retrieval.documents\", score=\"document.score\")\n", + ")\n", + "df_explode_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_orig.shape=}\")\n", + "df_explode_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_postgres.shape=}\")\n", + "df_explode_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_sqlite.shape=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_postgres)=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_sqlite)=}\")\n", + "print(f\"{df_explode_postgres.equals(df_explode_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_explode_orig.sample(5, random_state=42),\n", + " df_explode_postgres.sample(5, random_state=42),\n", + " df_explode_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "3c5cf80d", + "metadata": {}, + "source": [ + "# explode with no kwargs (no select or concat)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d01a5b7", + "metadata": {}, + "outputs": [], + "source": [ + "explode_query = (\n", + " SpanQuery()\n", + " .select(\"trace_id\", \"input.value\")\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .explode(\"retrieval.documents\")\n", + ")\n", + "df_explode_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_orig.shape=}\")\n", + "df_explode_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_postgres.shape=}\")\n", + "df_explode_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_sqlite.shape=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_postgres)=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_sqlite)=}\")\n", + "print(f\"{df_explode_postgres.equals(df_explode_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_explode_orig.sample(5, random_state=42),\n", + " df_explode_postgres.sample(5, random_state=42),\n", + " df_explode_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "1594b943", + "metadata": {}, + "source": [ + "# concat (no select or explode)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2d62404", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .concat(\"retrieval.documents\", content=\"document.content\", score=\"document.score\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "9325ee82", + "metadata": {}, + "source": [ + "# concat with explode (no select)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "926127e2", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .explode(\"retrieval.documents\", content=\"document.content\")\n", + " .concat(\"retrieval.documents\", score=\"document.score\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "178e93ba", + "metadata": {}, + "source": [ + "# concat with select (no explode)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f0f2c16-21cb-4bb6-9080-8eb5eb4a3c8b", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .select(\"trace_id\", \"input.value\")\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .concat(\"retrieval.documents\", content=\"document.content\", score=\"document.score\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "fa106e7e", + "metadata": {}, + "source": [ + "# concat with select and explode" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb674580", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .select(\"trace_id\", \"input.value\")\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .explode(\"retrieval.documents\", content=\"document.content\")\n", + " .concat(\"retrieval.documents\", score=\"document.score\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "90f020ff", + "metadata": {}, + "source": [ + "# no kwargs concat (no select or explode)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d67024f7", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .concat(\"retrieval.documents\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "e804609a", + "metadata": {}, + "source": [ + "# no kwargs concat with explode (no select)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "782f17a3", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .explode(\"retrieval.documents\", content=\"document.content\")\n", + " .concat(\"retrieval.documents\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "e3cf2da5", + "metadata": {}, + "source": [ + "# no kwargs concat with select (no explode)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc3abf22", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .select(\"trace_id\", \"input.value\")\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .concat(\"retrieval.documents\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "c79a5e52", + "metadata": {}, + "source": [ + "# no kwargs concat with select and explode" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53433554", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .select(\"trace_id\", \"input.value\")\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .explode(\"retrieval.documents\", content=\"document.content\")\n", + " .concat(\"retrieval.documents\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "00915314", + "metadata": {}, + "source": [ + "# no kwargs concat with no kwargs explode and select" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1458a8ac", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .select(\"trace_id\", \"input.value\")\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .explode(\"retrieval.documents\")\n", + " .concat(\"retrieval.documents\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "cca42d6b", + "metadata": {}, + "source": [ + "# no kwargs concat with no kwargs explode (no select)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a326ae97", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .explode(\"retrieval.documents\")\n", + " .concat(\"retrieval.documents\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "73ec4a95", + "metadata": {}, + "source": [ + "# concat index by name" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7dc3b94", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .with_index(\"name\")\n", + " .concat(\"retrieval.documents\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89ef08ea", + "metadata": {}, + "outputs": [], + "source": [ + "explode_query = (\n", + " SpanQuery()\n", + " .with_index(\"name\")\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .explode(\"retrieval.documents\")\n", + ")\n", + "df_explode_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_orig.shape=}\")\n", + "df_explode_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_postgres.shape=}\")\n", + "df_explode_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_sqlite.shape=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_postgres)=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_sqlite)=}\")\n", + "print(f\"{df_explode_postgres.equals(df_explode_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_explode_orig.sample(5, random_state=42),\n", + " df_explode_postgres.sample(5, random_state=42),\n", + " df_explode_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d814d7a6", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .with_index(\"trace_id\")\n", + " .concat(\"retrieval.documents\", score=\"document.score\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "755a3acd", + "metadata": {}, + "outputs": [], + "source": [ + "concat_query = (\n", + " SpanQuery()\n", + " .where(\"span_kind == 'RETRIEVER' and parent_id is not None and latency_ms > 200\")\n", + " .with_index(\"span_id\")\n", + " .concat(\"retrieval.documents\", score=\"document.score\")\n", + " .with_concat_separator(\"🌟\")\n", + ")\n", + "df_concat_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_orig.shape=}\")\n", + "df_concat_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_postgres.shape=}\")\n", + "df_concat_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(concat_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_concat_sqlite.shape=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_postgres)=}\")\n", + "print(f\"{df_concat_orig.equals(df_concat_sqlite)=}\")\n", + "print(f\"{df_concat_postgres.equals(df_concat_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_concat_orig.sample(5, random_state=42),\n", + " df_concat_postgres.sample(5, random_state=42),\n", + " df_concat_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + }, + { + "cell_type": "markdown", + "id": "4a33522e", + "metadata": {}, + "source": [ + "# explode embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7690d542", + "metadata": {}, + "outputs": [], + "source": [ + "explode_query = SpanQuery().explode(\"embedding.embeddings\", vector=\"embedding.vector\")\n", + "df_explode_orig = (\n", + " px.Client(endpoint=orig_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_orig.shape=}\")\n", + "df_explode_postgres = (\n", + " px.Client(endpoint=postgres_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_postgres.shape=}\")\n", + "df_explode_sqlite = (\n", + " px.Client(endpoint=sqlite_endpoint).query_spans(explode_query).sort_index().sort_index(axis=1)\n", + ")\n", + "print(f\"{df_explode_sqlite.shape=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_postgres)=}\")\n", + "print(f\"{df_explode_orig.equals(df_explode_sqlite)=}\")\n", + "print(f\"{df_explode_postgres.equals(df_explode_sqlite)=}\")\n", + "pd.concat(\n", + " [\n", + " df_explode_orig.sample(5, random_state=42),\n", + " df_explode_postgres.sample(5, random_state=42),\n", + " df_explode_sqlite.sample(5, random_state=42),\n", + " ]\n", + ").sort_index()" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index d7bc799581..59babcabc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ dependencies = [ "openinference-instrumentation-langchain>=0.1.12", "openinference-instrumentation-llama-index>=1.2.0", "openinference-instrumentation-openai>=0.1.4", - "sqlalchemy[asyncio]>=2, <3", + "sqlalchemy[asyncio]>=2.0.4, <3", "alembic>=1.3.0, <2", "aiosqlite", ] @@ -144,6 +144,7 @@ dependencies = [ "respx", # For OpenAI testing "nest-asyncio", # for executor testing "pyfakefs", # for experimental storage implementations + "astunparse; python_version<'3.9'", ] [tool.hatch.envs.type] diff --git a/src/phoenix/core/project.py b/src/phoenix/core/project.py index 4d554e594b..d7e3bd9b62 100644 --- a/src/phoenix/core/project.py +++ b/src/phoenix/core/project.py @@ -37,6 +37,7 @@ SpanStatusCode, TraceID, ) +from phoenix.utilities.attributes import get_attribute_value logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -64,7 +65,7 @@ def get_computed_value(self, key: str) -> Optional[Union[float, int]]: def __getitem__(self, key: Union[str, ComputedAttributes]) -> Any: if isinstance(key, ComputedAttributes): return self._self_computed_values.get(key) - return self.__wrapped__.attributes.get(key) + return get_attribute_value(self.__wrapped__.attributes, key) def __setitem__(self, key: ComputedAttributes, value: Any) -> None: if not isinstance(key, ComputedAttributes): @@ -508,10 +509,16 @@ def _add_span_to_trace(self, span: WrappedSpan) -> None: def _update_cached_statistics(self, span: WrappedSpan) -> None: # Update statistics for quick access later span_id = span.context.span_id - if token_count_update := span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_TOTAL): + if token_count_update := get_attribute_value( + span.attributes, SpanAttributes.LLM_TOKEN_COUNT_TOTAL + ): self._token_count_total += token_count_update if isinstance( - (retrieval_documents := span.attributes.get(SpanAttributes.RETRIEVAL_DOCUMENTS)), + ( + retrieval_documents := get_attribute_value( + span.attributes, SpanAttributes.RETRIEVAL_DOCUMENTS + ) + ), Sized, ) and (num_documents_update := len(retrieval_documents)): self._num_documents[span_id] += num_documents_update diff --git a/src/phoenix/db/bulk_inserter.py b/src/phoenix/db/bulk_inserter.py index 457fd14f71..c9fa0df537 100644 --- a/src/phoenix/db/bulk_inserter.py +++ b/src/phoenix/db/bulk_inserter.py @@ -22,6 +22,7 @@ from phoenix.db import models from phoenix.exceptions import PhoenixException from phoenix.trace.schemas import Span, SpanStatusCode +from phoenix.utilities.attributes import get_attribute_value logger = logging.getLogger(__name__) @@ -229,7 +230,6 @@ async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> .values( start_time=trace_start_time, end_time=trace_end_time, - latency_ms=(trace_end_time - trace_start_time).total_seconds() * 1000, ) ) else: @@ -242,17 +242,16 @@ async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> trace_id=span.context.trace_id, start_time=span.start_time, end_time=span.end_time, - latency_ms=(span.end_time - span.start_time).total_seconds() * 1000, ) .returning(models.Trace.id) ), ) cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) cumulative_llm_token_count_prompt = cast( - int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 0) + int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0 ) cumulative_llm_token_count_completion = cast( - int, span.attributes.get(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 0) + int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_COMPLETION) or 0 ) if accumulation := ( await session.execute( @@ -260,27 +259,25 @@ async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> func.sum(models.Span.cumulative_error_count), func.sum(models.Span.cumulative_llm_token_count_prompt), func.sum(models.Span.cumulative_llm_token_count_completion), - ).where(models.Span.parent_span_id == span.context.span_id) + ).where(models.Span.parent_id == span.context.span_id) ) ).first(): cumulative_error_count += cast(int, accumulation[0] or 0) cumulative_llm_token_count_prompt += cast(int, accumulation[1] or 0) cumulative_llm_token_count_completion += cast(int, accumulation[2] or 0) - latency_ms = (span.end_time - span.start_time).total_seconds() * 1000 session.add( models.Span( span_id=span.context.span_id, trace_rowid=trace_rowid, - parent_span_id=span.parent_id, - kind=span.span_kind.value, + parent_id=span.parent_id, + span_kind=span.span_kind.value, name=span.name, start_time=span.start_time, end_time=span.end_time, attributes=span.attributes, events=span.events, - status=span.status_code.value, + status_code=span.status_code.value, status_message=span.status_message, - latency_ms=latency_ms, cumulative_error_count=cumulative_error_count, cumulative_llm_token_count_prompt=cumulative_llm_token_count_prompt, cumulative_llm_token_count_completion=cumulative_llm_token_count_completion, @@ -288,17 +285,17 @@ async def _insert_span(session: AsyncSession, span: Span, project_name: str) -> ) # Propagate cumulative values to ancestors. This is usually a no-op, since # the parent usually arrives after the child. But in the event that a - # child arrives after its parent, we need to make sure the all the + # child arrives after its parent, we need to make sure that all the # ancestors' cumulative values are updated. ancestors = ( - select(models.Span.id, models.Span.parent_span_id) + select(models.Span.id, models.Span.parent_id) .where(models.Span.span_id == span.parent_id) .cte(recursive=True) ) child = ancestors.alias() ancestors = ancestors.union_all( - select(models.Span.id, models.Span.parent_span_id).join( - child, models.Span.span_id == child.c.parent_span_id + select(models.Span.id, models.Span.parent_id).join( + child, models.Span.span_id == child.c.parent_id ) ) await session.execute( diff --git a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py index 1f59b17e71..da262cb10e 100644 --- a/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py +++ b/src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py @@ -61,7 +61,6 @@ def upgrade() -> None: sa.Column("trace_id", sa.String, nullable=False, unique=True), sa.Column("start_time", sa.TIMESTAMP(timezone=True), nullable=False, index=True), sa.Column("end_time", sa.TIMESTAMP(timezone=True), nullable=False), - sa.Column("latency_ms", sa.Float, nullable=False), ) op.create_table( @@ -75,24 +74,23 @@ def upgrade() -> None: index=True, ), sa.Column("span_id", sa.String, nullable=False, unique=True), - sa.Column("parent_span_id", sa.String, nullable=True, index=True), + sa.Column("parent_id", sa.String, nullable=True, index=True), sa.Column("name", sa.String, nullable=False), - sa.Column("kind", sa.String, nullable=False), + sa.Column("span_kind", sa.String, nullable=False), sa.Column("start_time", sa.TIMESTAMP(timezone=True), nullable=False), sa.Column("end_time", sa.TIMESTAMP(timezone=True), nullable=False), sa.Column("attributes", JSON_, nullable=False), sa.Column("events", JSON_, nullable=False), sa.Column( - "status", + "status_code", sa.String, # TODO(mikeldking): this doesn't seem to work... - sa.CheckConstraint("status IN ('OK', 'ERROR', 'UNSET')", "valid_status"), + sa.CheckConstraint("status_code IN ('OK', 'ERROR', 'UNSET')", "valid_status"), nullable=False, default="UNSET", server_default="UNSET", ), sa.Column("status_message", sa.String, nullable=False), - sa.Column("latency_ms", sa.Float, nullable=False), sa.Column("cumulative_error_count", sa.Integer, nullable=False), sa.Column("cumulative_llm_token_count_prompt", sa.Integer, nullable=False), sa.Column("cumulative_llm_token_count_completion", sa.Integer, nullable=False), diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 8044fe53e8..cc2ce36872 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -5,7 +5,9 @@ JSON, TIMESTAMP, CheckConstraint, + ColumnElement, Dialect, + Float, ForeignKey, MetaData, TypeDecorator, @@ -15,6 +17,8 @@ ) from sqlalchemy.dialects import postgresql from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import ( DeclarativeBase, Mapped, @@ -22,6 +26,7 @@ mapped_column, relationship, ) +from sqlalchemy.sql import expression JSON_ = JSON().with_variant( postgresql.JSONB(), # type: ignore @@ -120,7 +125,15 @@ class Trace(Base): trace_id: Mapped[str] start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True) end_time: Mapped[datetime] = mapped_column(UtcTimeStamp) - latency_ms: Mapped[float] + + @hybrid_property + def latency_ms(self) -> float: + return (self.end_time - self.start_time).total_seconds() * 1000 + + @latency_ms.inplace.expression + @classmethod + def _latency_ms_expression(cls) -> ColumnElement[float]: + return LatencyMs(cls.start_time, cls.end_time) project: Mapped["Project"] = relationship( "Project", @@ -149,24 +162,36 @@ class Span(Base): index=True, ) span_id: Mapped[str] - parent_span_id: Mapped[Optional[str]] = mapped_column(index=True) + parent_id: Mapped[Optional[str]] = mapped_column(index=True) name: Mapped[str] - kind: Mapped[str] + span_kind: Mapped[str] start_time: Mapped[datetime] = mapped_column(UtcTimeStamp) end_time: Mapped[datetime] = mapped_column(UtcTimeStamp) attributes: Mapped[Dict[str, Any]] events: Mapped[List[Dict[str, Any]]] - status: Mapped[str] = mapped_column( - CheckConstraint("status IN ('OK', 'ERROR', 'UNSET')", name="valid_status") + status_code: Mapped[str] = mapped_column( + CheckConstraint("status_code IN ('OK', 'ERROR', 'UNSET')", name="valid_status") ) status_message: Mapped[str] # TODO(mikeldking): is computed columns possible here - latency_ms: Mapped[float] cumulative_error_count: Mapped[int] cumulative_llm_token_count_prompt: Mapped[int] cumulative_llm_token_count_completion: Mapped[int] + @hybrid_property + def latency_ms(self) -> float: + return (self.end_time - self.start_time).total_seconds() * 1000 + + @latency_ms.inplace.expression + @classmethod + def _latency_ms_expression(cls) -> ColumnElement[float]: + return LatencyMs(cls.start_time, cls.end_time) + + @hybrid_property + def cumulative_llm_token_count_total(self) -> int: + return self.cumulative_llm_token_count_prompt + self.cumulative_llm_token_count_completion + trace: Mapped["Trace"] = relationship("Trace", back_populates="spans") document_annotations: Mapped[List["DocumentAnnotation"]] = relationship(back_populates="span") @@ -179,6 +204,31 @@ class Span(Base): ) +class LatencyMs(expression.FunctionElement[float]): + inherit_cache = True + type = Float() + name = "latency_ms" + + +@compiles(LatencyMs) # type: ignore +def _(element: Any, compiler: Any, **kw: Any) -> Any: + start_time, end_time = list(element.clauses) + return compiler.process( + (func.extract("EPOCH", end_time) - func.extract("EPOCH", start_time)) * 1000, **kw + ) + + +@compiles(LatencyMs, "sqlite") # type: ignore +def _(element: Any, compiler: Any, **kw: Any) -> Any: + start_time, end_time = list(element.clauses) + return compiler.process( + # FIXME: We don't know why sqlite returns a slightly different value. + # postgresql is correct because it matches the value computed by Python. + (func.unixepoch(end_time, "subsec") - func.unixepoch(start_time, "subsec")) * 1000, + **kw, + ) + + async def init_models(engine: AsyncEngine) -> None: async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) diff --git a/src/phoenix/server/api/input_types/SpanSort.py b/src/phoenix/server/api/input_types/SpanSort.py index 042fc02896..99050887be 100644 --- a/src/phoenix/server/api/input_types/SpanSort.py +++ b/src/phoenix/server/api/input_types/SpanSort.py @@ -15,9 +15,9 @@ from phoenix.server.api.types.SortDir import SortDir from phoenix.trace.schemas import SpanID -LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT -LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION -LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL +LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT.split(".") +LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION.split(".") +LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL.split(".") @strawberry.enum diff --git a/src/phoenix/server/api/routers/v1/spans.py b/src/phoenix/server/api/routers/v1/spans.py index 403fb6f4a3..a7816da41b 100644 --- a/src/phoenix/server/api/routers/v1/spans.py +++ b/src/phoenix/server/api/routers/v1/spans.py @@ -1,5 +1,3 @@ -import asyncio -from functools import partial from typing import AsyncIterator from starlette.requests import Request @@ -7,10 +5,8 @@ from starlette.status import HTTP_404_NOT_FOUND, HTTP_422_UNPROCESSABLE_ENTITY from phoenix.config import DEFAULT_PROJECT_NAME -from phoenix.core.traces import Traces from phoenix.server.api.routers.utils import df_to_bytes, from_iso_format from phoenix.trace.dsl import SpanQuery -from phoenix.utilities import query_spans # TODO: Add property details to SpanQuery schema @@ -67,7 +63,6 @@ async def query_spans_handler(request: Request) -> Response: 422: description: Request body is invalid """ - traces: Traces = request.app.state.traces payload = await request.json() queries = payload.pop("queries", []) project_name = ( @@ -76,42 +71,25 @@ async def query_spans_handler(request: Request) -> Response: or request.headers.get("project-name") or DEFAULT_PROJECT_NAME ) - if not (project := traces.get_project(project_name)): - return Response(status_code=HTTP_404_NOT_FOUND) - loop = asyncio.get_running_loop() - valid_eval_names = ( - await loop.run_in_executor( - None, - project.get_span_evaluation_names, - ) - if project - else () - ) try: - span_queries = [ - SpanQuery.from_dict( - query, - evals=project, - valid_eval_names=valid_eval_names, - ) - for query in queries - ] + span_queries = [SpanQuery.from_dict(query) for query in queries] except Exception as e: return Response( status_code=HTTP_422_UNPROCESSABLE_ENTITY, content=f"Invalid query: {e}", ) - results = await loop.run_in_executor( - None, - partial( - query_spans, - project, - *span_queries, - start_time=from_iso_format(payload.get("start_time")), - stop_time=from_iso_format(payload.get("stop_time")), - root_spans_only=payload.get("root_spans_only"), - ), - ) + async with request.app.state.db() as session: + results = [] + for query in span_queries: + results.append( + await session.run_sync( + query, + project_name=project_name, + start_time=from_iso_format(payload.get("start_time")), + stop_time=from_iso_format(payload.get("stop_time")), + root_spans_only=payload.get("root_spans_only"), + ) + ) if not results: return Response(status_code=HTTP_404_NOT_FOUND) @@ -131,54 +109,4 @@ async def get_spans_handler(request: Request) -> Response: operationId: legacyQuerySpans deprecated: true """ - traces: Traces = request.app.state.traces - payload = await request.json() - queries = payload.pop("queries", []) - project_name = request.query_params.get("project_name", DEFAULT_PROJECT_NAME) - if not (project := traces.get_project(project_name)): - return Response(status_code=HTTP_404_NOT_FOUND) - loop = asyncio.get_running_loop() - valid_eval_names = ( - await loop.run_in_executor( - None, - project.get_span_evaluation_names, - ) - if project - else () - ) - try: - span_queries = [ - SpanQuery.from_dict( - query, - evals=project, - valid_eval_names=valid_eval_names, - ) - for query in queries - ] - except Exception as e: - return Response( - status_code=HTTP_422_UNPROCESSABLE_ENTITY, - content=f"Invalid query: {e}", - ) - results = await loop.run_in_executor( - None, - partial( - query_spans, - project, - *span_queries, - start_time=from_iso_format(payload.get("start_time")), - stop_time=from_iso_format(payload.get("stop_time")), - root_spans_only=payload.get("root_spans_only"), - ), - ) - if not results: - return Response(status_code=HTTP_404_NOT_FOUND) - - async def content() -> AsyncIterator[bytes]: - for result in results: - yield df_to_bytes(result) - - return StreamingResponse( - content=content(), - media_type="application/x-pandas-arrow", - ) + return await query_spans_handler(request) diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index f1350499e2..cd16f3c0dc 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -190,9 +190,11 @@ async def spans( parent = select(models.Span.span_id).alias() stmt = stmt.outerjoin( parent, - models.Span.parent_span_id == parent.c.span_id, + models.Span.parent_id == parent.c.span_id, ).where(parent.c.span_id.is_(None)) - # TODO(persistence): enable filter + if filter_condition: + span_filter = SpanFilter(condition=filter_condition) + stmt = span_filter(stmt) if sort: stmt = stmt.order_by(sort.to_orm_expr()) async with info.context.db() as session: @@ -367,5 +369,5 @@ def validate_span_filter_condition(self, condition: str) -> ValidationResult: ) -LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT -LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION +LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT.split(".") +LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION.split(".") diff --git a/src/phoenix/server/api/types/Span.py b/src/phoenix/server/api/types/Span.py index 5071737275..03e828c405 100644 --- a/src/phoenix/server/api/types/Span.py +++ b/src/phoenix/server/api/types/Span.py @@ -1,8 +1,7 @@ import json -from collections import defaultdict from datetime import datetime from enum import Enum -from typing import Any, DefaultDict, Dict, Iterable, List, Mapping, Optional, Sized, cast +from typing import Any, List, Mapping, Optional, Sized, cast import numpy as np import strawberry @@ -21,6 +20,7 @@ from phoenix.server.api.types.Evaluation import DocumentEvaluation, SpanEvaluation from phoenix.server.api.types.MimeType import MimeType from phoenix.trace.schemas import SpanID +from phoenix.utilities.attributes import get_attribute_value EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR @@ -54,7 +54,9 @@ class SpanKind(Enum): @classmethod def _missing_(cls, v: Any) -> Optional["SpanKind"]: - return None if v else cls.unknown + if v and isinstance(v, str) and not v.isupper(): + return cls(v.upper()) + return cls.unknown @strawberry.type @@ -206,14 +208,14 @@ async def descendants( async with info.context.db() as session: descendant_ids = ( select(models.Span.id, models.Span.span_id) - .filter(models.Span.parent_span_id == str(self.context.span_id)) + .filter(models.Span.parent_id == str(self.context.span_id)) .cte(recursive=True) ) parent_ids = descendant_ids.alias() descendant_ids = descendant_ids.union_all( select(models.Span.id, models.Span.span_id).join( parent_ids, - models.Span.parent_span_id == parent_ids.c.span_id, + models.Span.parent_id == parent_ids.c.span_id, ) ) spans = await session.scalars( @@ -227,18 +229,18 @@ async def descendants( def to_gql_span(span: models.Span, project: Project) -> Span: events: List[SpanEvent] = list(map(SpanEvent.from_dict, span.events)) - input_value = cast(Optional[str], span.attributes.get(INPUT_VALUE)) - output_value = cast(Optional[str], span.attributes.get(OUTPUT_VALUE)) - retrieval_documents = span.attributes.get(RETRIEVAL_DOCUMENTS) + input_value = cast(Optional[str], get_attribute_value(span.attributes, INPUT_VALUE)) + output_value = cast(Optional[str], get_attribute_value(span.attributes, OUTPUT_VALUE)) + retrieval_documents = get_attribute_value(span.attributes, RETRIEVAL_DOCUMENTS) num_documents = len(retrieval_documents) if isinstance(retrieval_documents, Sized) else None return Span( project=project, span_rowid=span.id, name=span.name, - status_code=SpanStatusCode(span.status), + status_code=SpanStatusCode(span.status_code), status_message=span.status_message, - parent_id=cast(Optional[ID], span.parent_span_id), - span_kind=SpanKind(span.kind), + parent_id=cast(Optional[ID], span.parent_id), + span_kind=SpanKind(span.span_kind), start_time=span.start_time, end_time=span.end_time, latency_ms=span.latency_ms, @@ -246,35 +248,36 @@ def to_gql_span(span: models.Span, project: Project) -> Span: trace_id=cast(ID, span.trace.trace_id), span_id=cast(ID, span.span_id), ), - attributes=json.dumps( - _nested_attributes(_hide_embedding_vectors(span.attributes)), - cls=_JSONEncoder, - ), - metadata=_convert_metadata_to_string(span.attributes.get(METADATA)), + attributes=json.dumps(span.attributes, cls=_JSONEncoder), + # TODO(persistence): hide the embedding vectors as a string instead, + # e.g. f"<{len(vector)} dimensional vector>" + metadata=_convert_metadata_to_string(get_attribute_value(span.attributes, METADATA)), num_documents=num_documents, token_count_total=cast( Optional[int], - span.attributes.get(LLM_TOKEN_COUNT_TOTAL), + get_attribute_value(span.attributes, LLM_TOKEN_COUNT_TOTAL), ), token_count_prompt=cast( Optional[int], - span.attributes.get(LLM_TOKEN_COUNT_PROMPT), + get_attribute_value(span.attributes, LLM_TOKEN_COUNT_PROMPT), ), token_count_completion=cast( Optional[int], - span.attributes.get(LLM_TOKEN_COUNT_COMPLETION), + get_attribute_value(span.attributes, LLM_TOKEN_COUNT_COMPLETION), ), cumulative_token_count_total=span.cumulative_llm_token_count_prompt + span.cumulative_llm_token_count_completion, cumulative_token_count_prompt=span.cumulative_llm_token_count_prompt, cumulative_token_count_completion=span.cumulative_llm_token_count_completion, propagated_status_code=( - SpanStatusCode.ERROR if span.cumulative_error_count else SpanStatusCode(span.status) + SpanStatusCode.ERROR + if span.cumulative_error_count + else SpanStatusCode(span.status_code) ), events=events, input=( SpanIOValue( - mime_type=MimeType(span.attributes.get(INPUT_MIME_TYPE)), + mime_type=MimeType(get_attribute_value(span.attributes, INPUT_MIME_TYPE)), value=input_value, ) if input_value is not None @@ -282,7 +285,7 @@ def to_gql_span(span: models.Span, project: Project) -> Span: ), output=( SpanIOValue( - mime_type=MimeType(span.attributes.get(OUTPUT_MIME_TYPE)), + mime_type=MimeType(get_attribute_value(span.attributes, OUTPUT_MIME_TYPE)), value=output_value, ) if output_value is not None @@ -306,39 +309,6 @@ def default(self, obj: Any) -> Any: return super().default(obj) -def _trie() -> DefaultDict[str, Any]: - return defaultdict(_trie) - - -def _nested_attributes( - attributes: Mapping[str, Any], -) -> DefaultDict[str, Any]: - nested_attributes = _trie() - for attribute_name, attribute_value in attributes.items(): - trie = nested_attributes - keys = attribute_name.split(".") - for key in keys[:-1]: - trie = trie[key] - trie[keys[-1]] = attribute_value - return nested_attributes - - -def _hide_embedding_vectors( - attributes: Mapping[str, Any], -) -> Dict[str, Any]: - _attributes = dict(attributes) - if not isinstance((embeddings := _attributes.get(EMBEDDING_EMBEDDINGS)), Iterable): - return _attributes - _embeddings = [] - for embedding in embeddings: - _embedding = dict(embedding) - if isinstance((vector := _embedding.get(EMBEDDING_VECTOR)), Sized): - _embedding[EMBEDDING_VECTOR] = f"<{len(vector)} dimensional vector>" - _embeddings.append(_embedding) - _attributes[EMBEDDING_EMBEDDINGS] = _embeddings - return _attributes - - def _convert_metadata_to_string(metadata: Any) -> Optional[str]: """ Converts metadata to a string representation. diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 0e91055a71..b10560943b 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -243,7 +243,7 @@ def create_app( ) ) initial_batch_of_evaluations = () if initial_evaluations is None else initial_evaluations - engine = create_engine(database) + engine = create_engine(database, echo=True) db = _db(engine) graphql = GraphQLWithContext( db=db, @@ -301,6 +301,7 @@ def create_app( ), ], ) + app.state.db = db app.state.traces = traces app.state.store = span_store app.state.read_only = read_only diff --git a/src/phoenix/session/client.py b/src/phoenix/session/client.py index 3565580dbe..502ac99276 100644 --- a/src/phoenix/session/client.py +++ b/src/phoenix/session/client.py @@ -2,7 +2,7 @@ import weakref from datetime import datetime from io import BytesIO -from typing import List, Optional, Union, cast +from typing import Any, List, Optional, Union, cast from urllib.parse import urljoin import pandas as pd @@ -10,7 +10,6 @@ from pyarrow import ArrowInvalid from requests import Session -import phoenix as px from phoenix.config import ( get_env_collector_endpoint, get_env_host, @@ -25,24 +24,14 @@ class Client(TraceDataExtractor): - def __init__( - self, - *, - endpoint: Optional[str] = None, - use_active_session_if_available: bool = True, - ): + def __init__(self, *, endpoint: Optional[str] = None, **kwargs: Any): """ Client for connecting to a Phoenix server. Args: endpoint (str, optional): Phoenix server endpoint, e.g. http://localhost:6006. If not provided, the endpoint will be inferred from the environment variables. - use_active_session_if_available (bool, optional): If px.active_session() is available - in the same runtime, e.g. the same Jupyter notebook, delegate the request to the - active session instead of making HTTP requests. This argument is set to False if - endpoint is provided explicitly. """ - self._use_active_session_if_available = use_active_session_if_available and not endpoint host = get_env_host() if host == "0.0.0.0": host = "127.0.0.1" @@ -51,8 +40,7 @@ def __init__( ) self._session = Session() weakref.finalize(self, self._session.close) - if not (self._use_active_session_if_available and px.active_session()): - self._warn_if_phoenix_is_not_running() + self._warn_if_phoenix_is_not_running() def query_spans( self, @@ -80,14 +68,6 @@ def query_spans( project_name = project_name or get_env_project_name() if not queries: queries = (SpanQuery(),) - if self._use_active_session_if_available and (session := px.active_session()): - return session.query_spans( - *queries, - start_time=start_time, - stop_time=stop_time, - root_spans_only=root_spans_only, - project_name=project_name, - ) response = self._session.post( url=urljoin(self._base_url, "/v1/spans"), params={"project_name": project_name}, @@ -134,8 +114,6 @@ def get_evaluations( empty list if no evaluations are found. """ project_name = project_name or get_env_project_name() - if self._use_active_session_if_available and (session := px.active_session()): - return session.get_evaluations(project_name=project_name) response = self._session.get( urljoin(self._base_url, "/v1/evaluations"), params={"project_name": project_name}, diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index 956585d255..223bbddf09 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -35,7 +35,6 @@ get_env_database_connection_str, get_env_host, get_env_port, - get_env_project_name, get_exported_files, get_working_dir, ) @@ -52,7 +51,6 @@ from phoenix.trace import Evaluations from phoenix.trace.dsl.query import SpanQuery from phoenix.trace.trace_dataset import TraceDataset -from phoenix.utilities import query_spans from phoenix.utilities.span_store import get_span_store, load_traces_data_from_store try: @@ -131,6 +129,77 @@ def __init__( self.exported_data = ExportedData() self.notebook_env = notebook_env or _get_notebook_environment() self.root_path = _get_root_path(self.notebook_env, self.port) + host = "127.0.0.1" if self.host == "0.0.0.0" else self.host + self._client = Client( + endpoint=f"http://{host}:{self.port}", + use_active_session_if_available=False, + ) + + def query_spans( + self, + *queries: SpanQuery, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + root_spans_only: Optional[bool] = None, + project_name: Optional[str] = None, + ) -> Optional[Union[pd.DataFrame, List[pd.DataFrame]]]: + """ + Queries the spans in the project based on the provided parameters. + + Parameters + ---------- + queries : *SpanQuery + Variable-length argument list of SpanQuery objects representing + the queries to be executed. + + start_time : datetime, optional + datetime representing the start time of the query. + + stop_time : datetime, optional + datetime representing the stop time of the query. + + root_spans_only : boolean, optional + whether to include only root spans in the results. + + project_name : string, optional + name of the project to query. Defaults to the project name set + in the environment variable `PHOENIX_PROJECT_NAME` or 'default' if not set. + + Returns: + results : DataFrame + DataFrame or list of DataFrames containing the query results. + """ + return self._client.query_spans( + *queries, + start_time=start_time, + stop_time=stop_time, + root_spans_only=root_spans_only, + project_name=project_name, + ) + + def get_evaluations( + self, + project_name: Optional[str] = None, + ) -> List[Evaluations]: + """ + Get the evaluations for a project. + + Parameters + ---------- + project_name : str, optional + The name of the project. If not provided, the project name set + in the environment variable `PHOENIX_PROJECT_NAME` will be used. + Otherwise, 'default' will be used. + + Returns + ------- + evaluations : List[Evaluations] + A list of evaluations for the specified project. + + """ + return self._client.get_evaluations( + project_name=project_name, + ) @abstractmethod def end(self) -> None: @@ -228,11 +297,6 @@ def __init__( self.trace_dataset.name if self.trace_dataset is not None else None ), ) - host = "127.0.0.1" if self.host == "0.0.0.0" else self.host - self._client = Client( - endpoint=f"http://{host}:{self.port}", - use_active_session_if_available=False, - ) @property def active(self) -> bool: @@ -242,28 +306,6 @@ def end(self) -> None: self.app_service.stop() self.temp_dir.cleanup() - def query_spans( - self, - *queries: SpanQuery, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - root_spans_only: Optional[bool] = None, - project_name: Optional[str] = None, - ) -> Optional[Union[pd.DataFrame, List[pd.DataFrame]]]: - return self._client.query_spans( - *queries, - start_time=start_time, - stop_time=stop_time, - root_spans_only=root_spans_only, - project_name=project_name, - ) - - def get_evaluations( - self, - project_name: Optional[str] = None, - ) -> List[Evaluations]: - return self._client.get_evaluations() - class ThreadSession(Session): def __init__( @@ -346,92 +388,6 @@ def end(self) -> None: self.server.close() self.temp_dir.cleanup() - def query_spans( - self, - *queries: SpanQuery, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - root_spans_only: Optional[bool] = None, - project_name: Optional[str] = None, - ) -> Optional[Union[pd.DataFrame, List[pd.DataFrame]]]: - """ - Queries the spans in the project based on the provided parameters. - - Parameters - ---------- - queries : *SpanQuery - Variable-length argument list of SpanQuery objects representing - the queries to be executed. - - start_time : datetime, optional - datetime representing the start time of the query. - - stop_time : datetime, optional - datetime representing the stop time of the query. - - root_spans_only : boolean, optional - whether to include only root spans in the results. - - project_name : string, optional - name of the project to query. Defaults to the project name set - in the environment variable `PHOENIX_PROJECT_NAME` or 'default' if not set. - - Returns: - results : DataFrame - DataFrame or list of DataFrames containing the query results. - """ - if not (traces := self.traces) or not ( - project := traces.get_project(project_name or get_env_project_name()) - ): - return None - if not queries: - queries = (SpanQuery(),) - valid_eval_names = project.get_span_evaluation_names() if project else () - queries = tuple( - SpanQuery.from_dict( - query.to_dict(), - evals=project, - valid_eval_names=valid_eval_names, - ) - for query in queries - ) - results = query_spans( - project, - *queries, - start_time=start_time, - stop_time=stop_time, - root_spans_only=root_spans_only, - ) - if len(results) == 1: - df = results[0] - return None if df.shape == (0, 0) else df - return results - - def get_evaluations( - self, - project_name: Optional[str] = None, - ) -> List[Evaluations]: - """ - Get the evaluations for a project. - - Parameters - ---------- - project_name : str, optional - The name of the project. If not provided, the project name set - in the environment variable `PHOENIX_PROJECT_NAME` will be used. - Otherwise, 'default' will be used. - - Returns - ------- - evaluations : List[Evaluations] - A list of evaluations for the specified project. - - """ - project_name = project_name or get_env_project_name() - if not (traces := self.traces) or not (project := traces.get_project(project_name)): - return [] - return project.export_evaluations() - def delete_all(prompt_before_delete: Optional[bool] = True) -> None: """ diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 3a55b2ab38..f24eaae02f 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -1,199 +1,417 @@ import ast -import inspect import sys +import typing from dataclasses import dataclass, field from difflib import SequenceMatcher -from typing import ( - Any, - Dict, - Iterable, - Iterator, - Mapping, - Optional, - Protocol, - Sequence, - Tuple, - cast, -) -from openinference.semconv import trace +import sqlalchemy +from sqlalchemy.sql.expression import Select from typing_extensions import TypeGuard import phoenix.trace.v1 as pb -from phoenix.trace.dsl.missing import MISSING -from phoenix.trace.schemas import ComputedAttributes, Span, SpanID +from phoenix.db import models +from phoenix.trace.schemas import SpanID -_VALID_EVAL_ATTRIBUTES: Tuple[str, ...] = tuple( +_VALID_EVAL_ATTRIBUTES: typing.Tuple[str, ...] = tuple( field.name for field in pb.Evaluation.Result.DESCRIPTOR.fields ) -class SupportsGetSpanEvaluation(Protocol): - def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluation]: ... +# Because postgresql is strongly typed, we cast JSON values to string +# by default unless it's hinted otherwise as done here. +_FLOAT_ATTRIBUTES: typing.FrozenSet[str] = frozenset( + { + "llm.token_count.completion", + "llm.token_count.prompt", + "llm.token_count.total", + } +) + +_STRING_NAMES = { + "span_id": models.Span.span_id, + "trace_id": models.Trace.trace_id, + "context.span_id": models.Span.span_id, + "context.trace_id": models.Trace.trace_id, + "parent_id": models.Span.parent_id, + "span_kind": models.Span.span_kind, + "name": models.Span.name, + "status_code": models.Span.status_code, + "status_message": models.Span.status_message, +} +_FLOAT_NAMES = { + "latency_ms": models.Span.latency_ms, + "cumulative_llm_token_count_completion": models.Span.cumulative_llm_token_count_completion, + "cumulative_llm_token_count_prompt": models.Span.cumulative_llm_token_count_prompt, + "cumulative_llm_token_count_total": models.Span.cumulative_llm_token_count_total, +} +# TODO(persistence): find a better home (and a better name) for _NAMES +_NAMES = { + **_STRING_NAMES, + **_FLOAT_NAMES, + "attributes": models.Span.attributes, + "events": models.Span.events, +} + + +# TODO(persistence): remove this protocol +class SupportsGetSpanEvaluation(typing.Protocol): + def get_span_evaluation(self, span_id: SpanID, name: str) -> typing.Optional[pb.Evaluation]: ... @dataclass(frozen=True) class SpanFilter: condition: str = "" - evals: Optional[SupportsGetSpanEvaluation] = None - valid_eval_names: Optional[Sequence[str]] = None + # TODO(persistence): remove `evals` and `valid_eval_names` from this class + evals: typing.Optional[SupportsGetSpanEvaluation] = None + valid_eval_names: typing.Optional[typing.Sequence[str]] = None translated: ast.Expression = field(init=False, repr=False) - compiled: Any = field(init=False, repr=False) + compiled: typing.Any = field(init=False, repr=False) def __bool__(self) -> bool: return bool(self.condition) def __post_init__(self) -> None: - condition = self.condition or "True" # default to no op - root = ast.parse(condition, mode="eval") - if self.condition: - _validate_expression(root, condition, valid_eval_names=self.valid_eval_names) - translated = _Translator(condition).visit(root) + if not (source := self.condition): + return + root = ast.parse(source, mode="eval") + _validate_expression(root, source, valid_eval_names=self.valid_eval_names) + translated = _Translator(source).visit(root) ast.fix_missing_locations(translated) compiled = compile(translated, filename="", mode="eval") object.__setattr__(self, "translated", translated) object.__setattr__(self, "compiled", compiled) - object.__setattr__(self, "evals", self.evals or MISSING) - def __call__(self, span: Span) -> bool: - return cast( - bool, + def __call__(self, select: Select[typing.Any]) -> Select[typing.Any]: + if not self.condition: + return select + return select.where( eval( self.compiled, - {"span": span, "_MISSING": MISSING, "evals": self.evals}, - ), + { + **_NAMES, + "not_": sqlalchemy.not_, + "and_": sqlalchemy.and_, + "or_": sqlalchemy.or_, + "cast": sqlalchemy.cast, + "Float": sqlalchemy.Float, + "String": sqlalchemy.String, + }, + ) ) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> typing.Dict[str, typing.Any]: return {"condition": self.condition} @classmethod def from_dict( cls, - obj: Mapping[str, Any], - evals: Optional[SupportsGetSpanEvaluation] = None, - valid_eval_names: Optional[Sequence[str]] = None, + obj: typing.Mapping[str, typing.Any], + # TODO(persistence): remove `evals` and `valid_eval_names` from this class + evals: typing.Optional[SupportsGetSpanEvaluation] = None, + valid_eval_names: typing.Optional[typing.Sequence[str]] = None, ) -> "SpanFilter": - return cls( - condition=obj.get("condition") or "", - evals=evals, - valid_eval_names=valid_eval_names, - ) + return cls(condition=obj.get("condition") or "") -def _replace_none_with_missing( - value: ast.expr, - as_str: bool = False, -) -> ast.IfExp: - """ - E.g. `value` becomes - `_MISSING if (_VALUE := value) is None else _VALUE` - """ - _store_VALUE = ast.Name(id="_VALUE", ctx=ast.Store()) - _load_VALUE = ast.Name(id="_VALUE", ctx=ast.Load()) - return ast.IfExp( - test=ast.Compare( - left=ast.NamedExpr(target=_store_VALUE, value=value), - ops=[ast.Is()], - comparators=[ast.Constant(value=None)], +def _is_string_constant(node: typing.Any) -> TypeGuard[ast.Constant]: + return isinstance(node, ast.Constant) and isinstance(node.value, str) + + +def _is_float_constant(node: typing.Any) -> TypeGuard[ast.Constant]: + return isinstance(node, ast.Constant) and isinstance(node.value, typing.SupportsFloat) + + +def _is_string_attribute(node: typing.Any) -> TypeGuard[ast.Call]: + return ( + isinstance(node, ast.Call) + and isinstance(func := node.func, ast.Attribute) + and func.attr == "as_string" + and isinstance(value := func.value, ast.Subscript) + and isinstance(name := value.value, ast.Name) + and name.id == "attributes" + ) + + +def _is_float_attribute(node: typing.Any) -> TypeGuard[ast.Call]: + return ( + isinstance(node, ast.Call) + and isinstance(func := node.func, ast.Attribute) + and func.attr == "as_float" + and isinstance(value := func.value, ast.Subscript) + and isinstance(name := value.value, ast.Name) + and name.id == "attributes" + ) + + +def _as_string_attribute(node: ast.Call) -> ast.Call: + return ast.Call( + func=ast.Attribute( + value=typing.cast(ast.Attribute, node.func).value, + attr="as_string", + ctx=ast.Load(), ), - body=ast.Name(id="_MISSING", ctx=ast.Load()), - orelse=_as_str(_load_VALUE) if as_str else _load_VALUE, + args=[], + keywords=[], ) -def _as_str(value: ast.expr) -> ast.Call: - """E.g. `value` becomes `str(value)`""" - return ast.Call(func=ast.Name(id="str", ctx=ast.Load()), args=[value], keywords=[]) +def _as_float_attribute(node: ast.Call) -> ast.Call: + return ast.Call( + func=ast.Attribute( + value=typing.cast(ast.Attribute, node.func).value, + attr="as_float", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) -def _ast_replacement(expression: str) -> ast.expr: - as_str = expression in ( - "span.status_code", - "span.span_kind", - "span.parent_id", - "span.context.span_id", - "span.context.trace_id", +def _is_cast( + node: typing.Any, + type_: typing.Optional[typing.Literal["Float", "String"]] = None, +) -> TypeGuard[ast.Call]: + return ( + isinstance(node, ast.Call) + and isinstance(func := node.func, ast.Name) + and func.id == "cast" + and len(node.args) == 2 + and isinstance(name := node.args[1], ast.Name) + and (not type_ or name.id == type_) ) - return _replace_none_with_missing(ast.parse(expression, mode="eval").body, as_str) - - -def _allowed_replacements() -> Iterator[Tuple[str, ast.expr]]: - for source_segment, ast_replacement in { - "name": _ast_replacement("span.name"), - "status_code": _ast_replacement("span.status_code"), - "span_kind": _ast_replacement("span.span_kind"), - "parent_id": _ast_replacement("span.parent_id"), - }.items(): - yield source_segment, ast_replacement - yield "span." + source_segment, ast_replacement - - for source_segment, ast_replacement in { - "span_id": _ast_replacement("span.context.span_id"), - "trace_id": _ast_replacement("span.context.trace_id"), - }.items(): - yield source_segment, ast_replacement - yield "context." + source_segment, ast_replacement - yield "span.context." + source_segment, ast_replacement - - for field_name in ( - getattr(klass, attr) - for name in dir(trace) - if name.endswith("Attributes") and inspect.isclass(klass := getattr(trace, name)) - for attr in dir(klass) - if attr.isupper() - ): - source_segment = field_name - ast_replacement = _ast_replacement(f"span.attributes.get('{field_name}')") - yield source_segment, ast_replacement - yield "attributes." + source_segment, ast_replacement - yield "span.attributes." + source_segment, ast_replacement - - for computed_attribute in ComputedAttributes: - source_segment = computed_attribute.value - ast_replacement = _ast_replacement(f"span.get_computed_value('{source_segment}')") - yield source_segment, ast_replacement -class _Translator(ast.NodeTransformer): - _allowed_fields: Mapping[str, ast.expr] = dict(_allowed_replacements()) +def _remove_cast(node: typing.Any) -> typing.Any: + return node.args[0] if _is_cast(node) else node + + +def _cast_as( + type_: typing.Literal["Float", "String"], + node: typing.Any, +) -> ast.Call: + if type_ == "Float" and _is_string(node): + if _is_string_attribute(node): + return _as_float_attribute(node) + if type_ == "String" and _is_float_attribute(node): + return _as_string_attribute(node) + return ast.Call( + func=ast.Name(id="cast", ctx=ast.Load()), + args=[ + _remove_cast(node), + ast.Name(id=type_, ctx=ast.Load()), + ], + keywords=[], + ) + +def _is_string(node: typing.Any) -> TypeGuard[ast.Call]: + return ( + isinstance(node, ast.Name) + and node.id in _STRING_NAMES + or _is_cast(node, "String") + or _is_string_constant(node) + or _is_string_attribute(node) + or isinstance(node, (ast.List, ast.Tuple)) + and len(node.elts) > 0 + and _is_string(node.elts[0]) + ) + + +def _is_float(node: typing.Any) -> TypeGuard[ast.Call]: + return ( + isinstance(node, ast.Name) + and node.id in _FLOAT_NAMES + or _is_cast(node, "Float") + or _is_float_constant(node) + or _is_float_attribute(node) + or isinstance(node, (ast.List, ast.Tuple)) + and len(node.elts) > 0 + and _is_float(node.elts[0]) + ) + + +def _split(key: str) -> typing.List[typing.Union[str, int]]: + return [int(part) if part.isdigit() else part for part in key.split(".")] + + +# TODO(persistence): support `evals['name'].score` et. al. +class _Translator(ast.NodeTransformer): def __init__(self, source: str) -> None: # Regarding the need for `source: str` for getting source segments: # In Python 3.8, we have to use `ast.get_source_segment(source, node)`. # In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`). self._source = source - def visit_Subscript(self, node: ast.Subscript) -> Any: - if _is_metadata(node) and (key := _get_subscript_key(node)): - return _ast_metadata_subscript(key) - source_segment: str = cast(str, ast.get_source_segment(self._source, node)) - raise SyntaxError(f"invalid expression: {source_segment}") # TODO: add details - - def visit_Attribute(self, node: ast.Attribute) -> Any: - if _is_eval(node.value) and (eval_name := _get_subscript_key(node.value)): - # e.g. `evals["name"].score` - return _ast_evaluation_result_value(eval_name, node.attr) - source_segment: str = cast(str, ast.get_source_segment(self._source, node)) - if replacement := self._allowed_fields.get(source_segment): - return replacement - raise SyntaxError(f"invalid expression: {source_segment}") # TODO: add details - - def visit_Name(self, node: ast.Name) -> Any: - source_segment: str = cast(str, ast.get_source_segment(self._source, node)) - if replacement := self._allowed_fields.get(source_segment): - return replacement - raise SyntaxError(f"invalid expression: {source_segment}") # TODO: add details - - def visit_Constant(self, node: ast.Constant) -> Any: - return ast.Name(id="_MISSING", ctx=ast.Load()) if node.value is None else node + def visit_Compare(self, node: ast.Compare) -> typing.Any: + if len(node.comparators) > 1: + args: typing.List[typing.Any] = [] + left = node.left + for i, (op, comparator) in enumerate(zip(node.ops, node.comparators)): + args.append(self.visit(ast.Compare(left=left, ops=[op], comparators=[comparator]))) + left = comparator + return ast.Call(func=ast.Name(id="and_", ctx=ast.Load()), args=args, keywords=[]) + left, op, right = self.visit(node.left), node.ops[0], self.visit(node.comparators[0]) + if _is_string(left) and _is_float(right): + left = _cast_as("Float", left) + elif _is_float(left) and _is_string(right): + right = _cast_as("Float", right) + if isinstance(op, (ast.In, ast.NotIn)): + if ( + _is_string_attribute(right) + or (typing.cast(str, ast.get_source_segment(self._source, right))) in _NAMES + ): + call = ast.Call( + # TODO(persistence): FIXME: This turns into `LIKE` which for sqlite is + # case-insensitive. We want case-sensitive matching for strings, + # so for sqlite we need to turn this into `GLOB` instead. + # TODO(persistence): FIXME: Special characters such as `%` for `LIKE` + # and `*` for `GLOB` need to be escaped. + func=ast.Attribute(value=right, attr="contains", ctx=ast.Load()), + args=[left], + keywords=[], + ) + if isinstance(op, ast.NotIn): + call = ast.Call( + func=ast.Name(id="not_", ctx=ast.Load()), args=[call], keywords=[] + ) + return call + elif isinstance(right, (ast.List, ast.Tuple)): + attr = "in_" if isinstance(op, ast.In) else "not_in" + return ast.Call( + func=ast.Attribute(value=left, attr=attr, ctx=ast.Load()), + args=[right], + keywords=[], + ) + else: + raise SyntaxError(f"invalid expression: {ast.get_source_segment(self._source, op)}") + if isinstance(op, ast.Is): + op = ast.Eq() + elif isinstance(op, ast.IsNot): + op = ast.NotEq() + return ast.Compare(left=left, ops=[op], comparators=[right]) + + def visit_BoolOp(self, node: ast.BoolOp) -> typing.Any: + if isinstance(node.op, ast.And): + func = ast.Name(id="and_", ctx=ast.Load()) + elif isinstance(node.op, ast.Or): + func = ast.Name(id="or_", ctx=ast.Load()) + else: + raise SyntaxError(f"invalid expression: {ast.get_source_segment(self._source, node)}") + args = [self.visit(value) for value in node.values] + return ast.Call(func=func, args=args, keywords=[]) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> typing.Any: + if isinstance(node.op, ast.Not): + return ast.Call( + func=ast.Name(id="not_", ctx=ast.Load()), + args=[self.visit(node.operand)], + keywords=[], + ) + if isinstance(node.op, ast.USub): + if _is_string_attribute(node.operand): + return _cast_as( + "Float", + ast.UnaryOp( + op=ast.USub(), + operand=_as_float_attribute(node.operand), + ), + ) + return _cast_as("Float", node) + return node + + def visit_BinOp(self, node: ast.BinOp) -> typing.Any: + left, right = self.visit(node.left), self.visit(node.right) + type_: typing.Literal["Float", "String"] = "String" + if _is_float(left) or _is_float(right): + type_ = "Float" + if _is_string_attribute(left): + left = _as_float_attribute(left) + elif _is_string_attribute(right): + right = _as_float_attribute(right) + return _cast_as( + type_, + ast.BinOp( + left=_remove_cast(left), + op=node.op, + right=_remove_cast(right), + ), + ) + def visit_Call(self, node: ast.Call) -> typing.Any: + source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) + if len(node.args) != 1: + raise SyntaxError(f"invalid expression: {source_segment}") + if not isinstance(node.func, ast.Name) or node.func.id not in ("str", "float", "int"): + raise SyntaxError( + f"invalid expression: {ast.get_source_segment(self._source, node.func)}" + ) + arg = self.visit(node.args[0]) + if node.func.id in ("float", "int") and _is_string(arg): + if _is_string_attribute(arg): + return _as_float_attribute(arg) + return _cast_as("Float", arg) + if node.func.id in ("str",) and _is_float(arg): + return _cast_as("String", arg) + return node + + def visit_Attribute(self, node: ast.Attribute) -> typing.Any: + source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) + if source_segment in _NAMES: + return node + attr = "as_float" if source_segment in _FLOAT_ATTRIBUTES else "as_string" + elts = [ast.Constant(value=part, kind=None) for part in _split(source_segment)] + return ast.Call( + func=ast.Attribute( + value=ast.Subscript( + value=ast.Name(id="attributes", ctx=ast.Load()), + slice=ast.List(elts=elts, ctx=ast.Load()) + if sys.version_info >= (3, 9) + else ast.Index(value=ast.List(elts=elts, ctx=ast.Load())), + ctx=ast.Load(), + ), + attr=attr, + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + def visit_Name(self, node: ast.Name) -> typing.Any: + source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) + if source_segment in _STRING_NAMES or source_segment in _FLOAT_NAMES: + return node + raise SyntaxError(f"invalid expression: {source_segment}") + + def visit_Subscript(self, node: ast.Subscript) -> typing.Any: + if _is_metadata(node): + elts = [ + ast.Constant(value="metadata", kind=None), + ast.Constant(value=typing.cast(str, _get_subscript_key(node)), kind=None), + ] + return ast.Call( + func=ast.Attribute( + value=ast.Subscript( + value=ast.Name(id="attributes", ctx=ast.Load()), + slice=ast.List(elts=elts, ctx=ast.Load()) + if sys.version_info >= (3, 9) + else ast.Index(value=ast.List(elts=elts, ctx=ast.Load())), + ctx=ast.Load(), + ), + attr="as_string", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + raise SyntaxError(f"invalid expression: {ast.get_source_segment(self._source, node)}") + + +# TODO(persistence): validate the expression def _validate_expression( expression: ast.Expression, source: str, - valid_eval_names: Optional[Sequence[str]] = None, - valid_eval_attributes: Tuple[str, ...] = _VALID_EVAL_ATTRIBUTES, + valid_eval_names: typing.Optional[typing.Sequence[str]] = None, + valid_eval_attributes: typing.Tuple[str, ...] = _VALID_EVAL_ATTRIBUTES, ) -> None: """ Validate primarily the structural (i.e. not semantic) characteristics of an @@ -208,7 +426,7 @@ def _validate_expression( # In Python 3.8, we have to use `ast.get_source_segment(source, node)`. # In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`). if not isinstance(expression, ast.Expression): - raise SyntaxError(f"invalid expression: {source}") # TODO: add details + raise SyntaxError(f"invalid expression: {source}") for i, node in enumerate(ast.walk(expression.body)): if i == 0: if isinstance(node, (ast.BoolOp, ast.Compare)): @@ -220,7 +438,7 @@ def _validate_expression( if not (eval_name := _get_subscript_key(node)) or ( valid_eval_names is not None and eval_name not in valid_eval_names ): - source_segment = cast(str, ast.get_source_segment(source, node)) + source_segment = typing.cast(str, ast.get_source_segment(source, node)) if eval_name and valid_eval_names: # suggest a valid eval name most similar to the one given choice, score = _find_best_match(eval_name, valid_eval_names) @@ -240,7 +458,7 @@ def _validate_expression( elif isinstance(node, ast.Attribute) and _is_eval(node.value): # e.g. `evals["name"].score` if (attr := node.attr) not in valid_eval_attributes: - source_segment = cast(str, ast.get_source_segment(source, node)) + source_segment = typing.cast(str, ast.get_source_segment(source, node)) # suggest a valid attribute most similar to the one given choice, score = _find_best_match(attr, valid_eval_attributes) if choice and score > 0.75: # arbitrary threshold @@ -256,6 +474,13 @@ def _validate_expression( else "" ) continue + elif ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id in ("str", "float", "int") + ): + # allow type casting functions + continue elif isinstance( node, ( @@ -281,41 +506,11 @@ def _validate_expression( ), ): continue - source_segment = cast(str, ast.get_source_segment(source, node)) - raise SyntaxError(f"invalid expression: {source_segment}") # TODO: add details - - -def _ast_evaluation_result_value(name: str, attr: str) -> ast.expr: - source = ( - f"_RESULT.{attr}.value if (" - f" _RESULT := (" - f" _MISSING if (" - f" _VALUE := evals.get_span_evaluation(" - f" span.context.span_id, '{name}'" - f" )" - f" ) is None " - f" else _VALUE" - f" ).result" - f").HasField('{attr}') " - f"else _MISSING" - ) - return ast.parse(source, mode="eval").body - - -def _ast_metadata_subscript(key: str) -> ast.expr: - source = ( - f"_MISSING if (" - f" _MD := span.attributes.get('metadata')" - f") is None else (" - f" _MISSING if not hasattr(_MD, 'get') or (" - f" _VALUE := _MD.get('{key}')" - f" ) is None else _VALUE" - f")" - ) - return ast.parse(source, mode="eval").body + source_segment = typing.cast(str, ast.get_source_segment(source, node)) + raise SyntaxError(f"invalid expression: {source_segment}") -def _is_eval(node: Any) -> TypeGuard[ast.Subscript]: +def _is_eval(node: typing.Any) -> TypeGuard[ast.Subscript]: # e.g. `evals["name"]` return ( isinstance(node, ast.Subscript) @@ -324,7 +519,15 @@ def _is_eval(node: Any) -> TypeGuard[ast.Subscript]: ) -def _is_metadata(node: Any) -> TypeGuard[ast.Subscript]: +def _is_attribute(node: typing.Any) -> TypeGuard[ast.Subscript]: + return ( + isinstance(node, ast.Subscript) + and isinstance(value := node.value, ast.Name) + and value.id == "attributes" + ) + + +def _is_metadata(node: typing.Any) -> TypeGuard[ast.Subscript]: # e.g. `metadata["name"]` return ( isinstance(node, ast.Subscript) @@ -333,7 +536,7 @@ def _is_metadata(node: Any) -> TypeGuard[ast.Subscript]: ) -def _get_subscript_key(node: ast.Subscript) -> Optional[str]: +def _get_subscript_key(node: ast.Subscript) -> typing.Optional[str]: if sys.version_info < (3, 9): # Note that `ast.Index` is deprecated in Python 3.9+, but is necessary # for Python 3.8 as part of `ast.Subscript`. @@ -352,7 +555,7 @@ def _get_subscript_key(node: ast.Subscript) -> Optional[str]: ) -def _disjunction(choices: Sequence[str]) -> str: +def _disjunction(choices: typing.Sequence[str]) -> str: """ E.g. `["a", "b", "c"]` becomes `"one of a, b, or c"` """ @@ -365,7 +568,9 @@ def _disjunction(choices: Sequence[str]) -> str: return f"one of {', '.join(choices[:-1])}, or {choices[-1]}" -def _find_best_match(source: str, choices: Iterable[str]) -> Tuple[Optional[str], float]: +def _find_best_match( + source: str, choices: typing.Iterable[str] +) -> typing.Tuple[typing.Optional[str], float]: best_choice, best_score = None, 0.0 for choice in choices: score = SequenceMatcher(None, source, choice).ratio() diff --git a/src/phoenix/trace/dsl/missing.py b/src/phoenix/trace/dsl/missing.py deleted file mode 100644 index 4b3aaeb214..0000000000 --- a/src/phoenix/trace/dsl/missing.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import Any - - -class _Missing: - """ - Falsify all comparisons except those with self; return self when getattr() - is called. Also, self is callable returning self. All this may seem peculiar - but is useful for getting the desired (and intuitive) behavior from any - boolean (i.e. comparative) expression without needing error handling when - missing values are encountered. `_Missing()` is intended to be a (fancier) - replacement for `None`. - """ - - def __lt__(self, _: Any) -> bool: - return False - - def __le__(self, _: Any) -> bool: - return False - - def __gt__(self, _: Any) -> bool: - return False - - def __ge__(self, _: Any) -> bool: - return False - - def __eq__(self, other: Any) -> bool: - return isinstance(other, _Missing) - - def __ne__(self, _: Any) -> bool: - return False - - def __len__(self) -> int: - return 0 - - def __iter__(self) -> Any: - return self - - def __next__(self) -> Any: - raise StopIteration() - - def __contains__(self, _: Any) -> bool: - return False - - def __str__(self) -> str: - return "" - - def __float__(self) -> float: - return float("nan") - - def __bool__(self) -> bool: - return False - - def __getattr__(self, _: Any) -> "_Missing": - return self - - def __call__(self, *_: Any, **__: Any) -> "_Missing": - return self - - -MISSING: Any = _Missing() diff --git a/src/phoenix/trace/dsl/query.py b/src/phoenix/trace/dsl/query.py index a272d49b43..780e9f0b1e 100644 --- a/src/phoenix/trace/dsl/query.py +++ b/src/phoenix/trace/dsl/query.py @@ -1,31 +1,44 @@ -import json from collections import defaultdict -from dataclasses import dataclass, field, fields, replace -from functools import cached_property, partial +from dataclasses import dataclass, field, replace +from datetime import datetime +from functools import cached_property +from itertools import chain +from random import randint, random from types import MappingProxyType from typing import ( Any, - Callable, - ClassVar, + DefaultDict, Dict, - Iterable, - Iterator, List, Mapping, Optional, Sequence, - Sized, - Tuple, cast, ) import pandas as pd from openinference.semconv.trace import SpanAttributes +from sqlalchemy import JSON, Column, Select, and_, func, select +from sqlalchemy.dialects.postgresql import aggregate_order_by +from sqlalchemy.orm import Session, aliased +from phoenix.config import DEFAULT_PROJECT_NAME +from phoenix.db import models from phoenix.trace.dsl import SpanFilter -from phoenix.trace.dsl.filter import SupportsGetSpanEvaluation -from phoenix.trace.schemas import ATTRIBUTE_PREFIX, CONTEXT_PREFIX, Span -from phoenix.trace.span_json_encoder import span_to_json +from phoenix.trace.dsl.filter import _NAMES, SupportsGetSpanEvaluation +from phoenix.trace.schemas import ATTRIBUTE_PREFIX +from phoenix.utilities.attributes import ( + JSON_STRING_ATTRIBUTES, + SEMANTIC_CONVENTIONS, + flatten, + get_attribute_value, + load_json_strings, + unflatten, +) + +# supported dialects +_SQLITE = "sqlite" +_POSTGRESQL = "postgresql" RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS @@ -39,60 +52,29 @@ "trace_id": "context.trace_id", } -# Because span_kind is an enum, it needs to be converted to string, -# so it's serializable by pyarrow. -_CONVERT_TO_STRING = ("span_kind",) - def _unalias(key: str) -> str: return _ALIASES.get(key, key) @dataclass(frozen=True) -class Projection: - key: str = "" - value: Callable[[Span], Any] = field(init=False, repr=False) - span_fields: ClassVar[Tuple[str, ...]] = tuple(f.name for f in fields(Span)) - - def __bool__(self) -> bool: - return bool(self.key) +class _Base: + """The sole purpose of this class is for `super().__post_init__()` to work""" def __post_init__(self) -> None: - key = _unalias(self.key) - object.__setattr__(self, "key", key) - if key.startswith(CONTEXT_PREFIX): - key = key[len(CONTEXT_PREFIX) :] - value = partial(self._from_context, key=key) - elif key.startswith(ATTRIBUTE_PREFIX): - key = self.key[len(ATTRIBUTE_PREFIX) :] - value = partial(self._from_attributes, key=key) - elif key in self.span_fields: - value = partial(self._from_span, key=key) - else: - value = partial(self._from_attributes, key=key) - if self.key in _CONVERT_TO_STRING: - object.__setattr__( - self, - "value", - lambda span: None if (v := value(span)) is None else str(v), - ) - else: - object.__setattr__(self, "value", value) + pass - def __call__(self, span: Span) -> Any: - return self.value(span) - @staticmethod - def _from_attributes(span: Span, key: str) -> Any: - return span.attributes.get(key) +@dataclass(frozen=True) +class Projection(_Base): + key: str = "" - @staticmethod - def _from_context(span: Span, key: str) -> Any: - return getattr(span.context, key, None) + def __post_init__(self) -> None: + super().__post_init__() + object.__setattr__(self, "key", _unalias(self.key)) - @staticmethod - def _from_span(span: Span, key: str) -> Any: - return getattr(span, key, None) + def __bool__(self) -> bool: + return bool(self.key) def to_dict(self) -> Dict[str, Any]: return {"key": self.key} @@ -105,61 +87,155 @@ def from_dict(cls, obj: Mapping[str, Any]) -> "Projection": @dataclass(frozen=True) -class Explosion(Projection): +class _HasTmpSuffix(_Base): + _tmp_suffix: str = field(init=False, repr=False) + """Ideally every column name should get a temporary random suffix that will + be removed at the end. This is necessary during query construction because + sqlalchemy is not always foolproof, so we should actively avoid name + collisions, which is increasingly likely as queries get more complex. The + suffix is randomized per instance. + """ + + def __post_init__(self) -> None: + super().__post_init__() + object.__setattr__(self, "_tmp_suffix", f"{randint(0, 10**6):06d}") + + def _remove_tmp_suffix(self, name: str) -> str: + if name.endswith(self._tmp_suffix): + return name[: -len(self._tmp_suffix)] + return name + + def _add_tmp_suffix(self, name: str) -> str: + if name.endswith(self._tmp_suffix): + return name + return name + self._tmp_suffix + + +@dataclass(frozen=True) +class Explosion(_HasTmpSuffix, Projection): kwargs: Mapping[str, str] = field(default_factory=lambda: MappingProxyType({})) primary_index_key: str = "context.span_id" - position_prefix: str = field(init=False, repr=False) - primary_index: Projection = field(init=False, repr=False) + _position_prefix: str = field(init=False, repr=False) + _primary_index: Projection = field(init=False, repr=False) + _array_tmp_col_label: str = field(init=False, repr=False) + """For sqlite we need to store the array in a temporary column to be able + to explode it later in pandas. `_array_tmp_col_label` is the name of this + temporary column. The temporary column will have a unique name + per instance. + """ def __post_init__(self) -> None: super().__post_init__() position_prefix = _PRESCRIBED_POSITION_PREFIXES.get(self.key, "") - object.__setattr__(self, "position_prefix", position_prefix) - object.__setattr__(self, "primary_index", Projection(self.primary_index_key)) + object.__setattr__(self, "_position_prefix", position_prefix) + object.__setattr__(self, "_primary_index", Projection(self.primary_index_key)) + object.__setattr__(self, "_array_tmp_col_label", f"__array_tmp_col_{random()}") @cached_property - def index_keys(self) -> Tuple[str, str]: - return (self.primary_index.key, f"{self.position_prefix}position") - - def with_primary_index_key(self, primary_index_key: str) -> "Explosion": - return replace(self, primary_index_key=primary_index_key) - - def __call__(self, span: Span) -> Iterator[Dict[str, Any]]: - if not isinstance(seq := self.value(span), Iterable): - return - has_mapping = False - for item in seq: - if isinstance(item, Mapping): - has_mapping = True - break - if not has_mapping: - for i, item in enumerate(seq): - if item is not None: - yield { - self.key: item, - self.primary_index.key: self.primary_index(span), - f"{self.position_prefix}position": i, - } - return - for i, item in enumerate(seq): - if not isinstance(item, Mapping): - continue - record = ( - {name: item.get(key) for name, key in self.kwargs.items()} - if self.kwargs - else dict(item) + def index_keys(self) -> List[str]: + return [self._primary_index.key, f"{self._position_prefix}position"] + + def with_primary_index_key(self, _: str) -> "Explosion": + print("`.with_primary_index_key(...)` is deprecated and wil be removed in the future.") + return self + + def update_sql(self, sql: Select[Any], dialect: str) -> Select[Any]: + array = models.Span.attributes[self.key.split(".")] + if dialect == _SQLITE: + # Because sqlite doesn't support `WITH ORDINALITY`, the order of + # the returned (table) values is not guaranteed. So we resort to + # post hoc processing using pandas. + return sql.where( + func.json_type(array) == "array", + ).add_columns( + array.label(self._array_tmp_col_label), ) - for v in record.values(): - if v is not None: - break - else: - record = {} - if not record: - continue - record[self.primary_index.key] = self.primary_index(span) - record[f"{self.position_prefix}position"] = i - yield record + elif dialect == _POSTGRESQL: + element = ( + func.jsonb_array_elements(array) + .table_valued( + Column("obj", JSON), + with_ordinality="position", + joins_implicitly=True, + ) + .render_derived() + ) + obj, position = element.c.obj, element.c.position + return sql.where( + and_( + func.jsonb_typeof(array) == "array", + func.jsonb_typeof(obj) == "object", + ) + ).add_columns( + # Use zero-based indexing for backward-compatibility. + (position - 1).label(f"{self._position_prefix}position"), + *( + ( + obj[key.split(".")].label(self._add_tmp_suffix(name)) + for name, key in self.kwargs.items() + ) + if self.kwargs + else (obj.label(self._array_tmp_col_label),) + ), + ) + raise NotImplementedError(f"Unsupported dialect: {dialect}") + + def update_df(self, df: pd.DataFrame, dialect: str) -> pd.DataFrame: + df = df.rename(self._remove_tmp_suffix, axis=1) + if df.empty: + columns = list( + set( + chain( + self.index_keys, + df.drop(self._array_tmp_col_label, axis=1, errors="ignore").columns, + self.kwargs.keys(), + ) + ) + ) + return pd.DataFrame(columns=columns).set_index(self.index_keys) + if dialect == _POSTGRESQL and not self.kwargs: + records = df.loc[:, self._array_tmp_col_label].map(flatten).map(dict).dropna() + return pd.concat( + [ + df.drop(self._array_tmp_col_label, axis=1), + pd.DataFrame.from_records(records.to_list(), index=records.index), + ], + axis=1, + ).set_index(self.index_keys) + if dialect == _SQLITE: + # Because sqlite doesn't support `WITH ORDINALITY`, the order of + # the returned (table) values is not guaranteed. So we resort to + # post hoc processing using pandas. + def _extract_values(array: List[Any]) -> List[Dict[str, Any]]: + if not self.kwargs: + return [ + { + **dict(flatten(obj)), + f"{self._position_prefix}position": i, + } + for i, obj in enumerate(array) + if isinstance(obj, Mapping) + ] + res: List[Dict[str, Any]] = [] + for i, obj in enumerate(array): + if not isinstance(obj, Mapping): + continue + values: Dict[str, Any] = {f"{self._position_prefix}position": i} + for name, key in self.kwargs.items(): + if (value := get_attribute_value(obj, key)) is not None: + values[name] = value + res.append(values) + return res + + records = df.loc[:, self._array_tmp_col_label].map(_extract_values).explode().dropna() + df_explode = pd.DataFrame.from_records(records.to_list(), index=records.index) + return ( + df.drop(self._array_tmp_col_label, axis=1) + .join(df_explode, how="outer") + .set_index(self.index_keys) + ) + return df.set_index(self.index_keys) def to_dict(self) -> Dict[str, Any]: return { @@ -186,27 +262,111 @@ def from_dict(cls, obj: Mapping[str, Any]) -> "Explosion": @dataclass(frozen=True) -class Concatenation(Projection): +class Concatenation(_HasTmpSuffix, Projection): kwargs: Mapping[str, str] = field(default_factory=lambda: MappingProxyType({})) separator: str = "\n\n" + _array_tmp_col_label: str = field(init=False, repr=False) + """For sqlite we need to store the array in a temporary column to be able + to concatenate it later in pandas. `_array_tmp_col_label` is the name of + this temporary column. The temporary column will have a unique name + per instance. + """ + + def __post_init__(self) -> None: + super().__post_init__() + object.__setattr__(self, "_array_tmp_col_label", f"__array_tmp_col_{random()}") + def with_separator(self, separator: str = "\n\n") -> "Concatenation": return replace(self, separator=separator) - def __call__(self, span: Span) -> Iterator[Tuple[str, str]]: - if not isinstance(seq := self.value(span), Iterable): - return - if not self.kwargs: - yield self.key, self.separator.join(map(str, seq)) - record = defaultdict(list) - for item in seq: - if not isinstance(item, Mapping): - continue - for k, v in self.kwargs.items(): - if value := item.get(v): - record[k].append(value) - for name, values in record.items(): - yield name, self.separator.join(map(str, values)) + def update_sql(self, stmt: Select[Any], dialect: str) -> Select[Any]: + array = models.Span.attributes[self.key.split(".")] + if dialect == _SQLITE: + # Because sqlite doesn't support WITH ORDINALITY, the order of + # the returned (table) values is not guaranteed. So we resort to + # post-processing using pandas. + return stmt.where( + func.json_type(array) == "array", + ).add_columns( + array.label(self._array_tmp_col_label), + ) + if dialect == _POSTGRESQL: + element = ( + ( + func.jsonb_array_elements(array) + if self.kwargs + else func.jsonb_array_elements_text(array) + ) + .table_valued( + Column("obj", JSON), + with_ordinality="position", + joins_implicitly=True, + ) + .render_derived() + ) + obj, position = element.c.obj, element.c.position + return ( + stmt.where( + and_( + func.jsonb_typeof(array) == "array", + *((func.jsonb_typeof(obj) == "object",) if self.kwargs else ()), + ) + ) + .add_columns( + *( + ( + func.string_agg( + obj[key.split(".")].as_string(), + aggregate_order_by(self.separator, position), # type: ignore + ).label(self._add_tmp_suffix(name)) + for name, key in self.kwargs.items() + ) + if self.kwargs + else ( + func.string_agg( + obj, + aggregate_order_by(self.separator, position), # type: ignore + ).label(self.key), + ) + ), + ) + .group_by(*stmt.columns.keys()) + ) + raise NotImplementedError(f"Unsupported dialect: {dialect}") + + def update_df(self, df: pd.DataFrame, dialect: str) -> pd.DataFrame: + df = df.rename(self._remove_tmp_suffix, axis=1) + if df.empty: + columns = list( + set( + chain( + df.drop(self._array_tmp_col_label, axis=1, errors="ignore").columns, + self.kwargs.keys(), + ) + ) + ) + return pd.DataFrame(columns=columns, index=df.index) + if dialect == _SQLITE: + # Because sqlite doesn't support WITH ORDINALITY, the order of + # the returned (table) values is not guaranteed. So we resort to + # post-processing using pandas. + def _concat_values(array: List[Any]) -> Dict[str, Any]: + if not self.kwargs: + return {self.key: self.separator.join(str(obj) for obj in array)} + values: DefaultDict[str, List[str]] = defaultdict(list) + for i, obj in enumerate(array): + if not isinstance(obj, Mapping): + continue + for name, key in self.kwargs.items(): + if (value := get_attribute_value(obj, key)) is not None: + values[name].append(str(value)) + return {k: self.separator.join(v) for k, v in values.items()} + + records = df.loc[:, self._array_tmp_col_label].map(_concat_values) + df_concat = pd.DataFrame.from_records(records.to_list(), index=records.index) + return df.drop(self._array_tmp_col_label, axis=1).join(df_concat, how="outer") + return df def to_dict(self) -> Dict[str, Any]: return { @@ -233,7 +393,7 @@ def from_dict(cls, obj: Mapping[str, Any]) -> "Concatenation": @dataclass(frozen=True) -class SpanQuery: +class SpanQuery(_HasTmpSuffix): _select: Mapping[str, Projection] = field(default_factory=lambda: MappingProxyType({})) _concat: Concatenation = field(default_factory=Concatenation) _explode: Explosion = field(default_factory=Explosion) @@ -241,6 +401,17 @@ class SpanQuery: _rename: Mapping[str, str] = field(default_factory=lambda: MappingProxyType({})) _index: Projection = field(default_factory=lambda: Projection("context.span_id")) + _pk_tmp_col_label: str = field(init=False, repr=False) + """We use `_pk_tmp_col_label` as a temporary column for storing + the row id, i.e. the primary key, of the spans table. This will help + us with joins without the risk of naming conflicts. The temporary + column will have a unique name per instance. + """ + + def __post_init__(self) -> None: + super().__post_init__() + object.__setattr__(self, "_pk_tmp_col_label", f"__pk_tmp_col_{random()}") + def __bool__(self) -> bool: return bool(self._select) or bool(self._filter) or bool(self._explode) or bool(self._concat) @@ -268,75 +439,116 @@ def rename(self, **kwargs: str) -> "SpanQuery": def with_index(self, key: str = "context.span_id") -> "SpanQuery": _index = Projection(key=key) - return replace(self, _index=_index) + return replace(self, _index=_index, _explode=replace(self._explode, primary_index_key=key)) def with_concat_separator(self, separator: str = "\n\n") -> "SpanQuery": _concat = self._concat.with_separator(separator) return replace(self, _concat=_concat) - def with_explode_primary_index_key(self, primary_index_key: str) -> "SpanQuery": - _explode = self._explode.with_primary_index_key(primary_index_key) - return replace(self, _explode=_explode) - - def __call__(self, spans: Iterable[Span]) -> pd.DataFrame: + def with_explode_primary_index_key(self, _: str) -> "SpanQuery": + print( + "`.with_explode_primary_index_key(...)` is deprecated and wil be " + "removed in the future. Use `.with_index(...)` instead." + ) + return self + + def __call__( + self, + session: Session, + project_name: str = DEFAULT_PROJECT_NAME, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + root_spans_only: Optional[bool] = None, + ) -> pd.DataFrame: + if not (self._select or self._explode or self._concat): + return _get_spans_dataframe( + session, + self._filter, + project_name, + start_time, + stop_time, + root_spans_only, + ) + assert session.bind is not None + dialect = session.bind.dialect.name + conn = session.connection() + index = _NAMES[self._index.key].label(self._add_tmp_suffix(self._index.key)) + row_id = models.Span.id.label(self._pk_tmp_col_label) + stmt = stmt0_orig = ( + # We do not allow `group_by` anything other than `row_id` because otherwise + # it's too complex for the post hoc processing step in pandas. + select(row_id) + .join_from(models.Span, models.Trace) + .join(models.Project) + .where(models.Project.name == project_name) + ) + stmt1_filter = None if self._filter: - spans = filter(self._filter, spans) - if self._explode: - spans = filter( - lambda span: (isinstance(seq := self._explode.value(span), Sized) and len(seq)), - spans, + stmt = stmt1_filter = self._filter(stmt) + stmt2_select = None + if self._select: + stmt = stmt2_select = stmt.add_columns( + *( + ( + models.Span.attributes[proj.key.split(".")] + if proj.key not in _NAMES + else _NAMES[proj.key] + ).label(self._add_tmp_suffix(name)) + for name, proj in self._select.items() + ) ) + stmt3_explode = None + if self._explode: + stmt = stmt3_explode = self._explode.update_sql(stmt, dialect) + df: Optional[pd.DataFrame] = None + # `concat` is separate because it has `group_by` but we can't always + # join to it as a subquery because it may require post hoc processing + # in pandas, so it's kept separate for simplicity. + df_concat: Optional[pd.DataFrame] = None + if self._explode or not self._concat: + if index.name not in stmt.selected_columns.keys(): + stmt = stmt.add_columns(index) + df = pd.read_sql(stmt, conn) if self._concat: - spans = filter( - lambda span: (isinstance(seq := self._concat.value(span), Sized) and len(seq)), - spans, - ) - if not (self._select or self._explode or self._concat): - if not (data := [json.loads(span_to_json(span)) for span in spans]): - return pd.DataFrame() - return ( - pd.json_normalize(data, max_level=1) - .rename(self._rename, axis=1, errors="ignore") - .set_index("context.span_id", drop=False) - ) - _selected: List[Dict[str, Any]] = [] - _exploded: List[Dict[str, Any]] = [] - for span in spans: - if self._select: - record = {name: proj(span) for name, proj in self._select.items()} - for v in record.values(): - if v is not None: - break - else: - record = {} - if self._concat: - record.update(self._concat(span)) - if record: - if self._index.key not in record: - record[self._index.key] = self._index(span) - _selected.append(record) - elif self._concat: - record = {self._index.key: self._index(span)} - record.update(self._concat(span)) - if record: - _selected.append(record) - if self._explode: - _exploded.extend(self._explode(span)) - if _selected: - select_df = pd.DataFrame(_selected) + if df is not None: + # We can't include stmt3_explode because it may be trying to + # explode the same column that we're trying to concatenate, + # resulting in duplicates. + stmt_no_explode = ( + stmt2_select + if stmt2_select is not None + else (stmt1_filter if stmt1_filter is not None else stmt0_orig) + ) + stmt4_concat = stmt_no_explode.with_only_columns(row_id) + else: + assert stmt3_explode is None + stmt4_concat = ( + stmt.add_columns(index) + if index.name not in stmt.selected_columns.keys() + else stmt + ) + stmt4_concat = self._concat.update_sql(stmt4_concat, dialect) + df_concat = pd.read_sql(stmt4_concat, conn) + df_concat = self._concat.update_df(df_concat, dialect) + assert df_concat is not None + if df is not None: + df_concat = df_concat.set_index(self._pk_tmp_col_label) + assert df is not None or df_concat is not None + if df is None: + assert df_concat is not None + df = df_concat.drop(self._pk_tmp_col_label, axis=1) + elif df_concat is not None: + df = df.set_index(self._pk_tmp_col_label) + df = df.join(df_concat, how="inner") else: - select_df = pd.DataFrame(columns=[self._index.key]) - select_df = select_df.set_index(self._index.key) + df = df.drop(self._pk_tmp_col_label, axis=1) + df = df.rename(self._remove_tmp_suffix, axis=1) if self._explode: - if _exploded: - explode_df = pd.DataFrame(_exploded) - else: - explode_df = pd.DataFrame(columns=self._explode.index_keys) - explode_df = explode_df.set_index(list(self._explode.index_keys)) - if not self._select: - return explode_df.rename(self._rename, axis=1, errors="ignore") - select_df = select_df.join(explode_df, how="outer") - return select_df.rename(self._rename, axis=1, errors="ignore") + df = self._explode.update_df(df, dialect) + else: + df = df.set_index(self._index.key) + df = df.rename(_ALIASES, axis=1, errors="ignore") + return df.rename(self._rename, axis=1, errors="ignore") def to_dict(self) -> Dict[str, Any]: return { @@ -404,3 +616,73 @@ def from_dict( else {} ), ) + + +def _get_spans_dataframe( + session: Session, + span_filter: SpanFilter, + project_name: str, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + root_spans_only: Optional[bool] = None, +) -> pd.DataFrame: + # legacy labels for backward-compatibility + span_id_label = "context.span_id" + trace_id_label = "context.trace_id" + stmt = ( + select( + models.Span.name, + models.Span.span_kind, + models.Span.parent_id, + models.Span.start_time, + models.Span.end_time, + models.Span.status_code, + models.Span.status_message, + models.Span.events, + models.Span.span_id.label(span_id_label), + models.Trace.trace_id.label(trace_id_label), + models.Span.attributes, + ) + .join(models.Trace) + .join(models.Project) + .where(models.Project.name == project_name) + ) + stmt = span_filter(stmt) + if start_time: + stmt = stmt.where(start_time <= models.Span.start_time) + if stop_time: + stmt = stmt.where(models.Span.start_time < stop_time) + if root_spans_only: + parent = aliased(models.Span) + stmt = stmt.outerjoin( + parent, + models.Span.parent_id == parent.span_id, + ).where(parent.span_id == None) # noqa E711 + df = pd.read_sql(stmt, session.connection()).set_index(span_id_label, drop=False) + if (attrs_label := "attributes") in df.columns: + df_attributes = pd.DataFrame.from_records( + df.attributes.map(_flatten_semantic_conventions), + ).set_axis(df.index, axis=0) + df = pd.concat( + [ + df.drop(attrs_label, axis=1), + df_attributes.add_prefix(attrs_label + "."), + ], + axis=1, + ) + return df + + +def _flatten_semantic_conventions(attributes: Mapping[str, Any]) -> Dict[str, Any]: + # This may be inefficient, but is needed to preserve backward-compatibility. + # Custom attributes do not get flattened. + return unflatten( + load_json_strings( + flatten( + attributes, + recurse_on_sequence=True, + json_string_attributes=JSON_STRING_ATTRIBUTES, + ), + ), + prefix_exclusions=SEMANTIC_CONVENTIONS, + ) diff --git a/src/phoenix/trace/otel.py b/src/phoenix/trace/otel.py index 21bb417b4a..031397e194 100644 --- a/src/phoenix/trace/otel.py +++ b/src/phoenix/trace/otel.py @@ -1,28 +1,22 @@ -import inspect import json from binascii import hexlify, unhexlify from datetime import datetime, timezone from types import MappingProxyType from typing import ( Any, - DefaultDict, Dict, Iterable, Iterator, - List, Mapping, Optional, Sequence, - Set, SupportsFloat, Tuple, - Union, cast, ) import numpy as np import opentelemetry.proto.trace.v1.trace_pb2 as otlp -from openinference.semconv import trace from openinference.semconv.trace import DocumentAttributes, SpanAttributes from opentelemetry.proto.common.v1.common_pb2 import AnyValue, ArrayValue, KeyValue from opentelemetry.util.types import Attributes, AttributeValue @@ -42,6 +36,14 @@ SpanStatusCode, TraceID, ) +from phoenix.utilities.attributes import ( + JSON_STRING_ATTRIBUTES, + flatten, + get_attribute_value, + has_mapping, + load_json_strings, + unflatten, +) DOCUMENT_METADATA = DocumentAttributes.DOCUMENT_METADATA INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE @@ -62,12 +64,8 @@ def decode(otlp_span: otlp.Span) -> Span: start_time = _decode_unix_nano(otlp_span.start_time_unix_nano) end_time = _decode_unix_nano(otlp_span.end_time_unix_nano) - attributes = dict(_unflatten(_load_json_strings(_decode_key_values(otlp_span.attributes)))) - span_kind = SpanKind(attributes.pop(OPENINFERENCE_SPAN_KIND, None)) - - for mime_type in (INPUT_MIME_TYPE, OUTPUT_MIME_TYPE): - if mime_type in attributes: - attributes[mime_type] = attributes[mime_type] + attributes = unflatten(load_json_strings(_decode_key_values(otlp_span.attributes))) + span_kind = SpanKind(get_attribute_value(attributes, OPENINFERENCE_SPAN_KIND)) status_code, status_message = _decode_status(otlp_span.status) events = [_decode_event(event) for event in otlp_span.events] @@ -149,28 +147,6 @@ def _decode_value(any_value: AnyValue) -> Any: assert_never(which) -_JSON_STRING_ATTRIBUTES = ( - DOCUMENT_METADATA, - LLM_PROMPT_TEMPLATE_VARIABLES, - METADATA, - TOOL_PARAMETERS, -) - - -def _load_json_strings(key_values: Iterable[Tuple[str, Any]]) -> Iterator[Tuple[str, Any]]: - for key, value in key_values: - if key.endswith(_JSON_STRING_ATTRIBUTES): - try: - dict_value = json.loads(value) - except Exception: - yield key, value - else: - if dict_value: - yield key, dict_value - else: - yield key, value - - StatusMessage: TypeAlias = str _STATUS_DECODING = MappingProxyType( @@ -187,120 +163,6 @@ def _decode_status(otlp_status: otlp.Status) -> Tuple[SpanStatusCode, StatusMess return status_code, otlp_status.message -_SEMANTIC_CONVENTIONS: List[str] = sorted( - ( - getattr(klass, attr) - for name in dir(trace) - if name.endswith("Attributes") and inspect.isclass(klass := getattr(trace, name)) - for attr in dir(klass) - if attr.isupper() - ), - reverse=True, -) # sorted so the longer strings go first - - -def _semantic_convention_prefix_partition(key: str, separator: str = ".") -> Tuple[str, str, str]: - """Return the longest prefix of `key` that is a semantic convention, and the remaining suffix - separated by `.`. For example, if `key` is "retrieval.documents.2.document.score", return - ("retrieval.documents", ".", "2.document.score"). The return signature is based on Python's - `.partition` method for strings. - """ - for prefix in _SEMANTIC_CONVENTIONS: - if key == prefix: - return key, "", "" - if key.startswith(prefix) and key[len(prefix) :].startswith(separator): - return prefix, separator, key[len(prefix) + len(separator) :] - return "", "", "" - - -class _Trie(DefaultDict[Union[str, int], "_Trie"]): - """Prefix Tree with special handling for indices (i.e. all-digit keys).""" - - def __init__(self) -> None: - super().__init__(_Trie) - self.value: Any = None - self.indices: Set[int] = set() - self.branches: Set[Union[str, int]] = set() - - def set_value(self, value: Any) -> None: - self.value = value - # value and indices must not coexist - self.branches.update(self.indices) - self.indices.clear() - - def add_index(self, index: int) -> "_Trie": - if self.value is not None: - self.branches.add(index) - elif index not in self.branches: - self.indices.add(index) - return self[index] - - def add_branch(self, branch: Union[str, int]) -> "_Trie": - if branch in self.indices: - self.indices.discard(cast(int, branch)) - self.branches.add(branch) - return self[branch] - - -# FIXME: Ideally we should not need something so complicated as a Trie, but it's useful here -# for backward compatibility reasons regarding some deeply nested objects such as TOOL_PARAMETERS. -# In the future, we should `json_dumps` them and not let things get too deeply nested. -def _build_trie( - key_value_pairs: Iterable[Tuple[str, Any]], - separator: str = ".", -) -> _Trie: - """Build a Trie (a.k.a. prefix tree) from `key_value_pairs`, by partitioning the keys by - separator. Each partition is a branch in the Trie. Special handling is done for partitions - that are all digits, e.g. "0", "12", etc., which are converted to integers and collected - as indices. - """ - trie = _Trie() - for key, value in key_value_pairs: - if value is None: - continue - t = trie - while True: - prefix, _, suffix = _semantic_convention_prefix_partition(key, separator) - if prefix: - t = t.add_branch(prefix) - else: - prefix, _, suffix = key.partition(separator) - if prefix.isdigit(): - index = int(prefix) - t = t.add_index(index) if suffix else t.add_branch(index) - else: - t = t.add_branch(prefix) - if not suffix: - break - key = suffix - t.set_value(value) - return trie - - -def _walk(trie: _Trie, prefix: str = "") -> Iterator[Tuple[str, Any]]: - if trie.value is not None: - yield prefix, trie.value - elif prefix and trie.indices: - yield prefix, [dict(_walk(trie[index])) for index in sorted(trie.indices)] - elif trie.indices: - for index in trie.indices: - yield from _walk(trie[index], prefix=f"{index}") - elif prefix: - yield prefix, dict(_walk(trie)) - return - for branch in trie.branches: - new_prefix = f"{prefix}.{branch}" if prefix else f"{branch}" - yield from _walk(trie[branch], new_prefix) - - -def _unflatten( - key_value_pairs: Iterable[Tuple[str, Any]], - separator: str = ".", -) -> Iterator[Tuple[str, Any]]: - trie = _build_trie(key_value_pairs, separator) - yield from _walk(trie) - - _BILLION = 1_000_000_000 # for converting seconds to nanoseconds @@ -313,11 +175,7 @@ def encode(span: Span) -> otlp.Span: start_time_unix_nano: int = int(span.start_time.timestamp() * _BILLION) end_time_unix_nano: int = int(span.end_time.timestamp() * _BILLION) if span.end_time else 0 - attributes: Dict[str, Any] = span.attributes.copy() - - for mime_type in (INPUT_MIME_TYPE, OUTPUT_MIME_TYPE): - if mime_type in attributes: - attributes[mime_type] = attributes[mime_type] + attributes: Dict[str, Any] = dict(span.attributes) for key, value in span.attributes.items(): if value is None: @@ -325,19 +183,34 @@ def encode(span: Span) -> otlp.Span: attributes.pop(key, None) elif isinstance(value, Mapping): attributes.pop(key, None) - if key.endswith(_JSON_STRING_ATTRIBUTES): + if key.endswith(JSON_STRING_ATTRIBUTES): attributes[key] = json.dumps(value) else: - attributes.update(_flatten_mapping(value, key)) + attributes.update( + flatten( + value, + prefix=key, + recurse_on_sequence=True, + json_string_attributes=JSON_STRING_ATTRIBUTES, + ) + ) elif ( not isinstance(value, str) and (isinstance(value, Sequence) or isinstance(value, np.ndarray)) - and _has_mapping(value) + and has_mapping(value) ): attributes.pop(key, None) - attributes.update(_flatten_sequence(value, key)) - - attributes[OPENINFERENCE_SPAN_KIND] = span.span_kind.value + attributes.update( + flatten( + value, + prefix=key, + recurse_on_sequence=True, + json_string_attributes=JSON_STRING_ATTRIBUTES, + ) + ) + + if OPENINFERENCE_SPAN_KIND not in attributes: + attributes[OPENINFERENCE_SPAN_KIND] = span.span_kind.value status = _encode_status(span.status_code, span.status_message) events = map(_encode_event, span.events) @@ -378,42 +251,6 @@ def _encode_identifier(identifier: Optional[str]) -> bytes: return unhexlify(identifier) -def _has_mapping(sequence: Sequence[Any]) -> bool: - for item in sequence: - if isinstance(item, Mapping): - return True - return False - - -def _flatten_mapping( - mapping: Mapping[str, Any], - prefix: str, -) -> Iterator[Tuple[str, Any]]: - for key, value in mapping.items(): - prefixed_key = f"{prefix}.{key}" - if isinstance(value, Mapping): - if key.endswith(_JSON_STRING_ATTRIBUTES): - yield prefixed_key, json.dumps(value) - else: - yield from _flatten_mapping(value, prefixed_key) - elif isinstance(value, Sequence): - yield from _flatten_sequence(value, prefixed_key) - elif value is not None: - yield prefixed_key, value - - -def _flatten_sequence( - sequence: Sequence[Any], - prefix: str, -) -> Iterator[Tuple[str, Any]]: - if isinstance(sequence, str) or not _has_mapping(sequence): - yield prefix, sequence - for idx, obj in enumerate(sequence): - if not isinstance(obj, Mapping): - continue - yield from _flatten_mapping(obj, f"{prefix}.{idx}") - - def _encode_event(event: SpanEvent) -> otlp.Span.Event: return otlp.Span.Event( name=event.name, diff --git a/src/phoenix/trace/schemas.py b/src/phoenix/trace/schemas.py index 7caeab9bf4..994756cf7d 100644 --- a/src/phoenix/trace/schemas.py +++ b/src/phoenix/trace/schemas.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, List, NamedTuple, Optional, Union +from typing import Any, Dict, List, Mapping, NamedTuple, Optional from uuid import UUID EXCEPTION_TYPE = "exception.type" @@ -54,9 +54,7 @@ def _missing_(cls, v: Any) -> Optional["SpanKind"]: TraceID = str SpanID = str -AttributePrimitiveValue = Union[str, bool, float, int] -AttributeValue = Union[AttributePrimitiveValue, List[AttributePrimitiveValue]] -SpanAttributes = Dict[str, AttributeValue] +SpanAttributes = Mapping[str, Any] @dataclass(frozen=True) diff --git a/src/phoenix/utilities/__init__.py b/src/phoenix/utilities/__init__.py index 8a86c72d19..e69de29bb2 100644 --- a/src/phoenix/utilities/__init__.py +++ b/src/phoenix/utilities/__init__.py @@ -1,26 +0,0 @@ -from datetime import datetime -from typing import List, Optional - -import pandas as pd - -from phoenix.core.project import Project -from phoenix.trace.dsl import SpanQuery - - -def query_spans( - project: Optional[Project], - *queries: SpanQuery, - start_time: Optional[datetime] = None, - stop_time: Optional[datetime] = None, - root_spans_only: Optional[bool] = None, -) -> List[pd.DataFrame]: - if not queries or not project: - return [] - spans = tuple( - project.get_spans( - start_time=start_time, - stop_time=stop_time, - root_spans_only=root_spans_only, - ) - ) - return [query(spans) for query in queries] diff --git a/src/phoenix/utilities/attributes.py b/src/phoenix/utilities/attributes.py new file mode 100644 index 0000000000..5413726c78 --- /dev/null +++ b/src/phoenix/utilities/attributes.py @@ -0,0 +1,278 @@ +import inspect +import json +from typing import ( + Any, + DefaultDict, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Union, + cast, +) + +from openinference.semconv import trace +from openinference.semconv.trace import DocumentAttributes, SpanAttributes +from typing_extensions import assert_never + +DOCUMENT_METADATA = DocumentAttributes.DOCUMENT_METADATA +LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES +METADATA = SpanAttributes.METADATA +TOOL_PARAMETERS = SpanAttributes.TOOL_PARAMETERS + +# attributes interpreted as JSON strings during ingestion +JSON_STRING_ATTRIBUTES = ( + DOCUMENT_METADATA, + LLM_PROMPT_TEMPLATE_VARIABLES, + METADATA, + TOOL_PARAMETERS, +) + +SEMANTIC_CONVENTIONS: List[str] = sorted( + # e.g. "input.value", "llm.token_count.total", etc. + ( + cast(str, getattr(klass, attr)) + for name in dir(trace) + if name.endswith("Attributes") and inspect.isclass(klass := getattr(trace, name)) + for attr in dir(klass) + if attr.isupper() + ), + key=len, + reverse=True, +) # sorted so the longer strings go first + + +def unflatten( + key_value_pairs: Iterable[Tuple[str, Any]], + *, + prefix_exclusions: Sequence[str] = (), + separator: str = ".", +) -> Dict[str, Any]: + # `prefix_exclusions` is intended to contain the semantic conventions + trie = _build_trie(key_value_pairs, separator=separator, prefix_exclusions=prefix_exclusions) + return dict(_walk(trie, separator=separator)) + + +def flatten( + obj: Union[Mapping[str, Any], Iterable[Any]], + *, + prefix: str = "", + separator: str = ".", + recurse_on_sequence: bool = False, + json_string_attributes: Optional[Sequence[str]] = None, +) -> Iterator[Tuple[str, Any]]: + if isinstance(obj, Mapping): + yield from _flatten_mapping( + obj, + prefix=prefix, + recurse_on_sequence=recurse_on_sequence, + json_string_attributes=json_string_attributes, + separator=separator, + ) + elif isinstance(obj, Iterable): + yield from _flatten_sequence( + obj, + prefix=prefix, + recurse_on_sequence=recurse_on_sequence, + json_string_attributes=json_string_attributes, + separator=separator, + ) + else: + assert_never(obj) + + +def has_mapping(sequence: Iterable[Any]) -> bool: + for item in sequence: + if isinstance(item, Mapping): + return True + return False + + +def get_attribute_value( + attributes: Optional[Mapping[str, Any]], + key: str, + separator: str = ".", +) -> Optional[Any]: + if not attributes: + return None + sub_keys = key.split(separator) + for sub_key in sub_keys[:-1]: + attributes = attributes.get(sub_key) + if not attributes: + return None + return attributes.get(sub_keys[-1]) + + +def load_json_strings(key_values: Iterable[Tuple[str, Any]]) -> Iterator[Tuple[str, Any]]: + for key, value in key_values: + if key.endswith(JSON_STRING_ATTRIBUTES): + try: + dict_value = json.loads(value) + except Exception: + yield key, value + else: + if dict_value: + yield key, dict_value + else: + yield key, value + + +def _partition_with_prefix_exclusion( + key: str, + separator: str = ".", + prefix_exclusions: Sequence[str] = (), +) -> Tuple[str, str, str]: + # prefix_exclusions should be sorted by length from the longest to the shortest + for prefix in prefix_exclusions: + if key.startswith(prefix) and ( + len(key) == len(prefix) or key[len(prefix) :].startswith(separator) + ): + return prefix, separator, key[len(prefix) + len(separator) :] + return key.partition(separator) + + +class _Trie(DefaultDict[Union[str, int], "_Trie"]): + """Prefix Tree with special handling for indices (i.e. all-digit keys).""" + + def __init__(self) -> None: + super().__init__(_Trie) + self.value: Any = None + self.indices: Set[int] = set() + self.branches: Set[Union[str, int]] = set() + + def set_value(self, value: Any) -> None: + self.value = value + # value and indices must not coexist + self.branches.update(self.indices) + self.indices.clear() + + def add_index(self, index: int) -> "_Trie": + if self.value is not None: + self.branches.add(index) + elif index not in self.branches: + self.indices.add(index) + return self[index] + + def add_branch(self, branch: Union[str, int]) -> "_Trie": + if branch in self.indices: + self.indices.discard(cast(int, branch)) + self.branches.add(branch) + return self[branch] + + +def _build_trie( + key_value_pairs: Iterable[Tuple[str, Any]], + *, + prefix_exclusions: Sequence[str] = (), + separator: str = ".", +) -> _Trie: + """Build a Trie (a.k.a. prefix tree) from `key_value_pairs`, by partitioning the keys by + separator. Each partition is a branch in the Trie. Special handling is done for partitions + that are all digits, e.g. "0", "12", etc., which are converted to integers and collected + as indices. + """ + trie = _Trie() + for key, value in key_value_pairs: + if value is None: + continue + t = trie + while True: + prefix, _, suffix = _partition_with_prefix_exclusion( + key, + separator, + prefix_exclusions, + ) + if prefix.isdigit(): + index = int(prefix) + t = t.add_index(index) if suffix else t.add_branch(index) + else: + t = t.add_branch(prefix) + if not suffix: + break + key = suffix + t.set_value(value) + return trie + + +def _walk( + trie: _Trie, + *, + prefix: str = "", + separator: str = ".", +) -> Iterator[Tuple[str, Any]]: + if trie.value is not None: + yield prefix, trie.value + elif prefix and trie.indices: + yield ( + prefix, + [dict(_walk(trie[index], separator=separator)) for index in sorted(trie.indices)], + ) + elif trie.indices: + for index in trie.indices: + yield from _walk(trie[index], prefix=f"{index}", separator=separator) + elif prefix: + yield prefix, dict(_walk(trie, separator=separator)) + return + for branch in trie.branches: + new_prefix = f"{prefix}{separator}{branch}" if prefix else f"{branch}" + yield from _walk(trie[branch], prefix=new_prefix, separator=separator) + + +def _flatten_mapping( + mapping: Mapping[str, Any], + *, + prefix: str = "", + recurse_on_sequence: bool = False, + json_string_attributes: Optional[Sequence[str]] = None, + separator: str = ".", +) -> Iterator[Tuple[str, Any]]: + for key, value in mapping.items(): + prefixed_key = f"{prefix}{separator}{key}" if prefix else key + if isinstance(value, Mapping): + if json_string_attributes and prefixed_key.endswith(JSON_STRING_ATTRIBUTES): + yield prefixed_key, json.dumps(value) + else: + yield from _flatten_mapping( + value, + prefix=prefixed_key, + recurse_on_sequence=recurse_on_sequence, + json_string_attributes=json_string_attributes, + separator=separator, + ) + elif isinstance(value, Sequence) and recurse_on_sequence: + yield from _flatten_sequence( + value, + prefix=prefixed_key, + recurse_on_sequence=recurse_on_sequence, + json_string_attributes=json_string_attributes, + separator=separator, + ) + elif value is not None: + yield prefixed_key, value + + +def _flatten_sequence( + sequence: Iterable[Any], + *, + prefix: str = "", + recurse_on_sequence: bool = False, + json_string_attributes: Optional[Sequence[str]] = None, + separator: str = ".", +) -> Iterator[Tuple[str, Any]]: + if isinstance(sequence, str) or not has_mapping(sequence): + yield prefix, sequence + for idx, obj in enumerate(sequence): + if not isinstance(obj, Mapping): + continue + yield from _flatten_mapping( + obj, + prefix=f"{prefix}{separator}{idx}" if prefix else f"{idx}", + recurse_on_sequence=recurse_on_sequence, + json_string_attributes=json_string_attributes, + separator=separator, + ) diff --git a/tests/core/test_project.py b/tests/core/test_project.py index 3baf8d4b13..06b55a2571 100644 --- a/tests/core/test_project.py +++ b/tests/core/test_project.py @@ -10,6 +10,7 @@ from phoenix.core.project import Project, _Spans from phoenix.trace.otel import decode from phoenix.trace.schemas import ComputedAttributes +from phoenix.utilities.attributes import get_attribute_value @pytest.mark.parametrize("permutation", list(permutations(range(5)))) @@ -33,7 +34,10 @@ def test_ingestion( assert _id_str(otlp_span.span_id) in _spans, f"{i=}, {s=}" latest_span = next(project.get_spans(span_ids=[_id_str(otlp_span.span_id)])) - expected_token_count_total += latest_span.attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] + expected_token_count_total += get_attribute_value( + latest_span.attributes, + SpanAttributes.LLM_TOKEN_COUNT_TOTAL, + ) assert project.token_count_total == expected_token_count_total, f"{i=}, {s=}" ingested_ids.add(latest_span.context.span_id) @@ -43,10 +47,10 @@ def test_ingestion( # across a missing parent. for span_id in ingested_ids.intersection(child_ids.keys()): span = next(project.get_spans(span_ids=[span_id])) - assert span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL] == span.attributes[ - SpanAttributes.LLM_TOKEN_COUNT_TOTAL - ] + sum( - span.attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] + assert span[ComputedAttributes.CUMULATIVE_LLM_TOKEN_COUNT_TOTAL] == get_attribute_value( + span.attributes, SpanAttributes.LLM_TOKEN_COUNT_TOTAL + ) + sum( + get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_TOTAL) for span in project.get_spans( span_ids=list(_connected_descendant_ids(span_id, child_ids, ingested_ids)) ) diff --git a/tests/server/api/types/test_span.py b/tests/server/api/types/test_span.py deleted file mode 100644 index 77c352004d..0000000000 --- a/tests/server/api/types/test_span.py +++ /dev/null @@ -1,17 +0,0 @@ -from phoenix.server.api.types.Span import _nested_attributes - - -def test_nested_attributes() -> None: - assert _nested_attributes( - { - "llm.model_name": ..., - "llm.prompt_template.variables": ..., - }, - ) == { - "llm": { - "model_name": ..., - "prompt_template": { - "variables": ..., - }, - }, - } diff --git a/tests/trace/conftest.py b/tests/trace/conftest.py new file mode 100644 index 0000000000..476711c032 --- /dev/null +++ b/tests/trace/conftest.py @@ -0,0 +1,19 @@ +from typing import Iterator + +import pytest +from phoenix.db.models import Base +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker + + +@pytest.fixture(scope="session") +def session_maker() -> sessionmaker: + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return sessionmaker(engine) + + +@pytest.fixture() +def session(session_maker: sessionmaker) -> Iterator[Session]: + with session_maker.begin() as session: + yield session diff --git a/tests/trace/dsl/test_filter.py b/tests/trace/dsl/test_filter.py index ff936ff60d..6a0957cd2d 100644 --- a/tests/trace/dsl/test_filter.py +++ b/tests/trace/dsl/test_filter.py @@ -1,139 +1,80 @@ -import ast -from collections import namedtuple -from itertools import count, islice -from random import random +import sys +from typing import Any -import phoenix.trace.v1 as pb import pytest -from google.protobuf.wrappers_pb2 import DoubleValue, StringValue -from openinference.semconv.trace import SpanAttributes -from phoenix.trace.dsl.filter import SpanFilter, _validate_expression - -LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL - - -def test_span_filter() -> None: - key = LLM_TOKEN_COUNT_TOTAL - Span = namedtuple("Span", "attributes name parent_id") - span_0 = Span({key: 0}, 0, "2") - span_1 = Span({key: 1}, 1, None) - span_2 = Span({}, None, "3") - spans = [span_0, span_1, span_2] - assert list(filter(SpanFilter(), spans)) == spans # no op - assert list(filter(SpanFilter("parent_id is None"), spans)) == [span_1] - assert list(filter(SpanFilter("parent_id is not None"), spans)) == [span_0, span_2] - assert list(filter(SpanFilter("parent_id == '3'"), spans)) == [spans[2]] - assert list(filter(SpanFilter("parent_id in ('2', '3')"), spans)) == [span_0, span_2] - assert list(filter(SpanFilter("parent_id in ['2']"), spans)) == [span_0] - for k in (key, "name"): - assert list(filter(SpanFilter(f"{k} > 0.5"), spans)) == [span_1] - assert list(filter(SpanFilter(f"{k} < 0.5"), spans)) == [span_0] - assert list(filter(SpanFilter(f"{k} >= 0.5"), spans)) == [span_1] - assert list(filter(SpanFilter(f"{k} <= 0.5"), spans)) == [span_0] - assert list(filter(SpanFilter(f"{k} == 0.5"), spans)) == [] - assert list(filter(SpanFilter(f"{k} != 0.5"), spans)) == [span_0, span_1] - assert list(filter(SpanFilter(f"{k} is not None"), spans)) == [span_0, span_1] - assert list(filter(SpanFilter(f"{k} is None"), spans)) == [span_2] - - -def test_ast_validate_expression() -> None: - _validate("a is None") - _validate("a > b") - _validate("a > b and c < -d") - _validate("a > b + c") - _validate("a > b and (c < d or e == f) and g >= h") - with pytest.raises(SyntaxError): - _validate("sqrt(x)") - with pytest.raises(SyntaxError): - _validate("abs(x) and x") - with pytest.raises(SyntaxError): - _validate("{} == {}") - - -def _validate(source: str): - _validate_expression(ast.parse(source, mode="eval"), source) - - -def test_span_filter_by_eval(spans, evals, eval_name): - spans = list(islice(spans, 3)) - - sf = SpanFilter(f"evals['{eval_name}'].score < 0.5", evals=evals) - assert list(filter(sf, spans)) == [spans[0]] - - sf = SpanFilter(f"evals['{eval_name}'].label == '1'", evals=evals) - assert list(filter(sf, spans)) == [spans[1]] - - sf = SpanFilter(f"evals['{eval_name}'].score is None", evals=evals) - assert list(filter(sf, spans)) == [spans[2]] - - sf = SpanFilter(f"evals['{eval_name}'].label is None", evals=evals) - assert list(filter(sf, spans)) == [spans[0], spans[2]] - - sf = SpanFilter(f"evals['{eval_name}'].score is not None", evals=evals) - assert list(filter(sf, spans)) == [spans[0], spans[1]] - - sf = SpanFilter(f"evals['{eval_name}'].label is not None", evals=evals) - assert list(filter(sf, spans)) == [spans[1]] - - # evals is None - sf = SpanFilter(f"evals['{eval_name}'].score < 0.5", evals=None) - assert list(filter(sf, spans)) == [] - - # non-existent evaluation name - sf = SpanFilter(f"evals['{random()}'].score < 0.5", evals=evals) - assert list(filter(sf, spans)) == [] - - # non-existent evaluation name - sf = SpanFilter(f"evals['{random()}'].label == '1'", evals=evals) - assert list(filter(sf, spans)) == [] - - -def test_span_filter_by_eval_exceptions(spans, evals, eval_name): - with pytest.raises(SyntaxError): - # no valid eval names - SpanFilter(f"evals['{eval_name}'].score < 0.5", evals=evals, valid_eval_names=[]) - with pytest.raises(SyntaxError): - # invalid attribute - SpanFilter(f"evals['{eval_name}'].scor < 0.5", evals=evals) - with pytest.raises(SyntaxError): - # misspelled evals - SpanFilter(f"eval['{eval_name}'].score < 0.5", evals=evals) - with pytest.raises(SyntaxError): - # non-string eval name - SpanFilter("evals[123].score < 0.5", evals=evals) - - -def test_span_filter_by_metadata(spans): - spans = list(islice(spans, 4)) - - sf = SpanFilter('metadata["odd index"] == 1') - assert list(filter(sf, spans)) == [spans[1]] - - -Span = namedtuple("Span", "context attributes") -Context = namedtuple("Context", "span_id") -Evals = namedtuple("Evals", "get_span_evaluation") - - -@pytest.fixture -def evals(eval_name): - result0 = pb.Evaluation.Result(score=DoubleValue(value=0)) - result1 = pb.Evaluation.Result(score=DoubleValue(value=1), label=StringValue(value="1")) - evaluations = {eval_name: {0: pb.Evaluation(result=result0), 1: pb.Evaluation(result=result1)}} - return Evals(lambda span_id, name: evaluations.get(name, {}).get(span_id)) - - -@pytest.fixture -def eval_name(): - return "correctness" - - -@pytest.fixture -def spans(): - return ( - Span( - context=Context(i), - attributes={**({"metadata": {"odd index": i}} if i % 2 else {})}, - ) - for i in count() - ) +from phoenix.db import models +from phoenix.trace.dsl.filter import SpanFilter +from sqlalchemy import select +from sqlalchemy.orm import Session + +if sys.version_info >= (3, 9): + from ast import unparse +else: + from astunparse import unparse + + +@pytest.mark.parametrize( + "expression,expected", + [ + ( + "parent_id is not None and 'abc' in name or span_kind == 'LLM' and span_id in ('123',)", # noqa E501 + "or_(and_(parent_id != None, name.contains('abc')), and_(span_kind == 'LLM', span_id.in_(('123',))))" # noqa E501 + if sys.version_info >= (3, 9) + else "or_(and_((parent_id != None), name.contains('abc')), and_((span_kind == 'LLM'), span_id.in_(('123',))))", # noqa E501 + ), + ( + "(parent_id is None or 'abc' not in name) and not (span_kind != 'LLM' or span_id not in ('123',))", # noqa E501 + "and_(or_(parent_id == None, not_(name.contains('abc'))), not_(or_(span_kind != 'LLM', span_id.not_in(('123',)))))" # noqa E501 + if sys.version_info >= (3, 9) + else "and_(or_((parent_id == None), not_(name.contains('abc'))), not_(or_((span_kind != 'LLM'), span_id.not_in(('123',)))))", # noqa E501 + ), + ( + "1000 < latency_ms < 2000 or status_code == 'ERROR' or 2000 <= cumulative_llm_token_count_total", # noqa E501 + "or_(and_(1000 < latency_ms, latency_ms < 2000), status_code == 'ERROR', 2000 <= cumulative_llm_token_count_total)" # noqa E501 + if sys.version_info >= (3, 9) + else "or_(and_((1000 < latency_ms), (latency_ms < 2000)), (status_code == 'ERROR'), (2000 <= cumulative_llm_token_count_total))", # noqa E501 + ), + ( + "llm.token_count.total - llm.token_count.prompt > 1000", + "cast(attributes[['llm', 'token_count', 'total']].as_float() - attributes[['llm', 'token_count', 'prompt']].as_float(), Float) > 1000" # noqa E501 + if sys.version_info >= (3, 9) + else "cast((attributes[['llm', 'token_count', 'total']].as_float() - attributes[['llm', 'token_count', 'prompt']].as_float()), Float) > 1000", # noqa E501 + ), + ( + "first.value in (1,) and second.value in ('2',) and '3' in third.value", + "and_(attributes[['first', 'value']].as_float().in_((1,)), attributes[['second', 'value']].as_string().in_(('2',)), attributes[['third', 'value']].as_string().contains('3'))", # noqa E501 + ), + ( + "'1.0' < my.value < 2.0", + "and_('1.0' < attributes[['my', 'value']].as_string(), attributes[['my', 'value']].as_float() < 2.0)" # noqa E501 + if sys.version_info >= (3, 9) + else "and_(('1.0' < attributes[['my', 'value']].as_string()), (attributes[['my', 'value']].as_float() < 2.0))", # noqa E501 + ), + ( + "first.value + 1 < second.value", + "cast(attributes[['first', 'value']].as_float() + 1, Float) < attributes[['second', 'value']].as_float()" # noqa E501 + if sys.version_info >= (3, 9) + else "cast((attributes[['first', 'value']].as_float() + 1), Float) < attributes[['second', 'value']].as_float()", # noqa E501 + ), + ( + "my.value == '1.0' or float(my.value) < 2.0", + "or_(attributes[['my', 'value']].as_string() == '1.0', attributes[['my', 'value']].as_float() < 2.0)" # noqa E501 + if sys.version_info >= (3, 9) + else "or_((attributes[['my', 'value']].as_string() == '1.0'), (attributes[['my', 'value']].as_float() < 2.0))", # noqa E501 + ), + ], +) +def test_translated(session: Session, expression: str, expected: str) -> None: + f = SpanFilter(expression) + assert _unparse(f.translated) == expected + # next line is only to test that the syntax is accepted + session.scalar(f(select(models.Span.id))) + + +def _unparse(exp: Any) -> str: + # `unparse` for python 3.8 outputs differently, + # otherwise this function is unnecessary. + s = unparse(exp).strip() + if s[0] == "(" and s[-1] == ")": + return s[1:-1] + return s diff --git a/tests/trace/dsl/test_query.py b/tests/trace/dsl/test_query.py deleted file mode 100644 index 114d817804..0000000000 --- a/tests/trace/dsl/test_query.py +++ /dev/null @@ -1,302 +0,0 @@ -from collections import namedtuple -from random import random - -import numpy as np -import pandas as pd -import pytest -from openinference.semconv.trace import DocumentAttributes, SpanAttributes -from pandas.testing import assert_frame_equal -from phoenix.trace.dsl.query import Concatenation, Explosion, Projection, SpanQuery -from phoenix.trace.schemas import ATTRIBUTE_PREFIX, CONTEXT_PREFIX - -DOCUMENT_CONTENT = DocumentAttributes.DOCUMENT_CONTENT -DOCUMENT_SCORE = DocumentAttributes.DOCUMENT_SCORE -INPUT_VALUE = SpanAttributes.INPUT_VALUE -OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE -RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS - -Context = namedtuple("Context", "span_id trace_id") -Span = namedtuple("Span", "context parent_id attributes") - -SPAN_ID = "span_id" -TRACE_ID = "trace_id" - - -def test_projection(spans): - for field in (TRACE_ID, f"{CONTEXT_PREFIX}{TRACE_ID}"): - project = Projection(field) - assert project(spans[0]) == "99" - assert project(spans[1]) == "99" - - for field in (SPAN_ID, f"{CONTEXT_PREFIX}{SPAN_ID}"): - project = Projection(field) - assert project(spans[0]) == "0" - assert project(spans[1]) == "1" - - for field in (INPUT_VALUE, f"{ATTRIBUTE_PREFIX}{INPUT_VALUE}"): - project = Projection(field) - assert project(spans[0]) == "000" - assert project(spans[1]) is None - - for field in (RETRIEVAL_DOCUMENTS, f"{ATTRIBUTE_PREFIX}{RETRIEVAL_DOCUMENTS}"): - project = Projection(field) - assert project(spans[0]) == [] - assert project(spans[1]) == [ - { - DOCUMENT_CONTENT: "10", - DOCUMENT_SCORE: 100, - } - ] - - -def test_concatenation(spans): - concat = Concatenation("12") - assert list(concat(spans[2])) == [("12", "1\n\n2")] - - -def test_explosion(spans): - explode = Explosion("12") - assert list(explode(spans[2])) == [ - { - "12": 1, - "context.span_id": "2", - "position": 0, - }, - { - "12": 2, - "context.span_id": "2", - "position": 1, - }, - ] - - explode = Explosion(RETRIEVAL_DOCUMENTS) - - assert list(explode(spans[0])) == [] - assert list(explode(spans[1])) == [ - { - DOCUMENT_CONTENT: "10", - DOCUMENT_SCORE: 100, - "context.span_id": "1", - "document_position": 0, - } - ] - assert list(explode(spans[2])) == [ - { - DOCUMENT_CONTENT: "20", - "context.span_id": "2", - "document_position": 0, - }, - { - DOCUMENT_SCORE: 201, - "context.span_id": "2", - "document_position": 1, - }, - { - DOCUMENT_CONTENT: "22", - DOCUMENT_SCORE: 203, - "context.span_id": "2", - "document_position": 3, - }, - ] - - -def test_query_select(spans): - query = SpanQuery().select( - input=INPUT_VALUE, - output=OUTPUT_VALUE, - ) - actual = query(spans) - desired = pd.DataFrame( - { - "context.span_id": ["0", "2"], - "input": ["000", None], - "output": [None, "222"], - } - ).set_index("context.span_id") - assert_frame_equal(actual, desired) - assert_frame_equal(SpanQuery.from_dict(query.to_dict())(spans), desired) - del query, actual, desired - - -def test_query_concat(spans): - sep = str(random()) - - query = ( - SpanQuery() - .concat( - RETRIEVAL_DOCUMENTS, - reference=DOCUMENT_CONTENT, - ) - .with_concat_separator(separator=sep) - ) - actual = query(spans) - desired = pd.DataFrame( - { - "context.span_id": ["1", "2"], - "reference": ["10", f"20{sep}22"], - } - ).set_index("context.span_id") - assert_frame_equal(actual, desired) - assert_frame_equal(SpanQuery.from_dict(query.to_dict())(spans), desired) - del query, actual, desired - - query = ( - SpanQuery() - .concat( - RETRIEVAL_DOCUMENTS, - score=DOCUMENT_SCORE, - ) - .with_concat_separator(separator=sep) - ) - actual = query(spans) - desired = pd.DataFrame( - { - "context.span_id": ["1", "2"], - "score": ["100", f"201{sep}203"], - } - ).set_index("context.span_id") - assert_frame_equal(actual, desired) - assert_frame_equal(SpanQuery.from_dict(query.to_dict())(spans), desired) - del query, actual, desired - - -def test_query_explode(spans): - query = ( - SpanQuery() - .select( - input=INPUT_VALUE, - output=OUTPUT_VALUE, - ) - .explode(RETRIEVAL_DOCUMENTS) - ) - actual = query(spans) - desired = pd.DataFrame( - { - "context.span_id": ["1", "2", "2", "2"], - "document_position": [0, 0, 1, 3], - "input": [None, None, None, None], - "output": [None, "222", "222", "222"], - DOCUMENT_CONTENT: ["10", "20", None, "22"], - DOCUMENT_SCORE: [100, None, 201, 203], - } - ).set_index(["context.span_id", "document_position"]) - assert_frame_equal(actual, desired) - assert_frame_equal(SpanQuery.from_dict(query.to_dict())(spans), desired) - del query, actual, desired - - query = SpanQuery().explode(RETRIEVAL_DOCUMENTS) - actual = query(spans) - desired = pd.DataFrame( - { - "context.span_id": ["1", "2", "2", "2"], - "document_position": [0, 0, 1, 3], - DOCUMENT_CONTENT: ["10", "20", None, "22"], - DOCUMENT_SCORE: [100, None, 201, 203], - } - ).set_index(["context.span_id", "document_position"]) - assert_frame_equal(actual, desired) - assert_frame_equal(SpanQuery.from_dict(query.to_dict())(spans), desired) - del query, actual, desired - - query = SpanQuery().explode( - RETRIEVAL_DOCUMENTS, - reference=DOCUMENT_CONTENT, - ) - actual = query(spans) - desired = pd.DataFrame( - { - "context.span_id": ["1", "2", "2"], - "document_position": [0, 0, 3], - "reference": ["10", "20", "22"], - } - ).set_index(["context.span_id", "document_position"]) - assert_frame_equal(actual, desired) - assert_frame_equal(SpanQuery.from_dict(query.to_dict())(spans), desired) - del query, actual, desired - - -def test_join(spans): - left_query = SpanQuery().select(input=INPUT_VALUE) - right_query = ( - SpanQuery() - .select(span_id="parent_id") - .concat( - RETRIEVAL_DOCUMENTS, - reference=DOCUMENT_CONTENT, - ) - ) - left_result = left_query(spans) - right_result = right_query(spans) - actual = pd.concat( - [left_result, right_result], - axis=1, - join="outer", - ) - desired = pd.DataFrame( - { - "context.span_id": ["0", "1"], - "input": ["000", None], - "reference": ["10", "20\n\n22"], - } - ).set_index("context.span_id") - assert_frame_equal(actual, desired) - assert_frame_equal( - pd.concat( - [ - SpanQuery.from_dict(left_query.to_dict())(spans), - SpanQuery.from_dict(right_query.to_dict())(spans), - ], - axis=1, - join="outer", - ), - desired, - ) - - -@pytest.fixture(scope="module") -def spans(): - return ( - Span( - context=Context(span_id="0", trace_id="99"), - parent_id=None, - attributes={ - INPUT_VALUE: "000", - RETRIEVAL_DOCUMENTS: [], - }, - ), - Span( - context=Context(span_id="1", trace_id="99"), - parent_id="0", - attributes={ - RETRIEVAL_DOCUMENTS: np.array( - [ - { - DOCUMENT_CONTENT: "10", - DOCUMENT_SCORE: 100, - } - ] - ), - }, - ), - Span( - context=Context(span_id="2", trace_id="99"), - parent_id="1", - attributes={ - "12": [1, 2], - OUTPUT_VALUE: "222", - RETRIEVAL_DOCUMENTS: [ - { - DOCUMENT_CONTENT: "20", - }, - { - DOCUMENT_SCORE: 201, - }, - None, - { - DOCUMENT_CONTENT: "22", - DOCUMENT_SCORE: 203, - }, - ], - }, - ), - ) diff --git a/tests/trace/test_otel.py b/tests/trace/test_otel.py index 14618a9df1..60a5b0b1c3 100644 --- a/tests/trace/test_otel.py +++ b/tests/trace/test_otel.py @@ -8,14 +8,10 @@ import pytest from google.protobuf.json_format import MessageToJson from openinference.semconv.trace import ( - DocumentAttributes, - EmbeddingAttributes, - MessageAttributes, SpanAttributes, - ToolCallAttributes, ) from opentelemetry.proto.common.v1.common_pb2 import AnyValue, ArrayValue, KeyValue -from phoenix.trace.otel import _decode_identifier, _encode_identifier, _unflatten, decode, encode +from phoenix.trace.otel import _decode_identifier, _encode_identifier, decode, encode from phoenix.trace.schemas import ( EXCEPTION_ESCAPED, EXCEPTION_MESSAGE, @@ -30,22 +26,7 @@ ) from pytest import approx -DOCUMENT_CONTENT = DocumentAttributes.DOCUMENT_CONTENT -DOCUMENT_ID = DocumentAttributes.DOCUMENT_ID -DOCUMENT_METADATA = DocumentAttributes.DOCUMENT_METADATA -DOCUMENT_SCORE = DocumentAttributes.DOCUMENT_SCORE -EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS -EMBEDDING_TEXT = EmbeddingAttributes.EMBEDDING_TEXT -EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR -LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES -LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES -MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE -MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS OPENINFERENCE_SPAN_KIND = SpanAttributes.OPENINFERENCE_SPAN_KIND -RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS -TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON -TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME -TOOL_PARAMETERS = SpanAttributes.TOOL_PARAMETERS def test_decode_encode(span): @@ -101,7 +82,11 @@ def test_decode_encode_status_code(span, span_status_code, otlp_status_code): @pytest.mark.parametrize("span_kind", list(SpanKind)) def test_decode_encode_span_kind(span, span_kind): - span = replace(span, span_kind=span_kind) + span = replace( + span, + span_kind=span_kind, + attributes={"openinference": {"span": {"kind": span_kind.value}}}, + ) otlp_span = encode(span) assert MessageToJson( KeyValue( @@ -151,7 +136,7 @@ def test_decode_encode_span_kind(span, span_kind): ], ) def test_decode_encode_attributes(span, attributes, otlp_key_value): - span = replace(span, attributes=attributes) + span = replace(span, attributes={**span.attributes, **attributes}) otlp_span = encode(span) assert MessageToJson(otlp_key_value) in set(map(MessageToJson, otlp_span.attributes)) decoded_span = decode(otlp_span) @@ -249,13 +234,15 @@ def test_decode_encode_documents(span): "m7": [4444.0], } attributes = { - RETRIEVAL_DOCUMENTS: [ - {DOCUMENT_ID: "d1", DOCUMENT_CONTENT: content, DOCUMENT_SCORE: score}, - {DOCUMENT_ID: "d2"}, - {DOCUMENT_CONTENT: content}, - {DOCUMENT_SCORE: score}, - {DOCUMENT_METADATA: document_metadata}, - ] + "retrieval": { + "documents": [ + {"document": {"id": "d1", "content": content, "score": score}}, + {"document": {"id": "d2"}}, + {"document": {"content": content}}, + {"document": {"score": score}}, + {"document": {"metadata": document_metadata}}, + ] + } } span = replace(span, attributes=attributes) otlp_span = encode(span) @@ -265,48 +252,53 @@ def test_decode_encode_documents(span): value=AnyValue(string_value="LLM"), ), KeyValue( - key=f"{RETRIEVAL_DOCUMENTS}.0.{DOCUMENT_ID}", + key="retrieval.documents.0.document.id", value=AnyValue(string_value="d1"), ), KeyValue( - key=f"{RETRIEVAL_DOCUMENTS}.0.{DOCUMENT_CONTENT}", + key="retrieval.documents.0.document.content", value=AnyValue(string_value=content), ), KeyValue( - key=f"{RETRIEVAL_DOCUMENTS}.0.{DOCUMENT_SCORE}", + key="retrieval.documents.0.document.score", value=AnyValue(double_value=score), ), KeyValue( - key=f"{RETRIEVAL_DOCUMENTS}.1.{DOCUMENT_ID}", + key="retrieval.documents.1.document.id", value=AnyValue(string_value="d2"), ), KeyValue( - key=f"{RETRIEVAL_DOCUMENTS}.2.{DOCUMENT_CONTENT}", + key="retrieval.documents.2.document.content", value=AnyValue(string_value=content), ), KeyValue( - key=f"{RETRIEVAL_DOCUMENTS}.3.{DOCUMENT_SCORE}", + key="retrieval.documents.3.document.score", value=AnyValue(double_value=score), ), KeyValue( - key=f"{RETRIEVAL_DOCUMENTS}.4.{DOCUMENT_METADATA}", + key="retrieval.documents.4.document.metadata", value=AnyValue(string_value=json.dumps(document_metadata)), ), ] assert set(map(MessageToJson, otlp_span.attributes)) == set(map(MessageToJson, otlp_attributes)) decoded_span = decode(otlp_span) - assert decoded_span.attributes[RETRIEVAL_DOCUMENTS] == span.attributes[RETRIEVAL_DOCUMENTS] + assert ( + decoded_span.attributes["retrieval"]["documents"] + == span.attributes["retrieval"]["documents"] + ) def test_decode_encode_embeddings(span): text = str(random()) vector = list(np.random.rand(3)) attributes = { - EMBEDDING_EMBEDDINGS: [ - {EMBEDDING_VECTOR: vector}, - {EMBEDDING_VECTOR: vector, EMBEDDING_TEXT: text}, - {EMBEDDING_TEXT: text}, - ] + "embedding": { + "embeddings": [ + {"embedding": {"vector": vector}}, + {"embedding": {"vector": vector, "text": text}}, + {"embedding": {"text": text}}, + ], + }, } span = replace(span, attributes=attributes) otlp_span = encode(span) @@ -321,43 +313,52 @@ def test_decode_encode_embeddings(span): value=AnyValue(string_value="LLM"), ), KeyValue( - key=f"{EMBEDDING_EMBEDDINGS}.0.{EMBEDDING_VECTOR}", + key="embedding.embeddings.0.embedding.vector", value=AnyValue(array_value=ArrayValue(values=vector_otlp_values)), ), KeyValue( - key=f"{EMBEDDING_EMBEDDINGS}.1.{EMBEDDING_VECTOR}", + key="embedding.embeddings.1.embedding.vector", value=AnyValue(array_value=ArrayValue(values=vector_otlp_values)), ), KeyValue( - key=f"{EMBEDDING_EMBEDDINGS}.1.{EMBEDDING_TEXT}", + key="embedding.embeddings.1.embedding.text", value=AnyValue(string_value=text), ), KeyValue( - key=f"{EMBEDDING_EMBEDDINGS}.2.{EMBEDDING_TEXT}", + key="embedding.embeddings.2.embedding.text", value=AnyValue(string_value=text), ), ] assert set(map(MessageToJson, otlp_span.attributes)) == set(map(MessageToJson, otlp_attributes)) decoded_span = decode(otlp_span) - assert decoded_span.attributes[EMBEDDING_EMBEDDINGS] == span.attributes[EMBEDDING_EMBEDDINGS] + assert ( + decoded_span.attributes["embedding"]["embeddings"] + == span.attributes["embedding"]["embeddings"] + ) def test_decode_encode_message_tool_calls(span): attributes = { - LLM_OUTPUT_MESSAGES: [ - { - MESSAGE_ROLE: "user", - }, - { - MESSAGE_ROLE: "assistant", - MESSAGE_TOOL_CALLS: [ - { - TOOL_CALL_FUNCTION_NAME: "multiply", - TOOL_CALL_FUNCTION_ARGUMENTS_JSON: '{\n "a": 2,\n "b": 3\n}', - } - ], - }, - ] + "llm": { + "output_messages": [ + {"message": {"role": "user"}}, + { + "message": { + "role": "assistant", + "tool_calls": [ + { + "tool_call": { + "function": { + "name": "multiply", + "arguments": '{\n "a": 2,\n "b": 3\n}', + }, + }, + }, + ], + }, + }, + ], + }, } span = replace(span, attributes=attributes) otlp_span = encode(span) @@ -367,29 +368,34 @@ def test_decode_encode_message_tool_calls(span): value=AnyValue(string_value="LLM"), ), KeyValue( - key=f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", + key="llm.output_messages.0.message.role", value=AnyValue(string_value="user"), ), KeyValue( - key=f"{LLM_OUTPUT_MESSAGES}.1.{MESSAGE_ROLE}", + key="llm.output_messages.1.message.role", value=AnyValue(string_value="assistant"), ), KeyValue( - key=f"{LLM_OUTPUT_MESSAGES}.1.{MESSAGE_TOOL_CALLS}.0.{TOOL_CALL_FUNCTION_NAME}", + key="llm.output_messages.1.message.tool_calls.0.tool_call.function.name", value=AnyValue(string_value="multiply"), ), KeyValue( - key=f"{LLM_OUTPUT_MESSAGES}.1.{MESSAGE_TOOL_CALLS}.0.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", + key="llm.output_messages.1.message.tool_calls.0.tool_call.function.arguments", value=AnyValue(string_value='{\n "a": 2,\n "b": 3\n}'), ), ] assert set(map(MessageToJson, otlp_span.attributes)) == set(map(MessageToJson, otlp_attributes)) decoded_span = decode(otlp_span) - assert decoded_span.attributes[LLM_OUTPUT_MESSAGES] == span.attributes[LLM_OUTPUT_MESSAGES] + assert ( + decoded_span.attributes["llm"]["output_messages"] + == span.attributes["llm"]["output_messages"] + ) def test_decode_encode_llm_prompt_template_variables(span): - attributes = {LLM_PROMPT_TEMPLATE_VARIABLES: {"context_str": "123", "query_str": "321"}} + attributes = { + "llm": {"prompt_template": {"variables": {"context_str": "123", "query_str": "321"}}} + } span = replace(span, attributes=attributes) otlp_span = encode(span) otlp_attributes = [ @@ -398,29 +404,33 @@ def test_decode_encode_llm_prompt_template_variables(span): value=AnyValue(string_value="LLM"), ), KeyValue( - key=f"{LLM_PROMPT_TEMPLATE_VARIABLES}", - value=AnyValue(string_value=json.dumps(attributes[LLM_PROMPT_TEMPLATE_VARIABLES])), + key="llm.prompt_template.variables", + value=AnyValue( + string_value=json.dumps(attributes["llm"]["prompt_template"]["variables"]) + ), ), ] assert set(map(MessageToJson, otlp_span.attributes)) == set(map(MessageToJson, otlp_attributes)) decoded_span = decode(otlp_span) assert ( - decoded_span.attributes[LLM_PROMPT_TEMPLATE_VARIABLES] - == span.attributes[LLM_PROMPT_TEMPLATE_VARIABLES] + decoded_span.attributes["llm"]["prompt_template"]["variables"] + == span.attributes["llm"]["prompt_template"]["variables"] ) def test_decode_encode_tool_parameters(span): attributes = { - TOOL_PARAMETERS: { - "title": "multiply", - "properties": { - "a": {"type": "integer", "title": "A"}, - "b": {"title": "B", "type": "integer"}, + "tool": { + "parameters": { + "title": "multiply", + "properties": { + "a": {"type": "integer", "title": "A"}, + "b": {"title": "B", "type": "integer"}, + }, + "required": ["a", "b"], + "type": "object", }, - "required": ["a", "b"], - "type": "object", - } + }, } span = replace(span, attributes=attributes) otlp_span = encode(span) @@ -430,99 +440,13 @@ def test_decode_encode_tool_parameters(span): value=AnyValue(string_value="LLM"), ), KeyValue( - key=f"{TOOL_PARAMETERS}", - value=AnyValue(string_value=json.dumps(attributes[TOOL_PARAMETERS])), + key="tool.parameters", + value=AnyValue(string_value=json.dumps(attributes["tool"]["parameters"])), ), ] assert set(map(MessageToJson, otlp_span.attributes)) == set(map(MessageToJson, otlp_attributes)) decoded_span = decode(otlp_span) - assert decoded_span.attributes[TOOL_PARAMETERS] == span.attributes[TOOL_PARAMETERS] - - -@pytest.mark.parametrize( - "key_value_pairs,desired", - [ - ((), {}), - ((("1", 0),), {"1": 0}), - ((("1.2", 0),), {"1": {"2": 0}}), - ((("1.0.2", 0),), {"1": [{"2": 0}]}), - ((("1.0.2.3", 0),), {"1": [{"2": {"3": 0}}]}), - ((("1.0.2.0.3", 0),), {"1": [{"2": [{"3": 0}]}]}), - ((("1.0.2.0.3.4", 0),), {"1": [{"2": [{"3": {"4": 0}}]}]}), - ((("1.0.2.0.3.0.4", 0),), {"1": [{"2": [{"3": [{"4": 0}]}]}]}), - ((("1.2", 1), ("1", 0)), {"1": 0, "1.2": 1}), - ((("1.2.3", 1), ("1", 0)), {"1": 0, "1.2": {"3": 1}}), - ((("1.2.3", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": 1}}), - ((("1.2.0.3", 1), ("1", 0)), {"1": 0, "1.2": [{"3": 1}]}), - ((("1.2.3.4", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": {"4": 1}}}), - ((("1.0.2.3", 1), ("1.0.2", 0)), {"1": [{"2": 0, "2.3": 1}]}), - ((("1.2.0.3.4", 1), ("1", 0)), {"1": 0, "1.2": [{"3": {"4": 1}}]}), - ((("1.2.3.0.4", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": [{"4": 1}]}}), - ((("1.0.2.3.4", 1), ("1.0.2", 0)), {"1": [{"2": 0, "2.3": {"4": 1}}]}), - ((("1.0.2.3.4", 1), ("1.0.2.3", 0)), {"1": [{"2": {"3": 0, "3.4": 1}}]}), - ((("1.2.0.3.0.4", 1), ("1", 0)), {"1": 0, "1.2": [{"3": [{"4": 1}]}]}), - ((("1.2.3.0.4.5", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": [{"4": {"5": 1}}]}}), - ((("1.0.2.3.0.4", 1), ("1.0.2", 0)), {"1": [{"2": 0, "2.3": [{"4": 1}]}]}), - ((("1.0.2.3.4.5", 1), ("1.0.2.3", 0)), {"1": [{"2": {"3": 0, "3.4": {"5": 1}}}]}), - ((("1.0.2.0.3.4", 1), ("1.0.2.0.3", 0)), {"1": [{"2": [{"3": 0, "3.4": 1}]}]}), - ((("1.2.0.3.0.4.5", 1), ("1", 0)), {"1": 0, "1.2": [{"3": [{"4": {"5": 1}}]}]}), - ((("1.2.3.0.4.0.5", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": [{"4": [{"5": 1}]}]}}), - ((("1.0.2.3.0.4.5", 1), ("1.0.2", 0)), {"1": [{"2": 0, "2.3": [{"4": {"5": 1}}]}]}), - ((("1.0.2.3.4.0.5", 1), ("1.0.2.3", 0)), {"1": [{"2": {"3": 0, "3.4": [{"5": 1}]}}]}), - ((("1.0.2.0.3.4.5", 1), ("1.0.2.0.3", 0)), {"1": [{"2": [{"3": 0, "3.4": {"5": 1}}]}]}), - ((("1.0.2.0.3.4.5", 1), ("1.0.2.0.3.4", 0)), {"1": [{"2": [{"3": {"4": 0, "4.5": 1}}]}]}), - ( - (("1.0.2.3.4.5.6", 2), ("1.0.2.3.4", 1), ("1.0.2", 0)), - {"1": [{"2": 0, "2.3": {"4": 1, "4.5": {"6": 2}}}]}, - ), - ( - (("0.0.0.0.0", 4), ("0.0.0.0", 3), ("0.0.0", 2), ("0.0", 1), ("0", 0)), - {"0": 0, "0.0": 1, "0.0.0": 2, "0.0.0.0": 3, "0.0.0.0.0": 4}, - ), - ( - (("a.9999999.c", 2), ("a.9999999.b", 1), ("a.99999.b", 0)), - {"a": [{"b": 0}, {"b": 1, "c": 2}]}, - ), - ((("a", 0), ("c", 2), ("b", 1), ("d", 3)), {"a": 0, "b": 1, "c": 2, "d": 3}), - ( - (("a.b.c", 0), ("a.e", 2), ("a.b.d", 1), ("f", 3)), - {"a": {"b": {"c": 0, "d": 1}, "e": 2}, "f": 3}, - ), - ( - (("a.1.d", 3), ("a.0.d", 2), ("a.0.c", 1), ("a.b", 0)), - {"a.b": 0, "a": [{"c": 1, "d": 2}, {"d": 3}]}, - ), - ( - (("a.0.d", 3), ("a.0.c", 2), ("a.b", 1), ("a", 0)), - {"a": 0, "a.b": 1, "a.0": {"c": 2, "d": 3}}, - ), - ( - (("a.0.1.d", 3), ("a.0.0.c", 2), ("a", 1), ("a.b", 0)), - {"a.b": 0, "a": 1, "a.0": [{"c": 2}, {"d": 3}]}, - ), - ( - (("a.1.0.e", 3), ("a.0.0.d", 2), ("a.0.0.c", 1), ("a.b", 0)), - {"a.b": 0, "a": [{"0": {"c": 1, "d": 2}}, {"0": {"e": 3}}]}, - ), - ( - (("a.b.1.e.0.f", 2), ("a.b.0.c", 0), ("a.b.0.d.e.0.f", 1)), - {"a": {"b": [{"c": 0, "d": {"e": [{"f": 1}]}}, {"e": [{"f": 2}]}]}}, - ), - ], -) -def test_unflatten(key_value_pairs, desired): - actual = dict(_unflatten(key_value_pairs)) - assert actual == desired - actual = dict(_unflatten(reversed(key_value_pairs))) - assert actual == desired - - -@pytest.mark.parametrize("key_value_pairs,desired", [((("1.0.2", 0),), {"1": [{"2": 0}]})]) -def test_unflatten_separator(key_value_pairs, desired): - separator = str(random()) - key_value_pairs = ((key.replace(".", separator), value) for key, value in key_value_pairs) - actual = dict(_unflatten(key_value_pairs, separator)) - assert actual == desired + assert decoded_span.attributes["tool"]["parameters"] == span.attributes["tool"]["parameters"] @pytest.fixture @@ -539,7 +463,7 @@ def span() -> Span: span_kind=SpanKind.LLM, start_time=start_time, end_time=end_time, - attributes={}, + attributes={"openinference": {"span": {"kind": "LLM"}}}, status_code=SpanStatusCode.ERROR, status_message="xyz", events=[], diff --git a/tests/utilities/test_attributes.py b/tests/utilities/test_attributes.py new file mode 100644 index 0000000000..85a31792e4 --- /dev/null +++ b/tests/utilities/test_attributes.py @@ -0,0 +1,90 @@ +from random import random + +import pytest +from phoenix.utilities.attributes import unflatten + + +@pytest.mark.parametrize( + "key_value_pairs,desired", + [ + ((), {}), + ((("1", 0),), {"1": 0}), + ((("1.2", 0),), {"1": {"2": 0}}), + ((("1.0.2", 0),), {"1": [{"2": 0}]}), + ((("1.0.2.3", 0),), {"1": [{"2": {"3": 0}}]}), + ((("1.0.2.0.3", 0),), {"1": [{"2": [{"3": 0}]}]}), + ((("1.0.2.0.3.4", 0),), {"1": [{"2": [{"3": {"4": 0}}]}]}), + ((("1.0.2.0.3.0.4", 0),), {"1": [{"2": [{"3": [{"4": 0}]}]}]}), + ((("1.2", 1), ("1", 0)), {"1": 0, "1.2": 1}), + ((("1.2.3", 1), ("1", 0)), {"1": 0, "1.2": {"3": 1}}), + ((("1.2.3", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": 1}}), + ((("1.2.0.3", 1), ("1", 0)), {"1": 0, "1.2": [{"3": 1}]}), + ((("1.2.3.4", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": {"4": 1}}}), + ((("1.0.2.3", 1), ("1.0.2", 0)), {"1": [{"2": 0, "2.3": 1}]}), + ((("1.2.0.3.4", 1), ("1", 0)), {"1": 0, "1.2": [{"3": {"4": 1}}]}), + ((("1.2.3.0.4", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": [{"4": 1}]}}), + ((("1.0.2.3.4", 1), ("1.0.2", 0)), {"1": [{"2": 0, "2.3": {"4": 1}}]}), + ((("1.0.2.3.4", 1), ("1.0.2.3", 0)), {"1": [{"2": {"3": 0, "3.4": 1}}]}), + ((("1.2.0.3.0.4", 1), ("1", 0)), {"1": 0, "1.2": [{"3": [{"4": 1}]}]}), + ((("1.2.3.0.4.5", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": [{"4": {"5": 1}}]}}), + ((("1.0.2.3.0.4", 1), ("1.0.2", 0)), {"1": [{"2": 0, "2.3": [{"4": 1}]}]}), + ((("1.0.2.3.4.5", 1), ("1.0.2.3", 0)), {"1": [{"2": {"3": 0, "3.4": {"5": 1}}}]}), + ((("1.0.2.0.3.4", 1), ("1.0.2.0.3", 0)), {"1": [{"2": [{"3": 0, "3.4": 1}]}]}), + ((("1.2.0.3.0.4.5", 1), ("1", 0)), {"1": 0, "1.2": [{"3": [{"4": {"5": 1}}]}]}), + ((("1.2.3.0.4.0.5", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": [{"4": [{"5": 1}]}]}}), + ((("1.0.2.3.0.4.5", 1), ("1.0.2", 0)), {"1": [{"2": 0, "2.3": [{"4": {"5": 1}}]}]}), + ((("1.0.2.3.4.0.5", 1), ("1.0.2.3", 0)), {"1": [{"2": {"3": 0, "3.4": [{"5": 1}]}}]}), + ((("1.0.2.0.3.4.5", 1), ("1.0.2.0.3", 0)), {"1": [{"2": [{"3": 0, "3.4": {"5": 1}}]}]}), + ((("1.0.2.0.3.4.5", 1), ("1.0.2.0.3.4", 0)), {"1": [{"2": [{"3": {"4": 0, "4.5": 1}}]}]}), + ( + (("1.0.2.3.4.5.6", 2), ("1.0.2.3.4", 1), ("1.0.2", 0)), + {"1": [{"2": 0, "2.3": {"4": 1, "4.5": {"6": 2}}}]}, + ), + ( + (("0.0.0.0.0", 4), ("0.0.0.0", 3), ("0.0.0", 2), ("0.0", 1), ("0", 0)), + {"0": 0, "0.0": 1, "0.0.0": 2, "0.0.0.0": 3, "0.0.0.0.0": 4}, + ), + ( + (("a.9999999.c", 2), ("a.9999999.b", 1), ("a.99999.b", 0)), + {"a": [{"b": 0}, {"b": 1, "c": 2}]}, + ), + ((("a", 0), ("c", 2), ("b", 1), ("d", 3)), {"a": 0, "b": 1, "c": 2, "d": 3}), + ( + (("a.b.c", 0), ("a.e", 2), ("a.b.d", 1), ("f", 3)), + {"a": {"b": {"c": 0, "d": 1}, "e": 2}, "f": 3}, + ), + ( + (("a.1.d", 3), ("a.0.d", 2), ("a.0.c", 1), ("a.b", 0)), + {"a.b": 0, "a": [{"c": 1, "d": 2}, {"d": 3}]}, + ), + ( + (("a.0.d", 3), ("a.0.c", 2), ("a.b", 1), ("a", 0)), + {"a": 0, "a.b": 1, "a.0": {"c": 2, "d": 3}}, + ), + ( + (("a.0.1.d", 3), ("a.0.0.c", 2), ("a", 1), ("a.b", 0)), + {"a.b": 0, "a": 1, "a.0": [{"c": 2}, {"d": 3}]}, + ), + ( + (("a.1.0.e", 3), ("a.0.0.d", 2), ("a.0.0.c", 1), ("a.b", 0)), + {"a.b": 0, "a": [{"0": {"c": 1, "d": 2}}, {"0": {"e": 3}}]}, + ), + ( + (("a.b.1.e.0.f", 2), ("a.b.0.c", 0), ("a.b.0.d.e.0.f", 1)), + {"a": {"b": [{"c": 0, "d": {"e": [{"f": 1}]}}, {"e": [{"f": 2}]}]}}, + ), + ], +) +def test_unflatten(key_value_pairs, desired): + actual = dict(unflatten(key_value_pairs)) + assert actual == desired + actual = dict(unflatten(reversed(key_value_pairs))) + assert actual == desired + + +@pytest.mark.parametrize("key_value_pairs,desired", [((("1.0.2", 0),), {"1": [{"2": 0}]})]) +def test_unflatten_separator(key_value_pairs, desired): + separator = str(random()) + key_value_pairs = ((key.replace(".", separator), value) for key, value in key_value_pairs) + actual = dict(unflatten(key_value_pairs, separator=separator)) + assert actual == desired From bd518b23bee31f27f583711a129ec43fe2bf56c8 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Tue, 16 Apr 2024 18:47:09 -0700 Subject: [PATCH 02/46] change unixepoch to julianday --- src/phoenix/db/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index cc2ce36872..22e20329fc 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -224,7 +224,8 @@ def _(element: Any, compiler: Any, **kw: Any) -> Any: return compiler.process( # FIXME: We don't know why sqlite returns a slightly different value. # postgresql is correct because it matches the value computed by Python. - (func.unixepoch(end_time, "subsec") - func.unixepoch(start_time, "subsec")) * 1000, + # unixepoch() gives the same results. + (func.julianday(end_time) - func.julianday(start_time)) * 86_400_000, **kw, ) From b90adf9d747808bb7d172f41f7e024790a12781d Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Tue, 16 Apr 2024 19:08:38 -0700 Subject: [PATCH 03/46] add sqlean for unit tests --- pyproject.toml | 1 + tests/trace/conftest.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c30129f756..d5e4294ad0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,6 +144,7 @@ dependencies = [ "respx", # For OpenAI testing "nest-asyncio", # for executor testing "astunparse; python_version<'3.9'", + "sqlean.py", ] [tool.hatch.envs.type] diff --git a/tests/trace/conftest.py b/tests/trace/conftest.py index 476711c032..8482caa97b 100644 --- a/tests/trace/conftest.py +++ b/tests/trace/conftest.py @@ -1,6 +1,7 @@ from typing import Iterator import pytest +import sqlean from phoenix.db.models import Base from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker @@ -8,7 +9,7 @@ @pytest.fixture(scope="session") def session_maker() -> sessionmaker: - engine = create_engine("sqlite:///:memory:") + engine = create_engine("sqlite:///:memory:", module=sqlean) Base.metadata.create_all(engine) return sessionmaker(engine) From fa0b1f64352631a55693001519d6df47ae2b44d7 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Wed, 17 Apr 2024 08:27:47 -0700 Subject: [PATCH 04/46] add tests --- src/phoenix/trace/dsl/query.py | 19 +- tests/trace/conftest.py | 20 -- tests/trace/dsl/conftest.py | 228 +++++++++++++ tests/trace/dsl/test_helpers.py | 69 ++++ tests/trace/dsl/test_query.py | 556 ++++++++++++++++++++++++++++++++ 5 files changed, 869 insertions(+), 23 deletions(-) delete mode 100644 tests/trace/conftest.py create mode 100644 tests/trace/dsl/conftest.py create mode 100644 tests/trace/dsl/test_helpers.py create mode 100644 tests/trace/dsl/test_query.py diff --git a/src/phoenix/trace/dsl/query.py b/src/phoenix/trace/dsl/query.py index 780e9f0b1e..17152ae276 100644 --- a/src/phoenix/trace/dsl/query.py +++ b/src/phoenix/trace/dsl/query.py @@ -430,7 +430,7 @@ def explode(self, key: str, **kwargs: str) -> "SpanQuery": return replace(self, _explode=_explode) def concat(self, key: str, **kwargs: str) -> "SpanQuery": - _concat = Concatenation(key=key, kwargs=kwargs) + _concat = Concatenation(key=key, kwargs=kwargs, separator=self._concat.separator) return replace(self, _concat=_concat) def rename(self, **kwargs: str) -> "SpanQuery": @@ -455,11 +455,13 @@ def with_explode_primary_index_key(self, _: str) -> "SpanQuery": def __call__( self, session: Session, - project_name: str = DEFAULT_PROJECT_NAME, + project_name: Optional[str] = None, start_time: Optional[datetime] = None, stop_time: Optional[datetime] = None, root_spans_only: Optional[bool] = None, ) -> pd.DataFrame: + if not project_name: + project_name = DEFAULT_PROJECT_NAME if not (self._select or self._explode or self._concat): return _get_spans_dataframe( session, @@ -474,7 +476,7 @@ def __call__( conn = session.connection() index = _NAMES[self._index.key].label(self._add_tmp_suffix(self._index.key)) row_id = models.Span.id.label(self._pk_tmp_col_label) - stmt = stmt0_orig = ( + stmt = ( # We do not allow `group_by` anything other than `row_id` because otherwise # it's too complex for the post hoc processing step in pandas. select(row_id) @@ -482,6 +484,17 @@ def __call__( .join(models.Project) .where(models.Project.name == project_name) ) + if start_time: + stmt = stmt.where(start_time <= models.Span.start_time) + if stop_time: + stmt = stmt.where(models.Span.start_time < stop_time) + if root_spans_only: + parent = aliased(models.Span) + stmt = stmt.outerjoin( + parent, + models.Span.parent_id == parent.span_id, + ).where(parent.span_id == None) # noqa E711 + stmt0_orig = stmt stmt1_filter = None if self._filter: stmt = stmt1_filter = self._filter(stmt) diff --git a/tests/trace/conftest.py b/tests/trace/conftest.py deleted file mode 100644 index 8482caa97b..0000000000 --- a/tests/trace/conftest.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Iterator - -import pytest -import sqlean -from phoenix.db.models import Base -from sqlalchemy import create_engine -from sqlalchemy.orm import Session, sessionmaker - - -@pytest.fixture(scope="session") -def session_maker() -> sessionmaker: - engine = create_engine("sqlite:///:memory:", module=sqlean) - Base.metadata.create_all(engine) - return sessionmaker(engine) - - -@pytest.fixture() -def session(session_maker: sessionmaker) -> Iterator[Session]: - with session_maker.begin() as session: - yield session diff --git a/tests/trace/dsl/conftest.py b/tests/trace/dsl/conftest.py new file mode 100644 index 0000000000..75eff1afdf --- /dev/null +++ b/tests/trace/dsl/conftest.py @@ -0,0 +1,228 @@ +from datetime import datetime +from typing import Iterator + +import pytest +import sqlean +from phoenix.config import DEFAULT_PROJECT_NAME +from phoenix.db import models +from phoenix.db.models import Base +from sqlalchemy import create_engine, insert +from sqlalchemy.orm import Session, sessionmaker + + +@pytest.fixture(scope="session") +def session_maker() -> sessionmaker: + # `sqlean` is added to help with running the test on GitHub CI for Windows, + # because its version of SQLite doesn't have `JSON_EXTRACT`. + engine = create_engine("sqlite:///:memory:", module=sqlean, echo=True) + Base.metadata.create_all(engine) + session_maker = sessionmaker(engine) + with session_maker.begin() as session: + _insert_project_default(session) + _insert_project_abc(session) + return session_maker + + +@pytest.fixture() +def session(session_maker: sessionmaker) -> Iterator[Session]: + with session_maker.begin() as session: + yield session + + +def _insert_project_default(session: Session) -> None: + project_row_id = session.scalar( + insert(models.Project).values(name=DEFAULT_PROJECT_NAME).returning(models.Project.id) + ) + trace_row_id = session.scalar( + insert(models.Trace) + .values( + trace_id="0123", + project_rowid=project_row_id, + start_time=datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"), + end_time=datetime.fromisoformat("2021-01-01T00:01:00.000+00:00"), + ) + .returning(models.Trace.id) + ) + _ = session.scalar( + insert(models.Span) + .values( + trace_rowid=trace_row_id, + span_id="2345", + parent_id=None, + name="root span", + span_kind="UNKNOWN", + start_time=datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"), + end_time=datetime.fromisoformat("2021-01-01T00:00:30.000+00:00"), + attributes={ + "input": {"value": "210"}, + "output": {"value": "321"}, + }, + events=[], + status_code="OK", + status_message="okay", + cumulative_error_count=0, + cumulative_llm_token_count_prompt=0, + cumulative_llm_token_count_completion=0, + ) + .returning(models.Span.id) + ) + _ = session.scalar( + insert(models.Span) + .values( + trace_rowid=trace_row_id, + span_id="4567", + parent_id="2345", + name="retriever span", + span_kind="RETRIEVER", + start_time=datetime.fromisoformat("2021-01-01T00:00:05.000+00:00"), + end_time=datetime.fromisoformat("2021-01-01T00:00:20.000+00:00"), + attributes={ + "input": { + "value": "xyz", + }, + "retrieval": { + "documents": [ + {"document": {"content": "A", "score": 1}}, + {"document": {"content": "B", "score": 2}}, + {"document": {"content": "C", "score": 3}}, + ], + }, + }, + events=[], + status_code="OK", + status_message="okay", + cumulative_error_count=0, + cumulative_llm_token_count_prompt=0, + cumulative_llm_token_count_completion=0, + ) + .returning(models.Span.id) + ) + + +def _insert_project_abc(session: Session) -> None: + project_row_id = session.scalar( + insert(models.Project).values(name="abc").returning(models.Project.id) + ) + trace_row_id = session.scalar( + insert(models.Trace) + .values( + trace_id="012", + project_rowid=project_row_id, + start_time=datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"), + end_time=datetime.fromisoformat("2021-01-01T00:01:00.000+00:00"), + ) + .returning(models.Trace.id) + ) + _ = session.scalar( + insert(models.Span) + .values( + trace_rowid=trace_row_id, + span_id="234", + parent_id="123", + name="root span", + span_kind="UNKNOWN", + start_time=datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"), + end_time=datetime.fromisoformat("2021-01-01T00:00:30.000+00:00"), + attributes={ + "input": {"value": "210"}, + "output": {"value": "321"}, + }, + events=[], + status_code="OK", + status_message="okay", + cumulative_error_count=1, + cumulative_llm_token_count_prompt=100, + cumulative_llm_token_count_completion=200, + ) + .returning(models.Span.id) + ) + _ = session.scalar( + insert(models.Span) + .values( + trace_rowid=trace_row_id, + span_id="345", + parent_id="234", + name="embedding span", + span_kind="EMBEDDING", + start_time=datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"), + end_time=datetime.fromisoformat("2021-01-01T00:00:05.000+00:00"), + attributes={ + "metadata": { + "a.b.c": 123, + "1.2.3": "abc", + }, + "embedding": { + "model_name": "xyz", + "embeddings": [ + {"embedding": {"vector": [1, 2, 3], "text": "123"}}, + {"embedding": {"vector": [2, 3, 4], "text": "234"}}, + ], + }, + }, + events=[], + status_code="OK", + status_message="no problemo", + cumulative_error_count=0, + cumulative_llm_token_count_prompt=0, + cumulative_llm_token_count_completion=0, + ) + .returning(models.Span.id) + ) + _ = session.scalar( + insert(models.Span) + .values( + trace_rowid=trace_row_id, + span_id="456", + parent_id="234", + name="retriever span", + span_kind="RETRIEVER", + start_time=datetime.fromisoformat("2021-01-01T00:00:05.000+00:00"), + end_time=datetime.fromisoformat("2021-01-01T00:00:20.000+00:00"), + attributes={ + "input": { + "value": "xyz", + }, + "retrieval": { + "documents": [ + {"document": {"content": "A", "score": 1}}, + {"document": {"content": "B", "score": 2}}, + {"document": {"content": "C", "score": 3}}, + ], + }, + }, + events=[], + status_code="OK", + status_message="okay", + cumulative_error_count=0, + cumulative_llm_token_count_prompt=0, + cumulative_llm_token_count_completion=0, + ) + .returning(models.Span.id) + ) + _ = session.scalar( + insert(models.Span) + .values( + trace_rowid=trace_row_id, + span_id="567", + parent_id="234", + name="llm span", + span_kind="LLM", + start_time=datetime.fromisoformat("2021-01-01T00:00:20.000+00:00"), + end_time=datetime.fromisoformat("2021-01-01T00:00:30.000+00:00"), + attributes={ + "llm": { + "token_count": { + "prompt": 100, + "completion": 200, + }, + }, + }, + events=[], + status_code="ERROR", + status_message="uh-oh", + cumulative_error_count=1, + cumulative_llm_token_count_prompt=100, + cumulative_llm_token_count_completion=200, + ) + .returning(models.Span.id) + ) diff --git a/tests/trace/dsl/test_helpers.py b/tests/trace/dsl/test_helpers.py new file mode 100644 index 0000000000..50a186c7f2 --- /dev/null +++ b/tests/trace/dsl/test_helpers.py @@ -0,0 +1,69 @@ +from datetime import datetime +from typing import List, Optional + +import pandas as pd +from pandas.testing import assert_frame_equal +from phoenix.trace.dsl import SpanQuery +from phoenix.trace.dsl.helpers import get_qa_with_reference, get_retrieved_documents +from sqlalchemy.orm import Session + + +def test_get_retrieved_documents(session: Session) -> None: + mock = _Mock(session) + expected = pd.DataFrame( + { + "context.span_id": ["4567", "4567", "4567"], + "document_position": [0, 1, 2], + "context.trace_id": ["0123", "0123", "0123"], + "input": ["xyz", "xyz", "xyz"], + "reference": ["A", "B", "C"], + "document_score": [1, 2, 3], + } + ).set_index(["context.span_id", "document_position"]) + actual = get_retrieved_documents(mock) + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + + +def test_get_qa_with_reference(session: Session) -> None: + mock = _Mock(session) + expected = pd.DataFrame( + { + "context.span_id": ["2345"], + "input": ["210"], + "output": ["321"], + "reference": ["A\n\nB\n\nC"], + } + ).set_index("context.span_id") + actual = get_qa_with_reference(mock) + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + + +class _Mock: + def __init__(self, session: Session) -> None: + self.session = session + + def query_spans( + self, + *span_queries: SpanQuery, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + project_name: Optional[str] = None, + ) -> List[pd.DataFrame]: + ans = [ + sq( + self.session, + start_time=start_time, + stop_time=stop_time, + project_name=project_name, + ) + for sq in span_queries + ] + if len(ans) == 1: + return ans[0] + return ans diff --git a/tests/trace/dsl/test_query.py b/tests/trace/dsl/test_query.py new file mode 100644 index 0000000000..92a758cc35 --- /dev/null +++ b/tests/trace/dsl/test_query.py @@ -0,0 +1,556 @@ +from datetime import datetime + +import pandas as pd +from pandas.testing import assert_frame_equal +from phoenix.trace.dsl import SpanQuery +from sqlalchemy.orm import Session + + +def test_select_all(session: Session) -> None: + # i.e. `get_spans_dataframe` + sq = SpanQuery() + expected = pd.DataFrame( + { + "context.span_id": ["234", "345", "456", "567"], + "context.trace_id": ["012", "012", "012", "012"], + "parent_id": ["123", "234", "234", "234"], + "name": ["root span", "embedding span", "retriever span", "llm span"], + "span_kind": ["UNKNOWN", "EMBEDDING", "RETRIEVER", "LLM"], + "status_code": ["OK", "OK", "OK", "ERROR"], + "status_message": ["okay", "no problemo", "okay", "uh-oh"], + "start_time": [ + datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"), + datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"), + datetime.fromisoformat("2021-01-01T00:00:05.000+00:00"), + datetime.fromisoformat("2021-01-01T00:00:20.000+00:00"), + ], + "end_time": [ + datetime.fromisoformat("2021-01-01T00:00:30.000+00:00"), + datetime.fromisoformat("2021-01-01T00:00:05.000+00:00"), + datetime.fromisoformat("2021-01-01T00:00:20.000+00:00"), + datetime.fromisoformat("2021-01-01T00:00:30.000+00:00"), + ], + "attributes.input.value": ["210", None, "xyz", None], + "attributes.output.value": ["321", None, None, None], + "attributes.llm.token_count.prompt": [None, None, None, 100.0], + "attributes.llm.token_count.completion": [None, None, None, 200.0], + "attributes.metadata": [None, {"a.b.c": 123, "1.2.3": "abc"}, None, None], + "attributes.embedding.model_name": [None, "xyz", None, None], + "attributes.embedding.embeddings": [ + None, + [ + {"embedding.vector": [1, 2, 3], "embedding.text": "123"}, + {"embedding.vector": [2, 3, 4], "embedding.text": "234"}, + ], + None, + None, + ], + "attributes.retrieval.documents": [ + None, + None, + [ + {"document.content": "A", "document.score": 1.0}, + {"document.content": "B", "document.score": 2.0}, + {"document.content": "C", "document.score": 3.0}, + ], + None, + ], + "events": [[], [], [], []], + } + ).set_index("context.span_id", drop=False) + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + +def test_select(session: Session) -> None: + sq = SpanQuery().select("name", tcp="llm.token_count.prompt") + expected = pd.DataFrame( + { + "context.span_id": ["234", "345", "456", "567"], + "name": ["root span", "embedding span", "retriever span", "llm span"], + "tcp": [None, None, None, 100.0], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + sq = SpanQuery().select("name", span_id="parent_id") + expected = pd.DataFrame( + { + "context.span_id": ["123", "234", "234", "234"], + "name": ["root span", "embedding span", "retriever span", "llm span"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + sq = SpanQuery().select("span_id").with_index("trace_id") + expected = pd.DataFrame( + { + "context.trace_id": ["012", "012", "012", "012"], + "context.span_id": ["234", "345", "456", "567"], + } + ).set_index("context.trace_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1).sort_values("context.span_id"), + expected.sort_index().sort_index(axis=1).sort_values("context.span_id"), + ) + del sq, actual, expected + + +def test_default_project(session: Session) -> None: + sq = SpanQuery().select( + "name", + **{"Latency (milliseconds)": "latency_ms"}, + ) + expected = pd.DataFrame( + { + "context.span_id": ["2345"], + "name": ["root span"], + "Latency (milliseconds)": [30000.0], + } + ).set_index("context.span_id") + actual = sq(session, root_spans_only=True) + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + +def test_root_spans_only(session: Session) -> None: + sq = SpanQuery().select( + "name", + **{"Latency (milliseconds)": "latency_ms"}, + ) + expected = pd.DataFrame( + { + "context.span_id": ["234"], + "name": ["root span"], + "Latency (milliseconds)": [30000.0], + } + ).set_index("context.span_id") + actual = sq( + session, + project_name="abc", + root_spans_only=True, + ) + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + +def test_start_time(session: Session) -> None: + sq = SpanQuery().select("name") + expected = pd.DataFrame( + { + "context.span_id": ["567"], + "name": ["llm span"], + } + ).set_index("context.span_id") + actual = sq( + session, + project_name="abc", + start_time=datetime.fromisoformat( + "2021-01-01T00:00:20.000+00:00", + ), + ) + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + +def test_stop_time(session: Session) -> None: + sq = SpanQuery().select("name") + expected = pd.DataFrame( + { + "context.span_id": ["234", "345"], + "name": ["root span", "embedding span"], + } + ).set_index("context.span_id") + actual = sq( + session, + project_name="abc", + stop_time=datetime.fromisoformat( + "2021-01-01T00:00:01.000+00:00", + ), + ) + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + +def test_filter_on_latency(session: Session) -> None: + sq = ( + SpanQuery() + .select( + "name", + **{"Latency (milliseconds)": "latency_ms"}, + ) + .where("9_000 < latency_ms < 11_000") + ) + expected = pd.DataFrame( + { + "context.span_id": ["567"], + "name": ["llm span"], + "Latency (milliseconds)": [10000.0], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + +def test_filter_on_metadata(session: Session) -> None: + sq = ( + SpanQuery() + .select("embedding.model_name") + .where( + "metadata['a.b.c'] == 123", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["345"], + "embedding.model_name": ["xyz"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + sq = ( + SpanQuery() + .select("embedding.model_name") + .where( + "metadata['1.2.3'] == 'abc'", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["345"], + "embedding.model_name": ["xyz"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + +def test_filter_on_span_id(session: Session) -> None: + sq = ( + SpanQuery() + .select("embedding.model_name") + .where( + "span_id == '345'", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["345"], + "embedding.model_name": ["xyz"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + sq = ( + SpanQuery() + .select("embedding.model_name") + .where( + "span_id in ['345', '567']", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["345", "567"], + "embedding.model_name": ["xyz", None], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + +def test_filter_on_trace_id(session: Session) -> None: + sq = ( + SpanQuery() + .select("metadata") + .where( + "trace_id == '012'", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["234", "345", "456", "567"], + "metadata": [None, {"a.b.c": 123, "1.2.3": "abc"}, None, None], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + sq = ( + SpanQuery() + .select("metadata") + .where( + "trace_id in ('012',)", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["234", "345", "456", "567"], + "metadata": [None, {"a.b.c": 123, "1.2.3": "abc"}, None, None], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + +def test_explode(session: Session) -> None: + sq = SpanQuery().explode("embedding.embeddings") + expected = pd.DataFrame( + { + "context.span_id": ["345", "345"], + "position": [0, 1], + "embedding.text": ["123", "234"], + "embedding.vector": [[1, 2, 3], [2, 3, 4]], + } + ).set_index(["context.span_id", "position"]) + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + sq = ( + SpanQuery() + .select("embedding.model_name") + .explode( + "embedding.embeddings", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["345", "345"], + "position": [0, 1], + "embedding.model_name": ["xyz", "xyz"], + "embedding.text": ["123", "234"], + "embedding.vector": [[1, 2, 3], [2, 3, 4]], + } + ).set_index(["context.span_id", "position"]) + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + sq = SpanQuery().explode( + "retrieval.documents", + content="document.content", + score="document.score", + ) + expected = pd.DataFrame( + { + "context.span_id": ["456", "456", "456"], + "document_position": [0, 1, 2], + "content": ["A", "B", "C"], + "score": [1, 2, 3], + } + ).set_index(["context.span_id", "document_position"]) + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + sq = ( + SpanQuery() + .select("trace_id") + .explode( + "retrieval.documents", + **{ + "콘텐츠": "document.content", + "スコア": "document.score", + }, + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["456", "456", "456"], + "document_position": [0, 1, 2], + "context.trace_id": ["012", "012", "012"], + "콘텐츠": ["A", "B", "C"], + "スコア": [1, 2, 3], + } + ).set_index(["context.span_id", "document_position"]) + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + +def test_concat(session: Session) -> None: + sq = SpanQuery().concat( + "retrieval.documents", + content="document.content", + ) + expected = pd.DataFrame( + { + "context.span_id": ["456"], + "content": ["A\n\nB\n\nC"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + sq = ( + SpanQuery() + .select("trace_id") + .concat( + "retrieval.documents", + content="document.content", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["456"], + "context.trace_id": ["012"], + "content": ["A\n\nB\n\nC"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + sq = ( + SpanQuery() + .with_index("name") + .with_concat_separator(",") + .concat( + "embedding.embeddings", + text="embedding.text", + ) + ) + expected = pd.DataFrame( + { + "name": ["embedding span"], + "text": ["123,234"], + } + ).set_index("name") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + +def test_explode_and_concat(session: Session) -> None: + sq = ( + SpanQuery() + .concat( + "retrieval.documents", + content="document.content", + ) + .explode( + "retrieval.documents", + score="document.score", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["456", "456", "456"], + "document_position": [0, 1, 2], + "content": ["A\n\nB\n\nC", "A\n\nB\n\nC", "A\n\nB\n\nC"], + "score": [1, 2, 3], + } + ).set_index(["context.span_id", "document_position"]) + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + sq = ( + SpanQuery() + .select("trace_id") + .concat( + "retrieval.documents", + **{"콘텐츠": "document.content"}, + ) + .explode( + "retrieval.documents", + **{"スコア": "document.score"}, + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["456", "456", "456"], + "document_position": [0, 1, 2], + "context.trace_id": ["012", "012", "012"], + "콘텐츠": ["A\n\nB\n\nC", "A\n\nB\n\nC", "A\n\nB\n\nC"], + "スコア": [1, 2, 3], + } + ).set_index(["context.span_id", "document_position"]) + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected From a0df97dd97ab822ad53ca93ca0e9287907a8a109 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Wed, 17 Apr 2024 08:54:11 -0700 Subject: [PATCH 05/46] clean up --- src/phoenix/server/app.py | 3 +- src/phoenix/session/client.py | 7 ++- tests/trace/dsl/test_query.py | 96 ++++++++++++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 5 deletions(-) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 135edc5cc9..2b3bf377e2 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -241,7 +241,7 @@ def create_app( ) ) initial_batch_of_evaluations = () if initial_evaluations is None else initial_evaluations - engine = create_engine(database, echo=True) + engine = create_engine(database) db = _db(engine) graphql = GraphQLWithContext( db=db, @@ -299,7 +299,6 @@ def create_app( ), ], ) - app.state.db = db app.state.traces = traces app.state.read_only = read_only app.state.db = db diff --git a/src/phoenix/session/client.py b/src/phoenix/session/client.py index 502ac99276..8d52724215 100644 --- a/src/phoenix/session/client.py +++ b/src/phoenix/session/client.py @@ -24,7 +24,12 @@ class Client(TraceDataExtractor): - def __init__(self, *, endpoint: Optional[str] = None, **kwargs: Any): + def __init__( + self, + *, + endpoint: Optional[str] = None, + **kwargs: Any, # for backward-compatibility + ): """ Client for connecting to a Phoenix server. diff --git a/tests/trace/dsl/test_query.py b/tests/trace/dsl/test_query.py index 92a758cc35..44b4465977 100644 --- a/tests/trace/dsl/test_query.py +++ b/tests/trace/dsl/test_query.py @@ -199,6 +199,98 @@ def test_stop_time(session: Session) -> None: del sq, actual, expected +def test_filter_for_none(session: Session) -> None: + sq = ( + SpanQuery() + .select("name") + .where( + "parent_id is None", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": [], + "name": [], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + check_dtype=False, + check_column_type=False, + check_frame_type=False, + check_index_type=False, + ) + del sq, actual, expected + + sq = ( + SpanQuery() + .select("name") + .where( + "output.value is not None", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["234"], + "name": ["root span"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + check_dtype=False, + check_column_type=False, + check_frame_type=False, + check_index_type=False, + ) + del sq, actual, expected + + +def test_filter_on_substring(session: Session) -> None: + sq = ( + SpanQuery() + .select("name") + .where( + "'y' in input.value", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["456"], + "name": ["retriever span"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + sq = ( + SpanQuery() + .select("name") + .where( + "'y' not in input.value", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["234"], + "name": ["root span"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + def test_filter_on_latency(session: Session) -> None: sq = ( SpanQuery() @@ -228,7 +320,7 @@ def test_filter_on_metadata(session: Session) -> None: SpanQuery() .select("embedding.model_name") .where( - "metadata['a.b.c'] == 123", + "12 - metadata['a.b.c'] == -111", ) ) expected = pd.DataFrame( @@ -248,7 +340,7 @@ def test_filter_on_metadata(session: Session) -> None: SpanQuery() .select("embedding.model_name") .where( - "metadata['1.2.3'] == 'abc'", + "'b' in metadata['1.2.3']", ) ) expected = pd.DataFrame( From c13976177271184b7bdf0d528fc8f9aba75939a8 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Wed, 17 Apr 2024 10:38:21 -0700 Subject: [PATCH 06/46] clean up --- src/phoenix/trace/dsl/filter.py | 19 +++++++- tests/trace/dsl/test_query.py | 78 +++++++++++++++++++++++++++++---- 2 files changed, 87 insertions(+), 10 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index f24eaae02f..377e836972 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -380,7 +380,24 @@ def visit_Name(self, node: ast.Name) -> typing.Any: source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) if source_segment in _STRING_NAMES or source_segment in _FLOAT_NAMES: return node - raise SyntaxError(f"invalid expression: {source_segment}") + name = source_segment + attr = "as_float" if name in _FLOAT_ATTRIBUTES else "as_string" + elts = [ast.Constant(value=name, kind=None)] + return ast.Call( + func=ast.Attribute( + value=ast.Subscript( + value=ast.Name(id="attributes", ctx=ast.Load()), + slice=ast.List(elts=elts, ctx=ast.Load()) + if sys.version_info >= (3, 9) + else ast.Index(value=ast.List(elts=elts, ctx=ast.Load())), + ctx=ast.Load(), + ), + attr=attr, + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) def visit_Subscript(self, node: ast.Subscript) -> typing.Any: if _is_metadata(node): diff --git a/tests/trace/dsl/test_query.py b/tests/trace/dsl/test_query.py index 44b4465977..cbf63498bf 100644 --- a/tests/trace/dsl/test_query.py +++ b/tests/trace/dsl/test_query.py @@ -111,6 +111,24 @@ def test_select(session: Session) -> None: del sq, actual, expected +def test_select_nonexistent(session: Session) -> None: + sq = SpanQuery().select("name", "opq", "opq.rst") + expected = pd.DataFrame( + { + "context.span_id": ["234", "345", "456", "567"], + "name": ["root span", "embedding span", "retriever span", "llm span"], + "opq": [None, None, None, None], + "opq.rst": [None, None, None, None], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + def test_default_project(session: Session) -> None: sq = SpanQuery().select( "name", @@ -241,18 +259,14 @@ def test_filter_for_none(session: Session) -> None: assert_frame_equal( actual.sort_index().sort_index(axis=1), expected.sort_index().sort_index(axis=1), - check_dtype=False, - check_column_type=False, - check_frame_type=False, - check_index_type=False, ) del sq, actual, expected -def test_filter_on_substring(session: Session) -> None: +def test_filter_for_substring(session: Session) -> None: sq = ( SpanQuery() - .select("name") + .select("input.value") .where( "'y' in input.value", ) @@ -260,7 +274,7 @@ def test_filter_on_substring(session: Session) -> None: expected = pd.DataFrame( { "context.span_id": ["456"], - "name": ["retriever span"], + "input.value": ["xyz"], } ).set_index("context.span_id") actual = sq(session, project_name="abc") @@ -272,7 +286,7 @@ def test_filter_on_substring(session: Session) -> None: sq = ( SpanQuery() - .select("name") + .select("input.value") .where( "'y' not in input.value", ) @@ -280,7 +294,53 @@ def test_filter_on_substring(session: Session) -> None: expected = pd.DataFrame( { "context.span_id": ["234"], - "name": ["root span"], + "input.value": ["210"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + +def test_filter_on_nonexistent(session: Session) -> None: + sq = ( + SpanQuery() + .select("name") + .where( + "opq is not None or opq.rst is not None", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": [], + "name": [], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + check_dtype=False, + check_column_type=False, + check_frame_type=False, + check_index_type=False, + ) + del sq, actual, expected + + sq = ( + SpanQuery() + .select("name") + .where( + "opq is None or opq.rst is None", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["234", "345", "456", "567"], + "name": ["root span", "embedding span", "retriever span", "llm span"], } ).set_index("context.span_id") actual = sq(session, project_name="abc") From ce7d9e178e496df59e0e1afaa69f152d5a87c4a1 Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Wed, 17 Apr 2024 11:32:30 -0700 Subject: [PATCH 07/46] add comments --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d5e4294ad0..cefbda368e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,8 +143,8 @@ dependencies = [ "httpx", # For OpenAI testing "respx", # For OpenAI testing "nest-asyncio", # for executor testing - "astunparse; python_version<'3.9'", - "sqlean.py", + "astunparse; python_version<'3.9'", # `ast.unparse(...)` is only available starting with Python 3.9 + "sqlean.py", # for running GitHub CI on Windows, because its SQLite doesn't support JSON_EXTRACT(...). ] [tool.hatch.envs.type] From aa7f0d6f6100823d4752364fefb0aab91f1f28aa Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Wed, 17 Apr 2024 14:01:09 -0700 Subject: [PATCH 08/46] fix cumulative token count --- src/phoenix/trace/dsl/filter.py | 10 ++++++++-- tests/trace/dsl/test_query.py | 20 ++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 377e836972..5ec0abc2b2 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -51,6 +51,12 @@ "attributes": models.Span.attributes, "events": models.Span.events, } +_BACKWARD_COMPATIBILITY_REPLACEMENTS = { + # for backward-compatibility with the previous implementation + "cumulative_token_count.completion": "cumulative_llm_token_count_completion", + "cumulative_token_count.prompt": "cumulative_llm_token_count_prompt", + "cumulative_token_count.total": "cumulative_llm_token_count_total", +} # TODO(persistence): remove this protocol @@ -356,8 +362,8 @@ def visit_Call(self, node: ast.Call) -> typing.Any: def visit_Attribute(self, node: ast.Attribute) -> typing.Any: source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) - if source_segment in _NAMES: - return node + if replacement := _BACKWARD_COMPATIBILITY_REPLACEMENTS.get(source_segment): + return ast.Name(id=replacement, ctx=ast.Load()) attr = "as_float" if source_segment in _FLOAT_ATTRIBUTES else "as_string" elts = [ast.Constant(value=part, kind=None) for part in _split(source_segment)] return ast.Call( diff --git a/tests/trace/dsl/test_query.py b/tests/trace/dsl/test_query.py index cbf63498bf..376f4f801a 100644 --- a/tests/trace/dsl/test_query.py +++ b/tests/trace/dsl/test_query.py @@ -375,6 +375,26 @@ def test_filter_on_latency(session: Session) -> None: del sq, actual, expected +def test_filter_on_cumulative_token_count(session: Session) -> None: + sq = ( + SpanQuery() + .select("name") + .where("290 < cumulative_token_count.total < 310 and llm.token_count.prompt is None") + ) + expected = pd.DataFrame( + { + "context.span_id": ["234"], + "name": ["root span"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + + def test_filter_on_metadata(session: Session) -> None: sq = ( SpanQuery() From 4ad058b05f37a9e3a13a2f0723ed6be46ced993d Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Wed, 17 Apr 2024 14:11:03 -0700 Subject: [PATCH 09/46] fix type-cast functions --- src/phoenix/trace/dsl/filter.py | 2 +- tests/trace/dsl/test_query.py | 40 +++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 5ec0abc2b2..5c78ad60dc 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -358,7 +358,7 @@ def visit_Call(self, node: ast.Call) -> typing.Any: return _cast_as("Float", arg) if node.func.id in ("str",) and _is_float(arg): return _cast_as("String", arg) - return node + return arg def visit_Attribute(self, node: ast.Attribute) -> typing.Any: source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) diff --git a/tests/trace/dsl/test_query.py b/tests/trace/dsl/test_query.py index 376f4f801a..1df6f71093 100644 --- a/tests/trace/dsl/test_query.py +++ b/tests/trace/dsl/test_query.py @@ -416,6 +416,26 @@ def test_filter_on_metadata(session: Session) -> None: ) del sq, actual, expected + sq = ( + SpanQuery() + .select("embedding.model_name") + .where( + "12 - int(metadata['a.b.c']) == -111", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["345"], + "embedding.model_name": ["xyz"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + sq = ( SpanQuery() .select("embedding.model_name") @@ -436,6 +456,26 @@ def test_filter_on_metadata(session: Session) -> None: ) del sq, actual, expected + sq = ( + SpanQuery() + .select("embedding.model_name") + .where( + "'b' in str(metadata['1.2.3'])", + ) + ) + expected = pd.DataFrame( + { + "context.span_id": ["345"], + "embedding.model_name": ["xyz"], + } + ).set_index("context.span_id") + actual = sq(session, project_name="abc") + assert_frame_equal( + actual.sort_index().sort_index(axis=1), + expected.sort_index().sort_index(axis=1), + ) + del sq, actual, expected + def test_filter_on_span_id(session: Session) -> None: sq = ( From b116600838fc78a9ed4493769c9444be6bdd7f32 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 18 Apr 2024 13:31:47 -0700 Subject: [PATCH 10/46] add notebook --- integration-tests/eval_query_testing.ipynb | 192 +++++++++++++++++++++ 1 file changed, 192 insertions(+) create mode 100644 integration-tests/eval_query_testing.ipynb diff --git a/integration-tests/eval_query_testing.ipynb b/integration-tests/eval_query_testing.ipynb new file mode 100644 index 0000000000..4f4e06751b --- /dev/null +++ b/integration-tests/eval_query_testing.ipynb @@ -0,0 +1,192 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from phoenix.db import models\n", + "from sqlalchemy import and_, create_engine, select\n", + "from sqlalchemy.orm import aliased, sessionmaker" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# PostgresSession = sessionmaker(\n", + "# create_engine(\n", + "# \"postgresql+psycopg://localhost:5432/postgres?user=postgres&password=mysecretpassword\",\n", + "# echo=True,\n", + "# ),\n", + "# expire_on_commit=False,\n", + "# )\n", + "SqliteSession = sessionmaker(\n", + " create_engine(\"sqlite:////Users/xandersong/.phoenix/phoenix.db\", echo=True),\n", + " expire_on_commit=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "orig_endpoint = \"http://127.0.0.1:6008\"\n", + "postgres_endpoint = \"http://127.0.0.1:6007\"\n", + "sqlite_endpoint = \"http://127.0.0.1:6006\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# SELECT * FROM spans\n", + "# JOIN (\n", + "# SELECT spans.id, sa.score, sa.label FROM spans\n", + "# JOIN span_annotations sa on spans.id = sa.span_rowid\n", + "# ) B ON spans.id == B.id\n", + "# WHERE B.score > 0.5 AND B.label == 'factual' AND spans.name == 'query';" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with SqliteSession() as session:\n", + " df = pd.read_sql(\n", + " select(models.Span.id)\n", + " .join(models.SpanAnnotation)\n", + " .where(models.SpanAnnotation.score > 0.5)\n", + " .where(models.SpanAnnotation.label == \"factual\"),\n", + " session.connection(),\n", + " )\n", + "df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "evals[\"Q&A Correctness\"].label == \"correct\" and evals[\"Hallucination\"].label == \"hallucinated\"\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with SqliteSession() as session:\n", + " A = aliased(models.SpanAnnotation)\n", + " B = aliased(models.SpanAnnotation)\n", + " df = pd.read_sql(\n", + " select(models.Span.id, A.name, A.label, B.name, B.label)\n", + " .join(A)\n", + " .join(B)\n", + " .where(\n", + " and_(\n", + " A.name == \"Q&A Correctness\",\n", + " A.label == \"correct\",\n", + " B.name == \"Hallucination\",\n", + " B.label == \"hallucinated\",\n", + " ),\n", + " ),\n", + " session.connection(),\n", + " )\n", + "df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "evals[\"Q&A Correctness\"].label == \"correct\" and evals[\"Hallucination\"].score == 0\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with SqliteSession() as session:\n", + " A = aliased(models.SpanAnnotation)\n", + " B = aliased(models.SpanAnnotation)\n", + " df = pd.read_sql(\n", + " select(models.Span.id, A.name, A.label, B.name, B.label)\n", + " .join(A)\n", + " .join(B)\n", + " .where(\n", + " and_(\n", + " A.name == \"Q&A Correctness\",\n", + " A.label == \"correct\",\n", + " B.name == \"Hallucination\",\n", + " B.score == 0,\n", + " ),\n", + " ),\n", + " session.connection(),\n", + " )\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with SqliteSession() as session:\n", + " A = aliased(models.SpanAnnotation)\n", + " df = pd.read_sql(\n", + " select(models.Span.id, A.name, A.label)\n", + " .join(A)\n", + " .where(\n", + " and_(\n", + " A.name == \"Q&A Correctness\",\n", + " A.label == \"correct\",\n", + " ),\n", + " ),\n", + " session.connection(),\n", + " )\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# SELECT span_annotations.span_rowid,\n", + "# MAX(CASE WHEN name = 'Hallucination' and score = 0 THEN 1 ELSE 0 END) AS A,\n", + "# MAX(CASE WHEN name = 'Q&A Correctness' and label = 'correct' THEN 1 ELSE 0 END) AS B\n", + "# FROM span_annotations\n", + "# WHERE name in ('Hallucination', 'Q&A Correctness')\n", + "# GROUP BY span_annotations.span_rowid\n", + "# HAVING A = 1 and B = 1\n", + "# ORDER BY span_rowid;" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 0b645e916e7946d7aa7dcb11d50158fbed818173 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 18 Apr 2024 15:38:01 -0700 Subject: [PATCH 11/46] add postgres --- integration-tests/eval_query_testing.ipynb | 141 +++++++++------------ 1 file changed, 62 insertions(+), 79 deletions(-) diff --git a/integration-tests/eval_query_testing.ipynb b/integration-tests/eval_query_testing.ipynb index 4f4e06751b..9292682c17 100644 --- a/integration-tests/eval_query_testing.ipynb +++ b/integration-tests/eval_query_testing.ipynb @@ -7,6 +7,7 @@ "outputs": [], "source": [ "import pandas as pd\n", + "from pandas.testing import assert_frame_equal\n", "from phoenix.db import models\n", "from sqlalchemy import and_, create_engine, select\n", "from sqlalchemy.orm import aliased, sessionmaker" @@ -18,13 +19,13 @@ "metadata": {}, "outputs": [], "source": [ - "# PostgresSession = sessionmaker(\n", - "# create_engine(\n", - "# \"postgresql+psycopg://localhost:5432/postgres?user=postgres&password=mysecretpassword\",\n", - "# echo=True,\n", - "# ),\n", - "# expire_on_commit=False,\n", - "# )\n", + "PostgresSession = sessionmaker(\n", + " create_engine(\n", + " \"postgresql+psycopg://localhost:5432/postgres?user=postgres&password=mysecretpassword\",\n", + " echo=True,\n", + " ),\n", + " expire_on_commit=False,\n", + ")\n", "SqliteSession = sessionmaker(\n", " create_engine(\"sqlite:////Users/xandersong/.phoenix/phoenix.db\", echo=True),\n", " expire_on_commit=False,\n", @@ -56,29 +57,12 @@ "# WHERE B.score > 0.5 AND B.label == 'factual' AND spans.name == 'query';" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with SqliteSession() as session:\n", - " df = pd.read_sql(\n", - " select(models.Span.id)\n", - " .join(models.SpanAnnotation)\n", - " .where(models.SpanAnnotation.score > 0.5)\n", - " .where(models.SpanAnnotation.label == \"factual\"),\n", - " session.connection(),\n", - " )\n", - "df" - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", - "evals[\"Q&A Correctness\"].label == \"correct\" and evals[\"Hallucination\"].label == \"hallucinated\"\n", + "evals[\"Q&A Correctness\"].label == \"correct\"\n", "```" ] }, @@ -88,24 +72,32 @@ "metadata": {}, "outputs": [], "source": [ - "with SqliteSession() as session:\n", - " A = aliased(models.SpanAnnotation)\n", - " B = aliased(models.SpanAnnotation)\n", - " df = pd.read_sql(\n", - " select(models.Span.id, A.name, A.label, B.name, B.label)\n", - " .join(A)\n", - " .join(B)\n", - " .where(\n", - " and_(\n", - " A.name == \"Q&A Correctness\",\n", - " A.label == \"correct\",\n", - " B.name == \"Hallucination\",\n", - " B.label == \"hallucinated\",\n", - " ),\n", + "stmt = (\n", + " select(models.Span.span_id)\n", + " .join(models.SpanAnnotation)\n", + " .where(\n", + " and_(\n", + " models.SpanAnnotation.name == \"Q&A Correctness\",\n", + " models.SpanAnnotation.label == \"correct\",\n", " ),\n", - " session.connection(),\n", " )\n", - "df" + ")\n", + "with SqliteSession() as sqlite_session:\n", + " sqlite_df = pd.read_sql(\n", + " stmt,\n", + " sqlite_session.connection(),\n", + " index_col=\"span_id\",\n", + " )\n", + "with PostgresSession() as postgres_session:\n", + " postgres_df = pd.read_sql(\n", + " stmt,\n", + " postgres_session.connection(),\n", + " index_col=\"span_id\",\n", + " )\n", + "assert_frame_equal(\n", + " sqlite_df.sort_index().sort_index(axis=1),\n", + " postgres_df.sort_index().sort_index(axis=1),\n", + ")" ] }, { @@ -113,7 +105,7 @@ "metadata": {}, "source": [ "```\n", - "evals[\"Q&A Correctness\"].label == \"correct\" and evals[\"Hallucination\"].score == 0\n", + "evals[\"Q&A Correctness\"].label == \"correct\" and evals[\"Hallucination\"].label == \"hallucinated\"\n", "```" ] }, @@ -123,46 +115,37 @@ "metadata": {}, "outputs": [], "source": [ - "with SqliteSession() as session:\n", - " A = aliased(models.SpanAnnotation)\n", - " B = aliased(models.SpanAnnotation)\n", - " df = pd.read_sql(\n", - " select(models.Span.id, A.name, A.label, B.name, B.label)\n", - " .join(A)\n", - " .join(B)\n", - " .where(\n", - " and_(\n", - " A.name == \"Q&A Correctness\",\n", - " A.label == \"correct\",\n", - " B.name == \"Hallucination\",\n", - " B.score == 0,\n", - " ),\n", + "A = aliased(models.SpanAnnotation)\n", + "B = aliased(models.SpanAnnotation)\n", + "stmt = (\n", + " select(models.Span.span_id)\n", + " .join(A)\n", + " .join(B)\n", + " .where(\n", + " and_(\n", + " A.name == \"Q&A Correctness\",\n", + " A.label == \"correct\",\n", + " B.name == \"Hallucination\",\n", + " B.label == \"hallucinated\",\n", " ),\n", - " session.connection(),\n", " )\n", - "df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with SqliteSession() as session:\n", - " A = aliased(models.SpanAnnotation)\n", - " df = pd.read_sql(\n", - " select(models.Span.id, A.name, A.label)\n", - " .join(A)\n", - " .where(\n", - " and_(\n", - " A.name == \"Q&A Correctness\",\n", - " A.label == \"correct\",\n", - " ),\n", - " ),\n", - " session.connection(),\n", + ")\n", + "with SqliteSession() as sqlite_session:\n", + " sqlite_df = pd.read_sql(\n", + " stmt,\n", + " sqlite_session.connection(),\n", + " index_col=\"span_id\",\n", + " )\n", + "with PostgresSession() as postgres_session:\n", + " postgres_df = pd.read_sql(\n", + " stmt,\n", + " postgres_session.connection(),\n", + " index_col=\"span_id\",\n", " )\n", - "df" + "assert_frame_equal(\n", + " sqlite_df.sort_index().sort_index(axis=1),\n", + " postgres_df.sort_index().sort_index(axis=1),\n", + ")" ] }, { From 6025925bfebfda72087e848baf1b617b79f5295f Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 18 Apr 2024 17:11:07 -0700 Subject: [PATCH 12/46] setup it for postgres and sqlite --- integration-tests/eval_query_testing.ipynb | 44 ++++++++++++++++++---- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/integration-tests/eval_query_testing.ipynb b/integration-tests/eval_query_testing.ipynb index 9292682c17..ba8e8328a3 100644 --- a/integration-tests/eval_query_testing.ipynb +++ b/integration-tests/eval_query_testing.ipynb @@ -7,6 +7,7 @@ "outputs": [], "source": [ "import pandas as pd\n", + "import phoenix as px\n", "from pandas.testing import assert_frame_equal\n", "from phoenix.db import models\n", "from sqlalchemy import and_, create_engine, select\n", @@ -38,9 +39,9 @@ "metadata": {}, "outputs": [], "source": [ - "orig_endpoint = \"http://127.0.0.1:6008\"\n", - "postgres_endpoint = \"http://127.0.0.1:6007\"\n", - "sqlite_endpoint = \"http://127.0.0.1:6006\"" + "original_endpoint = \"http://127.0.0.1:6008\"\n", + "sqlite_endpoint = \"http://127.0.0.1:6006\"\n", + "postgres_endpoint = \"http://127.0.0.1:6007\"" ] }, { @@ -58,12 +59,41 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "```\n", - "evals[\"Q&A Correctness\"].label == \"correct\"\n", - "```" + "def get_spans_dataframe(endpoint: str, filter_condition: str):\n", + " return (\n", + " px.Client(endpoint=endpoint)\n", + " .get_spans_dataframe(filter_condition, root_spans_only=True)\n", + " .sort_index() # type: ignore\n", + " .sort_index(axis=1)\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "filter_condition = \"evals['Q&A Correctness'].label == 'correct'\"\n", + "original_df = get_spans_dataframe(endpoint=original_endpoint, filter_condition=filter_condition)\n", + "postgres_df = get_spans_dataframe(endpoint=postgres_endpoint, filter_condition=filter_condition)\n", + "sqlite_df = get_spans_dataframe(endpoint=sqlite_endpoint, filter_condition=filter_condition)\n", + "print(f\"{original_df.shape=}\")\n", + "print(f\"{sqlite_df.shape=}\")\n", + "print(f\"{postgres_df.shape=}\")\n", + "assert_frame_equal(\n", + " original_df,\n", + " sqlite_df,\n", + ")\n", + "assert_frame_equal(\n", + " original_df,\n", + " postgres_df,\n", + ")" ] }, { From 7ce81143189ca000c3804e5df5195c6a1672b368 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Fri, 19 Apr 2024 07:33:53 -0700 Subject: [PATCH 13/46] stored parsed string in array --- src/phoenix/trace/dsl/filter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 5c78ad60dc..d96e9ac806 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -365,7 +365,8 @@ def visit_Attribute(self, node: ast.Attribute) -> typing.Any: if replacement := _BACKWARD_COMPATIBILITY_REPLACEMENTS.get(source_segment): return ast.Name(id=replacement, ctx=ast.Load()) attr = "as_float" if source_segment in _FLOAT_ATTRIBUTES else "as_string" - elts = [ast.Constant(value=part, kind=None) for part in _split(source_segment)] + source_segment_parts = _split(source_segment) + elts = [ast.Constant(value=part, kind=None) for part in source_segment_parts] return ast.Call( func=ast.Attribute( value=ast.Subscript( From 6ddd8e2f494a0641f14b94532567fa7fc3c5bd07 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Fri, 19 Apr 2024 11:08:56 -0700 Subject: [PATCH 14/46] Revert "stored parsed string in array" This reverts commit 7ce81143189ca000c3804e5df5195c6a1672b368. --- src/phoenix/trace/dsl/filter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index d96e9ac806..5c78ad60dc 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -365,8 +365,7 @@ def visit_Attribute(self, node: ast.Attribute) -> typing.Any: if replacement := _BACKWARD_COMPATIBILITY_REPLACEMENTS.get(source_segment): return ast.Name(id=replacement, ctx=ast.Load()) attr = "as_float" if source_segment in _FLOAT_ATTRIBUTES else "as_string" - source_segment_parts = _split(source_segment) - elts = [ast.Constant(value=part, kind=None) for part in source_segment_parts] + elts = [ast.Constant(value=part, kind=None) for part in _split(source_segment)] return ast.Call( func=ast.Attribute( value=ast.Subscript( From 29f65a556392a9d19b72c18873d0a0a0016036a0 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 20 Apr 2024 17:47:26 -0700 Subject: [PATCH 15/46] sqlite --- integration-tests/eval_query_testing.ipynb | 186 ++++++++++++++++++--- src/phoenix/trace/dsl/filter.py | 107 ++++++++++-- 2 files changed, 255 insertions(+), 38 deletions(-) diff --git a/integration-tests/eval_query_testing.ipynb b/integration-tests/eval_query_testing.ipynb index ba8e8328a3..15127cd730 100644 --- a/integration-tests/eval_query_testing.ipynb +++ b/integration-tests/eval_query_testing.ipynb @@ -67,9 +67,9 @@ "def get_spans_dataframe(endpoint: str, filter_condition: str):\n", " return (\n", " px.Client(endpoint=endpoint)\n", - " .get_spans_dataframe(filter_condition, root_spans_only=True)\n", - " .sort_index() # type: ignore\n", - " .sort_index(axis=1)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", + " .reindex(sorted(sqlite_df.columns), axis=1)\n", " )" ] }, @@ -82,18 +82,115 @@ "filter_condition = \"evals['Q&A Correctness'].label == 'correct'\"\n", "original_df = get_spans_dataframe(endpoint=original_endpoint, filter_condition=filter_condition)\n", "postgres_df = get_spans_dataframe(endpoint=postgres_endpoint, filter_condition=filter_condition)\n", - "sqlite_df = get_spans_dataframe(endpoint=sqlite_endpoint, filter_condition=filter_condition)\n", - "print(f\"{original_df.shape=}\")\n", - "print(f\"{sqlite_df.shape=}\")\n", - "print(f\"{postgres_df.shape=}\")\n", - "assert_frame_equal(\n", - " original_df,\n", - " sqlite_df,\n", - ")\n", - "assert_frame_equal(\n", - " original_df,\n", - " postgres_df,\n", - ")" + "sqlite_df = get_spans_dataframe(endpoint=sqlite_endpoint, filter_condition=filter_condition)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "original_df.compare(sqlite_df, result_names=(\"original\", \"sqlite\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "postgres_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "original_df.compare(postgres_df, result_names=(\"original\", \"postgres\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sqlite_df[[\"attributes.openinference.span.kind\", \"events\", \"parent_id\"]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "original_df[[\"attributes.openinference.span.kind\", \"events\", \"parent_id\"]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "original_df[\"events\"].map(len).value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "type(sqlite_df[\"events\"].iloc[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "original_df.equals(sqlite_df)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sqlite_df[\"parent_id\"].isna().sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "original_df[\"attributes.openinference.span.kind\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "original_df.columns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sqlite_df.columns" ] }, { @@ -145,20 +242,14 @@ "metadata": {}, "outputs": [], "source": [ - "A = aliased(models.SpanAnnotation)\n", + "A = aliased(models.SpanAnnotation, name=\"first-table\")\n", "B = aliased(models.SpanAnnotation)\n", "stmt = (\n", " select(models.Span.span_id)\n", - " .join(A)\n", - " .join(B)\n", - " .where(\n", - " and_(\n", - " A.name == \"Q&A Correctness\",\n", - " A.label == \"correct\",\n", - " B.name == \"Hallucination\",\n", - " B.label == \"hallucinated\",\n", - " ),\n", - " )\n", + " .join(A, onclause=A.name == \"Q&A Correctness\")\n", + " .join(B, onclause=B.name == \"Hallucination\")\n", + " .where(A.label == \"correct\")\n", + " .where(B.label == \"hallucinated\")\n", ")\n", "with SqliteSession() as sqlite_session:\n", " sqlite_df = pd.read_sql(\n", @@ -178,6 +269,49 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sqlalchemy import inspect\n", + "from sqlalchemy.orm import aliased\n", + "\n", + "insp = inspect(A)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sqlalchemy.orm.util import AliasedClass\n", + "\n", + "isinstance(A, AliasedClass)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sqlalchemy.sql.roles import JoinTargetRole\n", + "\n", + "isinstance(A, JoinTargetRole)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "insp.name" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 3c704ca2b1..548fcec515 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -1,4 +1,5 @@ import ast +import re import sys import typing from dataclasses import dataclass, field @@ -6,8 +7,10 @@ from types import MappingProxyType import sqlalchemy +from sqlalchemy.orm import aliased +from sqlalchemy.orm.util import AliasedClass from sqlalchemy.sql.expression import Select -from typing_extensions import TypeGuard, assert_never +from typing_extensions import TypeAlias, TypeGuard, assert_never import phoenix.trace.v1 as pb from phoenix.db import models @@ -18,6 +21,10 @@ ) +AnnotationAlias: TypeAlias = str +EvalExpression: TypeAlias = str +EvalName: TypeAlias = str + # Because postgresql is strongly typed, we cast JSON values to string # by default unless it's hinted otherwise as done here. _FLOAT_ATTRIBUTES: typing.FrozenSet[str] = frozenset( @@ -83,6 +90,12 @@ class SpanFilter: valid_eval_names: typing.Optional[typing.Sequence[str]] = None translated: ast.Expression = field(init=False, repr=False) compiled: typing.Any = field(init=False, repr=False) + aliased_annotations: typing.Tuple[AliasedClass[models.SpanAnnotation]] = field( + init=False, repr=False + ) + join_aliased_tables: typing.Callable[[Select[typing.Any]], Select[typing.Any]] = field( + init=False, repr=False + ) def __bool__(self) -> bool: return bool(self.condition) @@ -92,20 +105,30 @@ def __post_init__(self) -> None: return root = ast.parse(source, mode="eval") _validate_expression(root, source, valid_eval_names=self.valid_eval_names) - translated = _FilterTranslator(source).visit(root) + source, aliased_annotations, join_aliased_tables = _apply_aliases(source) + root = ast.parse(source, mode="eval") + translated = _FilterTranslator( + source=source, + annotation_aliases=[_get_alias(annotation) for annotation in aliased_annotations], + ).visit(root) ast.fix_missing_locations(translated) compiled = compile(translated, filename="", mode="eval") - object.__setattr__(self, "translated", translated) object.__setattr__(self, "compiled", compiled) + object.__setattr__(self, "aliased_annotations", aliased_annotations) + object.__setattr__(self, "join_aliased_tables", join_aliased_tables) def __call__(self, select: Select[typing.Any]) -> Select[typing.Any]: if not self.condition: return select - return select.where( + return self.join_aliased_tables(select).where( eval( self.compiled, { **_NAMES, + **{ + _get_alias(annotation): annotation + for annotation in self.aliased_annotations + }, "not_": sqlalchemy.not_, "and_": sqlalchemy.and_, "or_": sqlalchemy.or_, @@ -298,14 +321,6 @@ def visit_generic(self, node: ast.AST) -> typing.Any: def visit_Expression(self, node: ast.Expression) -> typing.Any: return ast.Expression(body=self.visit(node.body)) - def visit_Attribute(self, node: ast.Attribute) -> typing.Any: - source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) - if replacement := _BACKWARD_COMPATIBILITY_REPLACEMENTS.get(source_segment): - return ast.Name(id=replacement, ctx=ast.Load()) - if (keys := _get_attribute_keys_list(node)) is not None: - return _as_attribute(keys) - raise SyntaxError(f"invalid expression: {source_segment}") - def visit_Name(self, node: ast.Name) -> typing.Any: source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) if source_segment in _STRING_NAMES or source_segment in _FLOAT_NAMES: @@ -321,6 +336,20 @@ def visit_Subscript(self, node: ast.Subscript) -> typing.Any: # TODO(persistence): support `evals['name'].score` et. al. class _FilterTranslator(_ProjectionTranslator): + def __init__(self, source: str, annotation_aliases: typing.Sequence[str]) -> None: + super().__init__(source) + self._annotation_aliases = annotation_aliases + + def visit_Attribute(self, node: ast.Attribute) -> typing.Any: + source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) + if replacement := _BACKWARD_COMPATIBILITY_REPLACEMENTS.get(source_segment): + return ast.Name(id=replacement, ctx=ast.Load()) + if (keys := _get_attribute_keys_list(node)) is None: + raise SyntaxError(f"invalid expression: {source_segment}") + if keys and keys[0].value in self._annotation_aliases: + return node + return _as_attribute(keys) + def visit_Compare(self, node: ast.Compare) -> typing.Any: if len(node.comparators) > 1: args: typing.List[typing.Any] = [] @@ -687,3 +716,57 @@ def _find_best_match( if score > best_score: best_choice, best_score = choice, score return best_choice, best_score + + +def _apply_aliases( + source: str, +) -> typing.Tuple[ + str, + typing.Tuple[AliasedClass[models.SpanAnnotation], ...], + typing.Callable[[Select[typing.Any]], Select[typing.Any]], +]: + aliased_annotations: typing.Dict[EvalName, AliasedClass[models.SpanAnnotation]] = {} + for eval_expression, eval_name in _parse_eval_expressions_and_names(source): + if (aliased_annotation := aliased_annotations.get(eval_name)) is None: + aliased_annotation = typing.cast( + AliasedClass[models.SpanAnnotation], + aliased(models.SpanAnnotation, name=f"span_annotation_{len(aliased_annotations)}"), + ) + aliased_annotations[eval_name] = aliased_annotation + alias = _get_alias(aliased_annotation) + source = source.replace(eval_expression, alias) + + def join_aliased_tables(stmt: Select[typing.Any]) -> Select[typing.Any]: + for eval_name, AliasedSpanAnnotation in MappingProxyType(aliased_annotations).items(): + stmt = stmt.join( + AliasedSpanAnnotation, + onclause=( + sqlalchemy.and_( + AliasedSpanAnnotation.span_rowid == models.Span.id, + AliasedSpanAnnotation.name == eval_name, + ) + ), + ) + return stmt + + return source, tuple(aliased_annotations.values()), join_aliased_tables + + +def _parse_eval_expressions_and_names( + source: str, +) -> typing.Iterator[typing.Tuple[EvalExpression, EvalName]]: + for match in re.finditer(r"""(evals\[("(.*)"|'(.*)')\])""", source): + ( + eval_expression, + _, + double_quoted_eval_name, + single_quoted_eval_name, + ) = match.groups() + yield ( + eval_expression, + double_quoted_eval_name or single_quoted_eval_name or "", + ) + + +def _get_alias(aliased_annotation: AliasedClass[models.SpanAnnotation]) -> AnnotationAlias: + return str(sqlalchemy.inspect(aliased_annotation).name) From b8f5a169ffe2b2acc5d3f2f69cedf311ff4ba309 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 20 Apr 2024 18:29:40 -0700 Subject: [PATCH 16/46] add integration test case and correct error --- integration-tests/eval_query_testing.ipynb | 283 +++------------------ 1 file changed, 40 insertions(+), 243 deletions(-) diff --git a/integration-tests/eval_query_testing.ipynb b/integration-tests/eval_query_testing.ipynb index 15127cd730..8daa87b29a 100644 --- a/integration-tests/eval_query_testing.ipynb +++ b/integration-tests/eval_query_testing.ipynb @@ -6,12 +6,9 @@ "metadata": {}, "outputs": [], "source": [ - "import pandas as pd\n", "import phoenix as px\n", - "from pandas.testing import assert_frame_equal\n", - "from phoenix.db import models\n", - "from sqlalchemy import and_, create_engine, select\n", - "from sqlalchemy.orm import aliased, sessionmaker" + "from sqlalchemy import create_engine\n", + "from sqlalchemy.orm import sessionmaker" ] }, { @@ -50,27 +47,26 @@ "metadata": {}, "outputs": [], "source": [ - "# SELECT * FROM spans\n", - "# JOIN (\n", - "# SELECT spans.id, sa.score, sa.label FROM spans\n", - "# JOIN span_annotations sa on spans.id = sa.span_rowid\n", - "# ) B ON spans.id == B.id\n", - "# WHERE B.score > 0.5 AND B.label == 'factual' AND spans.name == 'query';" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "COMMON_COLUMNS = [\n", + " \"attributes.__computed__\",\n", + " \"attributes.input.value\",\n", + " \"attributes.openinference.span.kind\",\n", + " \"attributes.output.value\",\n", + " \"context.span_id\",\n", + " \"context.trace_id\",\n", + " \"end_time\",\n", + " \"events\",\n", + " \"name\",\n", + " \"parent_id\",\n", + " \"start_time\",\n", + " \"status_code\",\n", + " \"status_message\",\n", + "]\n", + "\n", + "\n", "def get_spans_dataframe(endpoint: str, filter_condition: str):\n", - " return (\n", - " px.Client(endpoint=endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", - " .reindex(sorted(sqlite_df.columns), axis=1)\n", - " )" + " df = px.Client(endpoint=endpoint).get_spans_dataframe(filter_condition)\n", + " return df.sort_index()" ] }, { @@ -82,70 +78,10 @@ "filter_condition = \"evals['Q&A Correctness'].label == 'correct'\"\n", "original_df = get_spans_dataframe(endpoint=original_endpoint, filter_condition=filter_condition)\n", "postgres_df = get_spans_dataframe(endpoint=postgres_endpoint, filter_condition=filter_condition)\n", - "sqlite_df = get_spans_dataframe(endpoint=sqlite_endpoint, filter_condition=filter_condition)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "original_df.compare(sqlite_df, result_names=(\"original\", \"sqlite\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "postgres_df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "original_df.compare(postgres_df, result_names=(\"original\", \"postgres\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sqlite_df[[\"attributes.openinference.span.kind\", \"events\", \"parent_id\"]]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "original_df[[\"attributes.openinference.span.kind\", \"events\", \"parent_id\"]]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "original_df[\"events\"].map(len).value_counts()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "type(sqlite_df[\"events\"].iloc[0])" + "sqlite_df = get_spans_dataframe(endpoint=sqlite_endpoint, filter_condition=filter_condition)\n", + "print(f\"{original_df.shape=}\")\n", + "print(f\"{postgres_df.shape=}\")\n", + "print(f\"{sqlite_df.shape=}\")" ] }, { @@ -154,162 +90,27 @@ "metadata": {}, "outputs": [], "source": [ - "original_df.equals(sqlite_df)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sqlite_df[\"parent_id\"].isna().sum()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "original_df[\"attributes.openinference.span.kind\"]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "original_df.columns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sqlite_df.columns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "stmt = (\n", - " select(models.Span.span_id)\n", - " .join(models.SpanAnnotation)\n", - " .where(\n", - " and_(\n", - " models.SpanAnnotation.name == \"Q&A Correctness\",\n", - " models.SpanAnnotation.label == \"correct\",\n", - " ),\n", - " )\n", - ")\n", - "with SqliteSession() as sqlite_session:\n", - " sqlite_df = pd.read_sql(\n", - " stmt,\n", - " sqlite_session.connection(),\n", - " index_col=\"span_id\",\n", - " )\n", - "with PostgresSession() as postgres_session:\n", - " postgres_df = pd.read_sql(\n", - " stmt,\n", - " postgres_session.connection(),\n", - " index_col=\"span_id\",\n", - " )\n", - "assert_frame_equal(\n", - " sqlite_df.sort_index().sort_index(axis=1),\n", - " postgres_df.sort_index().sort_index(axis=1),\n", + "sqlite_df[COMMON_COLUMNS].compare(\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " result_names=(\"sqlite\", \"original\"),\n", ")" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```\n", - "evals[\"Q&A Correctness\"].label == \"correct\" and evals[\"Hallucination\"].label == \"hallucinated\"\n", - "```" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "A = aliased(models.SpanAnnotation, name=\"first-table\")\n", - "B = aliased(models.SpanAnnotation)\n", - "stmt = (\n", - " select(models.Span.span_id)\n", - " .join(A, onclause=A.name == \"Q&A Correctness\")\n", - " .join(B, onclause=B.name == \"Hallucination\")\n", - " .where(A.label == \"correct\")\n", - " .where(B.label == \"hallucinated\")\n", + "filter_condition = (\n", + " \"\"\"evals['Q&A Correctness'].label == 'correct' and evals[\"Hallucination\"].score < 0.5\"\"\"\n", ")\n", - "with SqliteSession() as sqlite_session:\n", - " sqlite_df = pd.read_sql(\n", - " stmt,\n", - " sqlite_session.connection(),\n", - " index_col=\"span_id\",\n", - " )\n", - "with PostgresSession() as postgres_session:\n", - " postgres_df = pd.read_sql(\n", - " stmt,\n", - " postgres_session.connection(),\n", - " index_col=\"span_id\",\n", - " )\n", - "assert_frame_equal(\n", - " sqlite_df.sort_index().sort_index(axis=1),\n", - " postgres_df.sort_index().sort_index(axis=1),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sqlalchemy import inspect\n", - "from sqlalchemy.orm import aliased\n", - "\n", - "insp = inspect(A)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sqlalchemy.orm.util import AliasedClass\n", - "\n", - "isinstance(A, AliasedClass)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sqlalchemy.sql.roles import JoinTargetRole\n", - "\n", - "isinstance(A, JoinTargetRole)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "insp.name" + "original_df = get_spans_dataframe(endpoint=original_endpoint, filter_condition=filter_condition)\n", + "postgres_df = get_spans_dataframe(endpoint=postgres_endpoint, filter_condition=filter_condition)\n", + "sqlite_df = get_spans_dataframe(endpoint=sqlite_endpoint, filter_condition=filter_condition)\n", + "print(f\"{original_df.shape=}\")\n", + "print(f\"{postgres_df.shape=}\")\n", + "print(f\"{sqlite_df.shape=}\")" ] }, { @@ -318,14 +119,10 @@ "metadata": {}, "outputs": [], "source": [ - "# SELECT span_annotations.span_rowid,\n", - "# MAX(CASE WHEN name = 'Hallucination' and score = 0 THEN 1 ELSE 0 END) AS A,\n", - "# MAX(CASE WHEN name = 'Q&A Correctness' and label = 'correct' THEN 1 ELSE 0 END) AS B\n", - "# FROM span_annotations\n", - "# WHERE name in ('Hallucination', 'Q&A Correctness')\n", - "# GROUP BY span_annotations.span_rowid\n", - "# HAVING A = 1 and B = 1\n", - "# ORDER BY span_rowid;" + "sqlite_df[COMMON_COLUMNS].compare(\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " result_names=(\"sqlite\", \"original\"),\n", + ")" ] } ], From 931b6a1e10fccdd826a225fcc271650beca6ab81 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 20 Apr 2024 18:48:38 -0700 Subject: [PATCH 17/46] postgres running on it --- integration-tests/eval_query_testing.ipynb | 51 +++++++++++++++++----- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/integration-tests/eval_query_testing.ipynb b/integration-tests/eval_query_testing.ipynb index 8daa87b29a..8d748ccb61 100644 --- a/integration-tests/eval_query_testing.ipynb +++ b/integration-tests/eval_query_testing.ipynb @@ -61,12 +61,7 @@ " \"start_time\",\n", " \"status_code\",\n", " \"status_message\",\n", - "]\n", - "\n", - "\n", - "def get_spans_dataframe(endpoint: str, filter_condition: str):\n", - " df = px.Client(endpoint=endpoint).get_spans_dataframe(filter_condition)\n", - " return df.sort_index()" + "]" ] }, { @@ -76,9 +71,9 @@ "outputs": [], "source": [ "filter_condition = \"evals['Q&A Correctness'].label == 'correct'\"\n", - "original_df = get_spans_dataframe(endpoint=original_endpoint, filter_condition=filter_condition)\n", - "postgres_df = get_spans_dataframe(endpoint=postgres_endpoint, filter_condition=filter_condition)\n", - "sqlite_df = get_spans_dataframe(endpoint=sqlite_endpoint, filter_condition=filter_condition)\n", + "original_df = px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition)\n", + "postgres_df = px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition)\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition)\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -90,12 +85,28 @@ "metadata": {}, "outputs": [], "source": [ + "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", + "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", "sqlite_df[COMMON_COLUMNS].compare(\n", " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", " result_names=(\"sqlite\", \"original\"),\n", ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", + "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", + "postgres_df[COMMON_COLUMNS].compare(\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " result_names=(\"postgres\", \"original\"),\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -105,9 +116,9 @@ "filter_condition = (\n", " \"\"\"evals['Q&A Correctness'].label == 'correct' and evals[\"Hallucination\"].score < 0.5\"\"\"\n", ")\n", - "original_df = get_spans_dataframe(endpoint=original_endpoint, filter_condition=filter_condition)\n", - "postgres_df = get_spans_dataframe(endpoint=postgres_endpoint, filter_condition=filter_condition)\n", - "sqlite_df = get_spans_dataframe(endpoint=sqlite_endpoint, filter_condition=filter_condition)\n", + "original_df = px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition)\n", + "postgres_df = px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition)\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition)\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -119,11 +130,27 @@ "metadata": {}, "outputs": [], "source": [ + "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", + "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", "sqlite_df[COMMON_COLUMNS].compare(\n", " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", " result_names=(\"sqlite\", \"original\"),\n", ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", + "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", + "postgres_df[COMMON_COLUMNS].compare(\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " result_names=(\"postgres\", \"original\"),\n", + ")" + ] } ], "metadata": { From 94e192932b327b5a00ffc4b628e8dc1f9b2a00bc Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sat, 20 Apr 2024 18:52:55 -0700 Subject: [PATCH 18/46] add integration test for != --- integration-tests/eval_query_testing.ipynb | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/integration-tests/eval_query_testing.ipynb b/integration-tests/eval_query_testing.ipynb index 8d748ccb61..4dbab991a5 100644 --- a/integration-tests/eval_query_testing.ipynb +++ b/integration-tests/eval_query_testing.ipynb @@ -138,6 +138,49 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", + "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", + "postgres_df[COMMON_COLUMNS].compare(\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " result_names=(\"postgres\", \"original\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "filter_condition = \"\"\"evals['Q&A Correctness'].label != 'correct'\"\"\"\n", + "original_df = px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition)\n", + "postgres_df = px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition)\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition)\n", + "print(f\"{original_df.shape=}\")\n", + "print(f\"{postgres_df.shape=}\")\n", + "print(f\"{sqlite_df.shape=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", + "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", + "sqlite_df[COMMON_COLUMNS].compare(\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " result_names=(\"sqlite\", \"original\"),\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, From c1ed5435bca7f0da23a68fa729eb4c992566542d Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 21 Apr 2024 12:23:25 -0700 Subject: [PATCH 19/46] more it --- integration-tests/eval_query_testing.ipynb | 86 ++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/integration-tests/eval_query_testing.ipynb b/integration-tests/eval_query_testing.ipynb index 4dbab991a5..d2f0f0a9a1 100644 --- a/integration-tests/eval_query_testing.ipynb +++ b/integration-tests/eval_query_testing.ipynb @@ -181,6 +181,92 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", + "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", + "postgres_df[COMMON_COLUMNS].compare(\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " result_names=(\"postgres\", \"original\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "filter_condition = \"\"\"evals['Q&A Correctness'].label is not None\"\"\"\n", + "original_df = px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition)\n", + "postgres_df = px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition)\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition)\n", + "print(f\"{original_df.shape=}\")\n", + "print(f\"{postgres_df.shape=}\")\n", + "print(f\"{sqlite_df.shape=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", + "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", + "sqlite_df[COMMON_COLUMNS].compare(\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " result_names=(\"sqlite\", \"original\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", + "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", + "postgres_df[COMMON_COLUMNS].compare(\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " result_names=(\"postgres\", \"original\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "filter_condition = \"\"\"evals['Q&A Correctness'].score < evals[\"Hallucination\"].score\"\"\"\n", + "original_df = px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition)\n", + "postgres_df = px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition)\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition)\n", + "print(f\"{original_df.shape=}\")\n", + "print(f\"{postgres_df.shape=}\")\n", + "print(f\"{sqlite_df.shape=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", + "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", + "sqlite_df[COMMON_COLUMNS].compare(\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " result_names=(\"sqlite\", \"original\"),\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, From 017adb897a8319e9daa6ec44b03aa9164ebadb6a Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 22 Apr 2024 11:43:11 -0700 Subject: [PATCH 20/46] add sort_index to ensure the rows are similarly ordered for comparison in it --- integration-tests/eval_query_testing.ipynb | 50 +++++++++++++++------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/integration-tests/eval_query_testing.ipynb b/integration-tests/eval_query_testing.ipynb index d2f0f0a9a1..c76ac7c903 100644 --- a/integration-tests/eval_query_testing.ipynb +++ b/integration-tests/eval_query_testing.ipynb @@ -71,9 +71,13 @@ "outputs": [], "source": [ "filter_condition = \"evals['Q&A Correctness'].label == 'correct'\"\n", - "original_df = px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition)\n", - "postgres_df = px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition)\n", - "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition)\n", + "original_df = (\n", + " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + ")\n", + "postgres_df = (\n", + " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + ")\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -116,9 +120,13 @@ "filter_condition = (\n", " \"\"\"evals['Q&A Correctness'].label == 'correct' and evals[\"Hallucination\"].score < 0.5\"\"\"\n", ")\n", - "original_df = px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition)\n", - "postgres_df = px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition)\n", - "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition)\n", + "original_df = (\n", + " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + ")\n", + "postgres_df = (\n", + " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + ")\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -159,9 +167,13 @@ "outputs": [], "source": [ "filter_condition = \"\"\"evals['Q&A Correctness'].label != 'correct'\"\"\"\n", - "original_df = px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition)\n", - "postgres_df = px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition)\n", - "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition)\n", + "original_df = (\n", + " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + ")\n", + "postgres_df = (\n", + " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + ")\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -202,9 +214,13 @@ "outputs": [], "source": [ "filter_condition = \"\"\"evals['Q&A Correctness'].label is not None\"\"\"\n", - "original_df = px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition)\n", - "postgres_df = px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition)\n", - "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition)\n", + "original_df = (\n", + " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + ")\n", + "postgres_df = (\n", + " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + ")\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -245,9 +261,13 @@ "outputs": [], "source": [ "filter_condition = \"\"\"evals['Q&A Correctness'].score < evals[\"Hallucination\"].score\"\"\"\n", - "original_df = px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition)\n", - "postgres_df = px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition)\n", - "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition)\n", + "original_df = (\n", + " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + ")\n", + "postgres_df = (\n", + " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + ")\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" From 7e1483b9d0d793fb6fbf03df714a416ba0f0c75e Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 22 Apr 2024 12:54:39 -0700 Subject: [PATCH 21/46] add back visit_Attribute method --- src/phoenix/trace/dsl/filter.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 548fcec515..f386fc30a2 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -113,6 +113,7 @@ def __post_init__(self) -> None: ).visit(root) ast.fix_missing_locations(translated) compiled = compile(translated, filename="", mode="eval") + object.__setattr__(self, "translated", translated) object.__setattr__(self, "compiled", compiled) object.__setattr__(self, "aliased_annotations", aliased_annotations) object.__setattr__(self, "join_aliased_tables", join_aliased_tables) @@ -315,6 +316,14 @@ def __init__(self, source: str) -> None: # In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`). self._source = source + def visit_Attribute(self, node: ast.Attribute) -> typing.Any: + source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) + if replacement := _BACKWARD_COMPATIBILITY_REPLACEMENTS.get(source_segment): + return ast.Name(id=replacement, ctx=ast.Load()) + if (keys := _get_attribute_keys_list(node)) is None: + raise SyntaxError(f"invalid expression: {source_segment}") + return _as_attribute(keys) + def visit_generic(self, node: ast.AST) -> typing.Any: raise SyntaxError(f"invalid expression: {ast.get_source_segment(self._source, node)}") From 205bd3cfc053680eee624c83be5720e105f3a815 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 22 Apr 2024 13:45:49 -0700 Subject: [PATCH 22/46] fix unit test --- tests/trace/test_trace_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trace/test_trace_dataset.py b/tests/trace/test_trace_dataset.py index 335b85be65..e9940d8f78 100644 --- a/tests/trace/test_trace_dataset.py +++ b/tests/trace/test_trace_dataset.py @@ -303,7 +303,7 @@ def test_trace_dataset_load_logs_warning_when_an_evaluation_cannot_be_loaded(tmp with pytest.warns(UserWarning) as record: read_ds = TraceDataset.load(dataset_id, tmp_path) - assert len(record) == 1 + assert len(record) > 0 assert str(record[0].message).startswith("Failed to load"), "unexpected warning message" read_ds = TraceDataset.load(dataset_id, tmp_path) From cb5a6b4a03f27b2214560d456cdc32f6d00fbff0 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 22 Apr 2024 14:23:22 -0700 Subject: [PATCH 23/46] remove accidentally added files --- src/phoenix/utilities/attributes.py | 278 ---------------------------- tests/utilities/test_attributes.py | 90 --------- 2 files changed, 368 deletions(-) delete mode 100644 src/phoenix/utilities/attributes.py delete mode 100644 tests/utilities/test_attributes.py diff --git a/src/phoenix/utilities/attributes.py b/src/phoenix/utilities/attributes.py deleted file mode 100644 index 5413726c78..0000000000 --- a/src/phoenix/utilities/attributes.py +++ /dev/null @@ -1,278 +0,0 @@ -import inspect -import json -from typing import ( - Any, - DefaultDict, - Dict, - Iterable, - Iterator, - List, - Mapping, - Optional, - Sequence, - Set, - Tuple, - Union, - cast, -) - -from openinference.semconv import trace -from openinference.semconv.trace import DocumentAttributes, SpanAttributes -from typing_extensions import assert_never - -DOCUMENT_METADATA = DocumentAttributes.DOCUMENT_METADATA -LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES -METADATA = SpanAttributes.METADATA -TOOL_PARAMETERS = SpanAttributes.TOOL_PARAMETERS - -# attributes interpreted as JSON strings during ingestion -JSON_STRING_ATTRIBUTES = ( - DOCUMENT_METADATA, - LLM_PROMPT_TEMPLATE_VARIABLES, - METADATA, - TOOL_PARAMETERS, -) - -SEMANTIC_CONVENTIONS: List[str] = sorted( - # e.g. "input.value", "llm.token_count.total", etc. - ( - cast(str, getattr(klass, attr)) - for name in dir(trace) - if name.endswith("Attributes") and inspect.isclass(klass := getattr(trace, name)) - for attr in dir(klass) - if attr.isupper() - ), - key=len, - reverse=True, -) # sorted so the longer strings go first - - -def unflatten( - key_value_pairs: Iterable[Tuple[str, Any]], - *, - prefix_exclusions: Sequence[str] = (), - separator: str = ".", -) -> Dict[str, Any]: - # `prefix_exclusions` is intended to contain the semantic conventions - trie = _build_trie(key_value_pairs, separator=separator, prefix_exclusions=prefix_exclusions) - return dict(_walk(trie, separator=separator)) - - -def flatten( - obj: Union[Mapping[str, Any], Iterable[Any]], - *, - prefix: str = "", - separator: str = ".", - recurse_on_sequence: bool = False, - json_string_attributes: Optional[Sequence[str]] = None, -) -> Iterator[Tuple[str, Any]]: - if isinstance(obj, Mapping): - yield from _flatten_mapping( - obj, - prefix=prefix, - recurse_on_sequence=recurse_on_sequence, - json_string_attributes=json_string_attributes, - separator=separator, - ) - elif isinstance(obj, Iterable): - yield from _flatten_sequence( - obj, - prefix=prefix, - recurse_on_sequence=recurse_on_sequence, - json_string_attributes=json_string_attributes, - separator=separator, - ) - else: - assert_never(obj) - - -def has_mapping(sequence: Iterable[Any]) -> bool: - for item in sequence: - if isinstance(item, Mapping): - return True - return False - - -def get_attribute_value( - attributes: Optional[Mapping[str, Any]], - key: str, - separator: str = ".", -) -> Optional[Any]: - if not attributes: - return None - sub_keys = key.split(separator) - for sub_key in sub_keys[:-1]: - attributes = attributes.get(sub_key) - if not attributes: - return None - return attributes.get(sub_keys[-1]) - - -def load_json_strings(key_values: Iterable[Tuple[str, Any]]) -> Iterator[Tuple[str, Any]]: - for key, value in key_values: - if key.endswith(JSON_STRING_ATTRIBUTES): - try: - dict_value = json.loads(value) - except Exception: - yield key, value - else: - if dict_value: - yield key, dict_value - else: - yield key, value - - -def _partition_with_prefix_exclusion( - key: str, - separator: str = ".", - prefix_exclusions: Sequence[str] = (), -) -> Tuple[str, str, str]: - # prefix_exclusions should be sorted by length from the longest to the shortest - for prefix in prefix_exclusions: - if key.startswith(prefix) and ( - len(key) == len(prefix) or key[len(prefix) :].startswith(separator) - ): - return prefix, separator, key[len(prefix) + len(separator) :] - return key.partition(separator) - - -class _Trie(DefaultDict[Union[str, int], "_Trie"]): - """Prefix Tree with special handling for indices (i.e. all-digit keys).""" - - def __init__(self) -> None: - super().__init__(_Trie) - self.value: Any = None - self.indices: Set[int] = set() - self.branches: Set[Union[str, int]] = set() - - def set_value(self, value: Any) -> None: - self.value = value - # value and indices must not coexist - self.branches.update(self.indices) - self.indices.clear() - - def add_index(self, index: int) -> "_Trie": - if self.value is not None: - self.branches.add(index) - elif index not in self.branches: - self.indices.add(index) - return self[index] - - def add_branch(self, branch: Union[str, int]) -> "_Trie": - if branch in self.indices: - self.indices.discard(cast(int, branch)) - self.branches.add(branch) - return self[branch] - - -def _build_trie( - key_value_pairs: Iterable[Tuple[str, Any]], - *, - prefix_exclusions: Sequence[str] = (), - separator: str = ".", -) -> _Trie: - """Build a Trie (a.k.a. prefix tree) from `key_value_pairs`, by partitioning the keys by - separator. Each partition is a branch in the Trie. Special handling is done for partitions - that are all digits, e.g. "0", "12", etc., which are converted to integers and collected - as indices. - """ - trie = _Trie() - for key, value in key_value_pairs: - if value is None: - continue - t = trie - while True: - prefix, _, suffix = _partition_with_prefix_exclusion( - key, - separator, - prefix_exclusions, - ) - if prefix.isdigit(): - index = int(prefix) - t = t.add_index(index) if suffix else t.add_branch(index) - else: - t = t.add_branch(prefix) - if not suffix: - break - key = suffix - t.set_value(value) - return trie - - -def _walk( - trie: _Trie, - *, - prefix: str = "", - separator: str = ".", -) -> Iterator[Tuple[str, Any]]: - if trie.value is not None: - yield prefix, trie.value - elif prefix and trie.indices: - yield ( - prefix, - [dict(_walk(trie[index], separator=separator)) for index in sorted(trie.indices)], - ) - elif trie.indices: - for index in trie.indices: - yield from _walk(trie[index], prefix=f"{index}", separator=separator) - elif prefix: - yield prefix, dict(_walk(trie, separator=separator)) - return - for branch in trie.branches: - new_prefix = f"{prefix}{separator}{branch}" if prefix else f"{branch}" - yield from _walk(trie[branch], prefix=new_prefix, separator=separator) - - -def _flatten_mapping( - mapping: Mapping[str, Any], - *, - prefix: str = "", - recurse_on_sequence: bool = False, - json_string_attributes: Optional[Sequence[str]] = None, - separator: str = ".", -) -> Iterator[Tuple[str, Any]]: - for key, value in mapping.items(): - prefixed_key = f"{prefix}{separator}{key}" if prefix else key - if isinstance(value, Mapping): - if json_string_attributes and prefixed_key.endswith(JSON_STRING_ATTRIBUTES): - yield prefixed_key, json.dumps(value) - else: - yield from _flatten_mapping( - value, - prefix=prefixed_key, - recurse_on_sequence=recurse_on_sequence, - json_string_attributes=json_string_attributes, - separator=separator, - ) - elif isinstance(value, Sequence) and recurse_on_sequence: - yield from _flatten_sequence( - value, - prefix=prefixed_key, - recurse_on_sequence=recurse_on_sequence, - json_string_attributes=json_string_attributes, - separator=separator, - ) - elif value is not None: - yield prefixed_key, value - - -def _flatten_sequence( - sequence: Iterable[Any], - *, - prefix: str = "", - recurse_on_sequence: bool = False, - json_string_attributes: Optional[Sequence[str]] = None, - separator: str = ".", -) -> Iterator[Tuple[str, Any]]: - if isinstance(sequence, str) or not has_mapping(sequence): - yield prefix, sequence - for idx, obj in enumerate(sequence): - if not isinstance(obj, Mapping): - continue - yield from _flatten_mapping( - obj, - prefix=f"{prefix}{separator}{idx}" if prefix else f"{idx}", - recurse_on_sequence=recurse_on_sequence, - json_string_attributes=json_string_attributes, - separator=separator, - ) diff --git a/tests/utilities/test_attributes.py b/tests/utilities/test_attributes.py deleted file mode 100644 index 85a31792e4..0000000000 --- a/tests/utilities/test_attributes.py +++ /dev/null @@ -1,90 +0,0 @@ -from random import random - -import pytest -from phoenix.utilities.attributes import unflatten - - -@pytest.mark.parametrize( - "key_value_pairs,desired", - [ - ((), {}), - ((("1", 0),), {"1": 0}), - ((("1.2", 0),), {"1": {"2": 0}}), - ((("1.0.2", 0),), {"1": [{"2": 0}]}), - ((("1.0.2.3", 0),), {"1": [{"2": {"3": 0}}]}), - ((("1.0.2.0.3", 0),), {"1": [{"2": [{"3": 0}]}]}), - ((("1.0.2.0.3.4", 0),), {"1": [{"2": [{"3": {"4": 0}}]}]}), - ((("1.0.2.0.3.0.4", 0),), {"1": [{"2": [{"3": [{"4": 0}]}]}]}), - ((("1.2", 1), ("1", 0)), {"1": 0, "1.2": 1}), - ((("1.2.3", 1), ("1", 0)), {"1": 0, "1.2": {"3": 1}}), - ((("1.2.3", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": 1}}), - ((("1.2.0.3", 1), ("1", 0)), {"1": 0, "1.2": [{"3": 1}]}), - ((("1.2.3.4", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": {"4": 1}}}), - ((("1.0.2.3", 1), ("1.0.2", 0)), {"1": [{"2": 0, "2.3": 1}]}), - ((("1.2.0.3.4", 1), ("1", 0)), {"1": 0, "1.2": [{"3": {"4": 1}}]}), - ((("1.2.3.0.4", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": [{"4": 1}]}}), - ((("1.0.2.3.4", 1), ("1.0.2", 0)), {"1": [{"2": 0, "2.3": {"4": 1}}]}), - ((("1.0.2.3.4", 1), ("1.0.2.3", 0)), {"1": [{"2": {"3": 0, "3.4": 1}}]}), - ((("1.2.0.3.0.4", 1), ("1", 0)), {"1": 0, "1.2": [{"3": [{"4": 1}]}]}), - ((("1.2.3.0.4.5", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": [{"4": {"5": 1}}]}}), - ((("1.0.2.3.0.4", 1), ("1.0.2", 0)), {"1": [{"2": 0, "2.3": [{"4": 1}]}]}), - ((("1.0.2.3.4.5", 1), ("1.0.2.3", 0)), {"1": [{"2": {"3": 0, "3.4": {"5": 1}}}]}), - ((("1.0.2.0.3.4", 1), ("1.0.2.0.3", 0)), {"1": [{"2": [{"3": 0, "3.4": 1}]}]}), - ((("1.2.0.3.0.4.5", 1), ("1", 0)), {"1": 0, "1.2": [{"3": [{"4": {"5": 1}}]}]}), - ((("1.2.3.0.4.0.5", 1), ("1.2", 0)), {"1": {"2": 0, "2.3": [{"4": [{"5": 1}]}]}}), - ((("1.0.2.3.0.4.5", 1), ("1.0.2", 0)), {"1": [{"2": 0, "2.3": [{"4": {"5": 1}}]}]}), - ((("1.0.2.3.4.0.5", 1), ("1.0.2.3", 0)), {"1": [{"2": {"3": 0, "3.4": [{"5": 1}]}}]}), - ((("1.0.2.0.3.4.5", 1), ("1.0.2.0.3", 0)), {"1": [{"2": [{"3": 0, "3.4": {"5": 1}}]}]}), - ((("1.0.2.0.3.4.5", 1), ("1.0.2.0.3.4", 0)), {"1": [{"2": [{"3": {"4": 0, "4.5": 1}}]}]}), - ( - (("1.0.2.3.4.5.6", 2), ("1.0.2.3.4", 1), ("1.0.2", 0)), - {"1": [{"2": 0, "2.3": {"4": 1, "4.5": {"6": 2}}}]}, - ), - ( - (("0.0.0.0.0", 4), ("0.0.0.0", 3), ("0.0.0", 2), ("0.0", 1), ("0", 0)), - {"0": 0, "0.0": 1, "0.0.0": 2, "0.0.0.0": 3, "0.0.0.0.0": 4}, - ), - ( - (("a.9999999.c", 2), ("a.9999999.b", 1), ("a.99999.b", 0)), - {"a": [{"b": 0}, {"b": 1, "c": 2}]}, - ), - ((("a", 0), ("c", 2), ("b", 1), ("d", 3)), {"a": 0, "b": 1, "c": 2, "d": 3}), - ( - (("a.b.c", 0), ("a.e", 2), ("a.b.d", 1), ("f", 3)), - {"a": {"b": {"c": 0, "d": 1}, "e": 2}, "f": 3}, - ), - ( - (("a.1.d", 3), ("a.0.d", 2), ("a.0.c", 1), ("a.b", 0)), - {"a.b": 0, "a": [{"c": 1, "d": 2}, {"d": 3}]}, - ), - ( - (("a.0.d", 3), ("a.0.c", 2), ("a.b", 1), ("a", 0)), - {"a": 0, "a.b": 1, "a.0": {"c": 2, "d": 3}}, - ), - ( - (("a.0.1.d", 3), ("a.0.0.c", 2), ("a", 1), ("a.b", 0)), - {"a.b": 0, "a": 1, "a.0": [{"c": 2}, {"d": 3}]}, - ), - ( - (("a.1.0.e", 3), ("a.0.0.d", 2), ("a.0.0.c", 1), ("a.b", 0)), - {"a.b": 0, "a": [{"0": {"c": 1, "d": 2}}, {"0": {"e": 3}}]}, - ), - ( - (("a.b.1.e.0.f", 2), ("a.b.0.c", 0), ("a.b.0.d.e.0.f", 1)), - {"a": {"b": [{"c": 0, "d": {"e": [{"f": 1}]}}, {"e": [{"f": 2}]}]}}, - ), - ], -) -def test_unflatten(key_value_pairs, desired): - actual = dict(unflatten(key_value_pairs)) - assert actual == desired - actual = dict(unflatten(reversed(key_value_pairs))) - assert actual == desired - - -@pytest.mark.parametrize("key_value_pairs,desired", [((("1.0.2", 0),), {"1": [{"2": 0}]})]) -def test_unflatten_separator(key_value_pairs, desired): - separator = str(random()) - key_value_pairs = ((key.replace(".", separator), value) for key, value in key_value_pairs) - actual = dict(unflatten(key_value_pairs, separator=separator)) - assert actual == desired From a53732d2f0d01d898ec1888d6396c63c1844cbec Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 22 Apr 2024 22:37:09 -0700 Subject: [PATCH 24/46] refactor --- integration-tests/eval_query_testing.ipynb | 114 ++++++++++---- src/phoenix/trace/dsl/filter.py | 169 ++++++++++++--------- 2 files changed, 179 insertions(+), 104 deletions(-) diff --git a/integration-tests/eval_query_testing.ipynb b/integration-tests/eval_query_testing.ipynb index c76ac7c903..d04e0ddf83 100644 --- a/integration-tests/eval_query_testing.ipynb +++ b/integration-tests/eval_query_testing.ipynb @@ -72,12 +72,20 @@ "source": [ "filter_condition = \"evals['Q&A Correctness'].label == 'correct'\"\n", "original_df = (\n", - " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + " px.Client(endpoint=original_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", ")\n", "postgres_df = (\n", - " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + " px.Client(endpoint=postgres_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", + ")\n", + "sqlite_df = (\n", + " px.Client(endpoint=sqlite_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", ")\n", - "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -92,7 +100,9 @@ "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", "sqlite_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", + " COMMON_COLUMNS\n", + " ],\n", " result_names=(\"sqlite\", \"original\"),\n", ")" ] @@ -106,7 +116,9 @@ "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", "postgres_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", + " COMMON_COLUMNS\n", + " ],\n", " result_names=(\"postgres\", \"original\"),\n", ")" ] @@ -117,16 +129,22 @@ "metadata": {}, "outputs": [], "source": [ - "filter_condition = (\n", - " \"\"\"evals['Q&A Correctness'].label == 'correct' and evals[\"Hallucination\"].score < 0.5\"\"\"\n", - ")\n", + "filter_condition = \"\"\"evals['Q&A Correctness'].label == 'correct' and evals[\"Hallucination\"].score < 0.5\"\"\"\n", "original_df = (\n", - " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + " px.Client(endpoint=original_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", ")\n", "postgres_df = (\n", - " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + " px.Client(endpoint=postgres_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", + ")\n", + "sqlite_df = (\n", + " px.Client(endpoint=sqlite_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", ")\n", - "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -141,7 +159,9 @@ "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", "sqlite_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", + " COMMON_COLUMNS\n", + " ],\n", " result_names=(\"sqlite\", \"original\"),\n", ")" ] @@ -155,7 +175,9 @@ "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", "postgres_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", + " COMMON_COLUMNS\n", + " ],\n", " result_names=(\"postgres\", \"original\"),\n", ")" ] @@ -168,12 +190,20 @@ "source": [ "filter_condition = \"\"\"evals['Q&A Correctness'].label != 'correct'\"\"\"\n", "original_df = (\n", - " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + " px.Client(endpoint=original_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", ")\n", "postgres_df = (\n", - " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + " px.Client(endpoint=postgres_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", + ")\n", + "sqlite_df = (\n", + " px.Client(endpoint=sqlite_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", ")\n", - "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -188,7 +218,9 @@ "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", "sqlite_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", + " COMMON_COLUMNS\n", + " ],\n", " result_names=(\"sqlite\", \"original\"),\n", ")" ] @@ -202,7 +234,9 @@ "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", "postgres_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", + " COMMON_COLUMNS\n", + " ],\n", " result_names=(\"postgres\", \"original\"),\n", ")" ] @@ -215,12 +249,20 @@ "source": [ "filter_condition = \"\"\"evals['Q&A Correctness'].label is not None\"\"\"\n", "original_df = (\n", - " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + " px.Client(endpoint=original_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", ")\n", "postgres_df = (\n", - " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + " px.Client(endpoint=postgres_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", + ")\n", + "sqlite_df = (\n", + " px.Client(endpoint=sqlite_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", ")\n", - "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -235,7 +277,9 @@ "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", "sqlite_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", + " COMMON_COLUMNS\n", + " ],\n", " result_names=(\"sqlite\", \"original\"),\n", ")" ] @@ -249,7 +293,9 @@ "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", "postgres_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", + " COMMON_COLUMNS\n", + " ],\n", " result_names=(\"postgres\", \"original\"),\n", ")" ] @@ -262,12 +308,20 @@ "source": [ "filter_condition = \"\"\"evals['Q&A Correctness'].score < evals[\"Hallucination\"].score\"\"\"\n", "original_df = (\n", - " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + " px.Client(endpoint=original_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", ")\n", "postgres_df = (\n", - " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", + " px.Client(endpoint=postgres_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", + ")\n", + "sqlite_df = (\n", + " px.Client(endpoint=sqlite_endpoint)\n", + " .get_spans_dataframe(filter_condition)\n", + " .sort_index()\n", ")\n", - "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -282,7 +336,9 @@ "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", "sqlite_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", + " COMMON_COLUMNS\n", + " ],\n", " result_names=(\"sqlite\", \"original\"),\n", ")" ] @@ -296,7 +352,9 @@ "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", "postgres_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", + " COMMON_COLUMNS\n", + " ],\n", " result_names=(\"postgres\", \"original\"),\n", ")" ] diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index f386fc30a2..c990cd12dc 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -4,10 +4,12 @@ import typing from dataclasses import dataclass, field from difflib import SequenceMatcher +from itertools import chain from types import MappingProxyType +from uuid import uuid4 import sqlalchemy -from sqlalchemy.orm import aliased +from sqlalchemy.orm import Mapped, aliased from sqlalchemy.orm.util import AliasedClass from sqlalchemy.sql.expression import Select from typing_extensions import TypeAlias, TypeGuard, assert_never @@ -21,10 +23,46 @@ ) -AnnotationAlias: TypeAlias = str +EvalAttribute: TypeAlias = typing.Literal["label", "score"] EvalExpression: TypeAlias = str EvalName: TypeAlias = str + +@dataclass(frozen=True) +class EvalAlias: + eval_index: int + eval_name: EvalName + AliasedSpanAnnotation: AliasedClass[models.SpanAnnotation] = field(init=False, repr=False) + _label_attribute_alias: str = field(init=False, repr=False) + _score_attribute_alias: str = field(init=False, repr=False) + + def __post_init__(self) -> None: + table_alias = f"span_annotation_{self.eval_index}" + alias_id = str(uuid4()).replace("-", "") # prevent conflicts with user-defined attributes + label_attribute_alias = f"{table_alias}_label_{alias_id}" + score_attribute_alias = f"{table_alias}_score_{alias_id}" + AliasedSpanAnnotation = aliased(models.SpanAnnotation, name=table_alias) + object.__setattr__(self, "_label_attribute_alias", label_attribute_alias) + object.__setattr__(self, "_score_attribute_alias", score_attribute_alias) + object.__setattr__( + self, + "AliasedSpanAnnotation", + AliasedSpanAnnotation, + ) + + @property + def attributes(self) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]: + yield self._label_attribute_alias, self.AliasedSpanAnnotation.label + yield self._score_attribute_alias, self.AliasedSpanAnnotation.score + + def attribute_alias(self, attribute_name: str) -> str: + if attribute_name == "label": + return self._label_attribute_alias + if attribute_name == "score": + return self._score_attribute_alias + raise ValueError(f"Invalid attribute name: {attribute_name}") + + # Because postgresql is strongly typed, we cast JSON values to string # by default unless it's hinted otherwise as done here. _FLOAT_ATTRIBUTES: typing.FrozenSet[str] = frozenset( @@ -90,12 +128,7 @@ class SpanFilter: valid_eval_names: typing.Optional[typing.Sequence[str]] = None translated: ast.Expression = field(init=False, repr=False) compiled: typing.Any = field(init=False, repr=False) - aliased_annotations: typing.Tuple[AliasedClass[models.SpanAnnotation]] = field( - init=False, repr=False - ) - join_aliased_tables: typing.Callable[[Select[typing.Any]], Select[typing.Any]] = field( - init=False, repr=False - ) + eval_aliases: typing.Tuple[EvalAlias] = field(init=False, repr=False) def __bool__(self) -> bool: return bool(self.condition) @@ -105,31 +138,26 @@ def __post_init__(self) -> None: return root = ast.parse(source, mode="eval") _validate_expression(root, source, valid_eval_names=self.valid_eval_names) - source, aliased_annotations, join_aliased_tables = _apply_aliases(source) + source, eval_aliases = _apply_eval_aliases(source) + object.__setattr__(self, "eval_aliases", eval_aliases) root = ast.parse(source, mode="eval") translated = _FilterTranslator( - source=source, - annotation_aliases=[_get_alias(annotation) for annotation in aliased_annotations], + source=source, names=(alias for alias, _ in self.aliased_eval_attributes()) ).visit(root) ast.fix_missing_locations(translated) compiled = compile(translated, filename="", mode="eval") object.__setattr__(self, "translated", translated) object.__setattr__(self, "compiled", compiled) - object.__setattr__(self, "aliased_annotations", aliased_annotations) - object.__setattr__(self, "join_aliased_tables", join_aliased_tables) def __call__(self, select: Select[typing.Any]) -> Select[typing.Any]: if not self.condition: return select - return self.join_aliased_tables(select).where( + return self.join_aliased_relations(select).where( eval( self.compiled, { **_NAMES, - **{ - _get_alias(annotation): annotation - for annotation in self.aliased_annotations - }, + **dict(self.aliased_eval_attributes()), "not_": sqlalchemy.not_, "and_": sqlalchemy.and_, "or_": sqlalchemy.or_, @@ -153,6 +181,24 @@ def from_dict( ) -> "SpanFilter": return cls(condition=obj.get("condition") or "") + def join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any]: + for eval_alias in self.eval_aliases: + eval_name = eval_alias.eval_name + AliasedSpanAnnotation = eval_alias.AliasedSpanAnnotation + stmt = stmt.join( + AliasedSpanAnnotation, + onclause=( + sqlalchemy.and_( + AliasedSpanAnnotation.span_rowid == models.Span.id, + AliasedSpanAnnotation.name == eval_name, + ) + ), + ) + return stmt + + def aliased_eval_attributes(self) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]: + yield from chain.from_iterable(eval_alias.attributes for eval_alias in self.eval_aliases) + @dataclass(frozen=True) class Projector: @@ -310,19 +356,16 @@ def _is_float(node: typing.Any) -> TypeGuard[ast.Call]: class _ProjectionTranslator(ast.NodeTransformer): - def __init__(self, source: str) -> None: + def __init__(self, source: str, names: typing.Optional[typing.Iterable[str]] = None) -> None: # Regarding the need for `source: str` for getting source segments: # In Python 3.8, we have to use `ast.get_source_segment(source, node)`. # In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`). self._source = source - - def visit_Attribute(self, node: ast.Attribute) -> typing.Any: - source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) - if replacement := _BACKWARD_COMPATIBILITY_REPLACEMENTS.get(source_segment): - return ast.Name(id=replacement, ctx=ast.Load()) - if (keys := _get_attribute_keys_list(node)) is None: - raise SyntaxError(f"invalid expression: {source_segment}") - return _as_attribute(keys) + self._names = ( + (tuple(names) if names is not None else ()) + + tuple(_STRING_NAMES.keys()) + + tuple(_FLOAT_NAMES.keys()) + ) def visit_generic(self, node: ast.AST) -> typing.Any: raise SyntaxError(f"invalid expression: {ast.get_source_segment(self._source, node)}") @@ -330,9 +373,17 @@ def visit_generic(self, node: ast.AST) -> typing.Any: def visit_Expression(self, node: ast.Expression) -> typing.Any: return ast.Expression(body=self.visit(node.body)) + def visit_Attribute(self, node: ast.Attribute) -> typing.Any: + source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) + if replacement := _BACKWARD_COMPATIBILITY_REPLACEMENTS.get(source_segment): + return ast.Name(id=replacement, ctx=ast.Load()) + if (keys := _get_attribute_keys_list(node)) is not None: + return _as_attribute(keys) + raise SyntaxError(f"invalid expression: {source_segment}") + def visit_Name(self, node: ast.Name) -> typing.Any: source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) - if source_segment in _STRING_NAMES or source_segment in _FLOAT_NAMES: + if source_segment in self._names: return node name = source_segment return _as_attribute([ast.Constant(value=name, kind=None)]) @@ -345,20 +396,6 @@ def visit_Subscript(self, node: ast.Subscript) -> typing.Any: # TODO(persistence): support `evals['name'].score` et. al. class _FilterTranslator(_ProjectionTranslator): - def __init__(self, source: str, annotation_aliases: typing.Sequence[str]) -> None: - super().__init__(source) - self._annotation_aliases = annotation_aliases - - def visit_Attribute(self, node: ast.Attribute) -> typing.Any: - source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) - if replacement := _BACKWARD_COMPATIBILITY_REPLACEMENTS.get(source_segment): - return ast.Name(id=replacement, ctx=ast.Load()) - if (keys := _get_attribute_keys_list(node)) is None: - raise SyntaxError(f"invalid expression: {source_segment}") - if keys and keys[0].value in self._annotation_aliases: - return node - return _as_attribute(keys) - def visit_Compare(self, node: ast.Compare) -> typing.Any: if len(node.comparators) > 1: args: typing.List[typing.Any] = [] @@ -727,55 +764,35 @@ def _find_best_match( return best_choice, best_score -def _apply_aliases( +def _apply_eval_aliases( source: str, ) -> typing.Tuple[ str, - typing.Tuple[AliasedClass[models.SpanAnnotation], ...], - typing.Callable[[Select[typing.Any]], Select[typing.Any]], + typing.Tuple[EvalAlias, ...], ]: - aliased_annotations: typing.Dict[EvalName, AliasedClass[models.SpanAnnotation]] = {} - for eval_expression, eval_name in _parse_eval_expressions_and_names(source): - if (aliased_annotation := aliased_annotations.get(eval_name)) is None: - aliased_annotation = typing.cast( - AliasedClass[models.SpanAnnotation], - aliased(models.SpanAnnotation, name=f"span_annotation_{len(aliased_annotations)}"), - ) - aliased_annotations[eval_name] = aliased_annotation - alias = _get_alias(aliased_annotation) - source = source.replace(eval_expression, alias) - - def join_aliased_tables(stmt: Select[typing.Any]) -> Select[typing.Any]: - for eval_name, AliasedSpanAnnotation in MappingProxyType(aliased_annotations).items(): - stmt = stmt.join( - AliasedSpanAnnotation, - onclause=( - sqlalchemy.and_( - AliasedSpanAnnotation.span_rowid == models.Span.id, - AliasedSpanAnnotation.name == eval_name, - ) - ), - ) - return stmt - - return source, tuple(aliased_annotations.values()), join_aliased_tables + eval_aliases: typing.Dict[EvalName, EvalAlias] = {} + for eval_expression, eval_name, eval_attribute in _parse_eval_expressions_and_names(source): + if (eval_alias := eval_aliases.get(eval_name)) is None: + eval_alias = EvalAlias(eval_index=len(eval_aliases), eval_name=eval_name) + eval_aliases[eval_name] = eval_alias + alias_name = eval_alias.attribute_alias(eval_attribute) + source = source.replace(eval_expression, alias_name) + return source, tuple(eval_aliases.values()) def _parse_eval_expressions_and_names( source: str, -) -> typing.Iterator[typing.Tuple[EvalExpression, EvalName]]: - for match in re.finditer(r"""(evals\[("(.*)"|'(.*)')\])""", source): +) -> typing.Iterator[typing.Tuple[EvalExpression, EvalName, EvalAttribute]]: + for match in re.finditer(r"""(evals\[("(.*)"|'(.*)')\][.](label|score))""", source): ( eval_expression, _, double_quoted_eval_name, single_quoted_eval_name, + evaluation_attribute_name, ) = match.groups() yield ( eval_expression, double_quoted_eval_name or single_quoted_eval_name or "", + typing.cast(EvalAttribute, evaluation_attribute_name), ) - - -def _get_alias(aliased_annotation: AliasedClass[models.SpanAnnotation]) -> AnnotationAlias: - return str(sqlalchemy.inspect(aliased_annotation).name) From 3e5b16384481d22c125df5188695e9c2968b5ac3 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 22 Apr 2024 22:49:44 -0700 Subject: [PATCH 25/46] remove unnecessary check --- src/phoenix/trace/dsl/filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index c990cd12dc..167e587bf5 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -793,6 +793,6 @@ def _parse_eval_expressions_and_names( ) = match.groups() yield ( eval_expression, - double_quoted_eval_name or single_quoted_eval_name or "", + double_quoted_eval_name or single_quoted_eval_name, typing.cast(EvalAttribute, evaluation_attribute_name), ) From de7158e0d3897993ba0effa05ab61fa2347cf767 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 12:44:35 -0700 Subject: [PATCH 26/46] add unit test for regex and fix bug with greedy regex matching --- src/phoenix/trace/dsl/filter.py | 2 +- tests/trace/dsl/test_filter.py | 58 ++++++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 167e587bf5..c61adcfbc7 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -783,7 +783,7 @@ def _apply_eval_aliases( def _parse_eval_expressions_and_names( source: str, ) -> typing.Iterator[typing.Tuple[EvalExpression, EvalName, EvalAttribute]]: - for match in re.finditer(r"""(evals\[("(.*)"|'(.*)')\][.](label|score))""", source): + for match in re.finditer(r"""(evals\[("(.*?)"|'(.*?)')\][.](label|score))""", source): ( eval_expression, _, diff --git a/tests/trace/dsl/test_filter.py b/tests/trace/dsl/test_filter.py index b8e369319d..b513615067 100644 --- a/tests/trace/dsl/test_filter.py +++ b/tests/trace/dsl/test_filter.py @@ -1,10 +1,13 @@ import ast import sys from typing import Any, List, Optional +from unittest.mock import patch +from uuid import UUID +import phoenix.trace.dsl.filter import pytest from phoenix.db import models -from phoenix.trace.dsl.filter import SpanFilter, _get_attribute_keys_list +from phoenix.trace.dsl.filter import SpanFilter, _apply_eval_aliases, _get_attribute_keys_list from sqlalchemy import select from sqlalchemy.orm import Session @@ -152,6 +155,59 @@ def test_filter_translated(session: Session, expression: str, expected: str) -> session.scalar(f(select(models.Span.id))) +@pytest.mark.parametrize( + "filter_condition,expected", + [ + pytest.param( + """evals["Q&A Correctness"].label is not None""", + "span_annotation_0_label_00000000000000000000000000000000 is not None", + id="double-quoted-eval-name", + ), + pytest.param( + """evals['Q&A Correctness'].label is not None""", + "span_annotation_0_label_00000000000000000000000000000000 is not None", + id="single-quoted-eval-name", + ), + pytest.param( + """evals[""].label is not None""", + "span_annotation_0_label_00000000000000000000000000000000 is not None", + id="empty-eval-name", + ), + pytest.param( + """evals['Hallucination'].label == 'correct' or evals['Hallucination'].score < 0.5""", # noqa E501 + "span_annotation_0_label_00000000000000000000000000000000 == 'correct' or span_annotation_0_score_00000000000000000000000000000000 < 0.5", # noqa E501 + id="repeated-single-quoted-eval-name", + ), + pytest.param( + """evals["Hallucination"].label == 'correct' or evals["Hallucination"].score < 0.5""", # noqa E501 + "span_annotation_0_label_00000000000000000000000000000000 == 'correct' or span_annotation_0_score_00000000000000000000000000000000 < 0.5", # noqa E501 + id="repeated-double-quoted-eval-name", + ), + pytest.param( + """evals['Hallucination'].label == 'correct' or evals["Hallucination"].score < 0.5""", # noqa E501 + "span_annotation_0_label_00000000000000000000000000000000 == 'correct' or span_annotation_0_score_00000000000000000000000000000000 < 0.5", # noqa E501 + id="repeated-mixed-quoted-eval-name", + ), + pytest.param( + """evals['Q&A Correctness'].label == 'correct' and evals["Hallucination"].score < 0.5""", # noqa E501 + "span_annotation_0_label_00000000000000000000000000000000 == 'correct' and span_annotation_1_score_00000000000000000000000000000000 < 0.5", # noqa E501 + id="distinct-mixed-quoted-eval-names", + ), + pytest.param( + """evals["Hallucination].label is not None""", + """evals["Hallucination].label is not None""", + id="missing-quote", + ), + ], +) +def test_apply_eval_aliases(filter_condition: str, expected: str) -> None: + with patch.object( + phoenix.trace.dsl.filter, "uuid4", return_value=UUID("00000000-0000-0000-0000-000000000000") + ): + aliased, _ = _apply_eval_aliases(filter_condition) + assert aliased == expected + + def _unparse(exp: Any) -> str: # `unparse` for python 3.8 outputs differently, # otherwise this function is unnecessary. From 5f4bee684ad3c560ca309bbd28b69795c6f27ac4 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 13:27:38 -0700 Subject: [PATCH 27/46] it notebook format --- integration-tests/eval_query_testing.ipynb | 114 +++++---------------- 1 file changed, 28 insertions(+), 86 deletions(-) diff --git a/integration-tests/eval_query_testing.ipynb b/integration-tests/eval_query_testing.ipynb index d04e0ddf83..c76ac7c903 100644 --- a/integration-tests/eval_query_testing.ipynb +++ b/integration-tests/eval_query_testing.ipynb @@ -72,20 +72,12 @@ "source": [ "filter_condition = \"evals['Q&A Correctness'].label == 'correct'\"\n", "original_df = (\n", - " px.Client(endpoint=original_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", + " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", ")\n", "postgres_df = (\n", - " px.Client(endpoint=postgres_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", - ")\n", - "sqlite_df = (\n", - " px.Client(endpoint=sqlite_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", + " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", ")\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -100,9 +92,7 @@ "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", "sqlite_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", - " COMMON_COLUMNS\n", - " ],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", " result_names=(\"sqlite\", \"original\"),\n", ")" ] @@ -116,9 +106,7 @@ "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", "postgres_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", - " COMMON_COLUMNS\n", - " ],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", " result_names=(\"postgres\", \"original\"),\n", ")" ] @@ -129,22 +117,16 @@ "metadata": {}, "outputs": [], "source": [ - "filter_condition = \"\"\"evals['Q&A Correctness'].label == 'correct' and evals[\"Hallucination\"].score < 0.5\"\"\"\n", + "filter_condition = (\n", + " \"\"\"evals['Q&A Correctness'].label == 'correct' and evals[\"Hallucination\"].score < 0.5\"\"\"\n", + ")\n", "original_df = (\n", - " px.Client(endpoint=original_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", + " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", ")\n", "postgres_df = (\n", - " px.Client(endpoint=postgres_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", - ")\n", - "sqlite_df = (\n", - " px.Client(endpoint=sqlite_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", + " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", ")\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -159,9 +141,7 @@ "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", "sqlite_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", - " COMMON_COLUMNS\n", - " ],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", " result_names=(\"sqlite\", \"original\"),\n", ")" ] @@ -175,9 +155,7 @@ "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", "postgres_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", - " COMMON_COLUMNS\n", - " ],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", " result_names=(\"postgres\", \"original\"),\n", ")" ] @@ -190,20 +168,12 @@ "source": [ "filter_condition = \"\"\"evals['Q&A Correctness'].label != 'correct'\"\"\"\n", "original_df = (\n", - " px.Client(endpoint=original_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", + " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", ")\n", "postgres_df = (\n", - " px.Client(endpoint=postgres_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", - ")\n", - "sqlite_df = (\n", - " px.Client(endpoint=sqlite_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", + " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", ")\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -218,9 +188,7 @@ "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", "sqlite_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", - " COMMON_COLUMNS\n", - " ],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", " result_names=(\"sqlite\", \"original\"),\n", ")" ] @@ -234,9 +202,7 @@ "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", "postgres_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", - " COMMON_COLUMNS\n", - " ],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", " result_names=(\"postgres\", \"original\"),\n", ")" ] @@ -249,20 +215,12 @@ "source": [ "filter_condition = \"\"\"evals['Q&A Correctness'].label is not None\"\"\"\n", "original_df = (\n", - " px.Client(endpoint=original_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", + " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", ")\n", "postgres_df = (\n", - " px.Client(endpoint=postgres_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", - ")\n", - "sqlite_df = (\n", - " px.Client(endpoint=sqlite_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", + " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", ")\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -277,9 +235,7 @@ "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", "sqlite_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", - " COMMON_COLUMNS\n", - " ],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", " result_names=(\"sqlite\", \"original\"),\n", ")" ] @@ -293,9 +249,7 @@ "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", "postgres_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", - " COMMON_COLUMNS\n", - " ],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", " result_names=(\"postgres\", \"original\"),\n", ")" ] @@ -308,20 +262,12 @@ "source": [ "filter_condition = \"\"\"evals['Q&A Correctness'].score < evals[\"Hallucination\"].score\"\"\"\n", "original_df = (\n", - " px.Client(endpoint=original_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", + " px.Client(endpoint=original_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", ")\n", "postgres_df = (\n", - " px.Client(endpoint=postgres_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", - ")\n", - "sqlite_df = (\n", - " px.Client(endpoint=sqlite_endpoint)\n", - " .get_spans_dataframe(filter_condition)\n", - " .sort_index()\n", + " px.Client(endpoint=postgres_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", ")\n", + "sqlite_df = px.Client(endpoint=sqlite_endpoint).get_spans_dataframe(filter_condition).sort_index()\n", "print(f\"{original_df.shape=}\")\n", "print(f\"{postgres_df.shape=}\")\n", "print(f\"{sqlite_df.shape=}\")" @@ -336,9 +282,7 @@ "print(f\"{set(original_df.columns).difference(set(sqlite_df.columns))=}\")\n", "print(f\"{set(sqlite_df.columns).difference(set(original_df.columns))=}\")\n", "sqlite_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", - " COMMON_COLUMNS\n", - " ],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", " result_names=(\"sqlite\", \"original\"),\n", ")" ] @@ -352,9 +296,7 @@ "print(f\"{set(original_df.columns).difference(set(postgres_df.columns))=}\")\n", "print(f\"{set(postgres_df.columns).difference(set(original_df.columns))=}\")\n", "postgres_df[COMMON_COLUMNS].compare(\n", - " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[\n", - " COMMON_COLUMNS\n", - " ],\n", + " original_df.rename(columns={\"span_kind\": \"attributes.openinference.span.kind\"})[COMMON_COLUMNS],\n", " result_names=(\"postgres\", \"original\"),\n", ")" ] From 5371477e934dcdf268805041436e33f45fe6b86a Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 14:26:43 -0700 Subject: [PATCH 28/46] rename variables --- src/phoenix/trace/dsl/filter.py | 42 +++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 9952f0fea8..eedc643bb9 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -29,15 +29,15 @@ @dataclass(frozen=True) -class EvalAlias: - eval_index: int - eval_name: EvalName +class AliasedAnnotationRelation: + index: int + name: str AliasedSpanAnnotation: AliasedClass[models.SpanAnnotation] = field(init=False, repr=False) _label_attribute_alias: str = field(init=False, repr=False) _score_attribute_alias: str = field(init=False, repr=False) def __post_init__(self) -> None: - table_alias = f"span_annotation_{self.eval_index}" + table_alias = f"span_annotation_{self.index}" alias_id = str(uuid4()).replace("-", "") # prevent conflicts with user-defined attributes label_attribute_alias = f"{table_alias}_label_{alias_id}" score_attribute_alias = f"{table_alias}_score_{alias_id}" @@ -128,7 +128,9 @@ class SpanFilter: valid_eval_names: typing.Optional[typing.Sequence[str]] = None translated: ast.Expression = field(init=False, repr=False) compiled: typing.Any = field(init=False, repr=False) - eval_aliases: typing.Tuple[EvalAlias] = field(init=False, repr=False) + aliased_annotation_relations: typing.Tuple[AliasedAnnotationRelation] = field( + init=False, repr=False + ) def __bool__(self) -> bool: return bool(self.condition) @@ -138,11 +140,11 @@ def __post_init__(self) -> None: return root = ast.parse(source, mode="eval") _validate_expression(root, source, valid_eval_names=self.valid_eval_names) - source, eval_aliases = _apply_eval_aliases(source) - object.__setattr__(self, "eval_aliases", eval_aliases) + source, aliased_annotation_relations = _apply_eval_aliases(source) + object.__setattr__(self, "aliased_annotation_relations", aliased_annotation_relations) root = ast.parse(source, mode="eval") translated = _FilterTranslator( - source=source, names=(alias for alias, _ in self.aliased_eval_attributes()) + source=source, names=(alias for alias, _ in self._aliased_annotation_attributes()) ).visit(root) ast.fix_missing_locations(translated) compiled = compile(translated, filename="", mode="eval") @@ -152,12 +154,12 @@ def __post_init__(self) -> None: def __call__(self, select: Select[typing.Any]) -> Select[typing.Any]: if not self.condition: return select - return self.join_aliased_relations(select).where( + return self._join_aliased_relations(select).where( eval( self.compiled, { **_NAMES, - **dict(self.aliased_eval_attributes()), + **dict(self._aliased_annotation_attributes()), "not_": sqlalchemy.not_, "and_": sqlalchemy.and_, "or_": sqlalchemy.or_, @@ -182,9 +184,9 @@ def from_dict( ) -> "SpanFilter": return cls(condition=obj.get("condition") or "") - def join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any]: - for eval_alias in self.eval_aliases: - eval_name = eval_alias.eval_name + def _join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any]: + for eval_alias in self.aliased_annotation_relations: + eval_name = eval_alias.name AliasedSpanAnnotation = eval_alias.AliasedSpanAnnotation stmt = stmt.join( AliasedSpanAnnotation, @@ -197,8 +199,12 @@ def join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any] ) return stmt - def aliased_eval_attributes(self) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]: - yield from chain.from_iterable(eval_alias.attributes for eval_alias in self.eval_aliases) + def _aliased_annotation_attributes( + self, + ) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]: + yield from chain.from_iterable( + eval_alias.attributes for eval_alias in self.aliased_annotation_relations + ) @dataclass(frozen=True) @@ -764,12 +770,12 @@ def _apply_eval_aliases( source: str, ) -> typing.Tuple[ str, - typing.Tuple[EvalAlias, ...], + typing.Tuple[AliasedAnnotationRelation, ...], ]: - eval_aliases: typing.Dict[EvalName, EvalAlias] = {} + eval_aliases: typing.Dict[EvalName, AliasedAnnotationRelation] = {} for eval_expression, eval_name, eval_attribute in _parse_eval_expressions_and_names(source): if (eval_alias := eval_aliases.get(eval_name)) is None: - eval_alias = EvalAlias(eval_index=len(eval_aliases), eval_name=eval_name) + eval_alias = AliasedAnnotationRelation(index=len(eval_aliases), name=eval_name) eval_aliases[eval_name] = eval_alias alias_name = eval_alias.attribute_alias(eval_attribute) source = source.replace(eval_expression, alias_name) From 4f02631435a4bf4d6df7fe3e70e2e39318ae8449 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 14:59:25 -0700 Subject: [PATCH 29/46] add docstrings and use frozenset --- src/phoenix/trace/dsl/filter.py | 68 +++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index eedc643bb9..f12a943194 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -30,6 +30,13 @@ @dataclass(frozen=True) class AliasedAnnotationRelation: + """ + Represents an aliased `span_annotation` relation (i.e., SQL table). Used to + perform joins on span evaluations during filtering. An alias is required + because the `span_annotation` may be joined multiple times for different + evaluation names. + """ + index: int name: str AliasedSpanAnnotation: AliasedClass[models.SpanAnnotation] = field(init=False, repr=False) @@ -52,10 +59,17 @@ def __post_init__(self) -> None: @property def attributes(self) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]: + """ + Alias names and attributes (i.e., columns) of the `span_annotation` + relation. + """ yield self._label_attribute_alias, self.AliasedSpanAnnotation.label yield self._score_attribute_alias, self.AliasedSpanAnnotation.score def attribute_alias(self, attribute_name: str) -> str: + """ + Returns an alias for the given attribute (i.e., column). + """ if attribute_name == "label": return self._label_attribute_alias if attribute_name == "score": @@ -185,6 +199,21 @@ def from_dict( return cls(condition=obj.get("condition") or "") def _join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any]: + """ + Joins the aliased relations to the given statement. E.g., for the filter condition: + + ``` + evals["Hallucination"].score > 0.5 + ``` + + an alias (e.g., `A`) is generated for the `span_annotations` relation. An input statement + `select(Span)` is transformed to: + + ``` + A = aliased(SpanAnnotation) + select(Span).join(A, onclause=(and_(Span.id == A.span_rowid, A.name == "Hallucination"))) + ``` + """ for eval_alias in self.aliased_annotation_relations: eval_name = eval_alias.name AliasedSpanAnnotation = eval_alias.AliasedSpanAnnotation @@ -202,6 +231,10 @@ def _join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any def _aliased_annotation_attributes( self, ) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]: + """ + Yields all alias names and attributes (i.e., columns) for the aliased + annotation relations (tables). + """ yield from chain.from_iterable( eval_alias.attributes for eval_alias in self.aliased_annotation_relations ) @@ -368,10 +401,12 @@ def __init__(self, source: str, names: typing.Optional[typing.Iterable[str]] = N # In Python 3.8, we have to use `ast.get_source_segment(source, node)`. # In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`). self._source = source - self._names = ( - (tuple(names) if names is not None else ()) - + tuple(_STRING_NAMES.keys()) - + tuple(_FLOAT_NAMES.keys()) + self._names = frozenset( + chain( + (iter(names) if names is not None else ()), + _STRING_NAMES.keys(), + _FLOAT_NAMES.keys(), + ) ) def visit_generic(self, node: ast.AST) -> typing.Any: @@ -772,6 +807,24 @@ def _apply_eval_aliases( str, typing.Tuple[AliasedAnnotationRelation, ...], ]: + """ + Substitutes `evals[].` with aliases. Returns the + updated source code in addition to the aliased relations. + + Example: + + input: + + ``` + evals['Hallucination'].label == 'correct' or evals['Hallucination'].score < 0.5 + ``` + + output: + + ``` + span_annotation_0_label_123 == 'correct' or span_annotation_0_score_456 < 0.5 + ``` + """ eval_aliases: typing.Dict[EvalName, AliasedAnnotationRelation] = {} for eval_expression, eval_name, eval_attribute in _parse_eval_expressions_and_names(source): if (eval_alias := eval_aliases.get(eval_name)) is None: @@ -785,6 +838,13 @@ def _apply_eval_aliases( def _parse_eval_expressions_and_names( source: str, ) -> typing.Iterator[typing.Tuple[EvalExpression, EvalName, EvalAttribute]]: + """ + Parses filter conditions for evaluation expressions of the form: + + ``` + evals[""]. + ``` + """ for match in re.finditer(r"""(evals\[("(.*?)"|'(.*?)')\][.](label|score))""", source): ( eval_expression, From f5f616d95c118dc0d481d254c892404e39af7c55 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 15:17:24 -0700 Subject: [PATCH 30/46] add eval test case for filter --- tests/trace/dsl/test_filter.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/trace/dsl/test_filter.py b/tests/trace/dsl/test_filter.py index 1242943544..f4bc9aa42a 100644 --- a/tests/trace/dsl/test_filter.py +++ b/tests/trace/dsl/test_filter.py @@ -146,10 +146,17 @@ def test_get_attribute_keys_list(expression: str, expected: Optional[List[str]]) if sys.version_info >= (3, 9) else "and_((attributes[['attributes']].as_string() == attributes[['attributes']].as_string()), (attributes[['attributes']].as_string() != attributes[['attributes', 'attributes']].as_string()))", # noqa E501 ), + ( + """evals['Q&A Correctness'].label == 'correct' and evals["Hallucination"].score < 0.5""", # noqa E501 + "and_(span_annotation_0_label_00000000000000000000000000000000 == 'correct', cast(span_annotation_1_score_00000000000000000000000000000000, Float) < 0.5)", # noqa E501 + ), ], ) def test_filter_translated(session: Session, expression: str, expected: str) -> None: - f = SpanFilter(expression) + with patch.object( + phoenix.trace.dsl.filter, "uuid4", return_value=UUID("00000000-0000-0000-0000-000000000000") + ): + f = SpanFilter(expression) assert _unparse(f.translated) == expected # next line is only to test that the syntax is accepted session.scalar(f(select(models.Span.id))) From 5f8f2181ca52c3d79a55fd5058bafb1b5b8bf63e Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 15:26:32 -0700 Subject: [PATCH 31/46] use shorter random id --- src/phoenix/trace/dsl/filter.py | 4 ++-- tests/trace/dsl/test_filter.py | 25 ++++++++++++------------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index f12a943194..31307c7c36 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -5,8 +5,8 @@ from dataclasses import dataclass, field from difflib import SequenceMatcher from itertools import chain +from random import randint from types import MappingProxyType -from uuid import uuid4 import sqlalchemy from sqlalchemy.orm import Mapped, aliased @@ -45,7 +45,7 @@ class AliasedAnnotationRelation: def __post_init__(self) -> None: table_alias = f"span_annotation_{self.index}" - alias_id = str(uuid4()).replace("-", "") # prevent conflicts with user-defined attributes + alias_id = f"{randint(0, 10**6):06d}" # prevent conflicts with user-defined attributes label_attribute_alias = f"{table_alias}_label_{alias_id}" score_attribute_alias = f"{table_alias}_score_{alias_id}" AliasedSpanAnnotation = aliased(models.SpanAnnotation, name=table_alias) diff --git a/tests/trace/dsl/test_filter.py b/tests/trace/dsl/test_filter.py index f4bc9aa42a..d7deb8e074 100644 --- a/tests/trace/dsl/test_filter.py +++ b/tests/trace/dsl/test_filter.py @@ -2,7 +2,6 @@ import sys from typing import Any, List, Optional from unittest.mock import patch -from uuid import UUID import phoenix.trace.dsl.filter import pytest @@ -148,13 +147,15 @@ def test_get_attribute_keys_list(expression: str, expected: Optional[List[str]]) ), ( """evals['Q&A Correctness'].label == 'correct' and evals["Hallucination"].score < 0.5""", # noqa E501 - "and_(span_annotation_0_label_00000000000000000000000000000000 == 'correct', cast(span_annotation_1_score_00000000000000000000000000000000, Float) < 0.5)", # noqa E501 + "and_(span_annotation_0_label_000000 == 'correct', cast(span_annotation_1_score_000000 Float) < 0.5)", # noqa E501 ), ], ) def test_filter_translated(session: Session, expression: str, expected: str) -> None: with patch.object( - phoenix.trace.dsl.filter, "uuid4", return_value=UUID("00000000-0000-0000-0000-000000000000") + phoenix.trace.dsl.filter, + "randint", + return_value=0, ): f = SpanFilter(expression) assert _unparse(f.translated) == expected @@ -167,37 +168,37 @@ def test_filter_translated(session: Session, expression: str, expected: str) -> [ pytest.param( """evals["Q&A Correctness"].label is not None""", - "span_annotation_0_label_00000000000000000000000000000000 is not None", + "span_annotation_0_label_000000 is not None", id="double-quoted-eval-name", ), pytest.param( """evals['Q&A Correctness'].label is not None""", - "span_annotation_0_label_00000000000000000000000000000000 is not None", + "span_annotation_0_label_000000 is not None", id="single-quoted-eval-name", ), pytest.param( """evals[""].label is not None""", - "span_annotation_0_label_00000000000000000000000000000000 is not None", + "span_annotation_0_label_000000 is not None", id="empty-eval-name", ), pytest.param( """evals['Hallucination'].label == 'correct' or evals['Hallucination'].score < 0.5""", # noqa E501 - "span_annotation_0_label_00000000000000000000000000000000 == 'correct' or span_annotation_0_score_00000000000000000000000000000000 < 0.5", # noqa E501 + "span_annotation_0_label_000000 == 'correct' or span_annotation_0_score_000000 < 0.5", # noqa E501 id="repeated-single-quoted-eval-name", ), pytest.param( """evals["Hallucination"].label == 'correct' or evals["Hallucination"].score < 0.5""", # noqa E501 - "span_annotation_0_label_00000000000000000000000000000000 == 'correct' or span_annotation_0_score_00000000000000000000000000000000 < 0.5", # noqa E501 + "span_annotation_0_label_000000 == 'correct' or span_annotation_0_score_000000 < 0.5", # noqa E501 id="repeated-double-quoted-eval-name", ), pytest.param( """evals['Hallucination'].label == 'correct' or evals["Hallucination"].score < 0.5""", # noqa E501 - "span_annotation_0_label_00000000000000000000000000000000 == 'correct' or span_annotation_0_score_00000000000000000000000000000000 < 0.5", # noqa E501 + "span_annotation_0_label_000000 == 'correct' or span_annotation_0_score_000000 < 0.5", # noqa E501 id="repeated-mixed-quoted-eval-name", ), pytest.param( """evals['Q&A Correctness'].label == 'correct' and evals["Hallucination"].score < 0.5""", # noqa E501 - "span_annotation_0_label_00000000000000000000000000000000 == 'correct' and span_annotation_1_score_00000000000000000000000000000000 < 0.5", # noqa E501 + "span_annotation_0_label_000000 == 'correct' and span_annotation_1_score_000000 < 0.5", # noqa E501 id="distinct-mixed-quoted-eval-names", ), pytest.param( @@ -208,9 +209,7 @@ def test_filter_translated(session: Session, expression: str, expected: str) -> ], ) def test_apply_eval_aliases(filter_condition: str, expected: str) -> None: - with patch.object( - phoenix.trace.dsl.filter, "uuid4", return_value=UUID("00000000-0000-0000-0000-000000000000") - ): + with patch.object(phoenix.trace.dsl.filter, "randint", return_value=0): aliased, _ = _apply_eval_aliases(filter_condition) assert aliased == expected From 54a8e5c22f09c9e9314b998c5cc839dfa920cfa8 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 16:25:14 -0700 Subject: [PATCH 32/46] make test case cover python3.8 --- tests/trace/dsl/test_filter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/trace/dsl/test_filter.py b/tests/trace/dsl/test_filter.py index d7deb8e074..721dde3269 100644 --- a/tests/trace/dsl/test_filter.py +++ b/tests/trace/dsl/test_filter.py @@ -147,7 +147,9 @@ def test_get_attribute_keys_list(expression: str, expected: Optional[List[str]]) ), ( """evals['Q&A Correctness'].label == 'correct' and evals["Hallucination"].score < 0.5""", # noqa E501 - "and_(span_annotation_0_label_000000 == 'correct', cast(span_annotation_1_score_000000 Float) < 0.5)", # noqa E501 + "and_(span_annotation_0_label_000000 == 'correct', cast(span_annotation_1_score_000000 Float) < 0.5)" # noqa E501 + if sys.version_info >= (3, 9) + else "and_((span_annotation_0_label_000000 == 'correct'), (cast(span_annotation_1_score_000000 Float)) < 0.5)", # noqa E501 ), ], ) From 7b4c8444cabe2497fecad93f49be9bf7b3b2312b Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 16:28:03 -0700 Subject: [PATCH 33/46] fix error in test case --- tests/trace/dsl/test_filter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trace/dsl/test_filter.py b/tests/trace/dsl/test_filter.py index 721dde3269..560c6ea081 100644 --- a/tests/trace/dsl/test_filter.py +++ b/tests/trace/dsl/test_filter.py @@ -147,9 +147,9 @@ def test_get_attribute_keys_list(expression: str, expected: Optional[List[str]]) ), ( """evals['Q&A Correctness'].label == 'correct' and evals["Hallucination"].score < 0.5""", # noqa E501 - "and_(span_annotation_0_label_000000 == 'correct', cast(span_annotation_1_score_000000 Float) < 0.5)" # noqa E501 + "and_(span_annotation_0_label_000000 == 'correct', cast(span_annotation_1_score_000000, Float) < 0.5)" # noqa E501 if sys.version_info >= (3, 9) - else "and_((span_annotation_0_label_000000 == 'correct'), (cast(span_annotation_1_score_000000 Float)) < 0.5)", # noqa E501 + else "and_((span_annotation_0_label_000000 == 'correct'), (cast(span_annotation_1_score_000000, Float)) < 0.5)", # noqa E501 ), ], ) From 092d52df672fc5bcfd469a338418d6e7973815c2 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 16:46:16 -0700 Subject: [PATCH 34/46] add word boundary to regex and corresponding tests --- src/phoenix/trace/dsl/filter.py | 2 +- tests/trace/dsl/test_filter.py | 22 +++++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 31307c7c36..d3857e601f 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -845,7 +845,7 @@ def _parse_eval_expressions_and_names( evals[""]. ``` """ - for match in re.finditer(r"""(evals\[("(.*?)"|'(.*?)')\][.](label|score))""", source): + for match in re.finditer(r"""\b(evals\[("(.*?)"|'(.*?)')\][.](label|score))\b""", source): ( eval_expression, _, diff --git a/tests/trace/dsl/test_filter.py b/tests/trace/dsl/test_filter.py index 560c6ea081..6ad0b63bc4 100644 --- a/tests/trace/dsl/test_filter.py +++ b/tests/trace/dsl/test_filter.py @@ -206,7 +206,27 @@ def test_filter_translated(session: Session, expression: str, expected: str) -> pytest.param( """evals["Hallucination].label is not None""", """evals["Hallucination].label is not None""", - id="missing-quote", + id="missing-right-quotation-mark", + ), + pytest.param( + """evals["Hallucination"].label == 'correct' orevals["Hallucination"].score < 0.5""", # noqa E501 + """span_annotation_0_label_000000 == 'correct' orevals["Hallucination"].score < 0.5""", # noqa E501 + id="no-word-boundary-on-the-left", + ), + pytest.param( + """evals["Hallucination"].scoreq < 0.5""", # noqa E501 + """evals["Hallucination"].scoreq < 0.5""", # noqa E501 + id="no-word-boundary-on-the-right", + ), + pytest.param( + """0.5 Date: Tue, 23 Apr 2024 16:52:04 -0700 Subject: [PATCH 35/46] compile regex --- src/phoenix/trace/dsl/filter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index d3857e601f..72357f7157 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -845,7 +845,8 @@ def _parse_eval_expressions_and_names( evals[""]. ``` """ - for match in re.finditer(r"""\b(evals\[("(.*?)"|'(.*?)')\][.](label|score))\b""", source): + pattern = re.compile(r"""\b(evals\[("(.*?)"|'(.*?)')\][.](label|score))\b""") + for match in pattern.finditer(source): ( eval_expression, _, From e25b66b8021485401b7934d4608d15f619306fca Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 16:55:13 -0700 Subject: [PATCH 36/46] simplify regex --- src/phoenix/trace/dsl/filter.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 72357f7157..373791b663 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -845,17 +845,15 @@ def _parse_eval_expressions_and_names( evals[""]. ``` """ - pattern = re.compile(r"""\b(evals\[("(.*?)"|'(.*?)')\][.](label|score))\b""") + pattern = re.compile(r"""\b(evals\[(".*?"|'.*?')\][.](label|score))\b""") for match in pattern.finditer(source): ( eval_expression, - _, - double_quoted_eval_name, - single_quoted_eval_name, + quoted_eval_name, evaluation_attribute_name, ) = match.groups() yield ( eval_expression, - double_quoted_eval_name or single_quoted_eval_name, + quoted_eval_name[1:-1], typing.cast(EvalAttribute, evaluation_attribute_name), ) From db5d25017535622254871c5d8b8212390eedd7b9 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 17:03:22 -0700 Subject: [PATCH 37/46] improve variable naming --- src/phoenix/trace/dsl/filter.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 373791b663..6816864428 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -158,7 +158,8 @@ def __post_init__(self) -> None: object.__setattr__(self, "aliased_annotation_relations", aliased_annotation_relations) root = ast.parse(source, mode="eval") translated = _FilterTranslator( - source=source, names=(alias for alias, _ in self._aliased_annotation_attributes()) + source=source, + reserved_keywords=(alias for alias, _ in self._aliased_annotation_attributes()), ).visit(root) ast.fix_missing_locations(translated) compiled = compile(translated, filename="", mode="eval") @@ -396,14 +397,16 @@ def _is_float(node: typing.Any) -> TypeGuard[ast.Call]: class _ProjectionTranslator(ast.NodeTransformer): - def __init__(self, source: str, names: typing.Optional[typing.Iterable[str]] = None) -> None: + def __init__( + self, source: str, reserved_keywords: typing.Optional[typing.Iterable[str]] = None + ) -> None: # Regarding the need for `source: str` for getting source segments: # In Python 3.8, we have to use `ast.get_source_segment(source, node)`. # In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`). self._source = source - self._names = frozenset( + self._reserved_keywords = frozenset( chain( - (iter(names) if names is not None else ()), + (iter(reserved_keywords) if reserved_keywords is not None else ()), _STRING_NAMES.keys(), _FLOAT_NAMES.keys(), ) @@ -425,7 +428,7 @@ def visit_Attribute(self, node: ast.Attribute) -> typing.Any: def visit_Name(self, node: ast.Name) -> typing.Any: source_segment = typing.cast(str, ast.get_source_segment(self._source, node)) - if source_segment in self._names: + if source_segment in self._reserved_keywords: return node name = source_segment return _as_attribute([ast.Constant(value=name, kind=None)]) From a10fc8b61a0499dc79fd3addade14ab46ff0e278 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 17:06:08 -0700 Subject: [PATCH 38/46] simplify types --- src/phoenix/trace/dsl/filter.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 6816864428..cbe368d5dc 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -397,16 +397,14 @@ def _is_float(node: typing.Any) -> TypeGuard[ast.Call]: class _ProjectionTranslator(ast.NodeTransformer): - def __init__( - self, source: str, reserved_keywords: typing.Optional[typing.Iterable[str]] = None - ) -> None: + def __init__(self, source: str, reserved_keywords: typing.Iterable[str] = ()) -> None: # Regarding the need for `source: str` for getting source segments: # In Python 3.8, we have to use `ast.get_source_segment(source, node)`. # In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`). self._source = source self._reserved_keywords = frozenset( chain( - (iter(reserved_keywords) if reserved_keywords is not None else ()), + reserved_keywords, _STRING_NAMES.keys(), _FLOAT_NAMES.keys(), ) From c1bea9d203ece5a1f415c23683731af8ed2e69ad Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 17:12:29 -0700 Subject: [PATCH 39/46] use literal type and assert_never --- src/phoenix/trace/dsl/filter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index cbe368d5dc..73366d6d8b 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -66,15 +66,15 @@ def attributes(self) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]: yield self._label_attribute_alias, self.AliasedSpanAnnotation.label yield self._score_attribute_alias, self.AliasedSpanAnnotation.score - def attribute_alias(self, attribute_name: str) -> str: + def attribute_alias(self, attribute: EvalAttribute) -> str: """ Returns an alias for the given attribute (i.e., column). """ - if attribute_name == "label": + if attribute == "label": return self._label_attribute_alias - if attribute_name == "score": + if attribute == "score": return self._score_attribute_alias - raise ValueError(f"Invalid attribute name: {attribute_name}") + assert_never(attribute) # Because postgresql is strongly typed, we cast JSON values to string From 31cff060028f8373467bcd23a568338e8c366ef8 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 17:14:21 -0700 Subject: [PATCH 40/46] change function name --- src/phoenix/trace/dsl/filter.py | 4 ++-- tests/trace/dsl/test_filter.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 73366d6d8b..8c204c810f 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -154,7 +154,7 @@ def __post_init__(self) -> None: return root = ast.parse(source, mode="eval") _validate_expression(root, source, valid_eval_names=self.valid_eval_names) - source, aliased_annotation_relations = _apply_eval_aliases(source) + source, aliased_annotation_relations = _apply_eval_aliasing(source) object.__setattr__(self, "aliased_annotation_relations", aliased_annotation_relations) root = ast.parse(source, mode="eval") translated = _FilterTranslator( @@ -802,7 +802,7 @@ def _find_best_match( return best_choice, best_score -def _apply_eval_aliases( +def _apply_eval_aliasing( source: str, ) -> typing.Tuple[ str, diff --git a/tests/trace/dsl/test_filter.py b/tests/trace/dsl/test_filter.py index 6ad0b63bc4..d5e98ad411 100644 --- a/tests/trace/dsl/test_filter.py +++ b/tests/trace/dsl/test_filter.py @@ -6,7 +6,7 @@ import phoenix.trace.dsl.filter import pytest from phoenix.db import models -from phoenix.trace.dsl.filter import SpanFilter, _apply_eval_aliases, _get_attribute_keys_list +from phoenix.trace.dsl.filter import SpanFilter, _apply_eval_aliasing, _get_attribute_keys_list from sqlalchemy import select from sqlalchemy.orm import Session @@ -230,9 +230,9 @@ def test_filter_translated(session: Session, expression: str, expected: str) -> ), ], ) -def test_apply_eval_aliases(filter_condition: str, expected: str) -> None: +def test_apply_eval_aliasing(filter_condition: str, expected: str) -> None: with patch.object(phoenix.trace.dsl.filter, "randint", return_value=0): - aliased, _ = _apply_eval_aliases(filter_condition) + aliased, _ = _apply_eval_aliasing(filter_condition) assert aliased == expected From 063fcc96ef51d1d2eadcc1e5f24d581769936494 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 17:21:11 -0700 Subject: [PATCH 41/46] remove chain, move out regex --- src/phoenix/trace/dsl/filter.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 8c204c810f..b8f2f6a62d 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -27,6 +27,8 @@ EvalExpression: TypeAlias = str EvalName: TypeAlias = str +EVAL_EXPRESSION_PATTERN = re.compile(r"""\b(evals\[(".*?"|'.*?')\][.](label|score))\b""") + @dataclass(frozen=True) class AliasedAnnotationRelation: @@ -403,11 +405,9 @@ def __init__(self, source: str, reserved_keywords: typing.Iterable[str] = ()) -> # In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`). self._source = source self._reserved_keywords = frozenset( - chain( - reserved_keywords, - _STRING_NAMES.keys(), - _FLOAT_NAMES.keys(), - ) + *reserved_keywords, + *_STRING_NAMES.keys(), + *_FLOAT_NAMES.keys(), ) def visit_generic(self, node: ast.AST) -> typing.Any: @@ -846,8 +846,7 @@ def _parse_eval_expressions_and_names( evals[""]. ``` """ - pattern = re.compile(r"""\b(evals\[(".*?"|'.*?')\][.](label|score))\b""") - for match in pattern.finditer(source): + for match in EVAL_EXPRESSION_PATTERN.finditer(source): ( eval_expression, quoted_eval_name, From eb26ce0806c53d47667edaa8269bd6399df77dc3 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 17:36:58 -0700 Subject: [PATCH 42/46] change variable name --- src/phoenix/trace/dsl/filter.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index b8f2f6a62d..a66fcc28fc 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -41,7 +41,7 @@ class AliasedAnnotationRelation: index: int name: str - AliasedSpanAnnotation: AliasedClass[models.SpanAnnotation] = field(init=False, repr=False) + table: AliasedClass[models.SpanAnnotation] = field(init=False, repr=False) _label_attribute_alias: str = field(init=False, repr=False) _score_attribute_alias: str = field(init=False, repr=False) @@ -50,13 +50,13 @@ def __post_init__(self) -> None: alias_id = f"{randint(0, 10**6):06d}" # prevent conflicts with user-defined attributes label_attribute_alias = f"{table_alias}_label_{alias_id}" score_attribute_alias = f"{table_alias}_score_{alias_id}" - AliasedSpanAnnotation = aliased(models.SpanAnnotation, name=table_alias) + table = aliased(models.SpanAnnotation, name=table_alias) object.__setattr__(self, "_label_attribute_alias", label_attribute_alias) object.__setattr__(self, "_score_attribute_alias", score_attribute_alias) object.__setattr__( self, - "AliasedSpanAnnotation", - AliasedSpanAnnotation, + "table", + table, ) @property @@ -65,8 +65,8 @@ def attributes(self) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]: Alias names and attributes (i.e., columns) of the `span_annotation` relation. """ - yield self._label_attribute_alias, self.AliasedSpanAnnotation.label - yield self._score_attribute_alias, self.AliasedSpanAnnotation.score + yield self._label_attribute_alias, self.table.label + yield self._score_attribute_alias, self.table.score def attribute_alias(self, attribute: EvalAttribute) -> str: """ @@ -219,7 +219,7 @@ def _join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any """ for eval_alias in self.aliased_annotation_relations: eval_name = eval_alias.name - AliasedSpanAnnotation = eval_alias.AliasedSpanAnnotation + AliasedSpanAnnotation = eval_alias.table stmt = stmt.join( AliasedSpanAnnotation, onclause=( @@ -405,9 +405,11 @@ def __init__(self, source: str, reserved_keywords: typing.Iterable[str] = ()) -> # In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`). self._source = source self._reserved_keywords = frozenset( - *reserved_keywords, - *_STRING_NAMES.keys(), - *_FLOAT_NAMES.keys(), + chain( + reserved_keywords, + _STRING_NAMES.keys(), + _FLOAT_NAMES.keys(), + ) ) def visit_generic(self, node: ast.AST) -> typing.Any: From 4f0d1d3a625b00ee603405254f1e460c04fcd356 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 18:07:48 -0700 Subject: [PATCH 43/46] compute _aliased_annotation_attributes and store on SpanFilter --- src/phoenix/trace/dsl/filter.py | 34 +++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index a66fcc28fc..ef56bec84d 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -144,7 +144,10 @@ class SpanFilter: valid_eval_names: typing.Optional[typing.Sequence[str]] = None translated: ast.Expression = field(init=False, repr=False) compiled: typing.Any = field(init=False, repr=False) - aliased_annotation_relations: typing.Tuple[AliasedAnnotationRelation] = field( + _aliased_annotation_relations: typing.Tuple[AliasedAnnotationRelation] = field( + init=False, repr=False + ) + _aliased_annotation_attributes: typing.Dict[str, Mapped[typing.Any]] = field( init=False, repr=False ) @@ -157,16 +160,26 @@ def __post_init__(self) -> None: root = ast.parse(source, mode="eval") _validate_expression(root, source, valid_eval_names=self.valid_eval_names) source, aliased_annotation_relations = _apply_eval_aliasing(source) - object.__setattr__(self, "aliased_annotation_relations", aliased_annotation_relations) root = ast.parse(source, mode="eval") translated = _FilterTranslator( source=source, - reserved_keywords=(alias for alias, _ in self._aliased_annotation_attributes()), + reserved_keywords=( + alias + for aliased_annotation in aliased_annotation_relations + for alias, _ in aliased_annotation.attributes + ), ).visit(root) ast.fix_missing_locations(translated) compiled = compile(translated, filename="", mode="eval") + aliased_annotation_attributes = { + attribute + for aliased_annotation in aliased_annotation_relations + for attribute in aliased_annotation.attributes + } object.__setattr__(self, "translated", translated) object.__setattr__(self, "compiled", compiled) + object.__setattr__(self, "_aliased_annotation_relations", aliased_annotation_relations) + object.__setattr__(self, "_aliased_annotation_attributes", aliased_annotation_attributes) def __call__(self, select: Select[typing.Any]) -> Select[typing.Any]: if not self.condition: @@ -176,7 +189,7 @@ def __call__(self, select: Select[typing.Any]) -> Select[typing.Any]: self.compiled, { **_NAMES, - **dict(self._aliased_annotation_attributes()), + **self._aliased_annotation_attributes, "not_": sqlalchemy.not_, "and_": sqlalchemy.and_, "or_": sqlalchemy.or_, @@ -217,7 +230,7 @@ def _join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any select(Span).join(A, onclause=(and_(Span.id == A.span_rowid, A.name == "Hallucination"))) ``` """ - for eval_alias in self.aliased_annotation_relations: + for eval_alias in self._aliased_annotation_relations: eval_name = eval_alias.name AliasedSpanAnnotation = eval_alias.table stmt = stmt.join( @@ -231,17 +244,6 @@ def _join_aliased_relations(self, stmt: Select[typing.Any]) -> Select[typing.Any ) return stmt - def _aliased_annotation_attributes( - self, - ) -> typing.Iterator[typing.Tuple[str, Mapped[typing.Any]]]: - """ - Yields all alias names and attributes (i.e., columns) for the aliased - annotation relations (tables). - """ - yield from chain.from_iterable( - eval_alias.attributes for eval_alias in self.aliased_annotation_relations - ) - @dataclass(frozen=True) class Projector: From e3a19b8ea6fcb935e17d078d8fbbf9ec81ef8302 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 18:11:53 -0700 Subject: [PATCH 44/46] remove chain --- src/phoenix/trace/dsl/filter.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index ef56bec84d..7103384ee1 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -4,7 +4,6 @@ import typing from dataclasses import dataclass, field from difflib import SequenceMatcher -from itertools import chain from random import randint from types import MappingProxyType @@ -406,12 +405,10 @@ def __init__(self, source: str, reserved_keywords: typing.Iterable[str] = ()) -> # In Python 3.8, we have to use `ast.get_source_segment(source, node)`. # In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`). self._source = source - self._reserved_keywords = frozenset( - chain( - reserved_keywords, - _STRING_NAMES.keys(), - _FLOAT_NAMES.keys(), - ) + self._reserved_keywords: typing.FrozenSet[str] = frozenset( + *reserved_keywords, + *_STRING_NAMES.keys(), + *_FLOAT_NAMES.keys(), ) def visit_generic(self, node: ast.AST) -> typing.Any: From 14e7233bbc45119a13e734a31d2964282af8bfd2 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 18:19:55 -0700 Subject: [PATCH 45/46] Revert "remove chain" This reverts commit e3a19b8ea6fcb935e17d078d8fbbf9ec81ef8302. --- src/phoenix/trace/dsl/filter.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index 7103384ee1..ef56bec84d 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -4,6 +4,7 @@ import typing from dataclasses import dataclass, field from difflib import SequenceMatcher +from itertools import chain from random import randint from types import MappingProxyType @@ -405,10 +406,12 @@ def __init__(self, source: str, reserved_keywords: typing.Iterable[str] = ()) -> # In Python 3.8, we have to use `ast.get_source_segment(source, node)`. # In Python 3.9+, we can use `ast.unparse(node)` (no need for `source`). self._source = source - self._reserved_keywords: typing.FrozenSet[str] = frozenset( - *reserved_keywords, - *_STRING_NAMES.keys(), - *_FLOAT_NAMES.keys(), + self._reserved_keywords = frozenset( + chain( + reserved_keywords, + _STRING_NAMES.keys(), + _FLOAT_NAMES.keys(), + ) ) def visit_generic(self, node: ast.AST) -> typing.Any: From d684051c5fd054bcf5266a711b884d42edadfb23 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 23 Apr 2024 18:24:37 -0700 Subject: [PATCH 46/46] fix tests --- src/phoenix/trace/dsl/filter.py | 4 ++-- tests/trace/dsl/test_filter.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/phoenix/trace/dsl/filter.py b/src/phoenix/trace/dsl/filter.py index ef56bec84d..63f38d73de 100644 --- a/src/phoenix/trace/dsl/filter.py +++ b/src/phoenix/trace/dsl/filter.py @@ -172,9 +172,9 @@ def __post_init__(self) -> None: ast.fix_missing_locations(translated) compiled = compile(translated, filename="", mode="eval") aliased_annotation_attributes = { - attribute + alias: attribute for aliased_annotation in aliased_annotation_relations - for attribute in aliased_annotation.attributes + for alias, attribute in aliased_annotation.attributes } object.__setattr__(self, "translated", translated) object.__setattr__(self, "compiled", compiled) diff --git a/tests/trace/dsl/test_filter.py b/tests/trace/dsl/test_filter.py index d5e98ad411..f508a83c8d 100644 --- a/tests/trace/dsl/test_filter.py +++ b/tests/trace/dsl/test_filter.py @@ -149,7 +149,7 @@ def test_get_attribute_keys_list(expression: str, expected: Optional[List[str]]) """evals['Q&A Correctness'].label == 'correct' and evals["Hallucination"].score < 0.5""", # noqa E501 "and_(span_annotation_0_label_000000 == 'correct', cast(span_annotation_1_score_000000, Float) < 0.5)" # noqa E501 if sys.version_info >= (3, 9) - else "and_((span_annotation_0_label_000000 == 'correct'), (cast(span_annotation_1_score_000000, Float)) < 0.5)", # noqa E501 + else "and_((span_annotation_0_label_000000 == 'correct'), (cast(span_annotation_1_score_000000, Float) < 0.5))", # noqa E501 ), ], )