From 4b0c2d8c96a6fa4db92028df33d608fe113854d2 Mon Sep 17 00:00:00 2001 From: Tony Powell Date: Fri, 18 Oct 2024 14:23:23 -0400 Subject: [PATCH] Constrain invocation params by appropriate schema --- .../playground/InvocationParametersForm.tsx | 86 +++---------------- app/src/pages/playground/PlaygroundOutput.tsx | 29 ++++++- app/src/pages/playground/playgroundUtils.ts | 29 +++++++ app/src/pages/playground/schemas.ts | 62 ++++++++++++- 4 files changed, 128 insertions(+), 78 deletions(-) diff --git a/app/src/pages/playground/InvocationParametersForm.tsx b/app/src/pages/playground/InvocationParametersForm.tsx index 3ca04dd32e..057b10bf66 100644 --- a/app/src/pages/playground/InvocationParametersForm.tsx +++ b/app/src/pages/playground/InvocationParametersForm.tsx @@ -1,76 +1,20 @@ import React from "react"; -import { z } from "zod"; import { Flex, Slider, TextField } from "@arizeai/components"; import { ModelConfig } from "@phoenix/store"; -import { schemaForType } from "@phoenix/typeUtils"; +import { Mutable } from "@phoenix/typeUtils"; -import { InvocationParameters } from "./__generated__/PlaygroundOutputSubscription.graphql"; - -/** - * Model invocation parameters schema in zod. - * - * Includes all keys besides toolChoice - */ -const invocationParameterSchema = schemaForType()( - z.object({ - temperature: z.number().optional(), - topP: z.number().optional(), - maxTokens: z.number().optional(), - stop: z.array(z.string()).optional(), - seed: z.number().optional(), - maxCompletionTokens: z.number().optional(), - }) -); - -type InvocationParametersSchema = z.infer; - -/** - * Default set of invocation parameters for all providers and models. - */ -const baseInvocationParameterSchema = invocationParameterSchema.omit({ - maxCompletionTokens: true, -}); - -type BaseInvocationParameters = z.infer; - -/** - * Invocation parameters for O1 models. - */ -const o1BaseInvocationParameterSchema = baseInvocationParameterSchema - .extend({ - maxCompletionTokens: z.number().optional(), - }) - .omit({ maxTokens: true }); - -/** - * Provider schemas for all models and optionally for a specific model. - */ -const providerSchemas = { - OPENAI: { - default: baseInvocationParameterSchema, - "o1-preview": o1BaseInvocationParameterSchema, - "o1-preview-2024-09-12": o1BaseInvocationParameterSchema, - "o1-mini": o1BaseInvocationParameterSchema, - "o1-mini-2024-09-12": o1BaseInvocationParameterSchema, - }, - AZURE_OPENAI: { - default: baseInvocationParameterSchema, - }, - ANTHROPIC: { - default: baseInvocationParameterSchema, - }, -} satisfies Record< - ModelProvider, - Record> ->; +import { getInvocationParametersSchema } from "./playgroundUtils"; +import { InvocationParametersSchema } from "./schemas"; /** * Form field for a single invocation parameter. * * TODO(apowell): Should this be generic over the schema field data type? There * probably aren't enough fields for that to be worthwhile at the moment. + * TODO(apowell): Read disabled state and default values from schema and apply them + * to the input. */ const FormField = ({ field, @@ -167,25 +111,19 @@ export const InvocationParametersForm = ({ }: InvocationParametersFormProps) => { const { invocationParameters, provider, modelName } = model; // Get the schema for the incoming provider and model combination. - const schema = - providerSchemas?.[provider]?.[ - (modelName || "default") as keyof (typeof providerSchemas)[ModelProvider] - ] ?? providerSchemas[provider].default; + const schema = getInvocationParametersSchema({ + modelProvider: provider, + modelName: modelName || "default", + }); - // TODO(apowell): Should we instead fail open here and display all inputs? - const valid = schema.safeParse(invocationParameters); - if (!valid.success) { - return null; - } - // Generate form fields for all invocation parameters, constrained by the schema. - const parameters = valid.data; const fieldsForSchema = Object.keys(schema.shape).map((field) => { - const fieldKey = field as keyof typeof parameters; + const fieldKey = field as keyof (typeof schema)["shape"]; + const value = invocationParameters[fieldKey]; return ( )} onChange={(value) => onChange(fieldKey, value)} /> ); diff --git a/app/src/pages/playground/PlaygroundOutput.tsx b/app/src/pages/playground/PlaygroundOutput.tsx index 5520bdee26..6a55249eec 100644 --- a/app/src/pages/playground/PlaygroundOutput.tsx +++ b/app/src/pages/playground/PlaygroundOutput.tsx @@ -21,7 +21,10 @@ import { PlaygroundOutputSubscription$data, PlaygroundOutputSubscription$variables, } from "./__generated__/PlaygroundOutputSubscription.graphql"; -import { isChatMessages } from "./playgroundUtils"; +import { + getInvocationParametersSchema, + isChatMessages, +} from "./playgroundUtils"; import { RunMetadataFooter } from "./RunMetadataFooter"; import { TitleWithAlphabeticIndex } from "./TitleWithAlphabeticIndex"; import { PlaygroundInstanceProps } from "./types"; @@ -265,9 +268,26 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) { } : {}; - const invocationParameters: InvocationParameters = { + const invocationParametersSchema = getInvocationParametersSchema({ + modelProvider: instance.model.provider, + modelName: instance.model.modelName || "default", + }); + + // Constrain the invocation parameters to the schema. + // This prevents us from sending invalid parameters to the LLM. + let invocationParameters: InvocationParameters = { ...instance.model.invocationParameters, }; + + const valid = invocationParametersSchema.safeParse(invocationParameters); + if (!valid.success) { + // TODO(apowell): We should fail open here and display all inputs. + // eslint-disable-next-line no-console + console.error(valid.error); + // Fall back to the model's invocation parameters. + invocationParameters = instance.model.invocationParameters; + } + if (instance.tools.length) { invocationParameters["toolChoice"] = instance.toolChoice; } @@ -337,7 +357,10 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) { onCompleted: () => { markPlaygroundInstanceComplete(props.playgroundInstanceId); }, - onFailed: () => { + onFailed: (error) => { + // TODO(apowell): We should display this error to the user after formatting it nicely. + // eslint-disable-next-line no-console + console.error(error); markPlaygroundInstanceComplete(props.playgroundInstanceId); updateInstance({ instanceId: props.playgroundInstanceId, diff --git a/app/src/pages/playground/playgroundUtils.ts b/app/src/pages/playground/playgroundUtils.ts index 5ac2803c5e..70298fd324 100644 --- a/app/src/pages/playground/playgroundUtils.ts +++ b/app/src/pages/playground/playgroundUtils.ts @@ -30,6 +30,7 @@ import { MessageSchema, modelConfigSchema, outputSchema, + providerSchemas, } from "./schemas"; import { PlaygroundSpan } from "./spanPlaygroundPageLoader"; @@ -290,3 +291,31 @@ export const extractVariablesFromInstances = ({ return Array.from(variables); }; + +/** + * Gets the invocation parameters schema for a given model provider and model name. + * + * Falls back to the default schema for provider if the model name is not found. + * + * Falls back to the default schema for all providers if provider is not found. + */ +export const getInvocationParametersSchema = ({ + modelProvider, + modelName, +}: { + modelProvider: ModelProvider; + modelName: string; +}) => { + const providerSupported = modelProvider in providerSchemas; + if (!providerSupported) { + return providerSchemas[DEFAULT_MODEL_PROVIDER].default; + } + + const byProvider = providerSchemas[modelProvider]; + const modelSupported = modelName in byProvider; + if (!modelSupported) { + return byProvider.default; + } + + return byProvider[modelName as keyof typeof byProvider]; +}; diff --git a/app/src/pages/playground/schemas.ts b/app/src/pages/playground/schemas.ts index 087c975d9e..c03f84aa40 100644 --- a/app/src/pages/playground/schemas.ts +++ b/app/src/pages/playground/schemas.ts @@ -7,7 +7,9 @@ import { } from "@arizeai/openinference-semantic-conventions"; import { ChatMessage } from "@phoenix/store"; -import { schemaForType } from "@phoenix/typeUtils"; +import { Mutable, schemaForType } from "@phoenix/typeUtils"; + +import { InvocationParameters } from "./__generated__/PlaygroundOutputSubscription.graphql"; /** * The zod schema for llm tool calls in an input message @@ -112,3 +114,61 @@ export const modelConfigSchema = z.object({ .default("{}"), }), }); + +/** + * Model invocation parameters schema in zod. + * + * Includes all keys besides toolChoice + */ +const invocationParameterSchema = schemaForType< + Mutable +>()( + z.object({ + temperature: z.number().optional(), + topP: z.number().optional(), + maxTokens: z.number().optional(), + stop: z.array(z.string()).optional(), + seed: z.number().optional(), + maxCompletionTokens: z.number().optional(), + }) +); + +export type InvocationParametersSchema = z.infer< + typeof invocationParameterSchema +>; + +/** + * Default set of invocation parameters for all providers and models. + */ +const baseInvocationParameterSchema = invocationParameterSchema.omit({ + maxCompletionTokens: true, +}); + +/** + * Invocation parameters for O1 models. + */ +const o1BaseInvocationParameterSchema = invocationParameterSchema.pick({ + maxCompletionTokens: true, +}); + +/** + * Provider schemas for all models and optionally for a specific model. + */ +export const providerSchemas = { + OPENAI: { + default: baseInvocationParameterSchema, + "o1-preview": o1BaseInvocationParameterSchema, + "o1-preview-2024-09-12": o1BaseInvocationParameterSchema, + "o1-mini": o1BaseInvocationParameterSchema, + "o1-mini-2024-09-12": o1BaseInvocationParameterSchema, + }, + AZURE_OPENAI: { + default: baseInvocationParameterSchema, + }, + ANTHROPIC: { + default: baseInvocationParameterSchema, + }, +} satisfies Record< + ModelProvider, + Record> +>;