Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions packages/types/src/__tests__/provider-settings.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import { describe, it, expect } from "vitest"
import { getApiProtocol } from "../provider-settings.js"

describe("getApiProtocol", () => {
describe("Anthropic-style providers", () => {
it("should return 'anthropic' for anthropic provider", () => {
expect(getApiProtocol("anthropic")).toBe("anthropic")
expect(getApiProtocol("anthropic", "gpt-4")).toBe("anthropic")
})

it("should return 'anthropic' for claude-code provider", () => {
expect(getApiProtocol("claude-code")).toBe("anthropic")
expect(getApiProtocol("claude-code", "some-model")).toBe("anthropic")
})
})

describe("Vertex provider with Claude models", () => {
it("should return 'anthropic' for vertex provider with claude models", () => {
expect(getApiProtocol("vertex", "claude-3-opus")).toBe("anthropic")
expect(getApiProtocol("vertex", "Claude-3-Sonnet")).toBe("anthropic")
expect(getApiProtocol("vertex", "CLAUDE-instant")).toBe("anthropic")
expect(getApiProtocol("vertex", "anthropic/claude-3-haiku")).toBe("anthropic")
})

it("should return 'openai' for vertex provider with non-claude models", () => {
expect(getApiProtocol("vertex", "gpt-4")).toBe("openai")
expect(getApiProtocol("vertex", "gemini-pro")).toBe("openai")
expect(getApiProtocol("vertex", "llama-2")).toBe("openai")
})
})

describe("Bedrock provider with Claude models", () => {
it("should return 'anthropic' for bedrock provider with claude models", () => {
expect(getApiProtocol("bedrock", "claude-3-opus")).toBe("anthropic")
expect(getApiProtocol("bedrock", "Claude-3-Sonnet")).toBe("anthropic")
expect(getApiProtocol("bedrock", "CLAUDE-instant")).toBe("anthropic")
expect(getApiProtocol("bedrock", "anthropic.claude-v2")).toBe("anthropic")
})

it("should return 'openai' for bedrock provider with non-claude models", () => {
expect(getApiProtocol("bedrock", "gpt-4")).toBe("openai")
expect(getApiProtocol("bedrock", "titan-text")).toBe("openai")
expect(getApiProtocol("bedrock", "llama-2")).toBe("openai")
})
})

describe("Other providers with Claude models", () => {
it("should return 'openai' for non-vertex/bedrock providers with claude models", () => {
expect(getApiProtocol("openrouter", "claude-3-opus")).toBe("openai")
expect(getApiProtocol("openai", "claude-3-sonnet")).toBe("openai")
expect(getApiProtocol("litellm", "claude-instant")).toBe("openai")
expect(getApiProtocol("ollama", "claude-model")).toBe("openai")
})
})

describe("Edge cases", () => {
it("should return 'openai' when provider is undefined", () => {
expect(getApiProtocol(undefined)).toBe("openai")
expect(getApiProtocol(undefined, "claude-3-opus")).toBe("openai")
})

it("should return 'openai' when model is undefined", () => {
expect(getApiProtocol("openai")).toBe("openai")
expect(getApiProtocol("vertex")).toBe("openai")
expect(getApiProtocol("bedrock")).toBe("openai")
})

it("should handle empty strings", () => {
expect(getApiProtocol("vertex", "")).toBe("openai")
expect(getApiProtocol("bedrock", "")).toBe("openai")
})

it("should be case-insensitive for claude detection", () => {
expect(getApiProtocol("vertex", "CLAUDE-3-OPUS")).toBe("anthropic")
expect(getApiProtocol("bedrock", "claude-3-opus")).toBe("anthropic")
expect(getApiProtocol("vertex", "ClAuDe-InStAnT")).toBe("anthropic")
})
})
})
22 changes: 19 additions & 3 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,23 @@ export const getModelId = (settings: ProviderSettings): string | undefined => {
// Providers that use Anthropic-style API protocol
export const ANTHROPIC_STYLE_PROVIDERS: ProviderName[] = ["anthropic", "claude-code"]

// Helper function to determine API protocol for a provider
export const getApiProtocol = (provider: ProviderName | undefined): "anthropic" | "openai" => {
return provider && ANTHROPIC_STYLE_PROVIDERS.includes(provider) ? "anthropic" : "openai"
// Helper function to determine API protocol for a provider and model
export const getApiProtocol = (provider: ProviderName | undefined, modelId?: string): "anthropic" | "openai" => {
// First check if the provider is an Anthropic-style provider
if (provider && ANTHROPIC_STYLE_PROVIDERS.includes(provider)) {
return "anthropic"
}

// For vertex and bedrock providers, check if the model ID contains "claude" (case-insensitive)
if (
provider &&
(provider === "vertex" || provider === "bedrock") &&
modelId &&
modelId.toLowerCase().includes("claude")
) {
return "anthropic"
}

// Default to OpenAI protocol
return "openai"
}
6 changes: 4 additions & 2 deletions src/core/task/Task.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
TelemetryEventName,
TodoItem,
getApiProtocol,
getModelId,
} from "@roo-code/types"
import { TelemetryService } from "@roo-code/telemetry"
import { CloudService } from "@roo-code/cloud"
Expand Down Expand Up @@ -1211,8 +1212,9 @@ export class Task extends EventEmitter<ClineEvents> {
// take a few seconds. For the best UX we show a placeholder api_req_started
// message with a loading spinner as this happens.

// Determine API protocol based on provider
const apiProtocol = getApiProtocol(this.apiConfiguration.apiProvider)
// Determine API protocol based on provider and model
const modelId = getModelId(this.apiConfiguration)
const apiProtocol = getApiProtocol(this.apiConfiguration.apiProvider, modelId)

await this.say(
"api_req_started",
Expand Down
95 changes: 95 additions & 0 deletions src/core/task/__tests__/Task.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1398,5 +1398,100 @@ describe("Cline", () => {
expect(task.diffStrategy).toBeUndefined()
})
})

describe("getApiProtocol", () => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests in the 'getApiProtocol' describe block verify that the configuration's provider field remains unchanged, but they do not directly assert the computed API protocol. Consider adding tests that call getApiProtocol(provider, modelId) explicitly to ensure it returns the expected protocol based on the model name.

it("should determine API protocol based on provider and model", async () => {
// Test with Anthropic provider
const anthropicConfig = {
...mockApiConfig,
apiProvider: "anthropic" as const,
apiModelId: "gpt-4",
}
const anthropicTask = new Task({
provider: mockProvider,
apiConfiguration: anthropicConfig,
task: "test task",
startTask: false,
})
// Should use anthropic protocol even with non-claude model
expect(anthropicTask.apiConfiguration.apiProvider).toBe("anthropic")

// Test with OpenRouter provider and Claude model
const openrouterClaudeConfig = {
apiProvider: "openrouter" as const,
openRouterModelId: "anthropic/claude-3-opus",
}
const openrouterClaudeTask = new Task({
provider: mockProvider,
apiConfiguration: openrouterClaudeConfig,
task: "test task",
startTask: false,
})
expect(openrouterClaudeTask.apiConfiguration.apiProvider).toBe("openrouter")

// Test with OpenRouter provider and non-Claude model
const openrouterGptConfig = {
apiProvider: "openrouter" as const,
openRouterModelId: "openai/gpt-4",
}
const openrouterGptTask = new Task({
provider: mockProvider,
apiConfiguration: openrouterGptConfig,
task: "test task",
startTask: false,
})
expect(openrouterGptTask.apiConfiguration.apiProvider).toBe("openrouter")

// Test with various Claude model formats
const claudeModelFormats = [
"claude-3-opus",
"Claude-3-Sonnet",
"CLAUDE-instant",
"anthropic/claude-3-haiku",
"some-provider/claude-model",
]

for (const modelId of claudeModelFormats) {
const config = {
apiProvider: "openai" as const,
openAiModelId: modelId,
}
const task = new Task({
provider: mockProvider,
apiConfiguration: config,
task: "test task",
startTask: false,
})
// Verify the model ID contains claude (case-insensitive)
expect(modelId.toLowerCase()).toContain("claude")
}
})

it("should handle edge cases for API protocol detection", async () => {
// Test with undefined provider
const undefinedProviderConfig = {
apiModelId: "claude-3-opus",
}
const undefinedProviderTask = new Task({
provider: mockProvider,
apiConfiguration: undefinedProviderConfig,
task: "test task",
startTask: false,
})
expect(undefinedProviderTask.apiConfiguration.apiProvider).toBeUndefined()

// Test with no model ID
const noModelConfig = {
apiProvider: "openai" as const,
}
const noModelTask = new Task({
provider: mockProvider,
apiConfiguration: noModelConfig,
task: "test task",
startTask: false,
})
expect(noModelTask.apiConfiguration.apiProvider).toBe("openai")
})
})
})
})