Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Render invocation parameters form using data from api #5165

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 195 additions & 66 deletions app/src/pages/playground/InvocationParametersForm.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
import React from "react";
import React, { useCallback, useEffect } from "react";
import { graphql, useLazyLoadQuery } from "react-relay";

import { Flex, Slider, TextField } from "@arizeai/components";

import { ModelConfig } from "@phoenix/store";
import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";
import { PlaygroundInstance } from "@phoenix/store";
import { Mutable } from "@phoenix/typeUtils";

import { getInvocationParametersSchema } from "./playgroundUtils";
import { InvocationParametersSchema } from "./schemas";
import {
InvocationParametersFormQuery,
InvocationParametersFormQuery$data,
} from "./__generated__/InvocationParametersFormQuery.graphql";
import { InvocationParameterInput } from "./__generated__/PlaygroundOutputSubscription.graphql";

export type InvocationParameter = Mutable<
InvocationParametersFormQuery$data["modelInvocationParameters"]
>[number];

export type HandleInvocationParameterChange = (
parameter: InvocationParameter,
value: string | number | string[] | boolean | undefined
) => void;

/**
* Form field for a single invocation parameter.
Expand All @@ -16,113 +30,228 @@ const FormField = ({
value,
onChange,
}: {
field: keyof InvocationParametersSchema;
value: InvocationParametersSchema[keyof InvocationParametersSchema];
onChange: (
value: InvocationParametersSchema[keyof InvocationParametersSchema]
) => void;
field: InvocationParameter;
value: string | number | string[] | boolean | undefined;
onChange: (value: string | number | string[] | boolean | undefined) => void;
}) => {
switch (field) {
case "temperature":
if (typeof value !== "number" && value !== undefined) return null;
return (
<Slider
label="Temperature"
value={value}
step={0.1}
minValue={0}
maxValue={2}
onChange={(value) => onChange(value)}
/>
);
case "topP":
const { __typename } = field;
switch (__typename) {
case "InvocationParameterBase":
return null;
case "FloatInvocationParameter":
case "BoundedFloatInvocationParameter":
if (typeof value !== "number" && value !== undefined) return null;
return (
<Slider
label="Top P"
label={field.label}
isRequired={field.required}
value={value}
step={0.1}
minValue={0}
maxValue={1}
minValue={field.minValue}
maxValue={field.maxValue}
onChange={(value) => onChange(value)}
/>
);
case "maxCompletionTokens":
case "IntInvocationParameter":
return (
<TextField
label="Max Completion Tokens"
label={field.label}
isRequired={field.required}
value={value?.toString() || ""}
type="number"
onChange={(value) => onChange(Number(value))}
/>
);
case "maxTokens":
return (
<TextField
label="Max Tokens"
value={value?.toString() || ""}
type="number"
onChange={(value) => onChange(Number(value))}
/>
);
case "stop":
case "StringListInvocationParameter":
if (!Array.isArray(value) && value !== undefined) return null;
return (
<TextField
label="Stop"
label={field.label}
isRequired={field.required}
defaultValue={value?.join(", ") || ""}
onChange={(value) => onChange(value.split(/, */g))}
/>
);
case "seed":
case "StringInvocationParameter":
return (
<TextField
label="Seed"
label={field.label}
isRequired={field.required}
value={value?.toString() || ""}
type="number"
onChange={(value) => onChange(Number(value))}
type="text"
onChange={(value) => onChange(value)}
/>
);
case "BooleanInvocationParameter":
// TODO: add checkbox
return null;
default:
return null;
}
};

export type InvocationParametersChangeHandler = <
T extends keyof ModelConfig["invocationParameters"],
>(
parameter: T,
value: ModelConfig["invocationParameters"][T]
) => void;
const getInvocationParameterValue = (
field: InvocationParameter,
parameterInput: InvocationParameterInput
): string | number | string[] | boolean | null | undefined => {
const type = field.__typename;
switch (type) {
case "FloatInvocationParameter":
case "BoundedFloatInvocationParameter":
return parameterInput.valueFloat;
case "IntInvocationParameter":
return parameterInput.valueInt;
case "StringInvocationParameter":
return parameterInput.valueString;
case "StringListInvocationParameter":
return parameterInput.valueStringList as string[] | undefined | null;
case "BooleanInvocationParameter":
return parameterInput.valueBool;
default:
throw new Error(`Unsupported invocation parameter type: ${type}`);
}
};

const makeInvocationParameterInput = (
field: InvocationParameter,
value: string | number | string[] | boolean | undefined
): InvocationParameterInput => {
if (field.invocationName === undefined) {
throw new Error("Invocation name is required");
}
const type = field.__typename;
switch (type) {
case "FloatInvocationParameter":
case "BoundedFloatInvocationParameter":
return {
invocationName: field.invocationName,
valueFloat: value === undefined ? undefined : Number(value),
};
case "IntInvocationParameter":
return {
invocationName: field.invocationName,
valueInt: value === undefined ? undefined : Number(value),
};
case "StringInvocationParameter":
return {
invocationName: field.invocationName,
valueString: value === undefined ? undefined : String(value),
};
case "StringListInvocationParameter":
return {
invocationName: field.invocationName,
valueStringList: Array.isArray(value) ? value : undefined,
};
case "BooleanInvocationParameter":
return {
invocationName: field.invocationName,
valueBool: value === undefined ? undefined : Boolean(value),
};
default:
throw new Error(`Unsupported invocation parameter type: ${type}`);
}
};

