diff --git a/packages/types/src/telemetry.ts b/packages/types/src/telemetry.ts index 29612d42a2f..233edfe499e 100644 --- a/packages/types/src/telemetry.ts +++ b/packages/types/src/telemetry.ts @@ -72,6 +72,7 @@ export enum TelemetryEventName { CONSECUTIVE_MISTAKE_ERROR = "Consecutive Mistake Error", CODE_INDEX_ERROR = "Code Index Error", TELEMETRY_SETTINGS_CHANGED = "Telemetry Settings Changed", + MODEL_CACHE_EMPTY_RESPONSE = "Model Cache Empty Response", } /** @@ -196,6 +197,7 @@ export const rooCodeTelemetryEventSchema = z.discriminatedUnion("type", [ TelemetryEventName.SHELL_INTEGRATION_ERROR, TelemetryEventName.CONSECUTIVE_MISTAKE_ERROR, TelemetryEventName.CODE_INDEX_ERROR, + TelemetryEventName.MODEL_CACHE_EMPTY_RESPONSE, TelemetryEventName.CONTEXT_CONDENSED, TelemetryEventName.SLIDING_WINDOW_TRUNCATION, TelemetryEventName.TAB_SHOWN, diff --git a/src/api/providers/fetchers/__tests__/modelCache.spec.ts b/src/api/providers/fetchers/__tests__/modelCache.spec.ts index 530ea8de977..a3df3e1e49c 100644 --- a/src/api/providers/fetchers/__tests__/modelCache.spec.ts +++ b/src/api/providers/fetchers/__tests__/modelCache.spec.ts @@ -1,5 +1,14 @@ // Mocks must come first, before imports +// Mock TelemetryService +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + instance: { + captureEvent: vi.fn(), + }, + }, +})) + // Mock NodeCache to allow controlling cache behavior vi.mock("node-cache", () => { const mockGet = vi.fn().mockReturnValue(undefined) @@ -301,3 +310,187 @@ describe("getModelsFromCache disk fallback", () => { consoleErrorSpy.mockRestore() }) }) + +describe("empty cache protection", () => { + let mockCache: any + let mockGet: Mock + let mockSet: Mock + + beforeEach(() => { + vi.clearAllMocks() + // Get the mock cache instance + const MockedNodeCache = vi.mocked(NodeCache) + mockCache = new MockedNodeCache() + mockGet = mockCache.get + mockSet = mockCache.set + // Reset memory cache to always miss by default + mockGet.mockReturnValue(undefined) + }) + + describe("getModels", () => { + it("does not cache empty API responses", async () => { + // API returns empty object (simulating failure) + mockGetOpenRouterModels.mockResolvedValue({}) + + const result = await getModels({ provider: "openrouter" }) + + // Should return empty but NOT cache it + expect(result).toEqual({}) + expect(mockSet).not.toHaveBeenCalled() + }) + + it("caches non-empty API responses", async () => { + const mockModels = { + "openrouter/model": { + maxTokens: 8192, + contextWindow: 128000, + supportsPromptCache: false, + description: "OpenRouter model", + }, + } + mockGetOpenRouterModels.mockResolvedValue(mockModels) + + const result = await getModels({ provider: "openrouter" }) + + expect(result).toEqual(mockModels) + expect(mockSet).toHaveBeenCalledWith("openrouter", mockModels) + }) + }) + + describe("refreshModels", () => { + it("keeps existing cache when API returns empty response", async () => { + const existingModels = { + "openrouter/existing-model": { + maxTokens: 8192, + contextWindow: 128000, + supportsPromptCache: false, + description: "Existing cached model", + }, + } + + // Memory cache has existing data + mockGet.mockReturnValue(existingModels) + // API returns empty (failure) + mockGetOpenRouterModels.mockResolvedValue({}) + + const { refreshModels } = await import("../modelCache") + const result = await refreshModels({ provider: "openrouter" }) + + // Should return existing cache, not empty + expect(result).toEqual(existingModels) + // Should NOT update cache with empty data + expect(mockSet).not.toHaveBeenCalled() + }) + + it("updates cache when API returns valid non-empty response", async () => { + const existingModels = { + "openrouter/old-model": { + maxTokens: 4096, + contextWindow: 64000, + supportsPromptCache: false, + description: "Old model", + }, + } + const newModels = { + "openrouter/new-model": { + maxTokens: 8192, + contextWindow: 128000, + supportsPromptCache: true, + description: "New model", + }, + } + + mockGet.mockReturnValue(existingModels) + mockGetOpenRouterModels.mockResolvedValue(newModels) + + const { refreshModels } = await import("../modelCache") + const result = await refreshModels({ provider: "openrouter" }) + + // Should return new models + expect(result).toEqual(newModels) + // Should update cache with new data + expect(mockSet).toHaveBeenCalledWith("openrouter", newModels) + }) + + it("returns existing cache on API error", async () => { + const existingModels = { + "openrouter/cached-model": { + maxTokens: 8192, + contextWindow: 128000, + supportsPromptCache: false, + description: "Cached model", + }, + } + + mockGet.mockReturnValue(existingModels) + mockGetOpenRouterModels.mockRejectedValue(new Error("API error")) + + const { refreshModels } = await import("../modelCache") + const result = await refreshModels({ provider: "openrouter" }) + + // Should return existing cache on error + expect(result).toEqual(existingModels) + }) + + it("returns empty object when API errors and no cache exists", async () => { + mockGet.mockReturnValue(undefined) + mockGetOpenRouterModels.mockRejectedValue(new Error("API error")) + + const { refreshModels } = await import("../modelCache") + const result = await refreshModels({ provider: "openrouter" }) + + // Should return empty when no cache and API fails + expect(result).toEqual({}) + }) + + it("does not cache empty response when no existing cache", async () => { + // Both memory and disk cache are empty (initial state) + mockGet.mockReturnValue(undefined) + // API returns empty (failure/rate limit) + mockGetOpenRouterModels.mockResolvedValue({}) + + const { refreshModels } = await import("../modelCache") + const result = await refreshModels({ provider: "openrouter" }) + + // Should return empty but NOT cache it + expect(result).toEqual({}) + expect(mockSet).not.toHaveBeenCalled() + }) + + it("reuses in-flight request for concurrent calls to same provider", async () => { + const mockModels = { + "openrouter/model": { + maxTokens: 8192, + contextWindow: 128000, + supportsPromptCache: false, + description: "OpenRouter model", + }, + } + + // Create a delayed response to simulate API latency + let resolvePromise: (value: typeof mockModels) => void + const delayedPromise = new Promise((resolve) => { + resolvePromise = resolve + }) + mockGetOpenRouterModels.mockReturnValue(delayedPromise) + mockGet.mockReturnValue(undefined) + + const { refreshModels } = await import("../modelCache") + + // Start two concurrent refresh calls + const promise1 = refreshModels({ provider: "openrouter" }) + const promise2 = refreshModels({ provider: "openrouter" }) + + // API should only be called once (second call reuses in-flight request) + expect(mockGetOpenRouterModels).toHaveBeenCalledTimes(1) + + // Resolve the API call + resolvePromise!(mockModels) + + // Both promises should resolve to the same result + const [result1, result2] = await Promise.all([promise1, promise2]) + expect(result1).toEqual(mockModels) + expect(result2).toEqual(mockModels) + }) + }) +}) diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 50edbf274a8..ec9c214cca3 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -6,7 +6,8 @@ import NodeCache from "node-cache" import { z } from "zod" import type { ProviderName } from "@roo-code/types" -import { modelInfoSchema } from "@roo-code/types" +import { modelInfoSchema, TelemetryEventName } from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" import { safeWriteJson } from "../../../utils/safeWriteJson" @@ -35,6 +36,10 @@ const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) // Zod schema for validating ModelRecord structure from disk cache const modelRecordSchema = z.record(z.string(), modelInfoSchema) +// Track in-flight refresh requests to prevent concurrent API calls for the same provider +// This prevents race conditions where multiple calls might overwrite each other's results +const inFlightRefresh = new Map>() + async function writeModels(router: RouterName, data: ModelRecord) { const filename = `${router}_models.json` const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath) @@ -139,20 +144,25 @@ export const getModels = async (options: GetModelsOptions): Promise try { models = await fetchModelsFromProvider(options) - - // Cache the fetched models (even if empty, to signify a successful fetch with no models). - memoryCache.set(provider, models) - - await writeModels(provider, models).catch((err) => - console.error(`[MODEL_CACHE] Error writing ${provider} models to file cache:`, err), - ) - - try { - models = await readModels(provider) - } catch (error) { - console.error(`[getModels] error reading ${provider} models from file cache`, error) + const modelCount = Object.keys(models).length + + // Only cache non-empty results to prevent persisting failed API responses + // Empty results could indicate API failure rather than "no models exist" + if (modelCount > 0) { + memoryCache.set(provider, models) + + await writeModels(provider, models).catch((err) => + console.error(`[MODEL_CACHE] Error writing ${provider} models to file cache:`, err), + ) + } else { + TelemetryService.instance.captureEvent(TelemetryEventName.MODEL_CACHE_EMPTY_RESPONSE, { + provider, + context: "getModels", + hasExistingCache: false, + }) } - return models || {} + + return models } catch (error) { // Log the error and re-throw it so the caller can handle it (e.g., show a UI message). console.error(`[getModels] Failed to fetch models in modelCache for ${provider}:`, error) @@ -164,31 +174,71 @@ export const getModels = async (options: GetModelsOptions): Promise /** * Force-refresh models from API, bypassing cache. * Uses atomic writes so cache remains available during refresh. + * This function also prevents concurrent API calls for the same provider using + * in-flight request tracking to avoid race conditions. * * @param options - Provider options for fetching models - * @returns Fresh models from API + * @returns Fresh models from API, or existing cache if refresh yields worse data */ export const refreshModels = async (options: GetModelsOptions): Promise => { const { provider } = options - try { - // Force fresh API fetch - skip getModelsFromCache() check - const models = await fetchModelsFromProvider(options) + // Check if there's already an in-flight refresh for this provider + // This prevents race conditions where multiple concurrent refreshes might + // overwrite each other's results + const existingRequest = inFlightRefresh.get(provider) + if (existingRequest) { + return existingRequest + } - // Update memory cache first - memoryCache.set(provider, models) + // Create the refresh promise and track it + const refreshPromise = (async (): Promise => { + try { + // Force fresh API fetch - skip getModelsFromCache() check + const models = await fetchModelsFromProvider(options) + const modelCount = Object.keys(models).length + + // Get existing cached data for comparison + const existingCache = getModelsFromCache(provider) + const existingCount = existingCache ? Object.keys(existingCache).length : 0 + + if (modelCount === 0) { + TelemetryService.instance.captureEvent(TelemetryEventName.MODEL_CACHE_EMPTY_RESPONSE, { + provider, + context: "refreshModels", + hasExistingCache: existingCount > 0, + existingCacheSize: existingCount, + }) + if (existingCount > 0) { + return existingCache! + } else { + return {} + } + } - // Atomically write to disk (safeWriteJson handles atomic writes) - await writeModels(provider, models).catch((err) => - console.error(`[refreshModels] Error writing ${provider} models to disk:`, err), - ) + // Update memory cache first + memoryCache.set(provider, models) - return models - } catch (error) { - console.debug(`[refreshModels] Failed to refresh ${provider}:`, error) - // On error, return existing cache if available (graceful degradation) - return getModelsFromCache(provider) || {} - } + // Atomically write to disk (safeWriteJson handles atomic writes) + await writeModels(provider, models).catch((err) => + console.error(`[refreshModels] Error writing ${provider} models to disk:`, err), + ) + + return models + } catch (error) { + // Log the error for debugging, then return existing cache if available (graceful degradation) + console.error(`[refreshModels] Failed to refresh ${provider} models:`, error) + return getModelsFromCache(provider) || {} + } finally { + // Always clean up the in-flight tracking + inFlightRefresh.delete(provider) + } + })() + + // Track the in-flight request + inFlightRefresh.set(provider, refreshPromise) + + return refreshPromise } /**