Skip to content

Commit

Permalink
refactor: Render invocation parameters form using data from api (#5165)
Browse files Browse the repository at this point in the history
* refactor: Render invocation parameters form using data from api

* Map model invocation parameters to invocation parameter inputs
  • Loading branch information
cephalization authored Oct 24, 2024
1 parent b4a1d3c commit 866d0b7
Show file tree
Hide file tree
Showing 10 changed files with 512 additions and 194 deletions.
261 changes: 195 additions & 66 deletions app/src/pages/playground/InvocationParametersForm.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
import React from "react";
import React, { useCallback, useEffect } from "react";
import { graphql, useLazyLoadQuery } from "react-relay";

import { Flex, Slider, TextField } from "@arizeai/components";

import { ModelConfig } from "@phoenix/store";
import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";
import { PlaygroundInstance } from "@phoenix/store";
import { Mutable } from "@phoenix/typeUtils";

import { getInvocationParametersSchema } from "./playgroundUtils";
import { InvocationParametersSchema } from "./schemas";
import {
InvocationParametersFormQuery,
InvocationParametersFormQuery$data,
} from "./__generated__/InvocationParametersFormQuery.graphql";
import { InvocationParameterInput } from "./__generated__/PlaygroundOutputSubscription.graphql";

export type InvocationParameter = Mutable<
InvocationParametersFormQuery$data["modelInvocationParameters"]
>[number];

export type HandleInvocationParameterChange = (
parameter: InvocationParameter,
value: string | number | string[] | boolean | undefined
) => void;

/**
* Form field for a single invocation parameter.
Expand All @@ -16,113 +30,228 @@ const FormField = ({
value,
onChange,
}: {
field: keyof InvocationParametersSchema;
value: InvocationParametersSchema[keyof InvocationParametersSchema];
onChange: (
value: InvocationParametersSchema[keyof InvocationParametersSchema]
) => void;
field: InvocationParameter;
value: string | number | string[] | boolean | undefined;
onChange: (value: string | number | string[] | boolean | undefined) => void;
}) => {
switch (field) {
case "temperature":
if (typeof value !== "number" && value !== undefined) return null;
return (
<Slider
label="Temperature"
value={value}
step={0.1}
minValue={0}
maxValue={2}
onChange={(value) => onChange(value)}
/>
);
case "topP":
const { __typename } = field;
switch (__typename) {
case "InvocationParameterBase":
return null;
case "FloatInvocationParameter":
case "BoundedFloatInvocationParameter":
if (typeof value !== "number" && value !== undefined) return null;
return (
<Slider
label="Top P"
label={field.label}
isRequired={field.required}
value={value}
step={0.1}
minValue={0}
maxValue={1}
minValue={field.minValue}
maxValue={field.maxValue}
onChange={(value) => onChange(value)}
/>
);
case "maxCompletionTokens":
case "IntInvocationParameter":
return (
<TextField
label="Max Completion Tokens"
label={field.label}
isRequired={field.required}
value={value?.toString() || ""}
type="number"
onChange={(value) => onChange(Number(value))}
/>
);
case "maxTokens":
return (
<TextField
label="Max Tokens"
value={value?.toString() || ""}
type="number"
onChange={(value) => onChange(Number(value))}
/>
);
case "stop":
case "StringListInvocationParameter":
if (!Array.isArray(value) && value !== undefined) return null;
return (
<TextField
label="Stop"
label={field.label}
isRequired={field.required}
defaultValue={value?.join(", ") || ""}
onChange={(value) => onChange(value.split(/, */g))}
/>
);
case "seed":
case "StringInvocationParameter":
return (
<TextField
label="Seed"
label={field.label}
isRequired={field.required}
value={value?.toString() || ""}
type="number"
onChange={(value) => onChange(Number(value))}
type="text"
onChange={(value) => onChange(value)}
/>
);
case "BooleanInvocationParameter":
// TODO: add checkbox
return null;
default:
return null;
}
};

export type InvocationParametersChangeHandler = <
T extends keyof ModelConfig["invocationParameters"],
>(
parameter: T,
value: ModelConfig["invocationParameters"][T]
) => void;
const getInvocationParameterValue = (
field: InvocationParameter,
parameterInput: InvocationParameterInput
): string | number | string[] | boolean | null | undefined => {
const type = field.__typename;
switch (type) {
case "FloatInvocationParameter":
case "BoundedFloatInvocationParameter":
return parameterInput.valueFloat;
case "IntInvocationParameter":
return parameterInput.valueInt;
case "StringInvocationParameter":
return parameterInput.valueString;
case "StringListInvocationParameter":
return parameterInput.valueStringList as string[] | undefined | null;
case "BooleanInvocationParameter":
return parameterInput.valueBool;
default:
throw new Error(`Unsupported invocation parameter type: ${type}`);
}
};

const makeInvocationParameterInput = (
field: InvocationParameter,
value: string | number | string[] | boolean | undefined
): InvocationParameterInput => {
if (field.invocationName === undefined) {
throw new Error("Invocation name is required");
}
const type = field.__typename;
switch (type) {
case "FloatInvocationParameter":
case "BoundedFloatInvocationParameter":
return {
invocationName: field.invocationName,
valueFloat: value === undefined ? undefined : Number(value),
};
case "IntInvocationParameter":
return {
invocationName: field.invocationName,
valueInt: value === undefined ? undefined : Number(value),
};
case "StringInvocationParameter":
return {
invocationName: field.invocationName,
valueString: value === undefined ? undefined : String(value),
};
case "StringListInvocationParameter":
return {
invocationName: field.invocationName,
valueStringList: Array.isArray(value) ? value : undefined,
};
case "BooleanInvocationParameter":
return {
invocationName: field.invocationName,
valueBool: value === undefined ? undefined : Boolean(value),
};
default:
throw new Error(`Unsupported invocation parameter type: ${type}`);
}
};

type InvocationParametersFormProps = {
model: ModelConfig;
onChange: InvocationParametersChangeHandler;
instance: PlaygroundInstance;
};

export const InvocationParametersForm = ({
model,
onChange,
instance,
}: InvocationParametersFormProps) => {
const { invocationParameters, provider, modelName } = model;
// Get the schema for the incoming provider and model combination.
const schema = getInvocationParametersSchema({
modelProvider: provider,
modelName: modelName || "default",
});
const { model } = instance;
const updateInstanceModelInvocationParameters = usePlaygroundContext(
(state) => state.updateInstanceModelInvocationParameters
);
const filterInstanceModelInvocationParameters = usePlaygroundContext(
(state) => state.filterInstanceModelInvocationParameters
);
const { modelInvocationParameters } =
useLazyLoadQuery<InvocationParametersFormQuery>(
graphql`
query InvocationParametersFormQuery($input: ModelsInput!) {
modelInvocationParameters(input: $input) {
__typename
... on InvocationParameterBase {
invocationName
label
required
}
... on BoundedFloatInvocationParameter {
minValue
maxValue
}
}
}
`,
{ input: { providerKey: model.provider } }
);

useEffect(() => {
// filter invocation parameters to only include those that are supported by the model
if (modelInvocationParameters) {
filterInstanceModelInvocationParameters({
instanceId: instance.id,
modelSupportedInvocationParameters:
modelInvocationParameters as Mutable<
typeof modelInvocationParameters
>,
});
}
}, [
filterInstanceModelInvocationParameters,
instance.id,
modelInvocationParameters,
]);

const fieldsForSchema = Object.keys(schema.shape).map((field) => {
const fieldKey = field as keyof (typeof schema)["shape"];
const value = invocationParameters[fieldKey];
const onChange = useCallback(
(
field: InvocationParameter,
value: string | number | string[] | boolean | undefined
) => {
const existingParameter = instance.model.invocationParameters.find(
(p) => p.invocationName === field.invocationName
);

if (existingParameter) {
updateInstanceModelInvocationParameters({
instanceId: instance.id,
invocationParameters: instance.model.invocationParameters.map((p) =>
p.invocationName === field.invocationName
? makeInvocationParameterInput(field, value)
: p
),
});
} else {
updateInstanceModelInvocationParameters({
instanceId: instance.id,
invocationParameters: [
...instance.model.invocationParameters,
makeInvocationParameterInput(field, value),
],
});
}
},
[instance, updateInstanceModelInvocationParameters]
);

const fieldsForSchema = modelInvocationParameters.map((field) => {
const existingParameter = instance.model.invocationParameters.find(
(p) => p.invocationName === field.invocationName
);
const value = existingParameter
? getInvocationParameterValue(field, existingParameter)
: undefined;
return (
<FormField
key={fieldKey}
field={fieldKey}
value={value === null ? undefined : (value as Mutable<typeof value>)}
onChange={(value) => onChange(fieldKey, value)}
key={field.invocationName}
field={field}
value={value === null ? undefined : value}
onChange={(value) => onChange(field, value)}
/>
);
});

return (
<Flex direction="column" gap="size-200">
{fieldsForSchema}
Expand Down
22 changes: 11 additions & 11 deletions app/src/pages/playground/ModelConfigButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import { PlaygroundInstance } from "@phoenix/store";

import { ModelConfigButtonDialogQuery } from "./__generated__/ModelConfigButtonDialogQuery.graphql";
import {
InvocationParametersChangeHandler,
HandleInvocationParameterChange,
InvocationParametersForm,
} from "./InvocationParametersForm";
import { ModelPicker } from "./ModelPicker";
Expand Down Expand Up @@ -191,17 +191,16 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
[instance.model.provider, playgroundInstanceId, updateModel]
);

const onInvocationParametersChange: InvocationParametersChangeHandler =
const onInvocationParametersChange: HandleInvocationParameterChange =
useCallback(
(parameter, value) => {
// TODO(apowell): implement
// eslint-disable-next-line @typescript-eslint/no-unused-vars
(parameterDefinition, value) => {
updateModel({
instanceId: playgroundInstanceId,
model: {
...instance.model,
invocationParameters: {
...instance.model.invocationParameters,
[parameter]: value,
},
invocationParameters: [],
},
});
},
Expand Down Expand Up @@ -234,10 +233,11 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
onChange={onModelNameChange}
/>
)}
<InvocationParametersForm
model={instance.model}
onChange={onInvocationParametersChange}
/>
{instance.model.modelName ? (
<InvocationParametersForm instance={instance} />
) : (
<></>
)}
</Form>
</View>
);
Expand Down
Loading

0 comments on commit 866d0b7

Please sign in to comment.