Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/api/providers/fetchers/lmstudio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ export const forceFullModelDetailsLoad = async (baseUrl: string, modelId: string

const client = new LMStudioClient({ baseUrl: lmsUrl })
await client.llm.model(modelId)
await flushModels("lmstudio")
await getModels({ provider: "lmstudio" }) // Force cache update now.
// Flush and refresh cache to get updated model details
await flushModels("lmstudio", true)

// Mark this model as having full details loaded.
modelsWithLoadedDetails.add(modelId)
Expand Down
194 changes: 139 additions & 55 deletions src/api/providers/fetchers/modelCache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,74 @@ async function readModels(router: RouterName): Promise<ModelRecord | undefined>
return exists ? JSON.parse(await fs.readFile(filePath, "utf8")) : undefined
}

/**
* Fetch models from the provider API.
* Extracted to avoid duplication between getModels() and refreshModels().
*
* @param options - Provider options for fetching models
* @returns Fresh models from the provider API
*/
async function fetchModelsFromProvider(options: GetModelsOptions): Promise<ModelRecord> {
const { provider } = options

let models: ModelRecord

switch (provider) {
case "openrouter":
models = await getOpenRouterModels()
break
case "requesty":
// Requesty models endpoint requires an API key for per-user custom policies.
models = await getRequestyModels(options.baseUrl, options.apiKey)
break
case "glama":
models = await getGlamaModels()
break
case "unbound":
// Unbound models endpoint requires an API key to fetch application specific models.
models = await getUnboundModels(options.apiKey)
break
case "litellm":
// Type safety ensures apiKey and baseUrl are always provided for LiteLLM.
models = await getLiteLLMModels(options.apiKey, options.baseUrl)
break
case "ollama":
models = await getOllamaModels(options.baseUrl, options.apiKey)
break
case "lmstudio":
models = await getLMStudioModels(options.baseUrl)
break
case "deepinfra":
models = await getDeepInfraModels(options.apiKey, options.baseUrl)
break
case "io-intelligence":
models = await getIOIntelligenceModels(options.apiKey)
break
case "vercel-ai-gateway":
models = await getVercelAiGatewayModels()
break
case "huggingface":
models = await getHuggingFaceModels()
break
case "roo": {
// Roo Code Cloud provider requires baseUrl and optional apiKey
const rooBaseUrl = options.baseUrl ?? process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy"
models = await getRooModels(rooBaseUrl, options.apiKey)
break
}
case "chutes":
models = await getChutesModels(options.apiKey)
break
default: {
// Ensures router is exhaustively checked if RouterName is a strict union.
const exhaustiveCheck: never = provider
throw new Error(`Unknown provider: ${exhaustiveCheck}`)
}
}

return models
}

/**
* Get models from the cache or fetch them from the provider and cache them.
* There are two caches:
Expand All @@ -70,59 +138,7 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
}

try {
switch (provider) {
case "openrouter":
models = await getOpenRouterModels()
break
case "requesty":
// Requesty models endpoint requires an API key for per-user custom policies.
models = await getRequestyModels(options.baseUrl, options.apiKey)
break
case "glama":
models = await getGlamaModels()
break
case "unbound":
// Unbound models endpoint requires an API key to fetch application specific models.
models = await getUnboundModels(options.apiKey)
break
case "litellm":
// Type safety ensures apiKey and baseUrl are always provided for LiteLLM.
models = await getLiteLLMModels(options.apiKey, options.baseUrl)
break
case "ollama":
models = await getOllamaModels(options.baseUrl, options.apiKey)
break
case "lmstudio":
models = await getLMStudioModels(options.baseUrl)
break
case "deepinfra":
models = await getDeepInfraModels(options.apiKey, options.baseUrl)
break
case "io-intelligence":
models = await getIOIntelligenceModels(options.apiKey)
break
case "vercel-ai-gateway":
models = await getVercelAiGatewayModels()
break
case "huggingface":
models = await getHuggingFaceModels()
break
case "roo": {
// Roo Code Cloud provider requires baseUrl and optional apiKey
const rooBaseUrl =
options.baseUrl ?? process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy"
models = await getRooModels(rooBaseUrl, options.apiKey)
break
}
case "chutes":
models = await getChutesModels(options.apiKey)
break
default: {
// Ensures router is exhaustively checked if RouterName is a strict union.
const exhaustiveCheck: never = provider
throw new Error(`Unknown provider: ${exhaustiveCheck}`)
}
}
models = await fetchModelsFromProvider(options)

// Cache the fetched models (even if empty, to signify a successful fetch with no models).
memoryCache.set(provider, models)
Expand All @@ -145,13 +161,81 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
}
}

/**
* Force-refresh models from API, bypassing cache.
* Uses atomic writes so cache remains available during refresh.
*
* @param options - Provider options for fetching models
* @returns Fresh models from API
*/
export const refreshModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
const { provider } = options

