Skip to content

Commit

Permalink
feat: Scaffold model invocation params form (#5040) (#5045)
Browse files Browse the repository at this point in the history
* feat: Scaffold model invocation params form (#5040)

* fix: Fix types

* fix: Fix test outputs by updating model schema

* Send invocation parameters with completion stream params

* Generate invocation parameters form based on provider/model schema

* Constrain invocation params by appropriate schema

* Update playground span parsing error text

* Improve commenting, delete out of date comments

* Read and validate invocation parameters from span into store

* Improve playground span transformation test

* --amend

* Safely parse invocation params schema, separately from other model config
  • Loading branch information
cephalization authored Oct 22, 2024
1 parent b632b68 commit 6efc700
Show file tree
Hide file tree
Showing 9 changed files with 447 additions and 23 deletions.
131 changes: 131 additions & 0 deletions app/src/pages/playground/InvocationParametersForm.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import React from "react";

import { Flex, Slider, TextField } from "@arizeai/components";

import { ModelConfig } from "@phoenix/store";
import { Mutable } from "@phoenix/typeUtils";

import { getInvocationParametersSchema } from "./playgroundUtils";
import { InvocationParametersSchema } from "./schemas";

/**
* Form field for a single invocation parameter.
*/
const FormField = ({
field,
value,
onChange,
}: {
field: keyof InvocationParametersSchema;
value: InvocationParametersSchema[keyof InvocationParametersSchema];
onChange: (
value: InvocationParametersSchema[keyof InvocationParametersSchema]
) => void;
}) => {
switch (field) {
case "temperature":
if (typeof value !== "number" && value !== undefined) return null;
return (
<Slider
label="Temperature"
value={value}
step={0.1}
minValue={0}
maxValue={2}
onChange={(value) => onChange(value)}
/>
);
case "topP":
if (typeof value !== "number" && value !== undefined) return null;
return (
<Slider
label="Top P"
value={value}
step={0.1}
minValue={0}
maxValue={1}
onChange={(value) => onChange(value)}
/>
);
case "maxCompletionTokens":
return (
<TextField
label="Max Completion Tokens"
value={value?.toString() || ""}
type="number"
onChange={(value) => onChange(Number(value))}
/>
);
case "maxTokens":
return (
<TextField
label="Max Tokens"
value={value?.toString() || ""}
type="number"
onChange={(value) => onChange(Number(value))}
/>
);
case "stop":
if (!Array.isArray(value) && value !== undefined) return null;
return (
<TextField
label="Stop"
defaultValue={value?.join(", ") || ""}
onChange={(value) => onChange(value.split(/, */g))}
/>
);
case "seed":
return (
<TextField
label="Seed"
value={value?.toString() || ""}
type="number"
onChange={(value) => onChange(Number(value))}
/>
);
default:
return null;
}
};

export type InvocationParametersChangeHandler = <
T extends keyof ModelConfig["invocationParameters"],
>(
parameter: T,
value: ModelConfig["invocationParameters"][T]
) => void;

type InvocationParametersFormProps = {
model: ModelConfig;
onChange: InvocationParametersChangeHandler;
};

export const InvocationParametersForm = ({
model,
onChange,
}: InvocationParametersFormProps) => {
const { invocationParameters, provider, modelName } = model;
// Get the schema for the incoming provider and model combination.
const schema = getInvocationParametersSchema({
modelProvider: provider,
modelName: modelName || "default",
});

const fieldsForSchema = Object.keys(schema.shape).map((field) => {
const fieldKey = field as keyof (typeof schema)["shape"];
const value = invocationParameters[fieldKey];
return (
<FormField
key={fieldKey}
field={fieldKey}
value={value === null ? undefined : (value as Mutable<typeof value>)}
onChange={(value) => onChange(fieldKey, value)}
/>
);
});
return (
<Flex direction="column" gap="size-200">
{fieldsForSchema}
</Flex>
);
};
27 changes: 26 additions & 1 deletion app/src/pages/playground/ModelConfigButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";
import { PlaygroundInstance } from "@phoenix/store";

import { ModelConfigButtonDialogQuery } from "./__generated__/ModelConfigButtonDialogQuery.graphql";
import {
InvocationParametersChangeHandler,
InvocationParametersForm,
} from "./InvocationParametersForm";
import { ModelPicker } from "./ModelPicker";
import { ModelProviderPicker } from "./ModelProviderPicker";
import { PlaygroundInstanceProps } from "./types";
Expand Down Expand Up @@ -187,8 +191,25 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
[instance.model.provider, playgroundInstanceId, updateModel]
);

const onInvocationParametersChange: InvocationParametersChangeHandler =
useCallback(
(parameter, value) => {
updateModel({
instanceId: playgroundInstanceId,
model: {
...instance.model,
invocationParameters: {
...instance.model.invocationParameters,
[parameter]: value,
},
},
});
},
[instance.model, playgroundInstanceId, updateModel]
);

return (
<View padding="size-200">
<View padding="size-200" overflow="auto">
<Form>
<ModelProviderPicker
provider={instance.model.provider}
Expand All @@ -213,6 +234,10 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
onChange={onModelNameChange}
/>
)}
<InvocationParametersForm
model={instance.model}
onChange={onInvocationParametersChange}
/>
</Form>
</View>
);
Expand Down
30 changes: 27 additions & 3 deletions app/src/pages/playground/PlaygroundOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ import {
PlaygroundOutputSubscription$data,
PlaygroundOutputSubscription$variables,
} from "./__generated__/PlaygroundOutputSubscription.graphql";
import { isChatMessages } from "./playgroundUtils";
import {
getInvocationParametersSchema,
isChatMessages,
} from "./playgroundUtils";
import { RunMetadataFooter } from "./RunMetadataFooter";
import { TitleWithAlphabeticIndex } from "./TitleWithAlphabeticIndex";
import { PlaygroundInstanceProps } from "./types";
Expand Down Expand Up @@ -269,7 +272,25 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
}
: {};

