Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/providers/unbound.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export const unboundDefaultModelInfo: ModelInfo = {
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 3.0,
outputPrice: 15.0,
cacheWritesPrice: 3.75,
Expand Down
226 changes: 226 additions & 0 deletions src/api/providers/__tests__/unbound.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ vitest.mock("../fetchers/modelCache", () => ({
contextWindow: 200000,
supportsImages: true,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 3,
outputPrice: 15,
cacheWritesPrice: 3.75,
Expand All @@ -27,6 +28,7 @@ vitest.mock("../fetchers/modelCache", () => ({
contextWindow: 200000,
supportsImages: true,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 3,
outputPrice: 15,
cacheWritesPrice: 3.75,
Expand All @@ -39,6 +41,7 @@ vitest.mock("../fetchers/modelCache", () => ({
contextWindow: 200000,
supportsImages: true,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 3,
outputPrice: 15,
cacheWritesPrice: 3.75,
Expand All @@ -51,6 +54,7 @@ vitest.mock("../fetchers/modelCache", () => ({
contextWindow: 128000,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 5,
outputPrice: 15,
description: "GPT-4o",
Expand All @@ -60,6 +64,7 @@ vitest.mock("../fetchers/modelCache", () => ({
contextWindow: 128000,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 1,
outputPrice: 3,
description: "O3 Mini",
Expand Down Expand Up @@ -328,4 +333,225 @@ describe("UnboundHandler", () => {
expect(modelInfo.info).toBeDefined()
})
})

describe("Native Tool Calling", () => {
const testTools = [
{
type: "function" as const,
function: {
name: "test_tool",
description: "A test tool",
parameters: {
type: "object",
properties: {
arg1: { type: "string", description: "First argument" },
},
required: ["arg1"],
},
},
},
]

it("should include tools in request when model supports native tools and tools are provided", async () => {
mockWithResponse.mockResolvedValueOnce({
data: {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
},
})

const messageGenerator = handler.createMessage("test prompt", [], {
taskId: "test-task-id",
tools: testTools,
toolProtocol: "native",
})
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
tools: expect.arrayContaining([
expect.objectContaining({
type: "function",
function: expect.objectContaining({
name: "test_tool",
}),
}),
]),
parallel_tool_calls: false,
}),
expect.objectContaining({
headers: {
"X-Unbound-Metadata": expect.stringContaining("roo-code"),
},
}),
)
})

it("should include tool_choice when provided", async () => {
mockWithResponse.mockResolvedValueOnce({
data: {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
},
})

const messageGenerator = handler.createMessage("test prompt", [], {
taskId: "test-task-id",
tools: testTools,
toolProtocol: "native",
tool_choice: "auto",
})
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
tool_choice: "auto",
}),
expect.objectContaining({
headers: {
"X-Unbound-Metadata": expect.stringContaining("roo-code"),
},
}),
)
})

it("should not include tools when toolProtocol is xml", async () => {
mockWithResponse.mockResolvedValueOnce({
data: {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
},
})

const messageGenerator = handler.createMessage("test prompt", [], {
taskId: "test-task-id",
tools: testTools,
toolProtocol: "xml",
})
await messageGenerator.next()

const callArgs = mockCreate.mock.calls[mockCreate.mock.calls.length - 1][0]
expect(callArgs).not.toHaveProperty("tools")
expect(callArgs).not.toHaveProperty("tool_choice")
})

it("should yield tool_call_partial chunks during streaming", async () => {
mockWithResponse.mockResolvedValueOnce({
data: {
[Symbol.asyncIterator]: () => ({
next: vi
.fn()
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: "call_123",
function: {
name: "test_tool",
arguments: '{"arg1":',
},
},
],
},
},
],
},
})
.mockResolvedValueOnce({
done: false,
value: {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
function: {
arguments: '"value"}',
},
},
],
},
},
],
},
})
.mockResolvedValueOnce({ done: true }),
}),
},
})

const stream = handler.createMessage("test prompt", [], {
taskId: "test-task-id",
tools: testTools,
toolProtocol: "native",
})

const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks).toContainEqual({
type: "tool_call_partial",
index: 0,
id: "call_123",
name: "test_tool",
arguments: '{"arg1":',
})

