Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(playground): save model config by provider in preferences #5216

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions app/src/constants/generativeConstants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ export const ModelProviders: Record<ModelProvider, string> = {
* The default model provider
*/
export const DEFAULT_MODEL_PROVIDER: ModelProvider = "OPENAI";
/**
* The default model name
*/
export const DEFAULT_MODEL_NAME = "gpt-4o";

export const DEFAULT_CHAT_ROLE: ChatMessageRole = "user";

Expand Down
4 changes: 2 additions & 2 deletions app/src/contexts/PlaygroundContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { useZustand } from "use-zustand";

import {
createPlaygroundStore,
PlaygroundProps,
InitialPlaygroundState,
PlaygroundState,
PlaygroundStore,
} from "@phoenix/store";
Expand All @@ -13,7 +13,7 @@ export const PlaygroundContext = createContext<PlaygroundStore | null>(null);
export function PlaygroundProvider({
children,
...props
}: PropsWithChildren<Partial<PlaygroundProps>>) {
}: PropsWithChildren<InitialPlaygroundState>) {
const [store] = useState<PlaygroundStore>(() => createPlaygroundStore(props));
return (
<PlaygroundContext.Provider value={store}>
Expand Down
89 changes: 78 additions & 11 deletions app/src/pages/playground/ModelConfigButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ import {
Picker,
Text,
TextField,
Tooltip,
TooltipTrigger,
View,
} from "@arizeai/components";

import {
AZURE_OPENAI_API_VERSIONS,
ModelProviders,
} from "@phoenix/constants/generativeConstants";
import { useNotifySuccess } from "@phoenix/contexts";
import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";
import { usePreferencesContext } from "@phoenix/contexts/PreferencesContext";
import { PlaygroundInstance } from "@phoenix/store";

import { ModelConfigButtonDialogQuery } from "./__generated__/ModelConfigButtonDialogQuery.graphql";
Expand All @@ -43,6 +47,9 @@ function AzureOpenAiModelConfigFormField({
instance: PlaygroundInstance;
}) {
const updateModel = usePlaygroundContext((state) => state.updateModel);
const modelConfigByProvider = usePreferencesContext(
(state) => state.modelConfigByProvider
);

const updateModelConfig = useCallback(
({
Expand All @@ -58,9 +65,10 @@ function AzureOpenAiModelConfigFormField({
...instance.model,
[configKey]: value,
},
modelConfigByProvider,
});
},
[instance.id, instance.model, updateModel]
[instance.id, instance.model, modelConfigByProvider, updateModel]
);

return (
Expand Down Expand Up @@ -128,13 +136,7 @@ export function ModelConfigButton(props: ModelConfigButtonProps) {
size="compact"
onClick={() => {
startTransition(() => {
setDialog(
<Dialog title="Model Configuration" size="M">
<Suspense>
<ModelConfigDialogContent {...props} />
</Suspense>
</Dialog>
);
setDialog(<ModelConfigDialog {...props} />);
});
}}
>
Expand All @@ -156,18 +158,75 @@ export function ModelConfigButton(props: ModelConfigButtonProps) {
);
}

interface ModelConfigDialogProps extends ModelConfigButtonProps {}
function ModelConfigDialog(props: ModelConfigDialogProps) {
const instance = usePlaygroundContext((state) =>
state.instances.find(
(instance) => instance.id === props.playgroundInstanceId
)
);

if (!instance) {
throw new Error(
`Playground instance ${props.playgroundInstanceId} not found`
);
}
const setModelConfigForProvider = usePreferencesContext(
(state) => state.setModelConfigForProvider
);

const notifySuccess = useNotifySuccess();
const onSaveConfig = useCallback(() => {
setModelConfigForProvider({
provider: instance.model.provider,
modelConfig: instance.model,
});
notifySuccess({
title: "Model Configuration Saved",
message: `${ModelProviders[instance.model.provider]} model configuration saved`,
expireMs: 3000,
});
}, [instance.model, notifySuccess, setModelConfigForProvider]);
return (
<Dialog
title="Model Configuration"
size="M"
extra={
<TooltipTrigger delay={0} offset={5}>
<Button size={"compact"} variant="default" onClick={onSaveConfig}>
Save Config
</Button>
<Tooltip>
Remember configuration for{" "}
{ModelProviders[instance.model.provider] ?? "this provider"}.
</Tooltip>
</TooltipTrigger>
}
>
<Suspense>
<ModelConfigDialogContent {...props} />
</Suspense>
</Dialog>
);
}

interface ModelConfigDialogContentProps extends ModelConfigButtonProps {}
function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
const { playgroundInstanceId } = props;
const updateModel = usePlaygroundContext((state) => state.updateModel);
const instance = usePlaygroundContext((state) =>
state.instances.find((instance) => instance.id === playgroundInstanceId)
);

if (!instance) {
throw new Error(
`Playground instance ${props.playgroundInstanceId} not found`
);
}
const modelConfigByProvider = usePreferencesContext(
(state) => state.modelConfigByProvider
);
const updateModel = usePlaygroundContext((state) => state.updateModel);

const query = useLazyLoadQuery<ModelConfigButtonDialogQuery>(
graphql`
query ModelConfigButtonDialogQuery($providerKey: GenerativeProviderKey!) {
Expand All @@ -186,9 +245,15 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
provider: instance.model.provider,
modelName,
},
modelConfigByProvider,
});
},
[instance.model.provider, playgroundInstanceId, updateModel]
[
instance.model.provider,
modelConfigByProvider,
playgroundInstanceId,
updateModel,
]
);

const onInvocationParametersChange: InvocationParametersChangeHandler =
Expand All @@ -203,9 +268,10 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
[parameter]: value,
},
},
modelConfigByProvider,
});
},
[instance.model, playgroundInstanceId, updateModel]
[instance.model, modelConfigByProvider, playgroundInstanceId, updateModel]
);

return (
Expand All @@ -221,6 +287,7 @@ function ModelConfigDialogContent(props: ModelConfigDialogContentProps) {
provider,
modelName: null,
},
modelConfigByProvider,
});
}}
/>
Expand Down
13 changes: 10 additions & 3 deletions app/src/pages/playground/Playground.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ import {
PlaygroundProvider,
usePlaygroundContext,
} from "@phoenix/contexts/PlaygroundContext";
import { InitialPlaygroundState } from "@phoenix/store";
import { usePreferencesContext } from "@phoenix/contexts/PreferencesContext";
import { PlaygroundProps } from "@phoenix/store";

import { NUM_MAX_PLAYGROUND_INSTANCES } from "./constants";
import { PlaygroundCredentialsDropdown } from "./PlaygroundCredentialsDropdown";
Expand All @@ -39,7 +40,10 @@ const playgroundWrapCSS = css`
height: 100%;
`;

export function Playground(props: InitialPlaygroundState) {
export function Playground(props: Partial<PlaygroundProps>) {
const modelConfigByProvider = usePreferencesContext(
(state) => state.modelConfigByProvider
);
const showStreamToggle = useFeatureFlag("playgroundNonStreaming");
const [, setSearchParams] = useSearchParams();

Expand All @@ -56,7 +60,10 @@ export function Playground(props: InitialPlaygroundState) {
}, [setSearchParams]);

return (
<PlaygroundProvider {...props}>
<PlaygroundProvider
{...props}
modelConfigByProvider={modelConfigByProvider}
>
<div css={playgroundWrapCSS}>
<View
borderBottomColor="dark"
Expand Down
106 changes: 106 additions & 0 deletions app/src/store/playground/__tests__/playgroundStore.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import {
DEFAULT_MODEL_NAME,
DEFAULT_MODEL_PROVIDER,
} from "@phoenix/constants/generativeConstants";

import {
_resetInstanceId,
createPlaygroundInstance,
getInitialInstances,
} from "../playgroundStore";
import { InitialPlaygroundState } from "../types";

describe("getInitialInstances", () => {
beforeEach(() => {
_resetInstanceId();
});
it("should return instances from initialProps if they exist", () => {
const existingInstance = {
...createPlaygroundInstance(),
model: {
modelName: "test-model",
provider: "OPENAI" as const,
invocationParameters: {},
},
};
const initialProps: InitialPlaygroundState = {
instances: [existingInstance],
modelConfigByProvider: {},
};

const instances = getInitialInstances(initialProps);

expect(instances).toEqual([existingInstance]);
});

it("should create a new default instance if no instances exist in initialProps and there are no saved modelConfigs", () => {
const initialProps: InitialPlaygroundState = {
modelConfigByProvider: {},
};
const instances = getInitialInstances(initialProps);

expect(instances).toHaveLength(1);
expect(instances[0].id).toBe(0);
expect(instances[0].model.provider).toBe(DEFAULT_MODEL_PROVIDER);
expect(instances[0].model.modelName).toBe(DEFAULT_MODEL_NAME);
});

it("should use saved model config if available", () => {
const initialProps: InitialPlaygroundState = {
modelConfigByProvider: {
OPENAI: {
modelName: "test-model",
provider: "OPENAI",
invocationParameters: {},
},
},
};

const instances = getInitialInstances(initialProps);

expect(instances).toHaveLength(1);
expect(instances[0].model.provider).toBe("OPENAI");
expect(instances[0].model.modelName).toBe("test-model");
});

it("should use default model provider config if available", () => {
const initialProps: InitialPlaygroundState = {
modelConfigByProvider: {
OPENAI: {
modelName: "test-model-openai",
provider: "OPENAI",
invocationParameters: {},
},
ANTHROPIC: {
modelName: "test-model-anthropic",
provider: "ANTHROPIC",
invocationParameters: {},
},
},
};

const instances = getInitialInstances(initialProps);

expect(instances).toHaveLength(1);
expect(instances[0].model.provider).toBe("OPENAI");
expect(instances[0].model.modelName).toBe("test-model-openai");
});

it("should use any saved config if available if the default provider config is not", () => {
const initialProps: InitialPlaygroundState = {
modelConfigByProvider: {
ANTHROPIC: {
modelName: "test-model-anthropic",
provider: "ANTHROPIC",
invocationParameters: {},
},
},
};

const instances = getInitialInstances(initialProps);

expect(instances).toHaveLength(1);
expect(instances[0].model.provider).toBe("ANTHROPIC");
expect(instances[0].model.modelName).toBe("test-model-anthropic");
});
});
Loading
Loading