diff --git a/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts index 746a013f0f1..12ffb025883 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts @@ -36,9 +36,10 @@ vi.mock("vscode", () => ({ // Mock modelCache getModels/flushModels used by the handler const getModelsMock = vi.fn() +const flushModelsMock = vi.fn() vi.mock("../../../api/providers/fetchers/modelCache", () => ({ getModels: (...args: any[]) => getModelsMock(...args), - flushModels: vi.fn(), + flushModels: (...args: any[]) => flushModelsMock(...args), })) describe("webviewMessageHandler - requestRouterModels provider filter", () => { @@ -164,4 +165,60 @@ describe("webviewMessageHandler - requestRouterModels provider filter", () => { const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider) expect(providersCalled).toEqual(["openrouter"]) }) + + it("flushes cache when LiteLLM credentials are provided in message values", async () => { + // Provide LiteLLM credentials via message.values (simulating Refresh Models button) + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + values: { + litellmApiKey: "test-api-key", + litellmBaseUrl: "http://localhost:4000", + }, + } as any, + ) + + // flushModels should have been called for litellm with refresh=true + expect(flushModelsMock).toHaveBeenCalledWith("litellm", true) + + // getModels should have been called with the provided credentials + const litellmCalls = getModelsMock.mock.calls.filter((c: any[]) => c[0]?.provider === "litellm") + expect(litellmCalls.length).toBe(1) + expect(litellmCalls[0][0]).toEqual({ + provider: "litellm", + apiKey: "test-api-key", + baseUrl: "http://localhost:4000", + }) + }) + + it("does not flush cache when using stored LiteLLM credentials", async () => { + // Provide stored credentials via apiConfiguration + mockProvider.getState.mockResolvedValue({ + apiConfiguration: { + litellmApiKey: "stored-api-key", + litellmBaseUrl: "http://stored:4000", + }, + }) + + await webviewMessageHandler( + mockProvider as any, + { + type: "requestRouterModels", + } as any, + ) + + // flushModels should NOT have been called for litellm + const litellmFlushCalls = flushModelsMock.mock.calls.filter((c: any[]) => c[0] === "litellm") + expect(litellmFlushCalls.length).toBe(0) + + // getModels should still have been called with stored credentials + const litellmCalls = getModelsMock.mock.calls.filter((c: any[]) => c[0]?.provider === "litellm") + expect(litellmCalls.length).toBe(1) + expect(litellmCalls[0][0]).toEqual({ + provider: "litellm", + apiKey: "stored-api-key", + baseUrl: "http://stored:4000", + }) + }) }) diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index c769fc05fd7..7ccee8c100d 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -882,6 +882,12 @@ export const webviewMessageHandler = async ( const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl if (litellmApiKey && litellmBaseUrl) { + // If explicit credentials are provided in message.values (from Refresh Models button), + // flush the cache first to ensure we fetch fresh data with the new credentials + if (message?.values?.litellmApiKey || message?.values?.litellmBaseUrl) { + await flushModels("litellm", true) + } + candidates.push({ key: "litellm", options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl },