Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/types/src/telemetry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export enum TelemetryEventName {
CONSECUTIVE_MISTAKE_ERROR = "Consecutive Mistake Error",
CODE_INDEX_ERROR = "Code Index Error",
TELEMETRY_SETTINGS_CHANGED = "Telemetry Settings Changed",
MODEL_CACHE_EMPTY_RESPONSE = "Model Cache Empty Response",
}

/**
Expand Down Expand Up @@ -196,6 +197,7 @@ export const rooCodeTelemetryEventSchema = z.discriminatedUnion("type", [
TelemetryEventName.SHELL_INTEGRATION_ERROR,
TelemetryEventName.CONSECUTIVE_MISTAKE_ERROR,
TelemetryEventName.CODE_INDEX_ERROR,
TelemetryEventName.MODEL_CACHE_EMPTY_RESPONSE,
TelemetryEventName.CONTEXT_CONDENSED,
TelemetryEventName.SLIDING_WINDOW_TRUNCATION,
TelemetryEventName.TAB_SHOWN,
Expand Down
193 changes: 193 additions & 0 deletions src/api/providers/fetchers/__tests__/modelCache.spec.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
// Mocks must come first, before imports

// Mock TelemetryService
vi.mock("@roo-code/telemetry", () => ({
TelemetryService: {
instance: {
captureEvent: vi.fn(),
},
},
}))

// Mock NodeCache to allow controlling cache behavior
vi.mock("node-cache", () => {
const mockGet = vi.fn().mockReturnValue(undefined)
Expand Down Expand Up @@ -301,3 +310,187 @@ describe("getModelsFromCache disk fallback", () => {
consoleErrorSpy.mockRestore()
})
})

describe("empty cache protection", () => {
let mockCache: any
let mockGet: Mock
let mockSet: Mock

beforeEach(() => {
vi.clearAllMocks()
// Get the mock cache instance
const MockedNodeCache = vi.mocked(NodeCache)
mockCache = new MockedNodeCache()
mockGet = mockCache.get
mockSet = mockCache.set
// Reset memory cache to always miss by default
mockGet.mockReturnValue(undefined)
})

describe("getModels", () => {
it("does not cache empty API responses", async () => {
// API returns empty object (simulating failure)
mockGetOpenRouterModels.mockResolvedValue({})

const result = await getModels({ provider: "openrouter" })

// Should return empty but NOT cache it
expect(result).toEqual({})
expect(mockSet).not.toHaveBeenCalled()
})

it("caches non-empty API responses", async () => {
const mockModels = {
"openrouter/model": {
maxTokens: 8192,
contextWindow: 128000,
supportsPromptCache: false,
description: "OpenRouter model",
},
}
mockGetOpenRouterModels.mockResolvedValue(mockModels)

const result = await getModels({ provider: "openrouter" })

expect(result).toEqual(mockModels)
expect(mockSet).toHaveBeenCalledWith("openrouter", mockModels)
})
})

describe("refreshModels", () => {
it("keeps existing cache when API returns empty response", async () => {
const existingModels = {
"openrouter/existing-model": {
maxTokens: 8192,
contextWindow: 128000,
supportsPromptCache: false,
description: "Existing cached model",
},
}

// Memory cache has existing data
mockGet.mockReturnValue(existingModels)
// API returns empty (failure)
mockGetOpenRouterModels.mockResolvedValue({})

const { refreshModels } = await import("../modelCache")
const result = await refreshModels({ provider: "openrouter" })

// Should return existing cache, not empty
expect(result).toEqual(existingModels)
// Should NOT update cache with empty data
expect(mockSet).not.toHaveBeenCalled()
})

it("updates cache when API returns valid non-empty response", async () => {
const existingModels = {
"openrouter/old-model": {
maxTokens: 4096,
contextWindow: 64000,
supportsPromptCache: false,
description: "Old model",
},
}
const newModels = {
"openrouter/new-model": {
maxTokens: 8192,
contextWindow: 128000,
supportsPromptCache: true,
description: "New model",
},
}

mockGet.mockReturnValue(existingModels)
mockGetOpenRouterModels.mockResolvedValue(newModels)

const { refreshModels } = await import("../modelCache")
const result = await refreshModels({ provider: "openrouter" })

// Should return new models
expect(result).toEqual(newModels)
// Should update cache with new data
expect(mockSet).toHaveBeenCalledWith("openrouter", newModels)
})

it("returns existing cache on API error", async () => {
const existingModels = {
"openrouter/cached-model": {
maxTokens: 8192,
contextWindow: 128000,
supportsPromptCache: false,
description: "Cached model",
},
}

mockGet.mockReturnValue(existingModels)
mockGetOpenRouterModels.mockRejectedValue(new Error("API error"))

const { refreshModels } = await import("../modelCache")
const result = await refreshModels({ provider: "openrouter" })

// Should return existing cache on error
expect(result).toEqual(existingModels)
})

it("returns empty object when API errors and no cache exists", async () => {
mockGet.mockReturnValue(undefined)
mockGetOpenRouterModels.mockRejectedValue(new Error("API error"))

const { refreshModels } = await import("../modelCache")
const result = await refreshModels({ provider: "openrouter" })

// Should return empty when no cache and API fails
expect(result).toEqual({})
})

it("does not cache empty response when no existing cache", async () => {
// Both memory and disk cache are empty (initial state)
mockGet.mockReturnValue(undefined)
// API returns empty (failure/rate limit)
mockGetOpenRouterModels.mockResolvedValue({})

const { refreshModels } = await import("../modelCache")
const result = await refreshModels({ provider: "openrouter" })

// Should return empty but NOT cache it
expect(result).toEqual({})
expect(mockSet).not.toHaveBeenCalled()
})

it("reuses in-flight request for concurrent calls to same provider", async () => {
const mockModels = {
"openrouter/model": {
maxTokens: 8192,
contextWindow: 128000,
supportsPromptCache: false,
description: "OpenRouter model",
},
}

// Create a delayed response to simulate API latency
let resolvePromise: (value: typeof mockModels) => void
const delayedPromise = new Promise<typeof mockModels>((resolve) => {
resolvePromise = resolve
})
mockGetOpenRouterModels.mockReturnValue(delayedPromise)
mockGet.mockReturnValue(undefined)

const { refreshModels } = await import("../modelCache")

// Start two concurrent refresh calls
const promise1 = refreshModels({ provider: "openrouter" })
const promise2 = refreshModels({ provider: "openrouter" })

// API should only be called once (second call reuses in-flight request)
expect(mockGetOpenRouterModels).toHaveBeenCalledTimes(1)

// Resolve the API call
resolvePromise!(mockModels)

// Both promises should resolve to the same result
const [result1, result2] = await Promise.all([promise1, promise2])
expect(result1).toEqual(mockModels)
expect(result2).toEqual(mockModels)
})
})
})
110 changes: 80 additions & 30 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import NodeCache from "node-cache"
import { z } from "zod"

import type { ProviderName } from "@roo-code/types"
import { modelInfoSchema } from "@roo-code/types"
import { modelInfoSchema, TelemetryEventName } from "@roo-code/types"
import { TelemetryService } from "@roo-code/telemetry"

import { safeWriteJson } from "../../../utils/safeWriteJson"

Expand Down Expand Up @@ -35,6 +36,10 @@ const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
// Zod schema for validating ModelRecord structure from disk cache
const modelRecordSchema = z.record(z.string(), modelInfoSchema)

// Track in-flight refresh requests to prevent concurrent API calls for the same provider
// This prevents race conditions where multiple calls might overwrite each other's results
const inFlightRefresh = new Map<RouterName, Promise<ModelRecord>>()

async function writeModels(router: RouterName, data: ModelRecord) {
const filename = `${router}_models.json`
const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath)
Expand Down Expand Up @@ -139,20 +144,25 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>

try {
models = await fetchModelsFromProvider(options)

// Cache the fetched models (even if empty, to signify a successful fetch with no models).
memoryCache.set(provider, models)

await writeModels(provider, models).catch((err) =>
console.error(`[MODEL_CACHE] Error writing ${provider} models to file cache:`, err),
)

try {
models = await readModels(provider)
} catch (error) {
console.error(`[getModels] error reading ${provider} models from file cache`, error)
const modelCount = Object.keys(models).length

// Only cache non-empty results to prevent persisting failed API responses
// Empty results could indicate API failure rather than "no models exist"
if (modelCount > 0) {
memoryCache.set(provider, models)

await writeModels(provider, models).catch((err) =>
console.error(`[MODEL_CACHE] Error writing ${provider} models to file cache:`, err),
)
} else {
TelemetryService.instance.captureEvent(TelemetryEventName.MODEL_CACHE_EMPTY_RESPONSE, {
provider,
context: "getModels",
hasExistingCache: false,
})
}
return models || {}

return models
} catch (error) {
// Log the error and re-throw it so the caller can handle it (e.g., show a UI message).
console.error(`[getModels] Failed to fetch models in modelCache for ${provider}:`, error)
Expand All @@ -164,31 +174,71 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
/**
* Force-refresh models from API, bypassing cache.
* Uses atomic writes so cache remains available during refresh.
* This function also prevents concurrent API calls for the same provider using
* in-flight request tracking to avoid race conditions.
*
* @param options - Provider options for fetching models
* @returns Fresh models from API
* @returns Fresh models from API, or existing cache if refresh yields worse data
*/
export const refreshModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
const { provider } = options

