diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 5262e7602d6..4153db0da4e 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -19,7 +19,6 @@ import { moonshotModels, openAiNativeModels, qwenCodeModels, - rooModels, sambaNovaModels, vertexModels, vscodeLlmModels, @@ -49,6 +48,7 @@ export const dynamicProviders = [ "requesty", "unbound", "glama", + "roo", ] as const export type DynamicProvider = (typeof dynamicProviders)[number] @@ -677,7 +677,7 @@ export const MODELS_BY_PROVIDER: Record< models: Object.keys(openAiNativeModels), }, "qwen-code": { id: "qwen-code", label: "Qwen Code", models: Object.keys(qwenCodeModels) }, - roo: { id: "roo", label: "Roo", models: Object.keys(rooModels) }, + roo: { id: "roo", label: "Roo Code Cloud", models: [] }, sambanova: { id: "sambanova", label: "SambaNova", diff --git a/packages/types/src/providers/roo.ts b/packages/types/src/providers/roo.ts index fd705b1eb97..0b7ed89bd92 100644 --- a/packages/types/src/providers/roo.ts +++ b/packages/types/src/providers/roo.ts @@ -1,53 +1,49 @@ +import { z } from "zod" + import type { ModelInfo } from "../model.js" -export type RooModelId = - | "xai/grok-code-fast-1" - | "roo/code-supernova-1-million" - | "xai/grok-4-fast" - | "deepseek/deepseek-chat-v3.1" - -export const rooDefaultModelId: RooModelId = "xai/grok-code-fast-1" - -export const rooModels = { - "xai/grok-code-fast-1": { - maxTokens: 16_384, - contextWindow: 262_144, - supportsImages: false, - supportsPromptCache: true, - inputPrice: 0, - outputPrice: 0, - description: - "A reasoning model that is blazing fast and excels at agentic coding, accessible for free through Roo Code Cloud for a limited time. (Note: the free prompts and completions are logged by xAI and used to improve the model.)", - }, - "roo/code-supernova-1-million": { - maxTokens: 30_000, - contextWindow: 1_000_000, - supportsImages: true, - supportsPromptCache: true, - inputPrice: 0, - outputPrice: 0, - description: - "A versatile agentic coding stealth model with a 1M token context window that supports image inputs, accessible for free through Roo Code Cloud for a limited time. (Note: the free prompts and completions are logged by the model provider and used to improve the model.)", - }, - "xai/grok-4-fast": { - maxTokens: 30_000, - contextWindow: 2_000_000, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0, - outputPrice: 0, - description: - "Grok 4 Fast is xAI's latest multimodal model with SOTA cost-efficiency and a 2M token context window. (Note: prompts and completions are logged by xAI and used to improve the model.)", - deprecated: true, - }, - "deepseek/deepseek-chat-v3.1": { - maxTokens: 16_384, - contextWindow: 163_840, - supportsImages: false, - supportsPromptCache: false, - inputPrice: 0, - outputPrice: 0, - description: - "DeepSeek-V3.1 is a large hybrid reasoning model (671B parameters, 37B active). It extends the DeepSeek-V3 base with a two-phase long-context training process, reaching up to 128K tokens, and uses FP8 microscaling for efficient inference.", - }, -} as const satisfies Record +/** + * Roo Code Cloud is a dynamic provider - models are loaded from the /v1/models API endpoint. + * Default model ID used as fallback when no model is specified. + */ +export const rooDefaultModelId = "xai/grok-code-fast-1" + +/** + * Empty models object maintained for type compatibility. + * All model data comes dynamically from the API. + */ +export const rooModels = {} as const satisfies Record + +/** + * Roo Code Cloud API response schemas + */ + +export const RooPricingSchema = z.object({ + input: z.string(), + output: z.string(), + input_cache_read: z.string().optional(), + input_cache_write: z.string().optional(), +}) + +export const RooModelSchema = z.object({ + id: z.string(), + object: z.literal("model"), + created: z.number(), + owned_by: z.string(), + name: z.string(), + description: z.string(), + context_window: z.number(), + max_tokens: z.number(), + type: z.literal("language"), + tags: z.array(z.string()).optional(), + pricing: RooPricingSchema, + deprecated: z.boolean().optional(), +}) + +export const RooModelsResponseSchema = z.object({ + object: z.literal("list"), + data: z.array(RooModelSchema), +}) + +export type RooModel = z.infer +export type RooModelsResponse = z.infer diff --git a/src/api/providers/__tests__/roo.spec.ts b/src/api/providers/__tests__/roo.spec.ts index d4affa2beaf..c209aa51cdc 100644 --- a/src/api/providers/__tests__/roo.spec.ts +++ b/src/api/providers/__tests__/roo.spec.ts @@ -1,7 +1,7 @@ // npx vitest run api/providers/__tests__/roo.spec.ts import { Anthropic } from "@anthropic-ai/sdk" -import { rooDefaultModelId, rooModels } from "@roo-code/types" +import { rooDefaultModelId } from "@roo-code/types" import { ApiHandlerOptions } from "../../../shared/api" @@ -301,8 +301,9 @@ describe("RooHandler", () => { const modelInfo = handler.getModel() expect(modelInfo.id).toBe(mockOptions.apiModelId) expect(modelInfo.info).toBeDefined() - // xai/grok-code-fast-1 is a valid model in rooModels - expect(modelInfo.info).toBe(rooModels["xai/grok-code-fast-1"]) + // Models are loaded dynamically, so we just verify the structure + expect(modelInfo.info.maxTokens).toBeDefined() + expect(modelInfo.info.contextWindow).toBeDefined() }) it("should return default model when no model specified", () => { @@ -310,7 +311,9 @@ describe("RooHandler", () => { const modelInfo = handlerWithoutModel.getModel() expect(modelInfo.id).toBe(rooDefaultModelId) expect(modelInfo.info).toBeDefined() - expect(modelInfo.info).toBe(rooModels[rooDefaultModelId]) + // Models are loaded dynamically + expect(modelInfo.info.maxTokens).toBeDefined() + expect(modelInfo.info.contextWindow).toBeDefined() }) it("should handle unknown model ID with fallback info", () => { @@ -320,24 +323,27 @@ describe("RooHandler", () => { const modelInfo = handlerWithUnknownModel.getModel() expect(modelInfo.id).toBe("unknown-model-id") expect(modelInfo.info).toBeDefined() - // Should return fallback info for unknown models - expect(modelInfo.info.maxTokens).toBe(16_384) - expect(modelInfo.info.contextWindow).toBe(262_144) - expect(modelInfo.info.supportsImages).toBe(false) - expect(modelInfo.info.supportsPromptCache).toBe(true) - expect(modelInfo.info.inputPrice).toBe(0) - expect(modelInfo.info.outputPrice).toBe(0) + // Should return fallback info for unknown models (dynamic models will be merged in real usage) + expect(modelInfo.info.maxTokens).toBeDefined() + expect(modelInfo.info.contextWindow).toBeDefined() + expect(modelInfo.info.supportsImages).toBeDefined() + expect(modelInfo.info.supportsPromptCache).toBeDefined() + expect(modelInfo.info.inputPrice).toBeDefined() + expect(modelInfo.info.outputPrice).toBeDefined() }) - it("should return correct model info for all Roo models", () => { - // Test each model in rooModels - const modelIds = Object.keys(rooModels) as Array + it("should handle any model ID since models are loaded dynamically", () => { + // Test with various model IDs - they should all work since models are loaded dynamically + const testModelIds = ["xai/grok-code-fast-1", "roo/sonic", "deepseek/deepseek-chat-v3.1"] - for (const modelId of modelIds) { + for (const modelId of testModelIds) { const handlerWithModel = new RooHandler({ apiModelId: modelId }) const modelInfo = handlerWithModel.getModel() expect(modelInfo.id).toBe(modelId) - expect(modelInfo.info).toBe(rooModels[modelId]) + expect(modelInfo.info).toBeDefined() + // Verify the structure has required fields + expect(modelInfo.info.maxTokens).toBeDefined() + expect(modelInfo.info.contextWindow).toBeDefined() } }) }) diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index 2ccb73a4551..55b5bc3a304 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -24,6 +24,7 @@ import { getLMStudioModels } from "./lmstudio" import { getIOIntelligenceModels } from "./io-intelligence" import { getDeepInfraModels } from "./deepinfra" import { getHuggingFaceModels } from "./huggingface" +import { getRooModels } from "./roo" const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 }) @@ -99,6 +100,13 @@ export const getModels = async (options: GetModelsOptions): Promise case "huggingface": models = await getHuggingFaceModels() break + case "roo": { + // Roo Code Cloud provider requires baseUrl and optional apiKey + const rooBaseUrl = + options.baseUrl ?? process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy" + models = await getRooModels(rooBaseUrl, options.apiKey) + break + } default: { // Ensures router is exhaustively checked if RouterName is a strict union. const exhaustiveCheck: never = provider diff --git a/src/api/providers/fetchers/roo.ts b/src/api/providers/fetchers/roo.ts new file mode 100644 index 00000000000..5836c28b2d8 --- /dev/null +++ b/src/api/providers/fetchers/roo.ts @@ -0,0 +1,119 @@ +import { RooModelsResponseSchema } from "@roo-code/types" + +import type { ModelRecord } from "../../../shared/api" + +import { DEFAULT_HEADERS } from "../constants" + +/** + * Fetches available models from the Roo Code Cloud provider + * + * @param baseUrl The base URL of the Roo Code Cloud provider + * @param apiKey The API key (session token) for the Roo Code Cloud provider + * @returns A promise that resolves to a record of model IDs to model info + * @throws Will throw an error if the request fails or the response is not as expected. + */ +export async function getRooModels(baseUrl: string, apiKey?: string): Promise { + try { + const headers: Record = { + "Content-Type": "application/json", + ...DEFAULT_HEADERS, + } + + if (apiKey) { + headers["Authorization"] = `Bearer ${apiKey}` + } + + // Construct the models endpoint URL + // Strip trailing /v1 or /v1/ to avoid /v1/v1/models + const normalizedBase = baseUrl.replace(/\/?v1\/?$/, "") + const url = `${normalizedBase}/v1/models` + + // Use fetch with AbortController for better timeout handling + const controller = new AbortController() + const timeoutId = setTimeout(() => controller.abort(), 10000) + + try { + const response = await fetch(url, { + headers, + signal: controller.signal, + }) + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`) + } + + const data = await response.json() + const models: ModelRecord = {} + + // Validate response against schema + const parsed = RooModelsResponseSchema.safeParse(data) + + if (!parsed.success) { + console.error("Error fetching Roo Code Cloud models: Unexpected response format", data) + console.error("Validation errors:", parsed.error.format()) + throw new Error("Failed to fetch Roo Code Cloud models: Unexpected response format.") + } + + // Process the validated model data + for (const model of parsed.data.data) { + const modelId = model.id + + if (!modelId) continue + + // Extract model data from the validated API response + // All required fields are guaranteed by the schema + const contextWindow = model.context_window + const maxTokens = model.max_tokens + const tags = model.tags || [] + const pricing = model.pricing + + // Determine if the model supports images based on tags + const supportsImages = tags.includes("vision") + + // Parse pricing (API returns strings, convert to numbers) + const inputPrice = parseFloat(pricing.input) + const outputPrice = parseFloat(pricing.output) + const cacheReadPrice = pricing.input_cache_read ? parseFloat(pricing.input_cache_read) : undefined + const cacheWritePrice = pricing.input_cache_write ? parseFloat(pricing.input_cache_write) : undefined + + models[modelId] = { + maxTokens, + contextWindow, + supportsImages, + supportsPromptCache: Boolean(cacheReadPrice !== undefined), + inputPrice, + outputPrice, + cacheWritesPrice: cacheWritePrice, + cacheReadsPrice: cacheReadPrice, + description: model.description || model.name, + deprecated: model.deprecated || false, + } + } + + return models + } finally { + clearTimeout(timeoutId) + } + } catch (error: any) { + console.error("Error fetching Roo Code Cloud models:", error.message ? error.message : error) + + // Handle abort/timeout + if (error.name === "AbortError") { + throw new Error("Failed to fetch Roo Code Cloud models: Request timed out after 10 seconds.") + } + + // Handle fetch errors + if (error.message?.includes("HTTP")) { + throw new Error(`Failed to fetch Roo Code Cloud models: ${error.message}. Check base URL and API key.`) + } + + // Handle network errors + if (error instanceof TypeError) { + throw new Error( + "Failed to fetch Roo Code Cloud models: No response from server. Check Roo Code Cloud server status and base URL.", + ) + } + + throw new Error(`Failed to fetch Roo Code Cloud models: ${error.message || "An unknown error occurred."}`) + } +} diff --git a/src/api/providers/roo.ts b/src/api/providers/roo.ts index 6f10157a313..6deaa286e9c 100644 --- a/src/api/providers/roo.ts +++ b/src/api/providers/roo.ts @@ -1,18 +1,20 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { AuthState, rooDefaultModelId, rooModels, type RooModelId } from "@roo-code/types" +import { AuthState, rooDefaultModelId, type ModelInfo } from "@roo-code/types" import { CloudService } from "@roo-code/cloud" -import type { ApiHandlerOptions } from "../../shared/api" +import type { ApiHandlerOptions, ModelRecord } from "../../shared/api" import { ApiStream } from "../transform/stream" import type { ApiHandlerCreateMessageMetadata } from "../index" import { DEFAULT_HEADERS } from "./constants" import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" +import { getModels, flushModels, getModelsFromCache } from "../providers/fetchers/modelCache" -export class RooHandler extends BaseOpenAiCompatibleProvider { +export class RooHandler extends BaseOpenAiCompatibleProvider { private authStateListener?: (state: { state: AuthState }) => void + private fetcherBaseURL: string constructor(options: ApiHandlerOptions) { let sessionToken: string | undefined = undefined @@ -21,34 +23,62 @@ export class RooHandler extends BaseOpenAiCompatibleProvider { sessionToken = CloudService.instance.authService?.getSessionToken() } + let baseURL = process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy" + + // Ensure baseURL ends with /v1 for OpenAI client, but don't duplicate it + if (!baseURL.endsWith("/v1")) { + baseURL = `${baseURL}/v1` + } + // Always construct the handler, even without a valid token. // The provider-proxy server will return 401 if authentication fails. super({ ...options, providerName: "Roo Code Cloud", - baseURL: process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy/v1", + baseURL, // Already has /v1 suffix apiKey: sessionToken || "unauthenticated", // Use a placeholder if no token. defaultProviderModelId: rooDefaultModelId, - providerModels: rooModels, + providerModels: {}, defaultTemperature: 0.7, }) + // Load dynamic models asynchronously - strip /v1 from baseURL for fetcher + this.fetcherBaseURL = baseURL.endsWith("/v1") ? baseURL.slice(0, -3) : baseURL + this.loadDynamicModels(this.fetcherBaseURL, sessionToken).catch((error) => { + console.error("[RooHandler] Failed to load dynamic models:", error) + }) + if (CloudService.hasInstance()) { const cloudService = CloudService.instance this.authStateListener = (state: { state: AuthState }) => { if (state.state === "active-session") { + const newToken = cloudService.authService?.getSessionToken() this.client = new OpenAI({ baseURL: this.baseURL, - apiKey: cloudService.authService?.getSessionToken() ?? "unauthenticated", + apiKey: newToken ?? "unauthenticated", defaultHeaders: DEFAULT_HEADERS, }) + + // Flush cache and reload models with the new auth token + flushModels("roo") + .then(() => { + return this.loadDynamicModels(this.fetcherBaseURL, newToken) + }) + .catch((error) => { + console.error("[RooHandler] Failed to reload models after auth:", error) + }) } else if (state.state === "logged-out") { this.client = new OpenAI({ baseURL: this.baseURL, apiKey: "unauthenticated", defaultHeaders: DEFAULT_HEADERS, }) + + // Flush cache when logged out + flushModels("roo").catch((error) => { + console.error("[RooHandler] Failed to flush models on logout:", error) + }) } } @@ -103,17 +133,33 @@ export class RooHandler extends BaseOpenAiCompatibleProvider { } } + private async loadDynamicModels(baseURL: string, apiKey?: string): Promise { + try { + // Fetch models and cache them in the shared cache + await getModels({ + provider: "roo", + baseUrl: baseURL, + apiKey, + }) + } catch (error) { + console.error("[RooHandler] Error loading dynamic models:", error) + } + } + override getModel() { const modelId = this.options.apiModelId || rooDefaultModelId - const modelInfo = this.providerModels[modelId as RooModelId] ?? this.providerModels[rooDefaultModelId] + + // Get models from shared cache + const models = getModelsFromCache("roo") || {} + const modelInfo = models[modelId] if (modelInfo) { - return { id: modelId as RooModelId, info: modelInfo } + return { id: modelId, info: modelInfo } } // Return the requested model ID even if not found, with fallback info. return { - id: modelId as RooModelId, + id: modelId, info: { maxTokens: 16_384, contextWindow: 262_144, diff --git a/src/core/config/ProviderSettingsManager.ts b/src/core/config/ProviderSettingsManager.ts index 357a04b33a4..72e4b577e9c 100644 --- a/src/core/config/ProviderSettingsManager.ts +++ b/src/core/config/ProviderSettingsManager.ts @@ -11,7 +11,6 @@ import { DEFAULT_CONSECUTIVE_MISTAKE_LIMIT, getModelId, type ProviderName, - type RooModelId, } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" @@ -24,7 +23,7 @@ type ModelMigrations = { const MODEL_MIGRATIONS: ModelMigrations = { roo: { - "roo/code-supernova": "roo/code-supernova-1-million" as RooModelId, + "roo/code-supernova": "roo/code-supernova-1-million", }, } as const satisfies ModelMigrations diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index 71b79e35add..3d68fac2acb 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -2689,6 +2689,13 @@ describe("ClineProvider - Router Models", () => { expect(getModels).toHaveBeenCalledWith({ provider: "glama" }) expect(getModels).toHaveBeenCalledWith({ provider: "unbound", apiKey: "unbound-key" }) expect(getModels).toHaveBeenCalledWith({ provider: "vercel-ai-gateway" }) + expect(getModels).toHaveBeenCalledWith({ provider: "deepinfra" }) + expect(getModels).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "roo", + baseUrl: expect.any(String), + }), + ) expect(getModels).toHaveBeenCalledWith({ provider: "litellm", apiKey: "litellm-key", @@ -2704,6 +2711,7 @@ describe("ClineProvider - Router Models", () => { requesty: mockModels, glama: mockModels, unbound: mockModels, + roo: mockModels, litellm: mockModels, ollama: {}, lmstudio: {}, @@ -2742,6 +2750,7 @@ describe("ClineProvider - Router Models", () => { .mockRejectedValueOnce(new Error("Unbound API error")) // unbound fail .mockResolvedValueOnce(mockModels) // vercel-ai-gateway success .mockResolvedValueOnce(mockModels) // deepinfra success + .mockResolvedValueOnce(mockModels) // roo success .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm fail await messageHandler({ type: "requestRouterModels" }) @@ -2755,6 +2764,7 @@ describe("ClineProvider - Router Models", () => { requesty: {}, glama: mockModels, unbound: {}, + roo: mockModels, ollama: {}, lmstudio: {}, litellm: {}, @@ -2869,6 +2879,7 @@ describe("ClineProvider - Router Models", () => { requesty: mockModels, glama: mockModels, unbound: mockModels, + roo: mockModels, litellm: {}, ollama: {}, lmstudio: {}, diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index 469eb68d65b..749e8d090d8 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -218,12 +218,18 @@ describe("webviewMessageHandler - requestRouterModels", () => { }) // Verify getModels was called for each provider - expect(mockGetModels).toHaveBeenCalledWith({ provider: "deepinfra" }) expect(mockGetModels).toHaveBeenCalledWith({ provider: "openrouter" }) expect(mockGetModels).toHaveBeenCalledWith({ provider: "requesty", apiKey: "requesty-key" }) expect(mockGetModels).toHaveBeenCalledWith({ provider: "glama" }) expect(mockGetModels).toHaveBeenCalledWith({ provider: "unbound", apiKey: "unbound-key" }) expect(mockGetModels).toHaveBeenCalledWith({ provider: "vercel-ai-gateway" }) + expect(mockGetModels).toHaveBeenCalledWith({ provider: "deepinfra" }) + expect(mockGetModels).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "roo", + baseUrl: expect.any(String), + }), + ) expect(mockGetModels).toHaveBeenCalledWith({ provider: "litellm", apiKey: "litellm-key", @@ -242,6 +248,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { glama: mockModels, unbound: mockModels, litellm: mockModels, + roo: mockModels, ollama: {}, lmstudio: {}, "vercel-ai-gateway": mockModels, @@ -332,6 +339,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { requesty: mockModels, glama: mockModels, unbound: mockModels, + roo: mockModels, litellm: {}, ollama: {}, lmstudio: {}, @@ -360,6 +368,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { .mockRejectedValueOnce(new Error("Unbound API error")) // unbound .mockResolvedValueOnce(mockModels) // vercel-ai-gateway .mockResolvedValueOnce(mockModels) // deepinfra + .mockResolvedValueOnce(mockModels) // roo .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm await webviewMessageHandler(mockClineProvider, { @@ -375,6 +384,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { requesty: {}, glama: mockModels, unbound: {}, + roo: mockModels, litellm: {}, ollama: {}, lmstudio: {}, @@ -416,6 +426,7 @@ describe("webviewMessageHandler - requestRouterModels", () => { .mockRejectedValueOnce(new Error("Unbound API error")) // unbound .mockRejectedValueOnce(new Error("Vercel AI Gateway error")) // vercel-ai-gateway .mockRejectedValueOnce(new Error("DeepInfra API error")) // deepinfra + .mockRejectedValueOnce(new Error("Roo API error")) // roo .mockRejectedValueOnce(new Error("LiteLLM connection failed")) // litellm await webviewMessageHandler(mockClineProvider, { @@ -458,6 +469,20 @@ describe("webviewMessageHandler - requestRouterModels", () => { values: { provider: "deepinfra" }, }) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "Vercel AI Gateway error", + values: { provider: "vercel-ai-gateway" }, + }) + + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "singleRouterModelFetchResponse", + success: false, + error: "Roo API error", + values: { provider: "roo" }, + }) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ type: "singleRouterModelFetchResponse", success: false, diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 77cdc786502..38b51c71238 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -769,6 +769,7 @@ export const webviewMessageHandler = async ( glama: {}, ollama: {}, lmstudio: {}, + roo: {}, } const safeGetModels = async (options: GetModelsOptions): Promise => { @@ -805,6 +806,16 @@ export const webviewMessageHandler = async ( baseUrl: apiConfiguration.deepInfraBaseUrl, }, }, + { + key: "roo", + options: { + provider: "roo", + baseUrl: process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy", + apiKey: CloudService.hasInstance() + ? CloudService.instance.authService?.getSessionToken() + : undefined, + }, + }, ] // Add IO Intelligence if API key is provided. @@ -919,6 +930,38 @@ export const webviewMessageHandler = async ( } break } + case "requestRooModels": { + // Specific handler for Roo models only - flushes cache to ensure fresh auth token is used + try { + // Flush cache first to ensure fresh models with current auth state + await flushModels("roo") + + const rooModels = await getModels({ + provider: "roo", + baseUrl: process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy", + apiKey: CloudService.hasInstance() + ? CloudService.instance.authService?.getSessionToken() + : undefined, + }) + + // Always send a response, even if no models are returned + provider.postMessageToWebview({ + type: "singleRouterModelFetchResponse", + success: true, + values: { provider: "roo", models: rooModels }, + }) + } catch (error) { + // Send error response + const errorMessage = error instanceof Error ? error.message : String(error) + provider.postMessageToWebview({ + type: "singleRouterModelFetchResponse", + success: false, + error: errorMessage, + values: { provider: "roo" }, + }) + } + break + } case "requestOpenAiModels": if (message?.values?.baseUrl && message?.values?.apiKey) { const openAiModels = await getOpenAiModels( diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index ea3649aaa5e..a88d883f87e 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -69,6 +69,7 @@ export interface WebviewMessage { | "requestOpenAiModels" | "requestOllamaModels" | "requestLmStudioModels" + | "requestRooModels" | "requestVsCodeLmModels" | "requestHuggingFaceModels" | "openImage" diff --git a/src/shared/api.ts b/src/shared/api.ts index 79001cb0ad0..8b18e7f50d8 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -163,6 +163,7 @@ const dynamicProviderExtras = { glama: {} as {}, // eslint-disable-line @typescript-eslint/no-empty-object-type ollama: {} as {}, // eslint-disable-line @typescript-eslint/no-empty-object-type lmstudio: {} as {}, // eslint-disable-line @typescript-eslint/no-empty-object-type + roo: {} as { apiKey?: string; baseUrl?: string }, } as const satisfies Record // Build the dynamic options union from the map, intersected with CommonFetchParams diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index 37c1c286b98..7efa253abc4 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -1,7 +1,7 @@ import React, { memo, useCallback, useEffect, useMemo, useState } from "react" import { convertHeadersToObject } from "./utils/headers" import { useDebounce } from "react-use" -import { VSCodeLink, VSCodeButton } from "@vscode/webview-ui-toolkit/react" +import { VSCodeLink } from "@vscode/webview-ui-toolkit/react" import { ExternalLinkIcon } from "@radix-ui/react-icons" import { @@ -85,6 +85,7 @@ import { OpenRouter, QwenCode, Requesty, + Roo, SambaNova, Unbound, Vertex, @@ -228,7 +229,11 @@ const ApiOptions = ({ vscode.postMessage({ type: "requestLmStudioModels" }) } else if (selectedProvider === "vscode-lm") { vscode.postMessage({ type: "requestVsCodeLmModels" }) - } else if (selectedProvider === "litellm" || selectedProvider === "deepinfra") { + } else if ( + selectedProvider === "litellm" || + selectedProvider === "deepinfra" || + selectedProvider === "roo" + ) { vscode.postMessage({ type: "requestRouterModels" }) } }, @@ -667,22 +672,14 @@ const ApiOptions = ({ )} {selectedProvider === "roo" && ( -
- {cloudIsAuthenticated ? ( -
- {t("settings:providers.roo.authenticatedMessage")} -
- ) : ( -
- vscode.postMessage({ type: "rooCloudSignIn" })} - className="w-fit"> - {t("settings:providers.roo.connectButton")} - -
- )} -
+ )} {selectedProvider === "featherless" && ( diff --git a/webview-ui/src/components/settings/ModelPicker.tsx b/webview-ui/src/components/settings/ModelPicker.tsx index b5813047448..6020a260bd3 100644 --- a/webview-ui/src/components/settings/ModelPicker.tsx +++ b/webview-ui/src/components/settings/ModelPicker.tsx @@ -37,6 +37,7 @@ type ModelIdKey = keyof Pick< | "deepInfraModelId" | "ioIntelligenceModelId" | "vercelAiGatewayModelId" + | "apiModelId" > interface ModelPickerProps { diff --git a/webview-ui/src/components/settings/constants.ts b/webview-ui/src/components/settings/constants.ts index ae336730ff5..1ba0497edfa 100644 --- a/webview-ui/src/components/settings/constants.ts +++ b/webview-ui/src/components/settings/constants.ts @@ -19,7 +19,6 @@ import { doubaoModels, internationalZAiModels, fireworksModels, - rooModels, featherlessModels, } from "@roo-code/types" @@ -42,7 +41,6 @@ export const MODELS_BY_PROVIDER: Partial void + routerModels?: RouterModels + cloudIsAuthenticated: boolean + organizationAllowList: OrganizationAllowList + modelValidationError?: string +} + +export const Roo = ({ + apiConfiguration, + setApiConfigurationField, + routerModels, + cloudIsAuthenticated, + organizationAllowList, + modelValidationError, +}: RooProps) => { + const { t } = useAppTranslation() + + return ( + <> + {cloudIsAuthenticated ? ( +
+ {t("settings:providers.roo.authenticatedMessage")} +
+ ) : ( +
+ vscode.postMessage({ type: "rooCloudSignIn" })} + className="w-fit"> + {t("settings:providers.roo.connectButton")} + +
+ )} + + + ) +} diff --git a/webview-ui/src/components/settings/providers/index.ts b/webview-ui/src/components/settings/providers/index.ts index fe0e6cecf96..d6423b0c8a6 100644 --- a/webview-ui/src/components/settings/providers/index.ts +++ b/webview-ui/src/components/settings/providers/index.ts @@ -18,6 +18,7 @@ export { OpenAI } from "./OpenAI" export { OpenAICompatible } from "./OpenAICompatible" export { OpenRouter } from "./OpenRouter" export { QwenCode } from "./QwenCode" +export { Roo } from "./Roo" export { Requesty } from "./Requesty" export { SambaNova } from "./SambaNova" export { Unbound } from "./Unbound" diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index a3ce1e63e4e..df674893efc 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -51,7 +51,6 @@ import { ioIntelligenceDefaultModelId, ioIntelligenceModels, rooDefaultModelId, - rooModels, qwenCodeDefaultModelId, qwenCodeModels, vercelAiGatewayDefaultModelId, @@ -330,21 +329,10 @@ function getSelectedModel({ return { id, info } } case "roo": { - const requestedId = apiConfiguration.apiModelId - - // Check if the requested model exists in rooModels - if (requestedId && rooModels[requestedId as keyof typeof rooModels]) { - return { - id: requestedId, - info: rooModels[requestedId as keyof typeof rooModels], - } - } - - // Fallback to default model if requested model doesn't exist or is not specified - return { - id: rooDefaultModelId, - info: rooModels[rooDefaultModelId as keyof typeof rooModels], - } + // Roo is a dynamic provider - models are loaded from API + const id = apiConfiguration.apiModelId ?? rooDefaultModelId + const info = routerModels.roo[id] + return { id, info } } case "qwen-code": { const id = apiConfiguration.apiModelId ?? qwenCodeDefaultModelId diff --git a/webview-ui/src/context/ExtensionStateContext.tsx b/webview-ui/src/context/ExtensionStateContext.tsx index 37592c2349c..05868c31a5d 100644 --- a/webview-ui/src/context/ExtensionStateContext.tsx +++ b/webview-ui/src/context/ExtensionStateContext.tsx @@ -289,6 +289,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode global: {}, }) const [includeTaskHistoryInEnhance, setIncludeTaskHistoryInEnhance] = useState(true) + const [prevCloudIsAuthenticated, setPrevCloudIsAuthenticated] = useState(false) const setListApiConfigMeta = useCallback( (value: ProviderSettingsEntry[]) => setState((prevState) => ({ ...prevState, listApiConfigMeta: value })), @@ -420,6 +421,16 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode vscode.postMessage({ type: "webviewDidLaunch" }) }, []) + // Watch for authentication state changes and refresh Roo models + useEffect(() => { + const currentAuth = state.cloudIsAuthenticated ?? false + if (!prevCloudIsAuthenticated && currentAuth) { + // User just authenticated - refresh Roo models with the new auth token + vscode.postMessage({ type: "requestRooModels" }) + } + setPrevCloudIsAuthenticated(currentAuth) + }, [state.cloudIsAuthenticated, prevCloudIsAuthenticated]) + const contextValue: ExtensionStateContextType = { ...state, reasoningBlockCollapsed: state.reasoningBlockCollapsed ?? true, diff --git a/webview-ui/src/utils/__tests__/validate.test.ts b/webview-ui/src/utils/__tests__/validate.test.ts index 33ede230531..c2451dcd6f3 100644 --- a/webview-ui/src/utils/__tests__/validate.test.ts +++ b/webview-ui/src/utils/__tests__/validate.test.ts @@ -43,6 +43,7 @@ describe("Model Validation Functions", () => { "io-intelligence": {}, "vercel-ai-gateway": {}, huggingface: {}, + roo: {}, } const allowAllOrganization: OrganizationAllowList = {