diff --git a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts index bfbfdc424e4..a7824b1fa28 100644 --- a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts +++ b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts @@ -5,7 +5,7 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query" import { renderHook } from "@testing-library/react" import type { Mock } from "vitest" -import { ProviderSettings, ModelInfo, BEDROCK_1M_CONTEXT_MODEL_IDS } from "@roo-code/types" +import { ProviderSettings, ModelInfo, BEDROCK_1M_CONTEXT_MODEL_IDS, litellmDefaultModelInfo } from "@roo-code/types" import { useSelectedModel } from "../useSelectedModel" import { useRouterModels } from "../useRouterModels" @@ -540,4 +540,119 @@ describe("useSelectedModel", () => { expect(result.current.info?.contextWindow).toBe(200_000) }) }) + + describe("litellm provider", () => { + beforeEach(() => { + mockUseOpenRouterModelProviders.mockReturnValue({ + data: {}, + isLoading: false, + isError: false, + } as any) + }) + + it("should use litellmDefaultModelInfo as fallback when routerModels.litellm is empty", () => { + mockUseRouterModels.mockReturnValue({ + data: { + openrouter: {}, + requesty: {}, + unbound: {}, + litellm: {}, + "io-intelligence": {}, + }, + isLoading: false, + isError: false, + } as any) + + const apiConfiguration: ProviderSettings = { + apiProvider: "litellm", + litellmModelId: "some-model", + } + + const wrapper = createWrapper() + const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) + + expect(result.current.provider).toBe("litellm") + // Should fall back to default model ID since "some-model" doesn't exist in empty litellm models + expect(result.current.id).toBe("claude-3-7-sonnet-20250219") + // Should use litellmDefaultModelInfo as fallback + expect(result.current.info).toEqual(litellmDefaultModelInfo) + expect(result.current.info?.supportsNativeTools).toBe(true) + }) + + it("should use litellmDefaultModelInfo when selected model not found in routerModels", () => { + mockUseRouterModels.mockReturnValue({ + data: { + openrouter: {}, + requesty: {}, + unbound: {}, + litellm: { + "existing-model": { + maxTokens: 4096, + contextWindow: 8192, + supportsImages: false, + supportsPromptCache: false, + supportsNativeTools: true, + }, + }, + "io-intelligence": {}, + }, + isLoading: false, + isError: false, + } as any) + + const apiConfiguration: ProviderSettings = { + apiProvider: "litellm", + litellmModelId: "non-existing-model", + } + + const wrapper = createWrapper() + const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) + + expect(result.current.provider).toBe("litellm") + // Falls back to default model ID + expect(result.current.id).toBe("claude-3-7-sonnet-20250219") + // Should use litellmDefaultModelInfo as fallback since default model also not in router models + expect(result.current.info).toEqual(litellmDefaultModelInfo) + expect(result.current.info?.supportsNativeTools).toBe(true) + }) + + it("should use model info from routerModels when model exists", () => { + const customModelInfo: ModelInfo = { + maxTokens: 16384, + contextWindow: 128000, + supportsImages: true, + supportsPromptCache: true, + supportsNativeTools: true, + description: "Custom LiteLLM model", + } + + mockUseRouterModels.mockReturnValue({ + data: { + openrouter: {}, + requesty: {}, + unbound: {}, + litellm: { + "custom-model": customModelInfo, + }, + "io-intelligence": {}, + }, + isLoading: false, + isError: false, + } as any) + + const apiConfiguration: ProviderSettings = { + apiProvider: "litellm", + litellmModelId: "custom-model", + } + + const wrapper = createWrapper() + const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) + + expect(result.current.provider).toBe("litellm") + expect(result.current.id).toBe("custom-model") + // Should use the model info from routerModels, not the fallback + expect(result.current.info).toEqual(customModelInfo) + expect(result.current.info?.supportsNativeTools).toBe(true) + }) + }) }) diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index 79804165dc4..010adc3155b 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -27,6 +27,7 @@ import { ioIntelligenceModels, basetenModels, qwenCodeModels, + litellmDefaultModelInfo, BEDROCK_1M_CONTEXT_MODEL_IDS, isDynamicProvider, getProviderDefaultModelId, @@ -164,7 +165,7 @@ function getSelectedModel({ } case "litellm": { const id = getValidatedModelId(apiConfiguration.litellmModelId, routerModels.litellm, defaultModelId) - const info = routerModels.litellm?.[id] + const info = routerModels.litellm?.[id] ?? litellmDefaultModelInfo return { id, info } } case "xai": {