try {
// Force fresh API fetch - skip getModelsFromCache() check
const models = await fetchModelsFromProvider(options)
// Check if there's already an in-flight refresh for this provider
// This prevents race conditions where multiple concurrent refreshes might
// overwrite each other's results
const existingRequest = inFlightRefresh.get(provider)
if (existingRequest) {
return existingRequest
}

// Update memory cache first
memoryCache.set(provider, models)
// Create the refresh promise and track it
const refreshPromise = (async (): Promise<ModelRecord> => {
try {
// Force fresh API fetch - skip getModelsFromCache() check
const models = await fetchModelsFromProvider(options)
const modelCount = Object.keys(models).length

// Get existing cached data for comparison
const existingCache = getModelsFromCache(provider)
const existingCount = existingCache ? Object.keys(existingCache).length : 0

if (modelCount === 0) {
TelemetryService.instance.captureEvent(TelemetryEventName.MODEL_CACHE_EMPTY_RESPONSE, {
provider,
context: "refreshModels",
hasExistingCache: existingCount > 0,
existingCacheSize: existingCount,
})
if (existingCount > 0) {
return existingCache!
} else {
return {}
}
}

// Atomically write to disk (safeWriteJson handles atomic writes)
await writeModels(provider, models).catch((err) =>
console.error(`[refreshModels] Error writing ${provider} models to disk:`, err),
)
// Update memory cache first
memoryCache.set(provider, models)

return models
} catch (error) {
console.debug(`[refreshModels] Failed to refresh ${provider}:`, error)
// On error, return existing cache if available (graceful degradation)
return getModelsFromCache(provider) || {}
}
// Atomically write to disk (safeWriteJson handles atomic writes)
await writeModels(provider, models).catch((err) =>
console.error(`[refreshModels] Error writing ${provider} models to disk:`, err),
)

return models
} catch (error) {
// Log the error for debugging, then return existing cache if available (graceful degradation)
console.error(`[refreshModels] Failed to refresh ${provider} models:`, error)
return getModelsFromCache(provider) || {}
} finally {
// Always clean up the in-flight tracking
inFlightRefresh.delete(provider)
}
})()

// Track the in-flight request
inFlightRefresh.set(provider, refreshPromise)

return refreshPromise
}

/**
Expand Down
Loading