-
Notifications
You must be signed in to change notification settings - Fork 285
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* feat: Scaffold model invocation params form (#5040) * fix: Fix types * fix: Fix test outputs by updating model schema * Send invocation parameters with completion stream params * Generate invocation parameters form based on provider/model schema * Constrain invocation params by appropriate schema * Update playground span parsing error text * Improve commenting, delete out of date comments * Read and validate invocation parameters from span into store * Improve playground span transformation test * --amend * Safely parse invocation params schema, separately from other model config
- Loading branch information
1 parent
b632b68
commit 6efc700
Showing
9 changed files
with
447 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import React from "react"; | ||
|
||
import { Flex, Slider, TextField } from "@arizeai/components"; | ||
|
||
import { ModelConfig } from "@phoenix/store"; | ||
import { Mutable } from "@phoenix/typeUtils"; | ||
|
||
import { getInvocationParametersSchema } from "./playgroundUtils"; | ||
import { InvocationParametersSchema } from "./schemas"; | ||
|
||
/** | ||
* Form field for a single invocation parameter. | ||
*/ | ||
const FormField = ({ | ||
field, | ||
value, | ||
onChange, | ||
}: { | ||
field: keyof InvocationParametersSchema; | ||
value: InvocationParametersSchema[keyof InvocationParametersSchema]; | ||
onChange: ( | ||
value: InvocationParametersSchema[keyof InvocationParametersSchema] | ||
) => void; | ||
}) => { | ||
switch (field) { | ||
case "temperature": | ||
if (typeof value !== "number" && value !== undefined) return null; | ||
return ( | ||
<Slider | ||
label="Temperature" | ||
value={value} | ||
step={0.1} | ||
minValue={0} | ||
maxValue={2} | ||
onChange={(value) => onChange(value)} | ||
/> | ||
); | ||
case "topP": | ||
if (typeof value !== "number" && value !== undefined) return null; | ||
return ( | ||
<Slider | ||
label="Top P" | ||
value={value} | ||
step={0.1} | ||
minValue={0} | ||
maxValue={1} | ||
onChange={(value) => onChange(value)} | ||
/> | ||
); | ||
case "maxCompletionTokens": | ||
return ( | ||
<TextField | ||
label="Max Completion Tokens" | ||
value={value?.toString() || ""} | ||
type="number" | ||
onChange={(value) => onChange(Number(value))} | ||
/> | ||
); | ||
case "maxTokens": | ||
return ( | ||
<TextField | ||
label="Max Tokens" | ||
value={value?.toString() || ""} | ||
type="number" | ||
onChange={(value) => onChange(Number(value))} | ||
/> | ||
); | ||
case "stop": | ||
if (!Array.isArray(value) && value !== undefined) return null; | ||
return ( | ||
<TextField | ||
label="Stop" | ||
defaultValue={value?.join(", ") || ""} | ||
onChange={(value) => onChange(value.split(/, */g))} | ||
/> | ||
); | ||
case "seed": | ||
return ( | ||
<TextField | ||
label="Seed" | ||
value={value?.toString() || ""} | ||
type="number" | ||
onChange={(value) => onChange(Number(value))} | ||
/> | ||
); | ||
default: | ||
return null; | ||
} | ||
}; | ||
|
||
export type InvocationParametersChangeHandler = < | ||
T extends keyof ModelConfig["invocationParameters"], | ||
>( | ||
parameter: T, | ||
value: ModelConfig["invocationParameters"][T] | ||
) => void; | ||
|
||
type InvocationParametersFormProps = { | ||
model: ModelConfig; | ||
onChange: InvocationParametersChangeHandler; | ||
}; | ||
|
||
export const InvocationParametersForm = ({ | ||
model, | ||
onChange, | ||
}: InvocationParametersFormProps) => { | ||
const { invocationParameters, provider, modelName } = model; | ||
// Get the schema for the incoming provider and model combination. | ||
const schema = getInvocationParametersSchema({ | ||
modelProvider: provider, | ||
modelName: modelName || "default", | ||
}); | ||
|
||
const fieldsForSchema = Object.keys(schema.shape).map((field) => { | ||
const fieldKey = field as keyof (typeof schema)["shape"]; | ||
const value = invocationParameters[fieldKey]; | ||
return ( | ||
<FormField | ||
key={fieldKey} | ||
field={fieldKey} | ||
value={value === null ? undefined : (value as Mutable<typeof value>)} | ||
onChange={(value) => onChange(fieldKey, value)} | ||
/> | ||
); | ||
}); | ||
return ( | ||
<Flex direction="column" gap="size-200"> | ||
{fieldsForSchema} | ||
</Flex> | ||
); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.