Skip to content

Commit

Permalink
Constrain invocation params by appropriate schema
Browse files Browse the repository at this point in the history
  • Loading branch information
cephalization committed Oct 21, 2024
1 parent 403ee75 commit 4b0c2d8
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 78 deletions.
86 changes: 12 additions & 74 deletions app/src/pages/playground/InvocationParametersForm.tsx
Original file line number Diff line number Diff line change
@@ -1,76 +1,20 @@
import React from "react";
import { z } from "zod";

import { Flex, Slider, TextField } from "@arizeai/components";

import { ModelConfig } from "@phoenix/store";
import { schemaForType } from "@phoenix/typeUtils";
import { Mutable } from "@phoenix/typeUtils";

import { InvocationParameters } from "./__generated__/PlaygroundOutputSubscription.graphql";

/**
* Model invocation parameters schema in zod.
*
* Includes all keys besides toolChoice
*/
const invocationParameterSchema = schemaForType<InvocationParameters>()(
z.object({
temperature: z.number().optional(),
topP: z.number().optional(),
maxTokens: z.number().optional(),
stop: z.array(z.string()).optional(),
seed: z.number().optional(),
maxCompletionTokens: z.number().optional(),
})
);

type InvocationParametersSchema = z.infer<typeof invocationParameterSchema>;

/**
* Default set of invocation parameters for all providers and models.
*/
const baseInvocationParameterSchema = invocationParameterSchema.omit({
maxCompletionTokens: true,
});

type BaseInvocationParameters = z.infer<typeof baseInvocationParameterSchema>;

/**
* Invocation parameters for O1 models.
*/
const o1BaseInvocationParameterSchema = baseInvocationParameterSchema
.extend({
maxCompletionTokens: z.number().optional(),
})
.omit({ maxTokens: true });

/**
* Provider schemas for all models and optionally for a specific model.
*/
const providerSchemas = {
OPENAI: {
default: baseInvocationParameterSchema,
"o1-preview": o1BaseInvocationParameterSchema,
"o1-preview-2024-09-12": o1BaseInvocationParameterSchema,
"o1-mini": o1BaseInvocationParameterSchema,
"o1-mini-2024-09-12": o1BaseInvocationParameterSchema,
},
AZURE_OPENAI: {
default: baseInvocationParameterSchema,
},
ANTHROPIC: {
default: baseInvocationParameterSchema,
},
} satisfies Record<
ModelProvider,
Record<string, z.ZodType<BaseInvocationParameters>>
>;
import { getInvocationParametersSchema } from "./playgroundUtils";
import { InvocationParametersSchema } from "./schemas";