const invocationParameters: InvocationParameters = {};
const invocationParametersSchema = getInvocationParametersSchema({
modelProvider: instance.model.provider,
modelName: instance.model.modelName || "default",
});

let invocationParameters: InvocationParameters = {
...instance.model.invocationParameters,
};

// Constrain the invocation parameters to the schema.
// This prevents us from sending invalid parameters to the LLM since we may be
// storing parameters from previously selected models/providers within this instance.
const valid = invocationParametersSchema.safeParse(invocationParameters);
if (!valid.success) {
// If we cannot successfully parse the invocation parameters, just send them
// all and let the API fail if they are invalid.
invocationParameters = instance.model.invocationParameters;
}

if (instance.tools.length) {
invocationParameters["toolChoice"] = instance.toolChoice;
}
Expand Down Expand Up @@ -339,7 +360,10 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
onCompleted: () => {
markPlaygroundInstanceComplete(props.playgroundInstanceId);
},
onFailed: () => {
onFailed: (error) => {
// TODO(apowell): We should display this error to the user after formatting it nicely.
// eslint-disable-next-line no-console
console.error(error);
markPlaygroundInstanceComplete(props.playgroundInstanceId);
updateInstance({
instanceId: props.playgroundInstanceId,
Expand Down
81 changes: 79 additions & 2 deletions app/src/pages/playground/__tests__/playgroundUtils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import {

import {
INPUT_MESSAGES_PARSING_ERROR,
MODEL_NAME_PARSING_ERROR,
MODEL_CONFIG_PARSING_ERROR,
MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR,
OUTPUT_MESSAGES_PARSING_ERROR,
OUTPUT_VALUE_PARSING_ERROR,
SPAN_ATTRIBUTES_PARSING_ERROR,
Expand All @@ -30,6 +31,7 @@ const expectedPlaygroundInstanceWithIO: PlaygroundInstance = {
model: {
provider: "OPENAI",
modelName: "gpt-3.5-turbo",
invocationParameters: {},
},
input: { variablesValueCache: {} },
tools: [],
Expand Down Expand Up @@ -79,6 +81,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
model: {
provider: "OPENAI",
modelName: "gpt-4o",
invocationParameters: {},
},
template: defaultTemplate,
output: undefined,
Expand All @@ -96,6 +99,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
...expectedPlaygroundInstanceWithIO.model,
provider: "OPENAI",
modelName: "gpt-4o",
},
Expand All @@ -107,7 +111,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
INPUT_MESSAGES_PARSING_ERROR,
OUTPUT_MESSAGES_PARSING_ERROR,
OUTPUT_VALUE_PARSING_ERROR,
MODEL_NAME_PARSING_ERROR,
MODEL_CONFIG_PARSING_ERROR,
],
});
});
Expand Down Expand Up @@ -200,6 +204,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
...expectedPlaygroundInstanceWithIO.model,
provider: "OPENAI",
modelName: "gpt-4o",
},
Expand Down Expand Up @@ -251,6 +256,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
...expectedPlaygroundInstanceWithIO.model,
provider: "OPENAI",
modelName: "gpt-3.5-turbo",
},
Expand All @@ -270,6 +276,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
...expectedPlaygroundInstanceWithIO.model,
provider: "ANTHROPIC",
modelName: "claude-3-5-sonnet-20240620",
},
Expand All @@ -289,13 +296,83 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
...expectedPlaygroundInstanceWithIO.model,
provider: DEFAULT_MODEL_PROVIDER,
modelName: "test-my-deployment",
},
},
parsingErrors: [],
});
});

