Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions src/api/providers/__tests__/requesty.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import { TOOL_PROTOCOL } from "@roo-code/types"

import { RequestyHandler } from "../requesty"
import { ApiHandlerOptions } from "../../../shared/api"
import { Package } from "../../../shared/package"
import { ApiHandlerCreateMessageMetadata } from "../../index"

const mockCreate = vitest.fn()
const mockResolveToolProtocol = vitest.fn()

vitest.mock("openai", () => {
return {
Expand All @@ -23,6 +27,10 @@ vitest.mock("openai", () => {

vitest.mock("delay", () => ({ default: vitest.fn(() => Promise.resolve()) }))

vitest.mock("../../../utils/resolveToolProtocol", () => ({
resolveToolProtocol: (...args: any[]) => mockResolveToolProtocol(...args),
}))

vitest.mock("../fetchers/modelCache", () => ({
getModels: vitest.fn().mockImplementation(() => {
return Promise.resolve({
Expand Down Expand Up @@ -200,6 +208,176 @@ describe("RequestyHandler", () => {
const generator = handler.createMessage("test", [])
await expect(generator.next()).rejects.toThrow("API Error")
})

describe("native tool support", () => {
const systemPrompt = "test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user" as const, content: "What's the weather?" },
]

const mockTools: OpenAI.Chat.ChatCompletionTool[] = [
{
type: "function",
function: {
name: "get_weather",
description: "Get the current weather",
parameters: {
type: "object",
properties: {
location: { type: "string" },
},
required: ["location"],
},
},
},
]

beforeEach(() => {
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
id: "test-id",
choices: [{ delta: { content: "test response" } }],
}
},
}
mockCreate.mockResolvedValue(mockStream)
})

it("should include tools in request when toolProtocol is native", async () => {
mockResolveToolProtocol.mockReturnValue(TOOL_PROTOCOL.NATIVE)

const metadata: ApiHandlerCreateMessageMetadata = {
taskId: "test-task",
tools: mockTools,
tool_choice: "auto",
}

const handler = new RequestyHandler(mockOptions)
const iterator = handler.createMessage(systemPrompt, messages, metadata)
await iterator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
tools: expect.arrayContaining([
expect.objectContaining({
type: "function",
function: expect.objectContaining({
name: "get_weather",
description: "Get the current weather",
}),
}),
]),
tool_choice: "auto",
}),
)
})

it("should not include tools when toolProtocol is not native", async () => {
mockResolveToolProtocol.mockReturnValue(TOOL_PROTOCOL.XML)

const metadata: ApiHandlerCreateMessageMetadata = {
taskId: "test-task",
tools: mockTools,
tool_choice: "auto",
}

const handler = new RequestyHandler(mockOptions)
const iterator = handler.createMessage(systemPrompt, messages, metadata)
await iterator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.not.objectContaining({
tools: expect.anything(),
tool_choice: expect.anything(),
}),
)
})

it("should handle tool_call_partial chunks in streaming response", async () => {
mockResolveToolProtocol.mockReturnValue(TOOL_PROTOCOL.NATIVE)

const mockStreamWithToolCalls = {
async *[Symbol.asyncIterator]() {
yield {
id: "test-id",
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: "call_123",
function: {
name: "get_weather",
arguments: '{"location":',
},
},
],
},
},
],
}
yield {
id: "test-id",
choices: [
{
delta: {
tool_calls: [
{
index: 0,
function: {
arguments: '"New York"}',
},
},
],
},
},
],
}
yield {
id: "test-id",
choices: [{ delta: {} }],
usage: { prompt_tokens: 10, completion_tokens: 20 },
}
},
}
mockCreate.mockResolvedValue(mockStreamWithToolCalls)

const metadata: ApiHandlerCreateMessageMetadata = {
taskId: "test-task",
tools: mockTools,
}

const handler = new RequestyHandler(mockOptions)
const chunks = []
for await (const chunk of handler.createMessage(systemPrompt, messages, metadata)) {
chunks.push(chunk)
}

// Expect two tool_call_partial chunks and one usage chunk
expect(chunks).toHaveLength(3)
expect(chunks[0]).toEqual({
type: "tool_call_partial",
index: 0,
id: "call_123",
name: "get_weather",
arguments: '{"location":',
})
expect(chunks[1]).toEqual({
type: "tool_call_partial",
index: 0,
id: undefined,
name: undefined,
arguments: '"New York"}',
})
expect(chunks[2]).toMatchObject({
type: "usage",
inputTokens: 10,
outputTokens: 20,
})
})
})
})

describe("completePrompt", () => {
Expand Down
22 changes: 21 additions & 1 deletion src/api/providers/requesty.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"

import { type ModelInfo, requestyDefaultModelId, requestyDefaultModelInfo } from "@roo-code/types"
import { type ModelInfo, requestyDefaultModelId, requestyDefaultModelInfo, TOOL_PROTOCOL } from "@roo-code/types"

import type { ApiHandlerOptions, ModelRecord } from "../../shared/api"
import { resolveToolProtocol } from "../../utils/resolveToolProtocol"
import { calculateApiCostOpenAI } from "../../shared/cost"

import { convertToOpenAiMessages } from "../transform/openai-format"
Expand Down Expand Up @@ -133,6 +134,10 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan
? (reasoning_effort as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming["reasoning_effort"])
: undefined

// Check if native tool protocol is enabled
const toolProtocol = resolveToolProtocol(this.options, info)
const useNativeTools = toolProtocol === TOOL_PROTOCOL.NATIVE

const completionParams: RequestyChatCompletionParamsStreaming = {
messages: openAiMessages,
model,
Expand All @@ -143,6 +148,8 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan
stream: true,
stream_options: { include_usage: true },
requesty: { trace_id: metadata?.taskId, extra: { mode: metadata?.mode } },
...(useNativeTools && metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
...(useNativeTools && metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
}

let stream
Expand All @@ -165,6 +172,19 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan
yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" }
}

// Handle native tool calls
if (delta && "tool_calls" in delta && Array.isArray(delta.tool_calls)) {
for (const toolCall of delta.tool_calls) {
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
}

if (chunk.usage) {
lastUsage = chunk.usage
}
Expand Down
Loading