diff --git a/packages/types/src/__tests__/provider-settings.test.ts b/packages/types/src/__tests__/provider-settings.test.ts new file mode 100644 index 00000000000..87c5bbcc1c8 --- /dev/null +++ b/packages/types/src/__tests__/provider-settings.test.ts @@ -0,0 +1,79 @@ +import { describe, it, expect } from "vitest" +import { getApiProtocol } from "../provider-settings.js" + +describe("getApiProtocol", () => { + describe("Anthropic-style providers", () => { + it("should return 'anthropic' for anthropic provider", () => { + expect(getApiProtocol("anthropic")).toBe("anthropic") + expect(getApiProtocol("anthropic", "gpt-4")).toBe("anthropic") + }) + + it("should return 'anthropic' for claude-code provider", () => { + expect(getApiProtocol("claude-code")).toBe("anthropic") + expect(getApiProtocol("claude-code", "some-model")).toBe("anthropic") + }) + }) + + describe("Vertex provider with Claude models", () => { + it("should return 'anthropic' for vertex provider with claude models", () => { + expect(getApiProtocol("vertex", "claude-3-opus")).toBe("anthropic") + expect(getApiProtocol("vertex", "Claude-3-Sonnet")).toBe("anthropic") + expect(getApiProtocol("vertex", "CLAUDE-instant")).toBe("anthropic") + expect(getApiProtocol("vertex", "anthropic/claude-3-haiku")).toBe("anthropic") + }) + + it("should return 'openai' for vertex provider with non-claude models", () => { + expect(getApiProtocol("vertex", "gpt-4")).toBe("openai") + expect(getApiProtocol("vertex", "gemini-pro")).toBe("openai") + expect(getApiProtocol("vertex", "llama-2")).toBe("openai") + }) + }) + + describe("Bedrock provider with Claude models", () => { + it("should return 'anthropic' for bedrock provider with claude models", () => { + expect(getApiProtocol("bedrock", "claude-3-opus")).toBe("anthropic") + expect(getApiProtocol("bedrock", "Claude-3-Sonnet")).toBe("anthropic") + expect(getApiProtocol("bedrock", "CLAUDE-instant")).toBe("anthropic") + expect(getApiProtocol("bedrock", "anthropic.claude-v2")).toBe("anthropic") + }) + + it("should return 'openai' for bedrock provider with non-claude models", () => { + expect(getApiProtocol("bedrock", "gpt-4")).toBe("openai") + expect(getApiProtocol("bedrock", "titan-text")).toBe("openai") + expect(getApiProtocol("bedrock", "llama-2")).toBe("openai") + }) + }) + + describe("Other providers with Claude models", () => { + it("should return 'openai' for non-vertex/bedrock providers with claude models", () => { + expect(getApiProtocol("openrouter", "claude-3-opus")).toBe("openai") + expect(getApiProtocol("openai", "claude-3-sonnet")).toBe("openai") + expect(getApiProtocol("litellm", "claude-instant")).toBe("openai") + expect(getApiProtocol("ollama", "claude-model")).toBe("openai") + }) + }) + + describe("Edge cases", () => { + it("should return 'openai' when provider is undefined", () => { + expect(getApiProtocol(undefined)).toBe("openai") + expect(getApiProtocol(undefined, "claude-3-opus")).toBe("openai") + }) + + it("should return 'openai' when model is undefined", () => { + expect(getApiProtocol("openai")).toBe("openai") + expect(getApiProtocol("vertex")).toBe("openai") + expect(getApiProtocol("bedrock")).toBe("openai") + }) + + it("should handle empty strings", () => { + expect(getApiProtocol("vertex", "")).toBe("openai") + expect(getApiProtocol("bedrock", "")).toBe("openai") + }) + + it("should be case-insensitive for claude detection", () => { + expect(getApiProtocol("vertex", "CLAUDE-3-OPUS")).toBe("anthropic") + expect(getApiProtocol("bedrock", "claude-3-opus")).toBe("anthropic") + expect(getApiProtocol("vertex", "ClAuDe-InStAnT")).toBe("anthropic") + }) + }) +}) diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index 3b53627295e..be74ae6bb4c 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -303,7 +303,23 @@ export const getModelId = (settings: ProviderSettings): string | undefined => { // Providers that use Anthropic-style API protocol export const ANTHROPIC_STYLE_PROVIDERS: ProviderName[] = ["anthropic", "claude-code"] -// Helper function to determine API protocol for a provider -export const getApiProtocol = (provider: ProviderName | undefined): "anthropic" | "openai" => { - return provider && ANTHROPIC_STYLE_PROVIDERS.includes(provider) ? "anthropic" : "openai" +// Helper function to determine API protocol for a provider and model +export const getApiProtocol = (provider: ProviderName | undefined, modelId?: string): "anthropic" | "openai" => { + // First check if the provider is an Anthropic-style provider + if (provider && ANTHROPIC_STYLE_PROVIDERS.includes(provider)) { + return "anthropic" + } + + // For vertex and bedrock providers, check if the model ID contains "claude" (case-insensitive) + if ( + provider && + (provider === "vertex" || provider === "bedrock") && + modelId && + modelId.toLowerCase().includes("claude") + ) { + return "anthropic" + } + + // Default to OpenAI protocol + return "openai" } diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 8a1bf1101d4..53b8ef5b87d 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -23,6 +23,7 @@ import { TelemetryEventName, TodoItem, getApiProtocol, + getModelId, } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" import { CloudService } from "@roo-code/cloud" @@ -1211,8 +1212,9 @@ export class Task extends EventEmitter { // take a few seconds. For the best UX we show a placeholder api_req_started // message with a loading spinner as this happens. - // Determine API protocol based on provider - const apiProtocol = getApiProtocol(this.apiConfiguration.apiProvider) + // Determine API protocol based on provider and model + const modelId = getModelId(this.apiConfiguration) + const apiProtocol = getApiProtocol(this.apiConfiguration.apiProvider, modelId) await this.say( "api_req_started", diff --git a/src/core/task/__tests__/Task.spec.ts b/src/core/task/__tests__/Task.spec.ts index 797714cde8f..9aa5a8d7a89 100644 --- a/src/core/task/__tests__/Task.spec.ts +++ b/src/core/task/__tests__/Task.spec.ts @@ -1398,5 +1398,100 @@ describe("Cline", () => { expect(task.diffStrategy).toBeUndefined() }) }) + + describe("getApiProtocol", () => { + it("should determine API protocol based on provider and model", async () => { + // Test with Anthropic provider + const anthropicConfig = { + ...mockApiConfig, + apiProvider: "anthropic" as const, + apiModelId: "gpt-4", + } + const anthropicTask = new Task({ + provider: mockProvider, + apiConfiguration: anthropicConfig, + task: "test task", + startTask: false, + }) + // Should use anthropic protocol even with non-claude model + expect(anthropicTask.apiConfiguration.apiProvider).toBe("anthropic") + + // Test with OpenRouter provider and Claude model + const openrouterClaudeConfig = { + apiProvider: "openrouter" as const, + openRouterModelId: "anthropic/claude-3-opus", + } + const openrouterClaudeTask = new Task({ + provider: mockProvider, + apiConfiguration: openrouterClaudeConfig, + task: "test task", + startTask: false, + }) + expect(openrouterClaudeTask.apiConfiguration.apiProvider).toBe("openrouter") + + // Test with OpenRouter provider and non-Claude model + const openrouterGptConfig = { + apiProvider: "openrouter" as const, + openRouterModelId: "openai/gpt-4", + } + const openrouterGptTask = new Task({ + provider: mockProvider, + apiConfiguration: openrouterGptConfig, + task: "test task", + startTask: false, + }) + expect(openrouterGptTask.apiConfiguration.apiProvider).toBe("openrouter") + + // Test with various Claude model formats + const claudeModelFormats = [ + "claude-3-opus", + "Claude-3-Sonnet", + "CLAUDE-instant", + "anthropic/claude-3-haiku", + "some-provider/claude-model", + ] + + for (const modelId of claudeModelFormats) { + const config = { + apiProvider: "openai" as const, + openAiModelId: modelId, + } + const task = new Task({ + provider: mockProvider, + apiConfiguration: config, + task: "test task", + startTask: false, + }) + // Verify the model ID contains claude (case-insensitive) + expect(modelId.toLowerCase()).toContain("claude") + } + }) + + it("should handle edge cases for API protocol detection", async () => { + // Test with undefined provider + const undefinedProviderConfig = { + apiModelId: "claude-3-opus", + } + const undefinedProviderTask = new Task({ + provider: mockProvider, + apiConfiguration: undefinedProviderConfig, + task: "test task", + startTask: false, + }) + expect(undefinedProviderTask.apiConfiguration.apiProvider).toBeUndefined() + + // Test with no model ID + const noModelConfig = { + apiProvider: "openai" as const, + } + const noModelTask = new Task({ + provider: mockProvider, + apiConfiguration: noModelConfig, + task: "test task", + startTask: false, + }) + expect(noModelTask.apiConfiguration.apiProvider).toBe("openai") + }) + }) }) })