Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Scaffold model invocation params form (#5040) #5045

Merged
merged 12 commits into from
Oct 22, 2024
Merged
131 changes: 131 additions & 0 deletions app/src/pages/playground/InvocationParametersForm.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import React from "react";

import { Flex, Slider, TextField } from "@arizeai/components";

import { ModelConfig } from "@phoenix/store";
import { Mutable } from "@phoenix/typeUtils";

import { getInvocationParametersSchema } from "./playgroundUtils";
import { InvocationParametersSchema } from "./schemas";

/**
* Form field for a single invocation parameter.
*/
const FormField = ({
field,
value,
onChange,
}: {
field: keyof InvocationParametersSchema;
value: InvocationParametersSchema[keyof InvocationParametersSchema];
onChange: (
value: InvocationParametersSchema[keyof InvocationParametersSchema]
) => void;
}) => {
switch (field) {
case "temperature":
if (typeof value !== "number" && value !== undefined) return null;
return (
<Slider
label="Temperature"
value={value}
step={0.1}
minValue={0}
maxValue={2}
onChange={(value) => onChange(value)}
/>
);
case "topP":
if (typeof value !== "number" && value !== undefined) return null;
return (
<Slider
label="Top P"
value={value}
step={0.1}
minValue={0}
maxValue={1}
onChange={(value) => onChange(value)}
/>
);
case "maxCompletionTokens":
return (
<TextField
label="Max Completion Tokens"
value={value?.toString() || ""}
type="number"
onChange={(value) => onChange(Number(value))}
/>
);
case "maxTokens":
return (
<TextField
label="Max Tokens"
value={value?.toString() || ""}
type="number"
onChange={(value) => onChange(Number(value))}
/>
);
case "stop":
if (!Array.isArray(value) && value !== undefined) return null;
return (
<TextField
label="Stop"
defaultValue={value?.join(", ") || ""}
onChange={(value) => onChange(value.split(/, */g))}
/>
);
case "seed":
return (
<TextField
label="Seed"
value={value?.toString() || ""}
type="number"
onChange={(value) => onChange(Number(value))}
/>
);
default:
return null;
}
};

export type InvocationParametersChangeHandler = <
T extends keyof ModelConfig["invocationParameters"],
>(
parameter: T,
value: ModelConfig["invocationParameters"][T]
) => void;

type InvocationParametersFormProps = {
model: ModelConfig;
onChange: InvocationParametersChangeHandler;
};

export const InvocationParametersForm = ({
model,
onChange,
}: InvocationParametersFormProps) => {
const { invocationParameters, provider, modelName } = model;
// Get the schema for the incoming provider and model combination.
const schema = getInvocationParametersSchema({
modelProvider: provider,
modelName: modelName || "default",
});

const fieldsForSchema = Object.keys(schema.shape).map((field) => {
const fieldKey = field as keyof (typeof schema)["shape"];
const value = invocationParameters[fieldKey];
return (
<FormField
key={fieldKey}
field={fieldKey}
value={value === null ? undefined : (value as Mutable<typeof value>)}
onChange={(value) => onChange(fieldKey, value)}
/>
);
});
return (
<Flex direction="column" gap="size-200">
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a form? We use https://www.react-hook-form.com/ react hook form, seem slike this could be used here iwth controllers etc. might be tough though with the genericness of it, but would probably give it a try, can be afollow up

{fieldsForSchema}
</Flex>
);
};
27 changes: 26 additions & 1 deletion app/src/pages/playground/ModelConfigButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";
import { PlaygroundInstance } from "@phoenix/store";

import { ModelConfigButtonDialogQuery } from "./__generated__/ModelConfigButtonDialogQuery.graphql";
import {
InvocationParametersChangeHandler,
InvocationParametersForm,
} from "./InvocationParametersForm";
import { ModelPicker } from "./ModelPicker";
import { ModelProviderPicker } from "./ModelProviderPicker";
import { PlaygroundInstanceProps } from "./types";
Expand Down Expand Up @@ -187,8 +191,25 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
[instance.model.provider, playgroundInstanceId, updateModel]
);