try {
// Force fresh API fetch - skip getModelsFromCache() check
const models = await fetchModelsFromProvider(options)

// Update memory cache first
memoryCache.set(provider, models)

// Atomically write to disk (safeWriteJson handles atomic writes)
await writeModels(provider, models).catch((err) =>
console.error(`[refreshModels] Error writing ${provider} models to disk:`, err),
)

return models
} catch (error) {
console.debug(`[refreshModels] Failed to refresh ${provider}:`, error)
// On error, return existing cache if available (graceful degradation)
return getModelsFromCache(provider) || {}
}
}

/**
* Initialize background model cache refresh.
* Refreshes public provider caches without blocking or requiring auth.
* Should be called once during extension activation.
*/
export async function initializeModelCacheRefresh(): Promise<void> {
// Wait for extension to fully activate before refreshing
setTimeout(async () => {
// Providers that work without API keys
const publicProviders: Array<{ provider: RouterName; options: GetModelsOptions }> = [
{ provider: "openrouter", options: { provider: "openrouter" } },
{ provider: "glama", options: { provider: "glama" } },
{ provider: "vercel-ai-gateway", options: { provider: "vercel-ai-gateway" } },
]

// Refresh each provider in background (fire and forget)
for (const { options } of publicProviders) {
refreshModels(options).catch(() => {
// Silent fail - old cache remains available
})

// Small delay between refreshes to avoid API rate limits
await new Promise((resolve) => setTimeout(resolve, 500))
}
}, 2000)
}

/**
* Flush models memory cache for a specific router.
*
* @param router - The router to flush models for.
* @param refresh - If true, immediately fetch fresh data from API
*/
export const flushModels = async (router: RouterName) => {
memoryCache.del(router)
export const flushModels = async (router: RouterName, refresh: boolean = false): Promise<void> => {
if (refresh) {
// Don't delete memory cache - let refreshModels atomically replace it
// This prevents a race condition where getModels() might be called
// before refresh completes, avoiding a gap in cache availability
refreshModels({ provider: router } as GetModelsOptions).catch((error) => {
console.error(`[flushModels] Refresh failed for ${router}:`, error)
})
} else {
// Only delete memory cache when not refreshing
memoryCache.del(router)
}
}

/**
Expand Down
14 changes: 7 additions & 7 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ export const webviewMessageHandler = async (
break
case "flushRouterModels":
const routerNameFlush: RouterName = toRouterName(message.text)
await flushModels(routerNameFlush)
await flushModels(routerNameFlush, true)
break
case "requestRouterModels":
const { apiConfiguration } = await provider.getState()
Expand Down Expand Up @@ -932,8 +932,8 @@ export const webviewMessageHandler = async (
// Specific handler for Ollama models only.
const { apiConfiguration: ollamaApiConfig } = await provider.getState()
try {
// Flush cache first to ensure fresh models.
await flushModels("ollama")
// Flush cache and refresh to ensure fresh models.
await flushModels("ollama", true)

const ollamaModels = await getModels({
provider: "ollama",
Expand All @@ -954,8 +954,8 @@ export const webviewMessageHandler = async (
// Specific handler for LM Studio models only.
const { apiConfiguration: lmStudioApiConfig } = await provider.getState()
try {
// Flush cache first to ensure fresh models.
await flushModels("lmstudio")
// Flush cache and refresh to ensure fresh models.
await flushModels("lmstudio", true)

const lmStudioModels = await getModels({
provider: "lmstudio",
Expand All @@ -977,8 +977,8 @@ export const webviewMessageHandler = async (
case "requestRooModels": {
// Specific handler for Roo models only - flushes cache to ensure fresh auth token is used
try {
// Flush cache first to ensure fresh models with current auth state
await flushModels("roo")
// Flush cache and refresh to ensure fresh models with current auth state
await flushModels("roo", true)

const rooModels = await getModels({
provider: "roo",
Expand Down
17 changes: 7 additions & 10 deletions src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import {
CodeActionProvider,
} from "./activate"
import { initializeI18n } from "./i18n"
import { flushModels, getModels } from "./api/providers/fetchers/modelCache"
import { flushModels, getModels, initializeModelCacheRefresh } from "./api/providers/fetchers/modelCache"

/**
* Built using https://github.com/microsoft/vscode-webview-ui-toolkit
Expand Down Expand Up @@ -145,17 +145,11 @@ export async function activate(context: vscode.ExtensionContext) {
// Handle Roo models cache based on auth state
const handleRooModelsCache = async () => {
try {
await flushModels("roo")
// Flush and refresh cache on auth state changes
await flushModels("roo", true)

if (data.state === "active-session") {
// Reload models with the new auth token
const sessionToken = cloudService?.authService?.getSessionToken()
await getModels({
provider: "roo",
baseUrl: process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy",
apiKey: sessionToken,
})
cloudLogger(`[authStateChangedHandler] Reloaded Roo models cache for active session`)
cloudLogger(`[authStateChangedHandler] Refreshed Roo models cache for active session`)
} else {
cloudLogger(`[authStateChangedHandler] Flushed Roo models cache on logout`)
}
Expand Down Expand Up @@ -353,6 +347,9 @@ export async function activate(context: vscode.ExtensionContext) {
})
}

// Initialize background model cache refresh
initializeModelCacheRefresh()

return new API(outputChannel, provider, socketPath, enableLogging)
}

Expand Down
Loading