expect(chunks).toContainEqual({
type: "tool_call_partial",
index: 0,
id: undefined,
name: undefined,
arguments: '"value"}',
})
})

it("should set parallel_tool_calls based on metadata", async () => {
mockWithResponse.mockResolvedValueOnce({
data: {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
},
})

const messageGenerator = handler.createMessage("test prompt", [], {
taskId: "test-task-id",
tools: testTools,
toolProtocol: "native",
parallelToolCalls: true,
})
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
parallel_tool_calls: true,
}),
expect.objectContaining({
headers: {
"X-Unbound-Metadata": expect.stringContaining("roo-code"),
},
}),
)
})
})
})
1 change: 1 addition & 0 deletions src/api/providers/fetchers/unbound.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export async function getUnboundModels(apiKey?: string | null): Promise<Record<s
contextWindow: model?.contextWindow ? parseInt(model.contextWindow) : 0,
supportsImages: model?.supportsImages ?? false,
supportsPromptCache: model?.supportsPromptCaching ?? false,
supportsNativeTools: true,
inputPrice: model?.inputTokenPrice ? parseFloat(model.inputTokenPrice) : undefined,
outputPrice: model?.outputTokenPrice ? parseFloat(model.outputTokenPrice) : undefined,
cacheWritesPrice: model?.cacheWritePrice ? parseFloat(model.cacheWritePrice) : undefined,
Expand Down
49 changes: 48 additions & 1 deletion src/api/providers/unbound.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import { addCacheBreakpoints as addVertexCacheBreakpoints } from "../transform/c

import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { RouterProvider } from "./router-provider"
import { getModelParams } from "../transform/model-params"
import { getModels } from "./fetchers/modelCache"

const ORIGIN_APP = "roo-code"

Expand Down Expand Up @@ -52,12 +54,35 @@ export class UnboundHandler extends RouterProvider implements SingleCompletionHa
})
}

public override async fetchModel() {
this.models = await getModels({ provider: this.name, apiKey: this.client.apiKey, baseUrl: this.client.baseURL })
return this.getModel()
}

override getModel() {
const requestedId = this.options.unboundModelId ?? unboundDefaultModelId
const modelExists = this.models[requestedId]
const id = modelExists ? requestedId : unboundDefaultModelId
const info = modelExists ? this.models[requestedId] : unboundDefaultModelInfo

const params = getModelParams({
format: "openai",
modelId: id,
model: info,
settings: this.options,
})

return { id, info, ...params }
}

override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const { id: modelId, info } = await this.fetchModel()
// Ensure we have up-to-date model metadata
await this.fetchModel()
const { id: modelId, info } = this.getModel()

const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
{ role: "system", content: systemPrompt },
Expand All @@ -83,16 +108,25 @@ export class UnboundHandler extends RouterProvider implements SingleCompletionHa
maxTokens = info.maxTokens ?? undefined
}

// Check if model supports native tools and tools are provided with native protocol
const supportsNativeTools = info.supportsNativeTools ?? false
const useNativeTools =
supportsNativeTools && metadata?.tools && metadata.tools.length > 0 && metadata?.toolProtocol !== "xml"

const requestOptions: UnboundChatCompletionCreateParamsStreaming = {
model: modelId.split("/")[1],
max_tokens: maxTokens,
messages: openAiMessages,
stream: true,
stream_options: { include_usage: true },
unbound_metadata: {
originApp: ORIGIN_APP,
taskId: metadata?.taskId,
mode: metadata?.mode,
},
...(useNativeTools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
...(useNativeTools && metadata.tool_choice && { tool_choice: metadata.tool_choice }),
...(useNativeTools && { parallel_tool_calls: metadata?.parallelToolCalls ?? false }),
}

if (this.supportsTemperature(modelId)) {
Expand All @@ -111,6 +145,19 @@ export class UnboundHandler extends RouterProvider implements SingleCompletionHa
yield { type: "text", text: delta.content }
}

// Handle tool calls in stream - emit partial chunks for NativeToolCallParser
if (delta?.tool_calls) {
for (const toolCall of delta.tool_calls) {
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
}

if (usage) {
const usageData: ApiStreamUsageChunk = {
type: "usage",
Expand Down
Loading