diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index 06dbc035025..944fb3d5a9c 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -136,6 +136,48 @@ describe("webviewMessageHandler - requestLmStudioModels", () => { }) }) +describe("webviewMessageHandler - requestOllamaModels", () => { + beforeEach(() => { + vi.clearAllMocks() + mockClineProvider.getState = vi.fn().mockResolvedValue({ + apiConfiguration: { + ollamaModelId: "model-1", + ollamaBaseUrl: "http://localhost:1234", + }, + }) + }) + + it("successfully fetches models from Ollama", async () => { + const mockModels: ModelRecord = { + "model-1": { + maxTokens: 4096, + contextWindow: 8192, + supportsPromptCache: false, + description: "Test model 1", + }, + "model-2": { + maxTokens: 8192, + contextWindow: 16384, + supportsPromptCache: false, + description: "Test model 2", + }, + } + + mockGetModels.mockResolvedValue(mockModels) + + await webviewMessageHandler(mockClineProvider, { + type: "requestOllamaModels", + }) + + expect(mockGetModels).toHaveBeenCalledWith({ provider: "ollama", baseUrl: "http://localhost:1234" }) + + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "ollamaModels", + ollamaModels: mockModels, + }) + }) +}) + describe("webviewMessageHandler - requestRouterModels", () => { beforeEach(() => { vi.clearAllMocks() diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index a495489cc1d..1173174888a 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -624,7 +624,7 @@ export const webviewMessageHandler = async ( if (routerName === "ollama" && Object.keys(result.value.models).length > 0) { provider.postMessageToWebview({ type: "ollamaModels", - ollamaModels: Object.keys(result.value.models), + ollamaModels: result.value.models, }) } else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) { provider.postMessageToWebview({ @@ -669,7 +669,7 @@ export const webviewMessageHandler = async ( if (Object.keys(ollamaModels).length > 0) { provider.postMessageToWebview({ type: "ollamaModels", - ollamaModels: Object.keys(ollamaModels), + ollamaModels: ollamaModels, }) } } catch (error) { diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index d4caf2f6746..c1838739be6 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -148,7 +148,7 @@ export interface ExtensionMessage { clineMessage?: ClineMessage routerModels?: RouterModels openAiModels?: string[] - ollamaModels?: string[] + ollamaModels?: ModelRecord lmStudioModels?: ModelRecord vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[] huggingFaceModels?: Array<{ diff --git a/webview-ui/src/components/settings/providers/Ollama.tsx b/webview-ui/src/components/settings/providers/Ollama.tsx index b09ecad5d62..b3ff00ccdda 100644 --- a/webview-ui/src/components/settings/providers/Ollama.tsx +++ b/webview-ui/src/components/settings/providers/Ollama.tsx @@ -11,6 +11,7 @@ import { useRouterModels } from "@src/components/ui/hooks/useRouterModels" import { vscode } from "@src/utils/vscode" import { inputEventTransform } from "../transforms" +import { ModelRecord } from "@roo/api" type OllamaProps = { apiConfiguration: ProviderSettings @@ -20,7 +21,7 @@ type OllamaProps = { export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaProps) => { const { t } = useAppTranslation() - const [ollamaModels, setOllamaModels] = useState([]) + const [ollamaModels, setOllamaModels] = useState({}) const routerModels = useRouterModels() const handleInputChange = useCallback( @@ -40,7 +41,7 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro switch (message.type) { case "ollamaModels": { - const newModels = message.ollamaModels ?? [] + const newModels = message.ollamaModels ?? {} setOllamaModels(newModels) } break @@ -61,7 +62,7 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro if (!selectedModel) return false // Check if model exists in local ollama models - if (ollamaModels.length > 0 && ollamaModels.includes(selectedModel)) { + if (Object.keys(ollamaModels).length > 0 && selectedModel in ollamaModels) { return false // Model is available locally } @@ -116,15 +117,13 @@ export const Ollama = ({ apiConfiguration, setApiConfigurationField }: OllamaPro )} - {ollamaModels.length > 0 && ( + {Object.keys(ollamaModels).length > 0 && ( - {ollamaModels.map((model) => ( + {Object.keys(ollamaModels).map((model) => ( {model} diff --git a/webview-ui/src/components/ui/hooks/useOllamaModels.ts b/webview-ui/src/components/ui/hooks/useOllamaModels.ts new file mode 100644 index 00000000000..67a172b0d83 --- /dev/null +++ b/webview-ui/src/components/ui/hooks/useOllamaModels.ts @@ -0,0 +1,39 @@ +import { useQuery } from "@tanstack/react-query" + +import { ModelRecord } from "@roo/api" +import { ExtensionMessage } from "@roo/ExtensionMessage" + +import { vscode } from "@src/utils/vscode" + +const getOllamaModels = async () => + new Promise((resolve, reject) => { + const cleanup = () => { + window.removeEventListener("message", handler) + } + + const timeout = setTimeout(() => { + cleanup() + reject(new Error("Ollama models request timed out")) + }, 10000) + + const handler = (event: MessageEvent) => { + const message: ExtensionMessage = event.data + + if (message.type === "ollamaModels") { + clearTimeout(timeout) + cleanup() + + if (message.ollamaModels) { + resolve(message.ollamaModels) + } else { + reject(new Error("No Ollama models in response")) + } + } + } + + window.addEventListener("message", handler) + vscode.postMessage({ type: "requestOllamaModels" }) + }) + +export const useOllamaModels = (modelId?: string) => + useQuery({ queryKey: ["ollamaModels"], queryFn: () => (modelId ? getOllamaModels() : {}) }) diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index e9470e09026..e69f78a6962 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -63,19 +63,23 @@ import type { ModelRecord, RouterModels } from "@roo/api" import { useRouterModels } from "./useRouterModels" import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders" import { useLmStudioModels } from "./useLmStudioModels" +import { useOllamaModels } from "./useOllamaModels" export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { const provider = apiConfiguration?.apiProvider || "anthropic" const openRouterModelId = provider === "openrouter" ? apiConfiguration?.openRouterModelId : undefined const lmStudioModelId = provider === "lmstudio" ? apiConfiguration?.lmStudioModelId : undefined + const ollamaModelId = provider === "ollama" ? apiConfiguration?.ollamaModelId : undefined const routerModels = useRouterModels() const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId) const lmStudioModels = useLmStudioModels(lmStudioModelId) + const ollamaModels = useOllamaModels(ollamaModelId) const { id, info } = apiConfiguration && (typeof lmStudioModelId === "undefined" || typeof lmStudioModels.data !== "undefined") && + (typeof ollamaModelId === "undefined" || typeof ollamaModels.data !== "undefined") && typeof routerModels.data !== "undefined" && typeof openRouterModelProviders.data !== "undefined" ? getSelectedModel({ @@ -84,6 +88,7 @@ export const useSelectedModel = (apiConfiguration?: ProviderSettings) => { routerModels: routerModels.data, openRouterModelProviders: openRouterModelProviders.data, lmStudioModels: lmStudioModels.data, + ollamaModels: ollamaModels.data, }) : { id: anthropicDefaultModelId, info: undefined } @@ -108,12 +113,14 @@ function getSelectedModel({ routerModels, openRouterModelProviders, lmStudioModels, + ollamaModels, }: { provider: ProviderName apiConfiguration: ProviderSettings routerModels: RouterModels openRouterModelProviders: Record lmStudioModels: ModelRecord | undefined + ollamaModels: ModelRecord | undefined }): { id: string; info: ModelInfo | undefined } { // the `undefined` case are used to show the invalid selection to prevent // users from seeing the default model if their selection is invalid @@ -254,7 +261,7 @@ function getSelectedModel({ } case "ollama": { const id = apiConfiguration.ollamaModelId ?? "" - const info = routerModels.ollama && routerModels.ollama[id] + const info = ollamaModels && ollamaModels[apiConfiguration.ollamaModelId!] return { id, info: info || undefined,