From 6efc70087538265068092b6d72b2dbfd6f927688 Mon Sep 17 00:00:00 2001 From: Anthony Powell Date: Tue, 22 Oct 2024 11:28:12 -0400 Subject: [PATCH] feat: Scaffold model invocation params form (#5040) (#5045) * feat: Scaffold model invocation params form (#5040) * fix: Fix types * fix: Fix test outputs by updating model schema * Send invocation parameters with completion stream params * Generate invocation parameters form based on provider/model schema * Constrain invocation params by appropriate schema * Update playground span parsing error text * Improve commenting, delete out of date comments * Read and validate invocation parameters from span into store * Improve playground span transformation test * --amend * Safely parse invocation params schema, separately from other model config --- .../playground/InvocationParametersForm.tsx | 131 ++++++++++++++++++ .../pages/playground/ModelConfigButton.tsx | 27 +++- app/src/pages/playground/PlaygroundOutput.tsx | 30 +++- .../__tests__/playgroundUtils.test.ts | 81 ++++++++++- app/src/pages/playground/constants.tsx | 6 +- app/src/pages/playground/playgroundUtils.ts | 64 +++++++-- app/src/pages/playground/schemas.ts | 119 +++++++++++++++- app/src/store/playground/playgroundStore.tsx | 10 +- app/src/store/playground/types.ts | 2 + 9 files changed, 447 insertions(+), 23 deletions(-) create mode 100644 app/src/pages/playground/InvocationParametersForm.tsx diff --git a/app/src/pages/playground/InvocationParametersForm.tsx b/app/src/pages/playground/InvocationParametersForm.tsx new file mode 100644 index 0000000000..44d4c48636 --- /dev/null +++ b/app/src/pages/playground/InvocationParametersForm.tsx @@ -0,0 +1,131 @@ +import React from "react"; + +import { Flex, Slider, TextField } from "@arizeai/components"; + +import { ModelConfig } from "@phoenix/store"; +import { Mutable } from "@phoenix/typeUtils"; + +import { getInvocationParametersSchema } from "./playgroundUtils"; +import { InvocationParametersSchema } from "./schemas"; + +/** + * Form field for a single invocation parameter. + */ +const FormField = ({ + field, + value, + onChange, +}: { + field: keyof InvocationParametersSchema; + value: InvocationParametersSchema[keyof InvocationParametersSchema]; + onChange: ( + value: InvocationParametersSchema[keyof InvocationParametersSchema] + ) => void; +}) => { + switch (field) { + case "temperature": + if (typeof value !== "number" && value !== undefined) return null; + return ( + onChange(value)} + /> + ); + case "topP": + if (typeof value !== "number" && value !== undefined) return null; + return ( + onChange(value)} + /> + ); + case "maxCompletionTokens": + return ( + onChange(Number(value))} + /> + ); + case "maxTokens": + return ( + onChange(Number(value))} + /> + ); + case "stop": + if (!Array.isArray(value) && value !== undefined) return null; + return ( + onChange(value.split(/, */g))} + /> + ); + case "seed": + return ( + onChange(Number(value))} + /> + ); + default: + return null; + } +}; + +export type InvocationParametersChangeHandler = < + T extends keyof ModelConfig["invocationParameters"], +>( + parameter: T, + value: ModelConfig["invocationParameters"][T] +) => void; + +type InvocationParametersFormProps = { + model: ModelConfig; + onChange: InvocationParametersChangeHandler; +}; + +export const InvocationParametersForm = ({ + model, + onChange, +}: InvocationParametersFormProps) => { + const { invocationParameters, provider, modelName } = model; + // Get the schema for the incoming provider and model combination. + const schema = getInvocationParametersSchema({ + modelProvider: provider, + modelName: modelName || "default", + }); + + const fieldsForSchema = Object.keys(schema.shape).map((field) => { + const fieldKey = field as keyof (typeof schema)["shape"]; + const value = invocationParameters[fieldKey]; + return ( + )} + onChange={(value) => onChange(fieldKey, value)} + /> + ); + }); + return ( + + {fieldsForSchema} + + ); +}; diff --git a/app/src/pages/playground/ModelConfigButton.tsx b/app/src/pages/playground/ModelConfigButton.tsx index 7a64b8ac6d..3f147e0030 100644 --- a/app/src/pages/playground/ModelConfigButton.tsx +++ b/app/src/pages/playground/ModelConfigButton.tsx @@ -29,6 +29,10 @@ import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext"; import { PlaygroundInstance } from "@phoenix/store"; import { ModelConfigButtonDialogQuery } from "./__generated__/ModelConfigButtonDialogQuery.graphql"; +import { + InvocationParametersChangeHandler, + InvocationParametersForm, +} from "./InvocationParametersForm"; import { ModelPicker } from "./ModelPicker"; import { ModelProviderPicker } from "./ModelProviderPicker"; import { PlaygroundInstanceProps } from "./types"; @@ -187,8 +191,25 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) { [instance.model.provider, playgroundInstanceId, updateModel] ); + const onInvocationParametersChange: InvocationParametersChangeHandler = + useCallback( + (parameter, value) => { + updateModel({ + instanceId: playgroundInstanceId, + model: { + ...instance.model, + invocationParameters: { + ...instance.model.invocationParameters, + [parameter]: value, + }, + }, + }); + }, + [instance.model, playgroundInstanceId, updateModel] + ); + return ( - +
)} +
); diff --git a/app/src/pages/playground/PlaygroundOutput.tsx b/app/src/pages/playground/PlaygroundOutput.tsx index bb0a000145..9e9bd0eba0 100644 --- a/app/src/pages/playground/PlaygroundOutput.tsx +++ b/app/src/pages/playground/PlaygroundOutput.tsx @@ -21,7 +21,10 @@ import { PlaygroundOutputSubscription$data, PlaygroundOutputSubscription$variables, } from "./__generated__/PlaygroundOutputSubscription.graphql"; -import { isChatMessages } from "./playgroundUtils"; +import { + getInvocationParametersSchema, + isChatMessages, +} from "./playgroundUtils"; import { RunMetadataFooter } from "./RunMetadataFooter"; import { TitleWithAlphabeticIndex } from "./TitleWithAlphabeticIndex"; import { PlaygroundInstanceProps } from "./types"; @@ -269,7 +272,25 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) { } : {}; - const invocationParameters: InvocationParameters = {}; + const invocationParametersSchema = getInvocationParametersSchema({ + modelProvider: instance.model.provider, + modelName: instance.model.modelName || "default", + }); + + let invocationParameters: InvocationParameters = { + ...instance.model.invocationParameters, + }; + + // Constrain the invocation parameters to the schema. + // This prevents us from sending invalid parameters to the LLM since we may be + // storing parameters from previously selected models/providers within this instance. + const valid = invocationParametersSchema.safeParse(invocationParameters); + if (!valid.success) { + // If we cannot successfully parse the invocation parameters, just send them + // all and let the API fail if they are invalid. + invocationParameters = instance.model.invocationParameters; + } + if (instance.tools.length) { invocationParameters["toolChoice"] = instance.toolChoice; } @@ -339,7 +360,10 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) { onCompleted: () => { markPlaygroundInstanceComplete(props.playgroundInstanceId); }, - onFailed: () => { + onFailed: (error) => { + // TODO(apowell): We should display this error to the user after formatting it nicely. + // eslint-disable-next-line no-console + console.error(error); markPlaygroundInstanceComplete(props.playgroundInstanceId); updateInstance({ instanceId: props.playgroundInstanceId, diff --git a/app/src/pages/playground/__tests__/playgroundUtils.test.ts b/app/src/pages/playground/__tests__/playgroundUtils.test.ts index 4c3cc82e05..d93f8a8a71 100644 --- a/app/src/pages/playground/__tests__/playgroundUtils.test.ts +++ b/app/src/pages/playground/__tests__/playgroundUtils.test.ts @@ -7,7 +7,8 @@ import { import { INPUT_MESSAGES_PARSING_ERROR, - MODEL_NAME_PARSING_ERROR, + MODEL_CONFIG_PARSING_ERROR, + MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR, OUTPUT_MESSAGES_PARSING_ERROR, OUTPUT_VALUE_PARSING_ERROR, SPAN_ATTRIBUTES_PARSING_ERROR, @@ -30,6 +31,7 @@ const expectedPlaygroundInstanceWithIO: PlaygroundInstance = { model: { provider: "OPENAI", modelName: "gpt-3.5-turbo", + invocationParameters: {}, }, input: { variablesValueCache: {} }, tools: [], @@ -79,6 +81,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { model: { provider: "OPENAI", modelName: "gpt-4o", + invocationParameters: {}, }, template: defaultTemplate, output: undefined, @@ -96,6 +99,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { playgroundInstance: { ...expectedPlaygroundInstanceWithIO, model: { + ...expectedPlaygroundInstanceWithIO.model, provider: "OPENAI", modelName: "gpt-4o", }, @@ -107,7 +111,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { INPUT_MESSAGES_PARSING_ERROR, OUTPUT_MESSAGES_PARSING_ERROR, OUTPUT_VALUE_PARSING_ERROR, - MODEL_NAME_PARSING_ERROR, + MODEL_CONFIG_PARSING_ERROR, ], }); }); @@ -200,6 +204,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { playgroundInstance: { ...expectedPlaygroundInstanceWithIO, model: { + ...expectedPlaygroundInstanceWithIO.model, provider: "OPENAI", modelName: "gpt-4o", }, @@ -251,6 +256,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { playgroundInstance: { ...expectedPlaygroundInstanceWithIO, model: { + ...expectedPlaygroundInstanceWithIO.model, provider: "OPENAI", modelName: "gpt-3.5-turbo", }, @@ -270,6 +276,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { playgroundInstance: { ...expectedPlaygroundInstanceWithIO, model: { + ...expectedPlaygroundInstanceWithIO.model, provider: "ANTHROPIC", modelName: "claude-3-5-sonnet-20240620", }, @@ -289,6 +296,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { playgroundInstance: { ...expectedPlaygroundInstanceWithIO, model: { + ...expectedPlaygroundInstanceWithIO.model, provider: DEFAULT_MODEL_PROVIDER, modelName: "test-my-deployment", }, @@ -296,6 +304,75 @@ describe("transformSpanAttributesToPlaygroundInstance", () => { parsingErrors: [], }); }); + + it("should correctly parse the invocation parameters", () => { + const span = { + ...basePlaygroundSpan, + attributes: JSON.stringify({ + ...spanAttributesWithInputMessages, + llm: { + ...spanAttributesWithInputMessages.llm, + // note that snake case keys are automatically converted to camel case + invocation_parameters: + '{"top_p": 0.5, "max_tokens": 100, "seed": 12345, "stop": ["stop", "me"]}', + }, + }), + }; + expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({ + playgroundInstance: { + ...expectedPlaygroundInstanceWithIO, + model: { + ...expectedPlaygroundInstanceWithIO.model, + invocationParameters: { + topP: 0.5, + maxTokens: 100, + seed: 12345, + stop: ["stop", "me"], + }, + }, + }, + parsingErrors: [], + }); + }); + + it("should still parse the model name and provider even if invocation parameters are malformed", () => { + const span = { + ...basePlaygroundSpan, + attributes: JSON.stringify({ + ...spanAttributesWithInputMessages, + llm: { + ...spanAttributesWithInputMessages.llm, + invocation_parameters: "invalid json", + }, + }), + }; + expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({ + playgroundInstance: { + ...expectedPlaygroundInstanceWithIO, + }, + parsingErrors: [], + }); + }); + + it("should return invocation parameters parsing errors if the invocation parameters are the wrong type", () => { + const span = { + ...basePlaygroundSpan, + attributes: JSON.stringify({ + ...spanAttributesWithInputMessages, + llm: { + ...spanAttributesWithInputMessages.llm, + invocation_parameters: null, + }, + }), + }; + + expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({ + playgroundInstance: { + ...expectedPlaygroundInstanceWithIO, + }, + parsingErrors: [MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR], + }); + }); }); describe("getChatRole", () => { diff --git a/app/src/pages/playground/constants.tsx b/app/src/pages/playground/constants.tsx index 90be404948..5adaf3da81 100644 --- a/app/src/pages/playground/constants.tsx +++ b/app/src/pages/playground/constants.tsx @@ -22,8 +22,10 @@ export const OUTPUT_VALUE_PARSING_ERROR = "Unable to parse span output expected output.value to be present."; export const SPAN_ATTRIBUTES_PARSING_ERROR = "Unable to parse span attributes, attributes must be valid JSON."; -export const MODEL_NAME_PARSING_ERROR = - "Unable to parse model name, expected llm.model_name to be present."; +export const MODEL_CONFIG_PARSING_ERROR = + "Unable to parse model config, expected llm.model_name to be present."; +export const MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR = + "Unable to parse model config, expected llm.invocation_parameters json string to be present."; export const modelProviderToModelPrefixMap: Record = { AZURE_OPENAI: [], diff --git a/app/src/pages/playground/playgroundUtils.ts b/app/src/pages/playground/playgroundUtils.ts index 7063ed6b6a..a2b888b291 100644 --- a/app/src/pages/playground/playgroundUtils.ts +++ b/app/src/pages/playground/playgroundUtils.ts @@ -16,7 +16,8 @@ import { safelyParseJSON } from "@phoenix/utils/jsonUtils"; import { ChatRoleMap, INPUT_MESSAGES_PARSING_ERROR, - MODEL_NAME_PARSING_ERROR, + MODEL_CONFIG_PARSING_ERROR, + MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR, modelProviderToModelPrefixMap, OUTPUT_MESSAGES_PARSING_ERROR, OUTPUT_VALUE_PARSING_ERROR, @@ -28,8 +29,10 @@ import { llmInputMessageSchema, llmOutputMessageSchema, MessageSchema, - modelNameSchema, + modelConfigSchema, + modelConfigWithInvocationParametersSchema, outputSchema, + providerSchemas, } from "./schemas"; import { PlaygroundSpan } from "./spanPlaygroundPageLoader"; @@ -155,26 +158,37 @@ export function getModelProviderFromModelName( } /** - * Attempts to get the llm.model_name and inferred provider from the span attributes. + * Attempts to get the llm.model_name, inferred provider, and invocation parameters from the span attributes. * @param parsedAttributes the JSON parsed span attributes * @returns the model config if it exists or parsing errors if it does not */ -function getModelConfigFromAttributes( - parsedAttributes: unknown -): - | { modelConfig: ModelConfig; parsingErrors: never[] } - | { modelConfig: null; parsingErrors: string[] } { - const { success, data } = modelNameSchema.safeParse(parsedAttributes); +function getModelConfigFromAttributes(parsedAttributes: unknown): { + modelConfig: ModelConfig | null; + parsingErrors: string[]; +} { + const { success, data } = modelConfigSchema.safeParse(parsedAttributes); if (success) { + // parse invocation params separately, to avoid throwing away other model config if invocation params are invalid + const { + success: invocationParametersSuccess, + data: invocationParametersData, + } = modelConfigWithInvocationParametersSchema.safeParse(parsedAttributes); + const parsingErrors: string[] = []; + if (!invocationParametersSuccess) { + parsingErrors.push(MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR); + } return { modelConfig: { modelName: data.llm.model_name, provider: getModelProviderFromModelName(data.llm.model_name), + invocationParameters: invocationParametersSuccess + ? invocationParametersData.llm.invocation_parameters + : {}, }, - parsingErrors: [], + parsingErrors, }; } - return { modelConfig: null, parsingErrors: [MODEL_NAME_PARSING_ERROR] }; + return { modelConfig: null, parsingErrors: [MODEL_CONFIG_PARSING_ERROR] }; } /** @@ -289,3 +303,31 @@ export const extractVariablesFromInstances = ({ return Array.from(variables); }; + +/** + * Gets the invocation parameters schema for a given model provider and model name. + * + * Falls back to the default schema for provider if the model name is not found. + * + * Falls back to the default schema for all providers if provider is not found. + */ +export const getInvocationParametersSchema = ({ + modelProvider, + modelName, +}: { + modelProvider: ModelProvider; + modelName: string; +}) => { + const providerSupported = modelProvider in providerSchemas; + if (!providerSupported) { + return providerSchemas[DEFAULT_MODEL_PROVIDER].default; + } + + const byProvider = providerSchemas[modelProvider]; + const modelSupported = modelName in byProvider; + if (!modelSupported) { + return byProvider.default; + } + + return byProvider[modelName as keyof typeof byProvider]; +}; diff --git a/app/src/pages/playground/schemas.ts b/app/src/pages/playground/schemas.ts index cb650ab9cf..e5991126fa 100644 --- a/app/src/pages/playground/schemas.ts +++ b/app/src/pages/playground/schemas.ts @@ -7,7 +7,9 @@ import { } from "@arizeai/openinference-semantic-conventions"; import { ChatMessage } from "@phoenix/store"; -import { schemaForType } from "@phoenix/typeUtils"; +import { Mutable, schemaForType } from "@phoenix/typeUtils"; + +import { InvocationParameters } from "./__generated__/PlaygroundOutputSubscription.graphql"; /** * The zod schema for llm tool calls in an input message @@ -93,11 +95,122 @@ const chatMessageSchema = schemaForType()( export const chatMessagesSchema = z.array(chatMessageSchema); /** - * The zod schema for llm model name + * Model graphql invocation parameters schema in zod. + * + * Includes all keys besides toolChoice + */ +const invocationParameterSchema = schemaForType< + Mutable +>()( + z.object({ + temperature: z.coerce.number().optional(), + topP: z.coerce.number().optional(), + maxTokens: z.coerce.number().optional(), + stop: z.array(z.string()).optional(), + seed: z.coerce.number().optional(), + maxCompletionTokens: z.coerce.number().optional(), + }) +); + +/** + * The type of the invocation parameters schema + */ +export type InvocationParametersSchema = z.infer< + typeof invocationParameterSchema +>; + +/** + * Transform a string to an invocation parameters schema. + * + * If the string is not valid JSON, return an empty object. + * If the string is valid JSON, but does not match the invocation parameters schema, + * map the snake cased keys to camel case and return the result. + */ +const stringToInvocationParametersSchema = z + .string() + .transform((s) => { + let json; + try { + json = JSON.parse(s); + } catch (e) { + return {}; + } + // using the invocationParameterSchema as a base, + // apply all matching keys from the input string, + // and then map snake cased keys to camel case on top + return ( + invocationParameterSchema + .passthrough() + .transform((o) => ({ + ...o, + // map snake cased keys to camel case, the first char after each _ is uppercase + ...Object.fromEntries( + Object.entries(o).map(([k, v]) => [ + k.replace(/_([a-z])/g, (_, char) => char.toUpperCase()), + v, + ]) + ), + })) + // reparse the object to ensure the mapped keys are also validated + .transform(invocationParameterSchema.parse) + .parse(json) + ); + }) + .default("{}"); + +/** + * The zod schema for llm model config * @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions} */ -export const modelNameSchema = z.object({ +export const modelConfigSchema = z.object({ [SemanticAttributePrefixes.llm]: z.object({ [LLMAttributePostfixes.model_name]: z.string(), }), }); + +/** + * The zod schema for llm.invocation_parameters attributes + * @see {@link https://github.com/Arize-ai/openinference/blob/main/spec/semantic_conventions.md|Semantic Conventions} + */ +export const modelConfigWithInvocationParametersSchema = z.object({ + [SemanticAttributePrefixes.llm]: z.object({ + [LLMAttributePostfixes.invocation_parameters]: + stringToInvocationParametersSchema, + }), +}); + +/** + * Default set of invocation parameters for all providers and models. + */ +const baseInvocationParameterSchema = invocationParameterSchema.omit({ + maxCompletionTokens: true, +}); + +/** + * Invocation parameters for O1 models. + */ +const o1BaseInvocationParameterSchema = invocationParameterSchema.pick({ + maxCompletionTokens: true, +}); + +/** + * Provider schemas for all models and optionally for a specific model. + */ +export const providerSchemas = { + OPENAI: { + default: baseInvocationParameterSchema, + "o1-preview": o1BaseInvocationParameterSchema, + "o1-preview-2024-09-12": o1BaseInvocationParameterSchema, + "o1-mini": o1BaseInvocationParameterSchema, + "o1-mini-2024-09-12": o1BaseInvocationParameterSchema, + }, + AZURE_OPENAI: { + default: baseInvocationParameterSchema, + }, + ANTHROPIC: { + default: baseInvocationParameterSchema, + }, +} satisfies Record< + ModelProvider, + Record> +>; diff --git a/app/src/store/playground/playgroundStore.tsx b/app/src/store/playground/playgroundStore.tsx index 7d3712f748..8d3775b7e4 100644 --- a/app/src/store/playground/playgroundStore.tsx +++ b/app/src/store/playground/playgroundStore.tsx @@ -94,7 +94,11 @@ export function createPlaygroundInstance(): PlaygroundInstance { return { id: generateInstanceId(), template: generateChatCompletionTemplate(), - model: { provider: DEFAULT_MODEL_PROVIDER, modelName: "gpt-4o" }, + model: { + provider: DEFAULT_MODEL_PROVIDER, + modelName: "gpt-4o", + invocationParameters: {}, + }, tools: [], // Default to auto tool choice as you are probably testing the LLM for it's ability to pick toolChoice: "auto", @@ -221,6 +225,10 @@ export const createPlaygroundStore = ( model: { ...instance.model, ...model, + invocationParameters: { + ...instance.model.invocationParameters, + ...model.invocationParameters, + }, }, }; } diff --git a/app/src/store/playground/types.ts b/app/src/store/playground/types.ts index 4210941b6d..72778d9474 100644 --- a/app/src/store/playground/types.ts +++ b/app/src/store/playground/types.ts @@ -1,4 +1,5 @@ import { TemplateLanguage } from "@phoenix/components/templateEditor/types"; +import { InvocationParameters } from "@phoenix/pages/playground/__generated__/PlaygroundOutputSubscription.graphql"; import { OpenAIToolCall, OpenAIToolDefinition } from "@phoenix/schemas"; export type GenAIOperationType = "chat" | "text_completion"; @@ -76,6 +77,7 @@ export type ModelConfig = { modelName: string | null; endpoint?: string | null; apiVersion?: string | null; + invocationParameters: Partial>; }; /**