type InvocationParametersFormProps = {
model: ModelConfig;
onChange: InvocationParametersChangeHandler;
instance: PlaygroundInstance;
};

export const InvocationParametersForm = ({
model,
onChange,
instance,
}: InvocationParametersFormProps) => {
const { invocationParameters, provider, modelName } = model;
// Get the schema for the incoming provider and model combination.
const schema = getInvocationParametersSchema({
modelProvider: provider,
modelName: modelName || "default",
});
const { model } = instance;
const updateInstanceModelInvocationParameters = usePlaygroundContext(
(state) => state.updateInstanceModelInvocationParameters
);
const filterInstanceModelInvocationParameters = usePlaygroundContext(
(state) => state.filterInstanceModelInvocationParameters
);
const { modelInvocationParameters } =
useLazyLoadQuery<InvocationParametersFormQuery>(
graphql`
query InvocationParametersFormQuery($input: ModelsInput!) {
modelInvocationParameters(input: $input) {
__typename
... on InvocationParameterBase {
invocationName
label
required
}
... on BoundedFloatInvocationParameter {
minValue
maxValue
}
}
}
`,
{ input: { providerKey: model.provider } }
);

useEffect(() => {
// filter invocation parameters to only include those that are supported by the model
if (modelInvocationParameters) {
filterInstanceModelInvocationParameters({
instanceId: instance.id,
modelSupportedInvocationParameters:
modelInvocationParameters as Mutable<
typeof modelInvocationParameters
>,
});
}
}, [
filterInstanceModelInvocationParameters,
instance.id,
modelInvocationParameters,
]);

const fieldsForSchema = Object.keys(schema.shape).map((field) => {
const fieldKey = field as keyof (typeof schema)["shape"];
const value = invocationParameters[fieldKey];
const onChange = useCallback(
(
field: InvocationParameter,
value: string | number | string[] | boolean | undefined
) => {
const existingParameter = instance.model.invocationParameters.find(
(p) => p.invocationName === field.invocationName
);

if (existingParameter) {
updateInstanceModelInvocationParameters({
instanceId: instance.id,
invocationParameters: instance.model.invocationParameters.map((p) =>
p.invocationName === field.invocationName
? makeInvocationParameterInput(field, value)
: p
),
});
} else {
updateInstanceModelInvocationParameters({
instanceId: instance.id,
invocationParameters: [
...instance.model.invocationParameters,
makeInvocationParameterInput(field, value),
],
});
}
},
[instance, updateInstanceModelInvocationParameters]
);

const fieldsForSchema = modelInvocationParameters.map((field) => {
const existingParameter = instance.model.invocationParameters.find(
(p) => p.invocationName === field.invocationName
);
const value = existingParameter
? getInvocationParameterValue(field, existingParameter)
: undefined;
return (
<FormField
key={fieldKey}
field={fieldKey}
value={value === null ? undefined : (value as Mutable<typeof value>)}
onChange={(value) => onChange(fieldKey, value)}
key={field.invocationName}
field={field}
value={value === null ? undefined : value}
onChange={(value) => onChange(field, value)}
/>
);
});

return (
<Flex direction="column" gap="size-200">
{fieldsForSchema}
Expand Down
22 changes: 11 additions & 11 deletions app/src/pages/playground/ModelConfigButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import { ModelConfigButtonDialogQuery } from "./__generated__/ModelConfigButtonDialogQuery.graphql";
import {
InvocationParametersChangeHandler,
HandleInvocationParameterChange,
InvocationParametersForm,
} from "./InvocationParametersForm";
import { ModelPicker } from "./ModelPicker";
Expand Down Expand Up @@ -191,17 +191,16 @@
[instance.model.provider, playgroundInstanceId, updateModel]
);

const onInvocationParametersChange: InvocationParametersChangeHandler =
const onInvocationParametersChange: HandleInvocationParameterChange =

Check failure on line 194 in app/src/pages/playground/ModelConfigButton.tsx

View workflow job for this annotation

GitHub Actions / CI Typescript

'onInvocationParametersChange' is assigned a value but never used
useCallback(
(parameter, value) => {
// TODO(apowell): implement
// eslint-disable-next-line @typescript-eslint/no-unused-vars
(parameterDefinition, value) => {
updateModel({
instanceId: playgroundInstanceId,
model: {
...instance.model,
invocationParameters: {
...instance.model.invocationParameters,
[parameter]: value,
},
invocationParameters: [],
},
});
},
Expand Down Expand Up @@ -234,10 +233,11 @@
onChange={onModelNameChange}
/>
)}
<InvocationParametersForm
model={instance.model}
onChange={onInvocationParametersChange}
/>
{instance.model.modelName ? (
<InvocationParametersForm instance={instance} />
) : (
<></>
)}
</Form>
</View>
);
Expand Down
Loading
Loading