From a1886a022056444c79f244e7750bb9f745ced64b Mon Sep 17 00:00:00 2001 From: Parker Stafford <52351508+Parker-Stafford@users.noreply.github.com> Date: Fri, 25 Oct 2024 08:32:41 -0700 Subject: [PATCH] feat(playground): plumb through message tool_calls from span to playground (#5197) * feat(playground): plumb through message tool_calls from span to playground * cleanup --- .../__tests__/playgroundUtils.test.ts | 281 ++++++++++++++++++ app/src/pages/playground/playgroundUtils.ts | 44 ++- app/src/pages/playground/schemas.ts | 19 +- app/src/store/playground/types.ts | 2 +- 4 files changed, 334 insertions(+), 12 deletions(-) diff --git a/app/src/pages/playground/__tests__/playgroundUtils.test.ts b/app/src/pages/playground/__tests__/playgroundUtils.test.ts index d93f8a8a71..4ed94250cc 100644 --- a/app/src/pages/playground/__tests__/playgroundUtils.test.ts +++ b/app/src/pages/playground/__tests__/playgroundUtils.test.ts @@ -1,3 +1,4 @@ +import { TemplateLanguage } from "@phoenix/components/templateEditor/types"; import { DEFAULT_MODEL_PROVIDER } from "@phoenix/constants/generativeConstants"; import { _resetInstanceId, @@ -14,8 +15,13 @@ import { SPAN_ATTRIBUTES_PARSING_ERROR, } from "../constants"; import { + extractVariablesFromInstances, getChatRole, + getModelConfigFromAttributes, getModelProviderFromModelName, + getOutputFromAttributes, + getTemplateMessagesFromAttributes, + processAttributeToolCalls, transformSpanAttributesToPlaygroundInstance, } from "../playgroundUtils"; @@ -24,6 +30,25 @@ import { spanAttributesWithInputMessages, } from "./fixtures"; +const baseTestPlaygroundInstance: PlaygroundInstance = { + id: 0, + activeRunId: null, + isRunning: false, + model: { + provider: "OPENAI", + modelName: "gpt-3.5-turbo", + invocationParameters: {}, + }, + input: { variablesValueCache: {} }, + tools: [], + toolChoice: "auto", + spanId: null, + template: { + __type: "chat", + messages: [], + }, +}; + const expectedPlaygroundInstanceWithIO: PlaygroundInstance = { id: 0, activeRunId: null, @@ -410,3 +435,259 @@ describe("getModelProviderFromModelName", () => { ); }); }); + +const testSpanToolCall = { + tool_call: { + id: "1", + function: { + name: "functionName", + arguments: JSON.stringify({ arg1: "value1" }), + }, + }, +}; + +const expectedTestToolCall = { + id: "1", + function: { + name: "functionName", + arguments: JSON.stringify({ arg1: "value1" }), + }, +}; +describe("processAttributeToolCalls", () => { + it("should transform tool calls correctly", () => { + const toolCalls = [testSpanToolCall]; + expect(processAttributeToolCalls(toolCalls)).toEqual([ + expectedTestToolCall, + ]); + }); + + it("should filter out nullish tool calls", () => { + const toolCalls = [{}, testSpanToolCall]; + expect(processAttributeToolCalls(toolCalls)).toEqual([ + expectedTestToolCall, + ]); + }); +}); + +describe("getTemplateMessagesFromAttributes", () => { + it("should return parsing errors if input messages are invalid", () => { + const parsedAttributes = { llm: { input_messages: "invalid" } }; + expect(getTemplateMessagesFromAttributes(parsedAttributes)).toEqual({ + messageParsingErrors: [INPUT_MESSAGES_PARSING_ERROR], + messages: null, + }); + }); + + it("should return parsed messages as ChatMessages if input messages are valid", () => { + const parsedAttributes = { + llm: { + input_messages: [ + { + message: { + role: "human", + content: "Hello", + tool_calls: [testSpanToolCall], + }, + }, + ], + }, + }; + expect(getTemplateMessagesFromAttributes(parsedAttributes)).toEqual({ + messageParsingErrors: [], + messages: [ + { + id: expect.any(Number), + role: "user", + content: "Hello", + toolCalls: [expectedTestToolCall], + }, + ], + }); + }); +}); + +describe("getOutputFromAttributes", () => { + it("should return parsing errors if output messages are invalid", () => { + const parsedAttributes = { llm: { output_messages: "invalid" } }; + expect(getOutputFromAttributes(parsedAttributes)).toEqual({ + output: undefined, + outputParsingErrors: [ + OUTPUT_MESSAGES_PARSING_ERROR, + OUTPUT_VALUE_PARSING_ERROR, + ], + }); + }); + + it("should return parsed output if output messages are valid", () => { + const parsedAttributes = { + llm: { + output_messages: [ + { + message: { + role: "ai", + content: "This is an AI Answer", + }, + }, + ], + }, + }; + expect(getOutputFromAttributes(parsedAttributes)).toEqual({ + output: [ + { + id: expect.any(Number), + role: "ai", + content: "This is an AI Answer", + }, + ], + outputParsingErrors: [], + }); + }); + + it("should fallback to output.value if output_messages is not present", () => { + const parsedAttributes = { + output: { + value: "This is an AI Answer", + }, + }; + expect(getOutputFromAttributes(parsedAttributes)).toEqual({ + output: "This is an AI Answer", + outputParsingErrors: [OUTPUT_MESSAGES_PARSING_ERROR], + }); + }); +}); + +describe("getModelConfigFromAttributes", () => { + it("should return parsing errors if model config is invalid", () => { + const parsedAttributes = { llm: { model_name: 123 } }; + expect(getModelConfigFromAttributes(parsedAttributes)).toEqual({ + modelConfig: null, + parsingErrors: [MODEL_CONFIG_PARSING_ERROR], + }); + }); + + it("should return parsed model config if valid with the provider inferred", () => { + const parsedAttributes = { + llm: { + model_name: "gpt-3.5-turbo", + invocation_parameters: '{"top_p": 0.5, "max_tokens": 100}', + }, + }; + expect(getModelConfigFromAttributes(parsedAttributes)).toEqual({ + modelConfig: { + modelName: "gpt-3.5-turbo", + provider: "OPENAI", + invocationParameters: { + topP: 0.5, + maxTokens: 100, + }, + }, + parsingErrors: [], + }); + }); + + it("should return invocation parameters parsing errors if they are malformed", () => { + const parsedAttributes = { + llm: { + model_name: "gpt-3.5-turbo", + invocation_parameters: 100, + }, + }; + expect(getModelConfigFromAttributes(parsedAttributes)).toEqual({ + modelConfig: { + modelName: "gpt-3.5-turbo", + provider: "OPENAI", + invocationParameters: {}, + }, + parsingErrors: [MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR], + }); + }); +}); + +describe("extractVariablesFromInstances", () => { + it("should extract variables from chat messages", () => { + const instances: PlaygroundInstance[] = [ + { + ...baseTestPlaygroundInstance, + template: { + __type: "chat", + messages: [ + { id: 0, content: "Hello {{name}}", role: "user" }, + { id: 1, content: "How are you, {{name}}?", role: "ai" }, + ], + }, + }, + ]; + const templateLanguage = "MUSTACHE"; + expect( + extractVariablesFromInstances({ instances, templateLanguage }) + ).toEqual(["name"]); + }); + + it("should extract variables from text completion prompts", () => { + const instances: PlaygroundInstance[] = [ + { + ...baseTestPlaygroundInstance, + template: { + __type: "text_completion", + prompt: "Hello {{name}}", + }, + }, + ]; + const templateLanguage = "MUSTACHE"; + expect( + extractVariablesFromInstances({ instances, templateLanguage }) + ).toEqual(["name"]); + }); + + it("should handle multiple instances and variable extraction", () => { + const instances: PlaygroundInstance[] = [ + { + ...baseTestPlaygroundInstance, + template: { + __type: "chat", + messages: [ + { id: 0, content: "Hello {{name}}", role: "user" }, + { id: 1, content: "How are you, {{name}}?", role: "ai" }, + ], + }, + }, + { + ...baseTestPlaygroundInstance, + template: { + __type: "text_completion", + prompt: "Your age is {{age}}", + }, + }, + ]; + const templateLanguage = "MUSTACHE"; + expect( + extractVariablesFromInstances({ instances, templateLanguage }) + ).toEqual(["name", "age"]); + }); + + it("should handle multiple instances and variable extraction with fstring", () => { + const instances: PlaygroundInstance[] = [ + { + ...baseTestPlaygroundInstance, + template: { + __type: "chat", + messages: [ + { id: 0, content: "Hello {name}", role: "user" }, + { id: 1, content: "How are you, {{escaped}}?", role: "ai" }, + ], + }, + }, + { + ...baseTestPlaygroundInstance, + template: { + __type: "text_completion", + prompt: "Your age is {age}", + }, + }, + ]; + const templateLanguage: TemplateLanguage = "F_STRING"; + expect( + extractVariablesFromInstances({ instances, templateLanguage }) + ).toEqual(["name", "age"]); + }); +}); diff --git a/app/src/pages/playground/playgroundUtils.ts b/app/src/pages/playground/playgroundUtils.ts index a2b888b291..a81ba16fe2 100644 --- a/app/src/pages/playground/playgroundUtils.ts +++ b/app/src/pages/playground/playgroundUtils.ts @@ -64,6 +64,37 @@ export function getChatRole(role: string): ChatMessageRole { return DEFAULT_CHAT_ROLE; } +/** + * Takes tool calls on a message from span attributes and transforms them into tool calls for a message in the playground + * @param toolCalls Tool calls from a spans message to tool calls from a chat message in the playground + * @returns Tool calls for a message in the playground + * + * NB: Only exported for testing + */ +export function processAttributeToolCalls( + toolCalls?: MessageSchema["message"]["tool_calls"] +): ChatMessage["toolCalls"] { + if (toolCalls == null) { + return; + } + return toolCalls + .map(({ tool_call }) => { + if (tool_call == null) { + return null; + } + return { + id: tool_call.id ?? "", + function: { + name: tool_call.function?.name ?? "", + arguments: tool_call.function?.arguments ?? {}, + }, + }; + }) + .filter((toolCall): toolCall is NonNullable => { + return toolCall != null; + }); +} + /** * Takes a list of messages from span attributes and transforms them into a list of {@link ChatMessage|ChatMessages} * @param messages messages from attributes either input or output @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions}} @@ -77,6 +108,7 @@ function processAttributeMessagesToChatMessage( id: generateMessageId(), role: getChatRole(message.role), content: message.content, + toolCalls: processAttributeToolCalls(message.tool_calls), }; }); } @@ -85,8 +117,10 @@ function processAttributeMessagesToChatMessage( * Attempts to parse the input messages from the span attributes. * @param parsedAttributes the JSON parsed span attributes * @returns an object containing the parsed {@link ChatMessage|ChatMessages} and any parsing errors + * + * NB: Only exported for testing */ -function getTemplateMessagesFromAttributes(parsedAttributes: unknown) { +export function getTemplateMessagesFromAttributes(parsedAttributes: unknown) { const inputMessages = llmInputMessageSchema.safeParse(parsedAttributes); if (!inputMessages.success) { return { @@ -107,8 +141,10 @@ function getTemplateMessagesFromAttributes(parsedAttributes: unknown) { * Attempts to get llm.output_messages then output.value from the span attributes. * @param parsedAttributes the JSON parsed span attributes * @returns an object containing the parsed output and any parsing errors + * + * NB: Only exported for testing */ -function getOutputFromAttributes(parsedAttributes: unknown) { +export function getOutputFromAttributes(parsedAttributes: unknown) { const outputParsingErrors: string[] = []; const outputMessages = llmOutputMessageSchema.safeParse(parsedAttributes); if (outputMessages.success) { @@ -161,8 +197,10 @@ export function getModelProviderFromModelName( * Attempts to get the llm.model_name, inferred provider, and invocation parameters from the span attributes. * @param parsedAttributes the JSON parsed span attributes * @returns the model config if it exists or parsing errors if it does not + * + * NB: Only exported for testing */ -function getModelConfigFromAttributes(parsedAttributes: unknown): { +export function getModelConfigFromAttributes(parsedAttributes: unknown): { modelConfig: ModelConfig | null; parsingErrors: string[]; } { diff --git a/app/src/pages/playground/schemas.ts b/app/src/pages/playground/schemas.ts index e5991126fa..04b81143e9 100644 --- a/app/src/pages/playground/schemas.ts +++ b/app/src/pages/playground/schemas.ts @@ -8,6 +8,7 @@ import { import { ChatMessage } from "@phoenix/store"; import { Mutable, schemaForType } from "@phoenix/typeUtils"; +import { safelyParseJSON } from "@phoenix/utils/jsonUtils"; import { InvocationParameters } from "./__generated__/PlaygroundOutputSubscription.graphql"; @@ -17,10 +18,15 @@ import { InvocationParameters } from "./__generated__/PlaygroundOutputSubscripti */ const toolCallSchema = z .object({ - function: z + tool_call: z .object({ - name: z.string(), - arguments: z.string(), + id: z.string().optional(), + function: z + .object({ + name: z.string(), + arguments: z.string(), + }) + .partial(), }) .partial(), }) @@ -129,10 +135,8 @@ export type InvocationParametersSchema = z.infer< const stringToInvocationParametersSchema = z .string() .transform((s) => { - let json; - try { - json = JSON.parse(s); - } catch (e) { + const { json } = safelyParseJSON(s); + if (json == null) { return {}; } // using the invocationParameterSchema as a base, @@ -157,7 +161,6 @@ const stringToInvocationParametersSchema = z ); }) .default("{}"); - /** * The zod schema for llm model config * @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions} diff --git a/app/src/store/playground/types.ts b/app/src/store/playground/types.ts index 72778d9474..0e0758fc43 100644 --- a/app/src/store/playground/types.ts +++ b/app/src/store/playground/types.ts @@ -109,7 +109,7 @@ export interface PlaygroundInstance { toolChoice: ToolChoice; input: PlaygroundInput; model: ModelConfig; - output: ChatMessage[] | undefined | string; + output?: ChatMessage[] | string; spanId: string | null; activeRunId: number | null; /**