const onInvocationParametersChange: InvocationParametersChangeHandler =
useCallback(
(parameter, value) => {
updateModel({
instanceId: playgroundInstanceId,
model: {
...instance.model,
invocationParameters: {
...instance.model.invocationParameters,
[parameter]: value,
},
},
});
},
[instance.model, playgroundInstanceId, updateModel]
);

return (
<View padding="size-200">
<View padding="size-200" overflow="auto">
<Form>
<ModelProviderPicker
provider={instance.model.provider}
Expand All @@ -213,6 +234,10 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
onChange={onModelNameChange}
/>
)}
<InvocationParametersForm
model={instance.model}
onChange={onInvocationParametersChange}
/>
</Form>
</View>
);
Expand Down
30 changes: 27 additions & 3 deletions app/src/pages/playground/PlaygroundOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ import {
PlaygroundOutputSubscription$data,
PlaygroundOutputSubscription$variables,
} from "./__generated__/PlaygroundOutputSubscription.graphql";
import { isChatMessages } from "./playgroundUtils";
import {
getInvocationParametersSchema,
isChatMessages,
} from "./playgroundUtils";
import { RunMetadataFooter } from "./RunMetadataFooter";
import { TitleWithAlphabeticIndex } from "./TitleWithAlphabeticIndex";
import { PlaygroundInstanceProps } from "./types";
Expand Down Expand Up @@ -269,7 +272,25 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
}
: {};

const invocationParameters: InvocationParameters = {};
const invocationParametersSchema = getInvocationParametersSchema({
modelProvider: instance.model.provider,
modelName: instance.model.modelName || "default",
});

let invocationParameters: InvocationParameters = {
...instance.model.invocationParameters,
};

// Constrain the invocation parameters to the schema.
// This prevents us from sending invalid parameters to the LLM since we may be
// storing parameters from previously selected models/providers within this instance.
const valid = invocationParametersSchema.safeParse(invocationParameters);
if (!valid.success) {
// If we cannot successfully parse the invocation parameters, just send them
// all and let the API fail if they are invalid.
invocationParameters = instance.model.invocationParameters;
}

if (instance.tools.length) {
invocationParameters["toolChoice"] = instance.toolChoice;
}
Expand Down Expand Up @@ -339,7 +360,10 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
onCompleted: () => {
markPlaygroundInstanceComplete(props.playgroundInstanceId);
},
onFailed: () => {
onFailed: (error) => {
// TODO(apowell): We should display this error to the user after formatting it nicely.
// eslint-disable-next-line no-console
console.error(error);
markPlaygroundInstanceComplete(props.playgroundInstanceId);
updateInstance({
instanceId: props.playgroundInstanceId,
Expand Down
81 changes: 79 additions & 2 deletions app/src/pages/playground/__tests__/playgroundUtils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import {

import {
INPUT_MESSAGES_PARSING_ERROR,
MODEL_NAME_PARSING_ERROR,
MODEL_CONFIG_PARSING_ERROR,
MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR,
OUTPUT_MESSAGES_PARSING_ERROR,
OUTPUT_VALUE_PARSING_ERROR,
SPAN_ATTRIBUTES_PARSING_ERROR,
Expand All @@ -30,6 +31,7 @@ const expectedPlaygroundInstanceWithIO: PlaygroundInstance = {
model: {
provider: "OPENAI",
modelName: "gpt-3.5-turbo",
invocationParameters: {},
},
input: { variablesValueCache: {} },
tools: [],
Expand Down Expand Up @@ -79,6 +81,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
model: {
provider: "OPENAI",
modelName: "gpt-4o",
invocationParameters: {},
},
template: defaultTemplate,
output: undefined,
Expand All @@ -96,6 +99,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
...expectedPlaygroundInstanceWithIO.model,
provider: "OPENAI",
modelName: "gpt-4o",
},
Expand All @@ -107,7 +111,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
INPUT_MESSAGES_PARSING_ERROR,
OUTPUT_MESSAGES_PARSING_ERROR,
OUTPUT_VALUE_PARSING_ERROR,
MODEL_NAME_PARSING_ERROR,
MODEL_CONFIG_PARSING_ERROR,
],
});
});
Expand Down Expand Up @@ -200,6 +204,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
...expectedPlaygroundInstanceWithIO.model,
provider: "OPENAI",
modelName: "gpt-4o",
},
Expand Down Expand Up @@ -251,6 +256,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
...expectedPlaygroundInstanceWithIO.model,
provider: "OPENAI",
modelName: "gpt-3.5-turbo",
},
Expand All @@ -270,6 +276,7 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
...expectedPlaygroundInstanceWithIO.model,
provider: "ANTHROPIC",
modelName: "claude-3-5-sonnet-20240620",
},
Expand All @@ -289,13 +296,83 @@ describe("transformSpanAttributesToPlaygroundInstance", () => {
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
...expectedPlaygroundInstanceWithIO.model,
provider: DEFAULT_MODEL_PROVIDER,
modelName: "test-my-deployment",
},
},
parsingErrors: [],
});
});

