diff --git a/src/api/providers/__tests__/requesty.spec.ts b/src/api/providers/__tests__/requesty.spec.ts index a8fabd40338..df799426a72 100644 --- a/src/api/providers/__tests__/requesty.spec.ts +++ b/src/api/providers/__tests__/requesty.spec.ts @@ -3,11 +3,15 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" +import { TOOL_PROTOCOL } from "@roo-code/types" + import { RequestyHandler } from "../requesty" import { ApiHandlerOptions } from "../../../shared/api" import { Package } from "../../../shared/package" +import { ApiHandlerCreateMessageMetadata } from "../../index" const mockCreate = vitest.fn() +const mockResolveToolProtocol = vitest.fn() vitest.mock("openai", () => { return { @@ -23,6 +27,10 @@ vitest.mock("openai", () => { vitest.mock("delay", () => ({ default: vitest.fn(() => Promise.resolve()) })) +vitest.mock("../../../utils/resolveToolProtocol", () => ({ + resolveToolProtocol: (...args: any[]) => mockResolveToolProtocol(...args), +})) + vitest.mock("../fetchers/modelCache", () => ({ getModels: vitest.fn().mockImplementation(() => { return Promise.resolve({ @@ -200,6 +208,176 @@ describe("RequestyHandler", () => { const generator = handler.createMessage("test", []) await expect(generator.next()).rejects.toThrow("API Error") }) + + describe("native tool support", () => { + const systemPrompt = "test system prompt" + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user" as const, content: "What's the weather?" }, + ] + + const mockTools: OpenAI.Chat.ChatCompletionTool[] = [ + { + type: "function", + function: { + name: "get_weather", + description: "Get the current weather", + parameters: { + type: "object", + properties: { + location: { type: "string" }, + }, + required: ["location"], + }, + }, + }, + ] + + beforeEach(() => { + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + id: "test-id", + choices: [{ delta: { content: "test response" } }], + } + }, + } + mockCreate.mockResolvedValue(mockStream) + }) + + it("should include tools in request when toolProtocol is native", async () => { + mockResolveToolProtocol.mockReturnValue(TOOL_PROTOCOL.NATIVE) + + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: mockTools, + tool_choice: "auto", + } + + const handler = new RequestyHandler(mockOptions) + const iterator = handler.createMessage(systemPrompt, messages, metadata) + await iterator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + tools: expect.arrayContaining([ + expect.objectContaining({ + type: "function", + function: expect.objectContaining({ + name: "get_weather", + description: "Get the current weather", + }), + }), + ]), + tool_choice: "auto", + }), + ) + }) + + it("should not include tools when toolProtocol is not native", async () => { + mockResolveToolProtocol.mockReturnValue(TOOL_PROTOCOL.XML) + + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: mockTools, + tool_choice: "auto", + } + + const handler = new RequestyHandler(mockOptions) + const iterator = handler.createMessage(systemPrompt, messages, metadata) + await iterator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.not.objectContaining({ + tools: expect.anything(), + tool_choice: expect.anything(), + }), + ) + }) + + it("should handle tool_call_partial chunks in streaming response", async () => { + mockResolveToolProtocol.mockReturnValue(TOOL_PROTOCOL.NATIVE) + + const mockStreamWithToolCalls = { + async *[Symbol.asyncIterator]() { + yield { + id: "test-id", + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_123", + function: { + name: "get_weather", + arguments: '{"location":', + }, + }, + ], + }, + }, + ], + } + yield { + id: "test-id", + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + function: { + arguments: '"New York"}', + }, + }, + ], + }, + }, + ], + } + yield { + id: "test-id", + choices: [{ delta: {} }], + usage: { prompt_tokens: 10, completion_tokens: 20 }, + } + }, + } + mockCreate.mockResolvedValue(mockStreamWithToolCalls) + + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: mockTools, + } + + const handler = new RequestyHandler(mockOptions) + const chunks = [] + for await (const chunk of handler.createMessage(systemPrompt, messages, metadata)) { + chunks.push(chunk) + } + + // Expect two tool_call_partial chunks and one usage chunk + expect(chunks).toHaveLength(3) + expect(chunks[0]).toEqual({ + type: "tool_call_partial", + index: 0, + id: "call_123", + name: "get_weather", + arguments: '{"location":', + }) + expect(chunks[1]).toEqual({ + type: "tool_call_partial", + index: 0, + id: undefined, + name: undefined, + arguments: '"New York"}', + }) + expect(chunks[2]).toMatchObject({ + type: "usage", + inputTokens: 10, + outputTokens: 20, + }) + }) + }) }) describe("completePrompt", () => { diff --git a/src/api/providers/requesty.ts b/src/api/providers/requesty.ts index 979146b378f..3668265669a 100644 --- a/src/api/providers/requesty.ts +++ b/src/api/providers/requesty.ts @@ -1,9 +1,10 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { type ModelInfo, requestyDefaultModelId, requestyDefaultModelInfo } from "@roo-code/types" +import { type ModelInfo, requestyDefaultModelId, requestyDefaultModelInfo, TOOL_PROTOCOL } from "@roo-code/types" import type { ApiHandlerOptions, ModelRecord } from "../../shared/api" +import { resolveToolProtocol } from "../../utils/resolveToolProtocol" import { calculateApiCostOpenAI } from "../../shared/cost" import { convertToOpenAiMessages } from "../transform/openai-format" @@ -133,6 +134,10 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan ? (reasoning_effort as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming["reasoning_effort"]) : undefined + // Check if native tool protocol is enabled + const toolProtocol = resolveToolProtocol(this.options, info) + const useNativeTools = toolProtocol === TOOL_PROTOCOL.NATIVE + const completionParams: RequestyChatCompletionParamsStreaming = { messages: openAiMessages, model, @@ -143,6 +148,8 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan stream: true, stream_options: { include_usage: true }, requesty: { trace_id: metadata?.taskId, extra: { mode: metadata?.mode } }, + ...(useNativeTools && metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }), + ...(useNativeTools && metadata?.tool_choice && { tool_choice: metadata.tool_choice }), } let stream @@ -165,6 +172,19 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" } } + // Handle native tool calls + if (delta && "tool_calls" in delta && Array.isArray(delta.tool_calls)) { + for (const toolCall of delta.tool_calls) { + yield { + type: "tool_call_partial", + index: toolCall.index, + id: toolCall.id, + name: toolCall.function?.name, + arguments: toolCall.function?.arguments, + } + } + } + if (chunk.usage) { lastUsage = chunk.usage }