/**
* Form field for a single invocation parameter.
*
* TODO(apowell): Should this be generic over the schema field data type? There
* probably aren't enough fields for that to be worthwhile at the moment.
* TODO(apowell): Read disabled state and default values from schema and apply them
* to the input.
*/
const FormField = ({
field,
Expand Down Expand Up @@ -167,25 +111,19 @@ export const InvocationParametersForm = ({
}: InvocationParametersFormProps) => {
const { invocationParameters, provider, modelName } = model;
// Get the schema for the incoming provider and model combination.
const schema =
providerSchemas?.[provider]?.[
(modelName || "default") as keyof (typeof providerSchemas)[ModelProvider]
] ?? providerSchemas[provider].default;
const schema = getInvocationParametersSchema({
modelProvider: provider,
modelName: modelName || "default",
});

// TODO(apowell): Should we instead fail open here and display all inputs?
const valid = schema.safeParse(invocationParameters);
if (!valid.success) {
return null;
}
// Generate form fields for all invocation parameters, constrained by the schema.
const parameters = valid.data;
const fieldsForSchema = Object.keys(schema.shape).map((field) => {
const fieldKey = field as keyof typeof parameters;
const fieldKey = field as keyof (typeof schema)["shape"];
const value = invocationParameters[fieldKey];
return (
<FormField
key={fieldKey}
field={fieldKey}
value={parameters[fieldKey]}
value={value === null ? undefined : (value as Mutable<typeof value>)}
onChange={(value) => onChange(fieldKey, value)}
/>
);
Expand Down
29 changes: 26 additions & 3 deletions app/src/pages/playground/PlaygroundOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ import {
PlaygroundOutputSubscription$data,
PlaygroundOutputSubscription$variables,
} from "./__generated__/PlaygroundOutputSubscription.graphql";
import { isChatMessages } from "./playgroundUtils";
import {
getInvocationParametersSchema,
isChatMessages,
} from "./playgroundUtils";
import { RunMetadataFooter } from "./RunMetadataFooter";
import { TitleWithAlphabeticIndex } from "./TitleWithAlphabeticIndex";
import { PlaygroundInstanceProps } from "./types";
Expand Down Expand Up @@ -265,9 +268,26 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
}
: {};

const invocationParameters: InvocationParameters = {
const invocationParametersSchema = getInvocationParametersSchema({
modelProvider: instance.model.provider,
modelName: instance.model.modelName || "default",
});

// Constrain the invocation parameters to the schema.
// This prevents us from sending invalid parameters to the LLM.
let invocationParameters: InvocationParameters = {
...instance.model.invocationParameters,
};

const valid = invocationParametersSchema.safeParse(invocationParameters);
if (!valid.success) {
// TODO(apowell): We should fail open here and display all inputs.
// eslint-disable-next-line no-console
console.error(valid.error);
// Fall back to the model's invocation parameters.
invocationParameters = instance.model.invocationParameters;
}

if (instance.tools.length) {
invocationParameters["toolChoice"] = instance.toolChoice;
}
Expand Down Expand Up @@ -337,7 +357,10 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
onCompleted: () => {
markPlaygroundInstanceComplete(props.playgroundInstanceId);
},
onFailed: () => {
onFailed: (error) => {
// TODO(apowell): We should display this error to the user after formatting it nicely.
// eslint-disable-next-line no-console
console.error(error);
markPlaygroundInstanceComplete(props.playgroundInstanceId);
updateInstance({
instanceId: props.playgroundInstanceId,
Expand Down
29 changes: 29 additions & 0 deletions app/src/pages/playground/playgroundUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
MessageSchema,
modelConfigSchema,
outputSchema,
providerSchemas,
} from "./schemas";
import { PlaygroundSpan } from "./spanPlaygroundPageLoader";

Expand Down Expand Up @@ -290,3 +291,31 @@ export const extractVariablesFromInstances = ({

return Array.from(variables);
};

/**
* Gets the invocation parameters schema for a given model provider and model name.
*
* Falls back to the default schema for provider if the model name is not found.
*
* Falls back to the default schema for all providers if provider is not found.
*/
export const getInvocationParametersSchema = ({
modelProvider,
modelName,
}: {
modelProvider: ModelProvider;
modelName: string;
}) => {
const providerSupported = modelProvider in providerSchemas;
if (!providerSupported) {
return providerSchemas[DEFAULT_MODEL_PROVIDER].default;
}

const byProvider = providerSchemas[modelProvider];
const modelSupported = modelName in byProvider;
if (!modelSupported) {
return byProvider.default;
}

return byProvider[modelName as keyof typeof byProvider];
};
62 changes: 61 additions & 1 deletion app/src/pages/playground/schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import {
} from "@arizeai/openinference-semantic-conventions";

import { ChatMessage } from "@phoenix/store";
import { schemaForType } from "@phoenix/typeUtils";
import { Mutable, schemaForType } from "@phoenix/typeUtils";

import { InvocationParameters } from "./__generated__/PlaygroundOutputSubscription.graphql";

/**
* The zod schema for llm tool calls in an input message
Expand Down Expand Up @@ -112,3 +114,61 @@ export const modelConfigSchema = z.object({
.default("{}"),
}),
});

/**
* Model invocation parameters schema in zod.
*
* Includes all keys besides toolChoice
*/
const invocationParameterSchema = schemaForType<
Mutable<InvocationParameters>
>()(
z.object({
temperature: z.number().optional(),
topP: z.number().optional(),
maxTokens: z.number().optional(),
stop: z.array(z.string()).optional(),
seed: z.number().optional(),
maxCompletionTokens: z.number().optional(),
})
);

export type InvocationParametersSchema = z.infer<
typeof invocationParameterSchema
>;

/**
* Default set of invocation parameters for all providers and models.
*/
const baseInvocationParameterSchema = invocationParameterSchema.omit({
maxCompletionTokens: true,
});

/**
* Invocation parameters for O1 models.
*/
const o1BaseInvocationParameterSchema = invocationParameterSchema.pick({
maxCompletionTokens: true,
});

/**
* Provider schemas for all models and optionally for a specific model.
*/
export const providerSchemas = {
OPENAI: {
default: baseInvocationParameterSchema,
"o1-preview": o1BaseInvocationParameterSchema,
"o1-preview-2024-09-12": o1BaseInvocationParameterSchema,
"o1-mini": o1BaseInvocationParameterSchema,
"o1-mini-2024-09-12": o1BaseInvocationParameterSchema,
},
AZURE_OPENAI: {
default: baseInvocationParameterSchema,
},
ANTHROPIC: {
default: baseInvocationParameterSchema,
},
} satisfies Record<
ModelProvider,
Record<string, z.ZodType<InvocationParametersSchema>>
>;

0 comments on commit 4b0c2d8

Please sign in to comment.