diff --git a/src/__tests__/extension.spec.ts b/src/__tests__/extension.spec.ts index 3144a717d26..5b072672699 100644 --- a/src/__tests__/extension.spec.ts +++ b/src/__tests__/extension.spec.ts @@ -384,7 +384,7 @@ describe("extension.ts", () => { }) // Verify flushModels was called to clear the cache on logout - expect(flushModels).toHaveBeenCalledWith("roo", false) + expect(flushModels).toHaveBeenCalledWith({ provider: "roo" }, false) }) }) }) diff --git a/src/api/providers/__tests__/deepinfra.spec.ts b/src/api/providers/__tests__/deepinfra.spec.ts index 92b2362b3ec..1df6ffee60a 100644 --- a/src/api/providers/__tests__/deepinfra.spec.ts +++ b/src/api/providers/__tests__/deepinfra.spec.ts @@ -26,6 +26,7 @@ vitest.mock("../fetchers/modelCache", () => ({ getModels: vitest.fn().mockResolvedValue({ [deepInfraDefaultModelId]: deepInfraDefaultModelInfo, }), + getModelsFromCache: vitest.fn().mockReturnValue(undefined), })) import OpenAI from "openai" diff --git a/src/api/providers/__tests__/lite-llm.spec.ts b/src/api/providers/__tests__/lite-llm.spec.ts index fe62ad3922c..a95118469e2 100644 --- a/src/api/providers/__tests__/lite-llm.spec.ts +++ b/src/api/providers/__tests__/lite-llm.spec.ts @@ -42,6 +42,7 @@ vi.mock("../fetchers/modelCache", () => ({ "gpt-4-turbo": { ...litellmDefaultModelInfo, maxTokens: 8192 }, }) }), + getModelsFromCache: vi.fn().mockReturnValue(undefined), })) describe("LiteLLMHandler", () => { diff --git a/src/api/providers/__tests__/unbound.spec.ts b/src/api/providers/__tests__/unbound.spec.ts index e5d76b48336..f03442a7040 100644 --- a/src/api/providers/__tests__/unbound.spec.ts +++ b/src/api/providers/__tests__/unbound.spec.ts @@ -71,6 +71,7 @@ vitest.mock("../fetchers/modelCache", () => ({ }, }) }), + getModelsFromCache: vitest.fn().mockReturnValue(undefined), })) // Mock OpenAI client diff --git a/src/api/providers/__tests__/vercel-ai-gateway.spec.ts b/src/api/providers/__tests__/vercel-ai-gateway.spec.ts index 63e2f93d0f5..3c6b1c10696 100644 --- a/src/api/providers/__tests__/vercel-ai-gateway.spec.ts +++ b/src/api/providers/__tests__/vercel-ai-gateway.spec.ts @@ -51,6 +51,7 @@ vitest.mock("../fetchers/modelCache", () => ({ }, }) }), + getModelsFromCache: vitest.fn().mockReturnValue(undefined), })) vitest.mock("../../transform/caching/vercel-ai-gateway", () => ({ diff --git a/src/api/providers/fetchers/lmstudio.ts b/src/api/providers/fetchers/lmstudio.ts index 3068a962d85..73cb60e88e3 100644 --- a/src/api/providers/fetchers/lmstudio.ts +++ b/src/api/providers/fetchers/lmstudio.ts @@ -19,7 +19,7 @@ export const forceFullModelDetailsLoad = async (baseUrl: string, modelId: string const client = new LMStudioClient({ baseUrl: lmsUrl }) await client.llm.model(modelId) // Flush and refresh cache to get updated model details - await flushModels("lmstudio", true) + await flushModels({ provider: "lmstudio", baseUrl }, true) // Mark this model as having full details loaded. modelsWithLoadedDetails.add(modelId) diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 6bf31b64c1f..d22abf9c91c 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -267,20 +267,20 @@ export async function initializeModelCacheRefresh(): Promise { /** * Flush models memory cache for a specific router. * - * @param router - The router to flush models for. + * @param options - The options for fetching models, including provider, apiKey, and baseUrl * @param refresh - If true, immediately fetch fresh data from API */ -export const flushModels = async (router: RouterName, refresh: boolean = false): Promise => { +export const flushModels = async (options: GetModelsOptions, refresh: boolean = false): Promise => { + const { provider } = options if (refresh) { // Don't delete memory cache - let refreshModels atomically replace it // This prevents a race condition where getModels() might be called // before refresh completes, avoiding a gap in cache availability - refreshModels({ provider: router } as GetModelsOptions).catch((error) => { - console.error(`[flushModels] Refresh failed for ${router}:`, error) - }) + // Await the refresh to ensure the cache is updated before returning + await refreshModels(options) } else { // Only delete memory cache when not refreshing - memoryCache.del(router) + memoryCache.del(provider) } } diff --git a/src/api/providers/router-provider.ts b/src/api/providers/router-provider.ts index 25e9a11e1b2..e43f49aa2c2 100644 --- a/src/api/providers/router-provider.ts +++ b/src/api/providers/router-provider.ts @@ -5,7 +5,7 @@ import type { ModelInfo } from "@roo-code/types" import { ApiHandlerOptions, RouterName, ModelRecord } from "../../shared/api" import { BaseProvider } from "./base-provider" -import { getModels } from "./fetchers/modelCache" +import { getModels, getModelsFromCache } from "./fetchers/modelCache" import { DEFAULT_HEADERS } from "./constants" @@ -63,9 +63,22 @@ export abstract class RouterProvider extends BaseProvider { override getModel(): { id: string; info: ModelInfo } { const id = this.modelId ?? this.defaultModelId - return this.models[id] - ? { id, info: this.models[id] } - : { id: this.defaultModelId, info: this.defaultModelInfo } + // First check instance models (populated by fetchModel) + if (this.models[id]) { + return { id, info: this.models[id] } + } + + // Fall back to global cache (synchronous disk/memory cache) + // This ensures models are available before fetchModel() is called + const cachedModels = getModelsFromCache(this.name) + if (cachedModels?.[id]) { + // Also populate instance models for future calls + this.models = cachedModels + return { id, info: cachedModels[id] } + } + + // Last resort: return default model + return { id: this.defaultModelId, info: this.defaultModelInfo } } protected supportsTemperature(modelId: string): boolean { diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index 85e144290f0..9eb01406ff5 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -237,6 +237,7 @@ vi.mock("../../../integrations/misc/extract-text", () => ({ vi.mock("../../../api/providers/fetchers/modelCache", () => ({ getModels: vi.fn().mockResolvedValue({}), flushModels: vi.fn(), + getModelsFromCache: vi.fn().mockReturnValue(undefined), })) vi.mock("../../../shared/modes", () => ({ @@ -308,6 +309,7 @@ vi.mock("../../../integrations/misc/extract-text", () => ({ vi.mock("../../../api/providers/fetchers/modelCache", () => ({ getModels: vi.fn().mockResolvedValue({}), flushModels: vi.fn(), + getModelsFromCache: vi.fn().mockReturnValue(undefined), })) vi.mock("../diff/strategies/multi-search-replace", () => ({ diff --git a/src/core/webview/__tests__/ClineProvider.sticky-mode.spec.ts b/src/core/webview/__tests__/ClineProvider.sticky-mode.spec.ts index 7ff767a9733..3f820aace15 100644 --- a/src/core/webview/__tests__/ClineProvider.sticky-mode.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.sticky-mode.spec.ts @@ -151,6 +151,7 @@ vi.mock("../../prompts/system", () => ({ vi.mock("../../../api/providers/fetchers/modelCache", () => ({ getModels: vi.fn().mockResolvedValue({}), flushModels: vi.fn(), + getModelsFromCache: vi.fn().mockReturnValue(undefined), })) vi.mock("../../../integrations/misc/extract-text", () => ({ diff --git a/src/core/webview/__tests__/webviewMessageHandler.edit.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.edit.spec.ts index dfbce361e45..5b89c723d4c 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.edit.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.edit.spec.ts @@ -29,6 +29,7 @@ vi.mock("../../task-persistence", () => ({ vi.mock("../../../api/providers/fetchers/modelCache", () => ({ getModels: vi.fn(), flushModels: vi.fn(), + getModelsFromCache: vi.fn().mockReturnValue(undefined), })) vi.mock("../checkpointRestoreHandler", () => ({ diff --git a/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts index 7a69adbdde6..df2616a8425 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.routerModels.spec.ts @@ -177,8 +177,11 @@ describe("webviewMessageHandler - requestRouterModels provider filter", () => { } as any, ) - // flushModels should have been called for litellm with refresh=true - expect(flushModelsMock).toHaveBeenCalledWith("litellm", true) + // flushModels should have been called for litellm with refresh=true and credentials + expect(flushModelsMock).toHaveBeenCalledWith( + { provider: "litellm", apiKey: "test-api-key", baseUrl: "http://localhost:4000" }, + true, + ) // getModels should have been called with the provided credentials const litellmCalls = getModelsMock.mock.calls.filter((c: any[]) => c[0]?.provider === "litellm") diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index e1640d3f2a8..e408414879f 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -777,7 +777,9 @@ export const webviewMessageHandler = async ( break case "flushRouterModels": const routerNameFlush: RouterName = toRouterName(message.text) - await flushModels(routerNameFlush, true) + // Note: flushRouterModels is a generic flush without credentials + // For providers that need credentials, use their specific handlers + await flushModels({ provider: routerNameFlush } as GetModelsOptions, true) break case "requestRouterModels": const { apiConfiguration } = await provider.getState() @@ -869,7 +871,7 @@ export const webviewMessageHandler = async ( // If explicit credentials are provided in message.values (from Refresh Models button), // flush the cache first to ensure we fetch fresh data with the new credentials if (message?.values?.litellmApiKey || message?.values?.litellmBaseUrl) { - await flushModels("litellm", true) + await flushModels({ provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl }, true) } candidates.push({ @@ -923,14 +925,15 @@ export const webviewMessageHandler = async ( // Specific handler for Ollama models only. const { apiConfiguration: ollamaApiConfig } = await provider.getState() try { - // Flush cache and refresh to ensure fresh models. - await flushModels("ollama", true) - - const ollamaModels = await getModels({ - provider: "ollama", + const ollamaOptions = { + provider: "ollama" as const, baseUrl: ollamaApiConfig.ollamaBaseUrl, apiKey: ollamaApiConfig.ollamaApiKey, - }) + } + // Flush cache and refresh to ensure fresh models. + await flushModels(ollamaOptions, true) + + const ollamaModels = await getModels(ollamaOptions) if (Object.keys(ollamaModels).length > 0) { provider.postMessageToWebview({ type: "ollamaModels", ollamaModels: ollamaModels }) @@ -945,13 +948,14 @@ export const webviewMessageHandler = async ( // Specific handler for LM Studio models only. const { apiConfiguration: lmStudioApiConfig } = await provider.getState() try { + const lmStudioOptions = { + provider: "lmstudio" as const, + baseUrl: lmStudioApiConfig.lmStudioBaseUrl, + } // Flush cache and refresh to ensure fresh models. - await flushModels("lmstudio", true) + await flushModels(lmStudioOptions, true) - const lmStudioModels = await getModels({ - provider: "lmstudio", - baseUrl: lmStudioApiConfig.lmStudioBaseUrl, - }) + const lmStudioModels = await getModels(lmStudioOptions) if (Object.keys(lmStudioModels).length > 0) { provider.postMessageToWebview({ @@ -968,16 +972,17 @@ export const webviewMessageHandler = async ( case "requestRooModels": { // Specific handler for Roo models only - flushes cache to ensure fresh auth token is used try { - // Flush cache and refresh to ensure fresh models with current auth state - await flushModels("roo", true) - - const rooModels = await getModels({ - provider: "roo", + const rooOptions = { + provider: "roo" as const, baseUrl: process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy", apiKey: CloudService.hasInstance() ? CloudService.instance.authService?.getSessionToken() : undefined, - }) + } + // Flush cache and refresh to ensure fresh models with current auth state + await flushModels(rooOptions, true) + + const rooModels = await getModels(rooOptions) // Always send a response, even if no models are returned provider.postMessageToWebview({ diff --git a/src/extension.ts b/src/extension.ts index dcf53a2aeaf..f601b9d7cff 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -161,7 +161,7 @@ export async function activate(context: vscode.ExtensionContext) { }) } else { // Flush without refresh on logout - await flushModels("roo", false) + await flushModels({ provider: "roo" }, false) } } catch (error) { cloudLogger(