it("should correctly parse the invocation parameters", () => {
const span = {
...basePlaygroundSpan,
attributes: JSON.stringify({
...spanAttributesWithInputMessages,
llm: {
...spanAttributesWithInputMessages.llm,
// note that snake case keys are automatically converted to camel case
invocation_parameters:
'{"top_p": 0.5, "max_tokens": 100, "seed": 12345, "stop": ["stop", "me"]}',
},
}),
};
expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
...expectedPlaygroundInstanceWithIO.model,
invocationParameters: {
topP: 0.5,
maxTokens: 100,
seed: 12345,
stop: ["stop", "me"],
},
},
},
parsingErrors: [],
});
});

it("should still parse the model name and provider even if invocation parameters are malformed", () => {
const span = {
...basePlaygroundSpan,
attributes: JSON.stringify({
...spanAttributesWithInputMessages,
llm: {
...spanAttributesWithInputMessages.llm,
invocation_parameters: "invalid json",
},
}),
};
expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
},
parsingErrors: [],
});
});

it("should return invocation parameters parsing errors if the invocation parameters are the wrong type", () => {
const span = {
...basePlaygroundSpan,
attributes: JSON.stringify({
...spanAttributesWithInputMessages,
llm: {
...spanAttributesWithInputMessages.llm,
invocation_parameters: null,
},
}),
};

expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
},
parsingErrors: [MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR],
});
});
});

describe("getChatRole", () => {
Expand Down
6 changes: 4 additions & 2 deletions app/src/pages/playground/constants.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ export const OUTPUT_VALUE_PARSING_ERROR =
"Unable to parse span output expected output.value to be present.";
export const SPAN_ATTRIBUTES_PARSING_ERROR =
"Unable to parse span attributes, attributes must be valid JSON.";
export const MODEL_NAME_PARSING_ERROR =
"Unable to parse model name, expected llm.model_name to be present.";
export const MODEL_CONFIG_PARSING_ERROR =
"Unable to parse model config, expected llm.model_name to be present.";
export const MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR =
"Unable to parse model config, expected llm.invocation_parameters json string to be present.";

export const modelProviderToModelPrefixMap: Record<ModelProvider, string[]> = {
AZURE_OPENAI: [],
Expand Down
Loading

0 comments on commit 6efc700

Please sign in to comment.