Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions packages/types/src/providers/mistral.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,73 +11,82 @@ export const mistralModels = {
contextWindow: 128_000,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 2.0,
outputPrice: 5.0,
},
"devstral-medium-latest": {
maxTokens: 131_000,
maxTokens: 8192,
contextWindow: 131_000,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 0.4,
outputPrice: 2.0,
},
"mistral-medium-latest": {
maxTokens: 131_000,
maxTokens: 8192,
contextWindow: 131_000,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 0.4,
outputPrice: 2.0,
},
"codestral-latest": {
maxTokens: 256_000,
maxTokens: 8192,
contextWindow: 256_000,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 0.3,
outputPrice: 0.9,
},
"mistral-large-latest": {
maxTokens: 131_000,
maxTokens: 8192,
contextWindow: 131_000,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 2.0,
outputPrice: 6.0,
},
"ministral-8b-latest": {
maxTokens: 131_000,
maxTokens: 8192,
contextWindow: 131_000,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 0.1,
outputPrice: 0.1,
},
"ministral-3b-latest": {
maxTokens: 131_000,
maxTokens: 8192,
contextWindow: 131_000,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 0.04,
outputPrice: 0.04,
},
"mistral-small-latest": {
maxTokens: 32_000,
maxTokens: 8192,
contextWindow: 32_000,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 0.2,
outputPrice: 0.6,
},
"pixtral-large-latest": {
maxTokens: 131_000,
maxTokens: 8192,
contextWindow: 131_000,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 2.0,
outputPrice: 6.0,
},
} as const satisfies Record<string, ModelInfo>

export const MISTRAL_DEFAULT_TEMPERATURE = 0
export const MISTRAL_DEFAULT_TEMPERATURE = 1
221 changes: 220 additions & 1 deletion src/api/providers/__tests__/mistral.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ vi.mock("@mistralai/mistralai", () => {
})

import type { Anthropic } from "@anthropic-ai/sdk"
import type OpenAI from "openai"
import { MistralHandler } from "../mistral"
import type { ApiHandlerOptions } from "../../../shared/api"
import type { ApiStreamTextChunk, ApiStreamReasoningChunk } from "../../transform/stream"
import type { ApiHandlerCreateMessageMetadata } from "../../index"
import type { ApiStreamTextChunk, ApiStreamReasoningChunk, ApiStreamToolCallPartialChunk } from "../../transform/stream"

describe("MistralHandler", () => {
let handler: MistralHandler
Expand Down Expand Up @@ -223,6 +225,223 @@ describe("MistralHandler", () => {
})
})

describe("native tool calling", () => {
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [{ type: "text", text: "What's the weather?" }],
},
]

const mockTools: OpenAI.Chat.ChatCompletionTool[] = [
{
type: "function",
function: {
name: "get_weather",
description: "Get the current weather",
parameters: {
type: "object",
properties: {
location: { type: "string" },
},
required: ["location"],
},
},
},
]

it("should include tools in request when toolProtocol is native", async () => {
const metadata: ApiHandlerCreateMessageMetadata = {
taskId: "test-task",
tools: mockTools,
toolProtocol: "native",
}

const iterator = handler.createMessage(systemPrompt, messages, metadata)
await iterator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
tools: expect.arrayContaining([
expect.objectContaining({
type: "function",
function: expect.objectContaining({
name: "get_weather",
description: "Get the current weather",
parameters: expect.any(Object),
}),
}),
]),
toolChoice: "any",
}),
)
})

it("should not include tools when toolProtocol is xml", async () => {
const metadata: ApiHandlerCreateMessageMetadata = {
taskId: "test-task",
tools: mockTools,
toolProtocol: "xml",
}

const iterator = handler.createMessage(systemPrompt, messages, metadata)
await iterator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.not.objectContaining({
tools: expect.anything(),
}),
)
})

it("should handle tool calls in streaming response", async () => {
// Mock stream with tool calls
mockCreate.mockImplementationOnce(async (_options) => {
const stream = {
[Symbol.asyncIterator]: async function* () {
yield {
data: {
choices: [
{
delta: {
toolCalls: [
{
id: "call_123",
type: "function",
function: {
name: "get_weather",
arguments: '{"location":"New York"}',
},
},
],
},
index: 0,
},
],
},
}
},
}
return stream
})

const metadata: ApiHandlerCreateMessageMetadata = {
taskId: "test-task",
tools: mockTools,
toolProtocol: "native",
}

const iterator = handler.createMessage(systemPrompt, messages, metadata)
const results: ApiStreamToolCallPartialChunk[] = []

for await (const chunk of iterator) {
if (chunk.type === "tool_call_partial") {
results.push(chunk)
}
}

expect(results).toHaveLength(1)
expect(results[0]).toEqual({
type: "tool_call_partial",
index: 0,
id: "call_123",
name: "get_weather",
arguments: '{"location":"New York"}',
})
})

it("should handle multiple tool calls in a single response", async () => {
// Mock stream with multiple tool calls
mockCreate.mockImplementationOnce(async (_options) => {
const stream = {
[Symbol.asyncIterator]: async function* () {
yield {
data: {
choices: [
{
delta: {
toolCalls: [
{
id: "call_1",
type: "function",
function: {
name: "get_weather",
arguments: '{"location":"NYC"}',
},
},
{
id: "call_2",
type: "function",
function: {
name: "get_weather",
arguments: '{"location":"LA"}',
},
},
],
},
index: 0,
},
],
},
}
},
}
return stream
})

const metadata: ApiHandlerCreateMessageMetadata = {
taskId: "test-task",
tools: mockTools,
toolProtocol: "native",
}

const iterator = handler.createMessage(systemPrompt, messages, metadata)
const results: ApiStreamToolCallPartialChunk[] = []

for await (const chunk of iterator) {
if (chunk.type === "tool_call_partial") {
results.push(chunk)
}
}

expect(results).toHaveLength(2)
expect(results[0]).toEqual({
type: "tool_call_partial",
index: 0,
id: "call_1",
name: "get_weather",
arguments: '{"location":"NYC"}',
})
expect(results[1]).toEqual({
type: "tool_call_partial",
index: 1,
id: "call_2",
name: "get_weather",
arguments: '{"location":"LA"}',
})
})

it("should always set toolChoice to 'any' when tools are provided", async () => {
// Even if tool_choice is provided in metadata, we override it to "any"
const metadata: ApiHandlerCreateMessageMetadata = {
taskId: "test-task",
tools: mockTools,
toolProtocol: "native",
tool_choice: "auto", // This should be ignored
}

const iterator = handler.createMessage(systemPrompt, messages, metadata)
await iterator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: "any",
}),
)
})
})

describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
const prompt = "Test prompt"
Expand Down
Loading
Loading