diff --git a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts index 9907017b53d..4dca874ee20 100644 --- a/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts +++ b/webview-ui/src/components/ui/hooks/__tests__/useSelectedModel.spec.ts @@ -5,7 +5,13 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query" import { renderHook } from "@testing-library/react" import type { Mock } from "vitest" -import { ProviderSettings, ModelInfo, BEDROCK_1M_CONTEXT_MODEL_IDS, litellmDefaultModelInfo } from "@roo-code/types" +import { + ProviderSettings, + ModelInfo, + BEDROCK_1M_CONTEXT_MODEL_IDS, + litellmDefaultModelInfo, + openAiModelInfoSaneDefaults, +} from "@roo-code/types" import { useSelectedModel } from "../useSelectedModel" import { useRouterModels } from "../useRouterModels" @@ -661,4 +667,100 @@ describe("useSelectedModel", () => { expect(result.current.info?.defaultToolProtocol).toBe("native") }) }) + + describe("openai provider", () => { + beforeEach(() => { + mockUseRouterModels.mockReturnValue({ + data: { + openrouter: {}, + requesty: {}, + unbound: {}, + litellm: {}, + "io-intelligence": {}, + }, + isLoading: false, + isError: false, + } as any) + + mockUseOpenRouterModelProviders.mockReturnValue({ + data: {}, + isLoading: false, + isError: false, + } as any) + }) + + it("should use openAiModelInfoSaneDefaults when no custom model info is provided", () => { + const apiConfiguration: ProviderSettings = { + apiProvider: "openai", + openAiModelId: "gpt-4o", + } + + const wrapper = createWrapper() + const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) + + expect(result.current.provider).toBe("openai") + expect(result.current.id).toBe("gpt-4o") + expect(result.current.info).toEqual(openAiModelInfoSaneDefaults) + expect(result.current.info?.supportsNativeTools).toBe(true) + expect(result.current.info?.defaultToolProtocol).toBe("native") + }) + + it("should merge native tool defaults with custom model info", () => { + const customModelInfo: ModelInfo = { + maxTokens: 16384, + contextWindow: 128000, + supportsImages: true, + supportsPromptCache: false, + inputPrice: 0.01, + outputPrice: 0.03, + description: "Custom OpenAI-compatible model", + } + + const apiConfiguration: ProviderSettings = { + apiProvider: "openai", + openAiModelId: "custom-model", + openAiCustomModelInfo: customModelInfo, + } + + const wrapper = createWrapper() + const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) + + expect(result.current.provider).toBe("openai") + expect(result.current.id).toBe("custom-model") + // Should merge native tool defaults with custom model info + const nativeToolDefaults = { + supportsNativeTools: openAiModelInfoSaneDefaults.supportsNativeTools, + defaultToolProtocol: openAiModelInfoSaneDefaults.defaultToolProtocol, + } + expect(result.current.info).toEqual({ ...nativeToolDefaults, ...customModelInfo }) + expect(result.current.info?.supportsNativeTools).toBe(true) + expect(result.current.info?.defaultToolProtocol).toBe("native") + }) + + it("should allow custom model info to override native tool defaults", () => { + const customModelInfo: ModelInfo = { + maxTokens: 8192, + contextWindow: 32000, + supportsImages: false, + supportsPromptCache: false, + supportsNativeTools: false, // Explicitly disable + defaultToolProtocol: "xml", // Override default to use XML instead of native + } + + const apiConfiguration: ProviderSettings = { + apiProvider: "openai", + openAiModelId: "custom-model-no-tools", + openAiCustomModelInfo: customModelInfo, + } + + const wrapper = createWrapper() + const { result } = renderHook(() => useSelectedModel(apiConfiguration), { wrapper }) + + expect(result.current.provider).toBe("openai") + expect(result.current.id).toBe("custom-model-no-tools") + // Custom model info should override the native tool defaults + expect(result.current.info?.supportsNativeTools).toBe(false) + expect(result.current.info?.defaultToolProtocol).toBe("xml") + }) + }) }) diff --git a/webview-ui/src/components/ui/hooks/useSelectedModel.ts b/webview-ui/src/components/ui/hooks/useSelectedModel.ts index f405dd78930..7748de49bde 100644 --- a/webview-ui/src/components/ui/hooks/useSelectedModel.ts +++ b/webview-ui/src/components/ui/hooks/useSelectedModel.ts @@ -279,7 +279,13 @@ function getSelectedModel({ } case "openai": { const id = apiConfiguration.openAiModelId ?? "" - const info = apiConfiguration?.openAiCustomModelInfo ?? openAiModelInfoSaneDefaults + const customInfo = apiConfiguration?.openAiCustomModelInfo + // Only merge native tool call defaults, not prices or other model-specific info + const nativeToolDefaults = { + supportsNativeTools: openAiModelInfoSaneDefaults.supportsNativeTools, + defaultToolProtocol: openAiModelInfoSaneDefaults.defaultToolProtocol, + } + const info = customInfo ? { ...nativeToolDefaults, ...customInfo } : openAiModelInfoSaneDefaults return { id, info } } case "ollama": {