diff --git a/.changeset/seven-kids-return.md b/.changeset/seven-kids-return.md new file mode 100644 index 00000000000..d4da5cbc031 --- /dev/null +++ b/.changeset/seven-kids-return.md @@ -0,0 +1,10 @@ +--- +"roo-cline": minor +--- + +Adds refresh models button for Unbound provider +Adds a button above model picker to refresh models based on the current API Key. + +1. Clicking the refresh button saves the API Key and calls /models endpoint using that. +2. Gets the new models and updates the current model if it is invalid for the given API Key. +3. The refresh button also flushes existing Unbound models and refetches them. diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 9ab4b851fc5..c3c662415b0 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -65,7 +65,8 @@ export const getModels = async ( models = await getGlamaModels() break case "unbound": - models = await getUnboundModels() + // Unbound models endpoint requires an API key to fetch application specific models + models = await getUnboundModels(apiKey) break case "litellm": if (apiKey && baseUrl) { diff --git a/src/api/providers/fetchers/unbound.ts b/src/api/providers/fetchers/unbound.ts index 73a8c2f8970..7834debf355 100644 --- a/src/api/providers/fetchers/unbound.ts +++ b/src/api/providers/fetchers/unbound.ts @@ -2,11 +2,17 @@ import axios from "axios" import { ModelInfo } from "../../../shared/api" -export async function getUnboundModels(): Promise> { +export async function getUnboundModels(apiKey?: string | null): Promise> { const models: Record = {} try { - const response = await axios.get("https://api.getunbound.ai/models") + const headers: Record = {} + + if (apiKey) { + headers["Authorization"] = `Bearer ${apiKey}` + } + + const response = await axios.get("https://api.getunbound.ai/models", { headers }) if (response.data) { const rawModels: Record = response.data @@ -40,6 +46,7 @@ export async function getUnboundModels(): Promise> { } } catch (error) { console.error(`Error fetching Unbound models: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`) + throw new Error(`Failed to fetch Unbound models: ${error instanceof Error ? error.message : "Unknown error"}`) } return models diff --git a/webview-ui/src/components/settings/providers/Unbound.tsx b/webview-ui/src/components/settings/providers/Unbound.tsx index 77b24bb7cc5..3d5aa0c67a9 100644 --- a/webview-ui/src/components/settings/providers/Unbound.tsx +++ b/webview-ui/src/components/settings/providers/Unbound.tsx @@ -1,10 +1,13 @@ -import { useCallback } from "react" +import { useCallback, useState, useRef } from "react" import { VSCodeTextField } from "@vscode/webview-ui-toolkit/react" +import { useQueryClient } from "@tanstack/react-query" import { ProviderSettings, RouterModels, unboundDefaultModelId } from "@roo/shared/api" import { useAppTranslation } from "@src/i18n/TranslationContext" import { VSCodeButtonLink } from "@src/components/common/VSCodeButtonLink" +import { vscode } from "@src/utils/vscode" +import { Button } from "@src/components/ui" import { inputEventTransform } from "../transforms" import { ModelPicker } from "../ModelPicker" @@ -17,6 +20,13 @@ type UnboundProps = { export const Unbound = ({ apiConfiguration, setApiConfigurationField, routerModels }: UnboundProps) => { const { t } = useAppTranslation() + const [didRefetch, setDidRefetch] = useState() + const [isInvalidKey, setIsInvalidKey] = useState(false) + const queryClient = useQueryClient() + + // Add refs to store timer IDs + const didRefetchTimerRef = useRef() + const invalidKeyTimerRef = useRef() const handleInputChange = useCallback( ( @@ -29,6 +39,90 @@ export const Unbound = ({ apiConfiguration, setApiConfigurationField, routerMode [setApiConfigurationField], ) + const saveConfiguration = useCallback(async () => { + vscode.postMessage({ + type: "upsertApiConfiguration", + text: "default", + apiConfiguration: apiConfiguration, + }) + + const waitForStateUpdate = new Promise((resolve, reject) => { + const timeoutId = setTimeout(() => { + window.removeEventListener("message", messageHandler) + reject(new Error("Timeout waiting for state update")) + }, 10000) // 10 second timeout + + const messageHandler = (event: MessageEvent) => { + const message = event.data + if (message.type === "state") { + clearTimeout(timeoutId) + window.removeEventListener("message", messageHandler) + resolve() + } + } + window.addEventListener("message", messageHandler) + }) + + try { + await waitForStateUpdate + } catch (error) { + console.error("Failed to save configuration:", error) + } + }, [apiConfiguration]) + + const requestModels = useCallback(async () => { + vscode.postMessage({ type: "flushRouterModels", text: "unbound" }) + + const modelsPromise = new Promise((resolve) => { + const messageHandler = (event: MessageEvent) => { + const message = event.data + if (message.type === "routerModels") { + window.removeEventListener("message", messageHandler) + resolve() + } + } + window.addEventListener("message", messageHandler) + }) + + vscode.postMessage({ type: "requestRouterModels" }) + + await modelsPromise + + await queryClient.invalidateQueries({ queryKey: ["routerModels"] }) + + // After refreshing models, check if current model is in the updated list + // If not, select the first available model + const updatedModels = queryClient.getQueryData<{ unbound: RouterModels }>(["routerModels"])?.unbound + if (updatedModels && Object.keys(updatedModels).length > 0) { + const currentModelId = apiConfiguration?.unboundModelId + const modelExists = currentModelId && Object.prototype.hasOwnProperty.call(updatedModels, currentModelId) + + if (!currentModelId || !modelExists) { + const firstAvailableModelId = Object.keys(updatedModels)[0] + setApiConfigurationField("unboundModelId", firstAvailableModelId) + } + } + + if (!updatedModels || Object.keys(updatedModels).includes("error")) { + return false + } else { + return true + } + }, [queryClient, apiConfiguration, setApiConfigurationField]) + + const handleRefresh = useCallback(async () => { + await saveConfiguration() + const requestModelsResult = await requestModels() + + if (requestModelsResult) { + setDidRefetch(true) + didRefetchTimerRef.current = setTimeout(() => setDidRefetch(false), 3000) + } else { + setIsInvalidKey(true) + invalidKeyTimerRef.current = setTimeout(() => setIsInvalidKey(false), 3000) + } + }, [saveConfiguration, requestModels]) + return ( <> )} +
+ +
+ {didRefetch && ( +
+ {t("settings:providers.unboundRefreshModelsSuccess")} +
+ )} + {isInvalidKey && ( +
+ {t("settings:providers.unboundInvalidApiKey")} +
+ )}