it("should correctly parse the invocation parameters", () => {
const span = {
...basePlaygroundSpan,
attributes: JSON.stringify({
...spanAttributesWithInputMessages,
llm: {
...spanAttributesWithInputMessages.llm,
// note that snake case keys are automatically converted to camel case
invocation_parameters:
'{"top_p": 0.5, "max_tokens": 100, "seed": 12345, "stop": ["stop", "me"]}',
},
}),
};
expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
model: {
...expectedPlaygroundInstanceWithIO.model,
invocationParameters: {
topP: 0.5,
maxTokens: 100,
seed: 12345,
stop: ["stop", "me"],
},
},
},
parsingErrors: [],
});
});

it("should still parse the model name and provider even if invocation parameters are malformed", () => {
const span = {
...basePlaygroundSpan,
attributes: JSON.stringify({
...spanAttributesWithInputMessages,
llm: {
...spanAttributesWithInputMessages.llm,
invocation_parameters: "invalid json",
},
}),
};
expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
},
parsingErrors: [],
});
});

it("should return invocation parameters parsing errors if the invocation parameters are the wrong type", () => {
const span = {
...basePlaygroundSpan,
attributes: JSON.stringify({
...spanAttributesWithInputMessages,
llm: {
...spanAttributesWithInputMessages.llm,
invocation_parameters: null,
},
}),
};

expect(transformSpanAttributesToPlaygroundInstance(span)).toEqual({
playgroundInstance: {
...expectedPlaygroundInstanceWithIO,
},
parsingErrors: [MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR],
});
});
});

describe("getChatRole", () => {
Expand Down
6 changes: 4 additions & 2 deletions app/src/pages/playground/constants.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ export const OUTPUT_VALUE_PARSING_ERROR =
"Unable to parse span output expected output.value to be present.";
export const SPAN_ATTRIBUTES_PARSING_ERROR =
"Unable to parse span attributes, attributes must be valid JSON.";
export const MODEL_NAME_PARSING_ERROR =
"Unable to parse model name, expected llm.model_name to be present.";
export const MODEL_CONFIG_PARSING_ERROR =
"Unable to parse model config, expected llm.model_name to be present.";
export const MODEL_CONFIG_WITH_INVOCATION_PARAMETERS_PARSING_ERROR =
"Unable to parse model config, expected llm.invocation_parameters json string to be present.";

export const modelProviderToModelPrefixMap: Record<ModelProvider, string[]> = {
AZURE_OPENAI: [],
Expand Down
Loading
Loading