diff --git a/packages/types/src/providers/mistral.ts b/packages/types/src/providers/mistral.ts index fa70f2606a4..25546e5a42d 100644 --- a/packages/types/src/providers/mistral.ts +++ b/packages/types/src/providers/mistral.ts @@ -11,73 +11,82 @@ export const mistralModels = { contextWindow: 128_000, supportsImages: true, supportsPromptCache: false, + supportsNativeTools: true, inputPrice: 2.0, outputPrice: 5.0, }, "devstral-medium-latest": { - maxTokens: 131_000, + maxTokens: 8192, contextWindow: 131_000, supportsImages: true, supportsPromptCache: false, + supportsNativeTools: true, inputPrice: 0.4, outputPrice: 2.0, }, "mistral-medium-latest": { - maxTokens: 131_000, + maxTokens: 8192, contextWindow: 131_000, supportsImages: true, supportsPromptCache: false, + supportsNativeTools: true, inputPrice: 0.4, outputPrice: 2.0, }, "codestral-latest": { - maxTokens: 256_000, + maxTokens: 8192, contextWindow: 256_000, supportsImages: false, supportsPromptCache: false, + supportsNativeTools: true, inputPrice: 0.3, outputPrice: 0.9, }, "mistral-large-latest": { - maxTokens: 131_000, + maxTokens: 8192, contextWindow: 131_000, supportsImages: false, supportsPromptCache: false, + supportsNativeTools: true, inputPrice: 2.0, outputPrice: 6.0, }, "ministral-8b-latest": { - maxTokens: 131_000, + maxTokens: 8192, contextWindow: 131_000, supportsImages: false, supportsPromptCache: false, + supportsNativeTools: true, inputPrice: 0.1, outputPrice: 0.1, }, "ministral-3b-latest": { - maxTokens: 131_000, + maxTokens: 8192, contextWindow: 131_000, supportsImages: false, supportsPromptCache: false, + supportsNativeTools: true, inputPrice: 0.04, outputPrice: 0.04, }, "mistral-small-latest": { - maxTokens: 32_000, + maxTokens: 8192, contextWindow: 32_000, supportsImages: false, supportsPromptCache: false, + supportsNativeTools: true, inputPrice: 0.2, outputPrice: 0.6, }, "pixtral-large-latest": { - maxTokens: 131_000, + maxTokens: 8192, contextWindow: 131_000, supportsImages: true, supportsPromptCache: false, + supportsNativeTools: true, inputPrice: 2.0, outputPrice: 6.0, }, } as const satisfies Record -export const MISTRAL_DEFAULT_TEMPERATURE = 0 +export const MISTRAL_DEFAULT_TEMPERATURE = 1 diff --git a/src/api/providers/__tests__/mistral.spec.ts b/src/api/providers/__tests__/mistral.spec.ts index ff3c5d3d8ba..9c1a26763c1 100644 --- a/src/api/providers/__tests__/mistral.spec.ts +++ b/src/api/providers/__tests__/mistral.spec.ts @@ -39,9 +39,11 @@ vi.mock("@mistralai/mistralai", () => { }) import type { Anthropic } from "@anthropic-ai/sdk" +import type OpenAI from "openai" import { MistralHandler } from "../mistral" import type { ApiHandlerOptions } from "../../../shared/api" -import type { ApiStreamTextChunk, ApiStreamReasoningChunk } from "../../transform/stream" +import type { ApiHandlerCreateMessageMetadata } from "../../index" +import type { ApiStreamTextChunk, ApiStreamReasoningChunk, ApiStreamToolCallPartialChunk } from "../../transform/stream" describe("MistralHandler", () => { let handler: MistralHandler @@ -223,6 +225,223 @@ describe("MistralHandler", () => { }) }) + describe("native tool calling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text", text: "What's the weather?" }], + }, + ] + + const mockTools: OpenAI.Chat.ChatCompletionTool[] = [ + { + type: "function", + function: { + name: "get_weather", + description: "Get the current weather", + parameters: { + type: "object", + properties: { + location: { type: "string" }, + }, + required: ["location"], + }, + }, + }, + ] + + it("should include tools in request when toolProtocol is native", async () => { + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: mockTools, + toolProtocol: "native", + } + + const iterator = handler.createMessage(systemPrompt, messages, metadata) + await iterator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + tools: expect.arrayContaining([ + expect.objectContaining({ + type: "function", + function: expect.objectContaining({ + name: "get_weather", + description: "Get the current weather", + parameters: expect.any(Object), + }), + }), + ]), + toolChoice: "any", + }), + ) + }) + + it("should not include tools when toolProtocol is xml", async () => { + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: mockTools, + toolProtocol: "xml", + } + + const iterator = handler.createMessage(systemPrompt, messages, metadata) + await iterator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.not.objectContaining({ + tools: expect.anything(), + }), + ) + }) + + it("should handle tool calls in streaming response", async () => { + // Mock stream with tool calls + mockCreate.mockImplementationOnce(async (_options) => { + const stream = { + [Symbol.asyncIterator]: async function* () { + yield { + data: { + choices: [ + { + delta: { + toolCalls: [ + { + id: "call_123", + type: "function", + function: { + name: "get_weather", + arguments: '{"location":"New York"}', + }, + }, + ], + }, + index: 0, + }, + ], + }, + } + }, + } + return stream + }) + + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: mockTools, + toolProtocol: "native", + } + + const iterator = handler.createMessage(systemPrompt, messages, metadata) + const results: ApiStreamToolCallPartialChunk[] = [] + + for await (const chunk of iterator) { + if (chunk.type === "tool_call_partial") { + results.push(chunk) + } + } + + expect(results).toHaveLength(1) + expect(results[0]).toEqual({ + type: "tool_call_partial", + index: 0, + id: "call_123", + name: "get_weather", + arguments: '{"location":"New York"}', + }) + }) + + it("should handle multiple tool calls in a single response", async () => { + // Mock stream with multiple tool calls + mockCreate.mockImplementationOnce(async (_options) => { + const stream = { + [Symbol.asyncIterator]: async function* () { + yield { + data: { + choices: [ + { + delta: { + toolCalls: [ + { + id: "call_1", + type: "function", + function: { + name: "get_weather", + arguments: '{"location":"NYC"}', + }, + }, + { + id: "call_2", + type: "function", + function: { + name: "get_weather", + arguments: '{"location":"LA"}', + }, + }, + ], + }, + index: 0, + }, + ], + }, + } + }, + } + return stream + }) + + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: mockTools, + toolProtocol: "native", + } + + const iterator = handler.createMessage(systemPrompt, messages, metadata) + const results: ApiStreamToolCallPartialChunk[] = [] + + for await (const chunk of iterator) { + if (chunk.type === "tool_call_partial") { + results.push(chunk) + } + } + + expect(results).toHaveLength(2) + expect(results[0]).toEqual({ + type: "tool_call_partial", + index: 0, + id: "call_1", + name: "get_weather", + arguments: '{"location":"NYC"}', + }) + expect(results[1]).toEqual({ + type: "tool_call_partial", + index: 1, + id: "call_2", + name: "get_weather", + arguments: '{"location":"LA"}', + }) + }) + + it("should always set toolChoice to 'any' when tools are provided", async () => { + // Even if tool_choice is provided in metadata, we override it to "any" + const metadata: ApiHandlerCreateMessageMetadata = { + taskId: "test-task", + tools: mockTools, + toolProtocol: "native", + tool_choice: "auto", // This should be ignored + } + + const iterator = handler.createMessage(systemPrompt, messages, metadata) + await iterator.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + toolChoice: "any", + }), + ) + }) + }) + describe("completePrompt", () => { it("should complete prompt successfully", async () => { const prompt = "Test prompt" diff --git a/src/api/providers/mistral.ts b/src/api/providers/mistral.ts index fef215d43f9..96d2c332552 100644 --- a/src/api/providers/mistral.ts +++ b/src/api/providers/mistral.ts @@ -1,5 +1,6 @@ import { Anthropic } from "@anthropic-ai/sdk" import { Mistral } from "@mistralai/mistralai" +import OpenAI from "openai" import { type MistralModelId, mistralDefaultModelId, mistralModels, MISTRAL_DEFAULT_TEMPERATURE } from "@roo-code/types" @@ -19,6 +20,26 @@ type ContentChunkWithThinking = { thinking?: Array<{ type: string; text?: string }> } +// Type for Mistral tool calls in stream delta +type MistralToolCall = { + id?: string + type?: string + function?: { + name?: string + arguments?: string + } +} + +// Type for Mistral tool definition - matches Mistral SDK Tool type +type MistralTool = { + type: "function" + function: { + name: string + description?: string + parameters: Record + } +} + export class MistralHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions private client: Mistral @@ -47,14 +68,35 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { id: model, maxTokens, temperature } = this.getModel() - - const response = await this.client.chat.stream({ + const { id: model, info, maxTokens, temperature } = this.getModel() + + // Build request options + const requestOptions: { + model: string + messages: ReturnType + maxTokens: number + temperature: number + tools?: MistralTool[] + toolChoice?: "auto" | "none" | "any" | "required" | { type: "function"; function: { name: string } } + } = { model, messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)], - maxTokens, + maxTokens: maxTokens ?? info.maxTokens, temperature, - }) + } + + // Add tools if provided and toolProtocol is not 'xml' and model supports native tools + const supportsNativeTools = info.supportsNativeTools ?? false + if (metadata?.tools && metadata.tools.length > 0 && metadata?.toolProtocol !== "xml" && supportsNativeTools) { + requestOptions.tools = this.convertToolsForMistral(metadata.tools) + // Always use "any" to require tool use + requestOptions.toolChoice = "any" + } + + // Temporary debug log for QA + // console.log("[MISTRAL DEBUG] Raw API request body:", requestOptions) + + const response = await this.client.chat.stream(requestOptions) for await (const event of response) { const delta = event.data.choices[0]?.delta @@ -83,6 +125,22 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand } } + // Handle tool calls in stream + // Mistral SDK provides tool_calls in delta similar to OpenAI format + const toolCalls = (delta as { toolCalls?: MistralToolCall[] })?.toolCalls + if (toolCalls) { + for (let i = 0; i < toolCalls.length; i++) { + const toolCall = toolCalls[i] + yield { + type: "tool_call_partial", + index: i, + id: toolCall.id, + name: toolCall.function?.name, + arguments: toolCall.function?.arguments, + } + } + } + if (event.data.usage) { yield { type: "usage", @@ -93,6 +151,24 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand } } + /** + * Convert OpenAI tool definitions to Mistral format. + * Mistral uses the same format as OpenAI for function tools. + */ + private convertToolsForMistral(tools: OpenAI.Chat.ChatCompletionTool[]): MistralTool[] { + return tools + .filter((tool) => tool.type === "function") + .map((tool) => ({ + type: "function" as const, + function: { + name: tool.function.name, + description: tool.function.description, + // Mistral SDK requires parameters to be defined, use empty object as fallback + parameters: (tool.function.parameters as Record) || {}, + }, + })) + } + override getModel() { const id = this.options.apiModelId ?? mistralDefaultModelId const info = mistralModels[id as MistralModelId] ?? mistralModels[mistralDefaultModelId] diff --git a/src/api/transform/__tests__/mistral-format.spec.ts b/src/api/transform/__tests__/mistral-format.spec.ts index dce99406c72..51d70bb3114 100644 --- a/src/api/transform/__tests__/mistral-format.spec.ts +++ b/src/api/transform/__tests__/mistral-format.spec.ts @@ -83,10 +83,12 @@ describe("convertToMistralMessages", () => { }, ] - // Based on the implementation, tool results without accompanying text/image - // don't generate any messages + // Tool results are converted to Mistral "tool" role messages const mistralMessages = convertToMistralMessages(anthropicMessages) - expect(mistralMessages).toHaveLength(0) + expect(mistralMessages).toHaveLength(1) + expect(mistralMessages[0].role).toBe("tool") + expect((mistralMessages[0] as { toolCallId?: string }).toolCallId).toBe("weather-123") + expect(mistralMessages[0].content).toBe("Current temperature in London: 20°C") }) it("should handle user messages with mixed content (text, image, and tool results)", () => { @@ -116,24 +118,14 @@ describe("convertToMistralMessages", () => { ] const mistralMessages = convertToMistralMessages(anthropicMessages) - // Based on the implementation, only the text and image content is included - // Tool results are not converted to separate messages + // Mistral doesn't allow user messages after tool messages, so only tool results are converted + // User content (text/images) is intentionally skipped when there are tool results expect(mistralMessages).toHaveLength(1) - // Message should be the user message with text and image - expect(mistralMessages[0].role).toBe("user") - const userContent = mistralMessages[0].content as Array<{ - type: string - text?: string - imageUrl?: { url: string } - }> - expect(Array.isArray(userContent)).toBe(true) - expect(userContent).toHaveLength(2) - expect(userContent[0]).toEqual({ type: "text", text: "Here's the weather data and an image:" }) - expect(userContent[1]).toEqual({ - type: "image_url", - imageUrl: { url: "" }, - }) + // Only the tool result should be present + expect(mistralMessages[0].role).toBe("tool") + expect((mistralMessages[0] as { toolCallId?: string }).toolCallId).toBe("weather-123") + expect(mistralMessages[0].content).toBe("Current temperature in London: 20°C") }) it("should handle assistant messages with text content", () => { @@ -254,8 +246,8 @@ describe("convertToMistralMessages", () => { ] const mistralMessages = convertToMistralMessages(anthropicMessages) - // Based on the implementation, user messages with only tool results don't generate messages - expect(mistralMessages).toHaveLength(3) + // Tool results are now converted to tool messages + expect(mistralMessages).toHaveLength(4) // User message with image expect(mistralMessages[0].role).toBe("user") @@ -267,12 +259,17 @@ describe("convertToMistralMessages", () => { expect(Array.isArray(userContent)).toBe(true) expect(userContent).toHaveLength(2) - // Assistant message with text (tool_use is not included in Mistral format) + // Assistant message with text and toolCalls expect(mistralMessages[1].role).toBe("assistant") expect(mistralMessages[1].content).toBe("This image shows a landscape with mountains.") + // Tool result message + expect(mistralMessages[2].role).toBe("tool") + expect((mistralMessages[2] as { toolCallId?: string }).toolCallId).toBe("search-123") + expect(mistralMessages[2].content).toBe("Found information about different mountain types.") + // Final assistant message - expect(mistralMessages[2]).toEqual({ + expect(mistralMessages[3]).toEqual({ role: "assistant", content: "Based on the search results, I can tell you more about the mountains in the image.", }) diff --git a/src/api/transform/mistral-format.ts b/src/api/transform/mistral-format.ts index 3f9487a9980..c184e2d731c 100644 --- a/src/api/transform/mistral-format.ts +++ b/src/api/transform/mistral-format.ts @@ -10,6 +10,16 @@ export type MistralMessage = | (AssistantMessage & { role: "assistant" }) | (ToolMessage & { role: "tool" }) +// Type for Mistral tool calls in assistant messages +type MistralToolCallMessage = { + id: string + type: "function" + function: { + name: string + arguments: string + } +} + export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.MessageParam[]): MistralMessage[] { const mistralMessages: MistralMessage[] = [] @@ -21,7 +31,7 @@ export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.M }) } else { if (anthropicMessage.role === "user") { - const { nonToolMessages } = anthropicMessage.content.reduce<{ + const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{ nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[] toolMessages: Anthropic.ToolResultBlockParam[] }>( @@ -36,7 +46,35 @@ export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.M { nonToolMessages: [], toolMessages: [] }, ) - if (nonToolMessages.length > 0) { + // If there are tool results, handle them + // Mistral's message order is strict: user → assistant → tool → assistant + // We CANNOT put user messages after tool messages + if (toolMessages.length > 0) { + // Convert tool_result blocks to Mistral tool messages + for (const toolResult of toolMessages) { + let resultContent: string + if (typeof toolResult.content === "string") { + resultContent = toolResult.content + } else if (Array.isArray(toolResult.content)) { + // Extract text from content blocks + resultContent = toolResult.content + .filter((block): block is Anthropic.TextBlockParam => block.type === "text") + .map((block) => block.text) + .join("\n") + } else { + resultContent = "" + } + + mistralMessages.push({ + role: "tool", + toolCallId: toolResult.tool_use_id, + content: resultContent, + } as ToolMessage & { role: "tool" }) + } + // Note: We intentionally skip any non-tool user content when there are tool results + // because Mistral doesn't allow user messages after tool messages + } else if (nonToolMessages.length > 0) { + // Only add user content if there are NO tool results mistralMessages.push({ role: "user", content: nonToolMessages.map((part) => { @@ -53,7 +91,7 @@ export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.M }) } } else if (anthropicMessage.role === "assistant") { - const { nonToolMessages } = anthropicMessage.content.reduce<{ + const { nonToolMessages, toolMessages } = anthropicMessage.content.reduce<{ nonToolMessages: (Anthropic.TextBlockParam | Anthropic.ImageBlockParam)[] toolMessages: Anthropic.ToolUseBlockParam[] }>( @@ -80,10 +118,37 @@ export function convertToMistralMessages(anthropicMessages: Anthropic.Messages.M .join("\n") } - mistralMessages.push({ + // Convert tool_use blocks to Mistral toolCalls format + let toolCalls: MistralToolCallMessage[] | undefined + if (toolMessages.length > 0) { + toolCalls = toolMessages.map((toolUse) => ({ + id: toolUse.id, + type: "function" as const, + function: { + name: toolUse.name, + arguments: + typeof toolUse.input === "string" ? toolUse.input : JSON.stringify(toolUse.input), + }, + })) + } + + // Mistral requires either content or toolCalls to be non-empty + // If we have toolCalls but no content, we need to handle this properly + const assistantMessage: AssistantMessage & { role: "assistant" } = { role: "assistant", content, - }) + } + + if (toolCalls && toolCalls.length > 0) { + ;( + assistantMessage as AssistantMessage & { + role: "assistant" + toolCalls?: MistralToolCallMessage[] + } + ).toolCalls = toolCalls + } + + mistralMessages.push(assistantMessage) } } }