diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 58f6354f62d..c231b43e53a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -749,6 +749,9 @@ importers: '@ai-sdk/amazon-bedrock': specifier: ^4.0.50 version: 4.0.50(zod@3.25.76) + '@ai-sdk/baseten': + specifier: ^1.0.31 + version: 1.0.31(zod@3.25.76) '@ai-sdk/cerebras': specifier: ^1.0.0 version: 1.0.35(zod@3.25.76) @@ -1435,6 +1438,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/baseten@1.0.31': + resolution: {integrity: sha512-tGbV96WBb5nnfyUYFrPyBxrhw53YlKSJbMC+rH3HhQlUaIs8+m/Bm4M0isrek9owIIf4MmmSDZ5VZL08zz7eFQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/cerebras@1.0.35': resolution: {integrity: sha512-JrNdMYptrOUjNthibgBeAcBjZ/H+fXb49sSrWhOx5Aq8eUcrYvwQ2DtSAi8VraHssZu78NAnBMrgFWSUOTXFxw==} engines: {node: '>=18'} @@ -1513,6 +1522,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/openai-compatible@2.0.28': + resolution: {integrity: sha512-WzDnU0B13FMSSupDtm2lksFZvWGXnOfhG5S0HoPI0pkX5uVkr6N1UTATMyVaxLCG0MRkMhXCjkg4NXgEbb330Q==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.20': resolution: {integrity: sha512-iXHVe0apM2zUEzauqJwqmpC37A5rihrStAih5Ks+JE32iTe4LZ58y17UGBjpQQTCRw9YxMeo2UFLxLpBluyvLQ==} engines: {node: '>=18'} @@ -1543,6 +1558,12 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/provider-utils@4.0.14': + resolution: {integrity: sha512-7bzKd9lgiDeXM7O4U4nQ8iTxguAOkg8LZGD9AfDVZYjO5cKYRwBPwVjboFcVrxncRHu0tYxZtXZtiLKpG4pEng==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/provider@2.0.0': resolution: {integrity: sha512-6o7Y2SeO9vFKB8lArHXehNuusnpddKPk7xqL7T2/b+OvXMRIXUO1rR4wcv1hAFUAT9avGZshty3Wlua/XA7TvA==} engines: {node: '>=18'} @@ -1563,6 +1584,10 @@ packages: resolution: {integrity: sha512-VkPLrutM6VdA924/mG8OS+5frbVTcu6e046D2bgDo00tehBANR1QBJ/mPcZ9tXMFOsVcm6SQArOregxePzTFPw==} engines: {node: '>=18'} + '@ai-sdk/provider@3.0.8': + resolution: {integrity: sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==} + engines: {node: '>=18'} + '@ai-sdk/xai@3.0.46': resolution: {integrity: sha512-26qM/jYcFhF5krTM7bQT1CiZcdz22EQmA+r5me1hKYFM/yM20sSUMHnAcUzvzuuG9oQVKF0tziU2IcC0HX5huQ==} engines: {node: '>=18'} @@ -1880,6 +1905,93 @@ packages: resolution: {integrity: sha512-+EzkxvLNfiUeKMgy/3luqfsCWFRXLb7U6wNQTk60tovuckwB15B191tJWvpp4HjiQWdJkCxO3Wbvc6jlk3Xb2Q==} engines: {node: '>=6.9.0'} + '@basetenlabs/performance-client-android-arm-eabi@0.0.10': + resolution: {integrity: sha512-gwDZ6GDJA0AAmQAHxt2vaCz0tYTaLjxJKZnoYt+0Eji4gy231JZZFAwvbAqNdQCrGEQ9lXnk7SNM1Apet4NlYg==} + engines: {node: '>= 10'} + cpu: [arm] + os: [android] + + '@basetenlabs/performance-client-android-arm64@0.0.10': + resolution: {integrity: sha512-oGRB/6hH89majhsmoVmj1IAZv4C7F2aLeTSebevBelmdYO4CFkn5qewxLzU1pDkkmxVVk2k+TRpYa1Dt4B96qQ==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [android] + + '@basetenlabs/performance-client-darwin-arm64@0.0.10': + resolution: {integrity: sha512-QpBOUjeO05tWgFWkDw2RUQZa3BMplX5jNiBBTi5mH1lIL/m1sm2vkxoc0iorEESp1mMPstYFS/fr4ssBuO7wyA==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [darwin] + + '@basetenlabs/performance-client-darwin-universal@0.0.10': + resolution: {integrity: sha512-CBM38GAhekjylrlf7jW/0WNyFAGnAMBCNHZxaPnAjjhDNzJh1tcrwhvtOs66XbAqCOjO/tkt5Pdu6mg2Ui2Pjw==} + engines: {node: '>= 10'} + os: [darwin] + + '@basetenlabs/performance-client-darwin-x64@0.0.10': + resolution: {integrity: sha512-R+NsA72Axclh1CUpmaWOCLTWCqXn5/tFMj2z9BnHVSRTelx/pYFlx6ZngVTB1HYp1n21m3upPXGo8CHF8R7Itw==} + engines: {node: '>= 10'} + cpu: [x64] + os: [darwin] + + '@basetenlabs/performance-client-linux-arm-gnueabihf@0.0.10': + resolution: {integrity: sha512-96kEo0Eas4GVQdFkxIB1aAv6dy5Ga57j+RIg5l0Yiawv+AYIEmgk9BsGkqcwayp8Iiu6LN22Z+AUsGY2gstNrg==} + engines: {node: '>= 10'} + cpu: [arm] + os: [linux] + + '@basetenlabs/performance-client-linux-arm-musleabihf@0.0.10': + resolution: {integrity: sha512-lzEHeu+/BWDl2q+QZcqCkg1rDGF4MeyM3HgYwX+07t+vGZoqtM2we9vEV68wXMpl6ToEHQr7ML2KHA1Gb6ogxg==} + engines: {node: '>= 10'} + cpu: [arm] + os: [linux] + + '@basetenlabs/performance-client-linux-arm64-gnu@0.0.10': + resolution: {integrity: sha512-MnY2cIRY/cQOYERWIHhh5CoaS2wgmmXtGDVGSLYyZvjwizrXZvjkEz7Whv2jaQ21T5S56VER67RABjz2TItrHQ==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@basetenlabs/performance-client-linux-riscv64-gnu@0.0.10': + resolution: {integrity: sha512-2KUvdK4wuoZdIqNnJhx7cu6ybXCwtiwGAtlrEvhai3FOkUQ3wE2Xa+TQ33mNGSyFbw6wAvLawYtKVFmmw27gJw==} + engines: {node: '>= 10'} + cpu: [riscv64] + os: [linux] + + '@basetenlabs/performance-client-linux-x64-gnu@0.0.10': + resolution: {integrity: sha512-9jjQPjHLiVOGwUPlmhnBl7OmmO7hQ8WMt+v3mJuxkS5JTNDmVOngfmgGlbN9NjBhQMENjdcMUVOquVo7HeybGQ==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@basetenlabs/performance-client-linux-x64-musl@0.0.10': + resolution: {integrity: sha512-bjYB8FKcPvEa251Ep2Gm3tvywADL9eavVjZsikdf0AvJ1K5pT+vLLvJBU9ihBsTPWnbF4pJgxVjwS6UjVObsQA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@basetenlabs/performance-client-win32-arm64-msvc@0.0.10': + resolution: {integrity: sha512-Vxq5UXEmfh3C3hpwXdp3Daaf0dnLR9zFH2x8MJ1Hf/TcilmOP1clneewNpIv0e7MrnT56Z4pM6P3d8VFMZqBKg==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [win32] + + '@basetenlabs/performance-client-win32-ia32-msvc@0.0.10': + resolution: {integrity: sha512-KJrm7CgZdP/UDC5+tHtqE6w9XMfY5YUfMOxJfBZGSsLMqS2OGsakQsaF0a55k+58l29X5w/nAkjHrI1BcQO03w==} + engines: {node: '>= 10'} + cpu: [ia32] + os: [win32] + + '@basetenlabs/performance-client-win32-x64-msvc@0.0.10': + resolution: {integrity: sha512-M/mhvfTItUcUX+aeXRb5g5MbRlndfg6yelV7tSYfLU4YixMIe5yoGaAP3iDilpFJjcC99f+EU4l4+yLbPtpXig==} + engines: {node: '>= 10'} + cpu: [x64] + os: [win32] + + '@basetenlabs/performance-client@0.0.10': + resolution: {integrity: sha512-H6bpd1JcDbuJsOS2dNft+CCGLzBqHJO/ST/4mMKhLAW641J6PpVJUw1szYsk/dTetdedbWxHpMkvFObOKeP8nw==} + engines: {node: '>= 10'} + '@bcoe/v8-coverage@0.2.3': resolution: {integrity: sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==} @@ -11016,6 +11128,14 @@ snapshots: '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/baseten@1.0.31(zod@3.25.76)': + dependencies: + '@ai-sdk/openai-compatible': 2.0.28(zod@3.25.76) + '@ai-sdk/provider': 3.0.8 + '@ai-sdk/provider-utils': 4.0.14(zod@3.25.76) + '@basetenlabs/performance-client': 0.0.10 + zod: 3.25.76 + '@ai-sdk/cerebras@1.0.35(zod@3.25.76)': dependencies: '@ai-sdk/openai-compatible': 1.0.31(zod@3.25.76) @@ -11102,6 +11222,12 @@ snapshots: '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) zod: 3.25.76 + '@ai-sdk/openai-compatible@2.0.28(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.8 + '@ai-sdk/provider-utils': 4.0.14(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/provider-utils@3.0.20(zod@3.25.76)': dependencies: '@ai-sdk/provider': 2.0.1 @@ -11138,6 +11264,13 @@ snapshots: eventsource-parser: 3.0.6 zod: 3.25.76 + '@ai-sdk/provider-utils@4.0.14(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.8 + '@standard-schema/spec': 1.1.0 + eventsource-parser: 3.0.6 + zod: 3.25.76 + '@ai-sdk/provider@2.0.0': dependencies: json-schema: 0.4.0 @@ -11158,6 +11291,10 @@ snapshots: dependencies: json-schema: 0.4.0 + '@ai-sdk/provider@3.0.8': + dependencies: + json-schema: 0.4.0 + '@ai-sdk/xai@3.0.46(zod@3.25.76)': dependencies: '@ai-sdk/openai-compatible': 2.0.26(zod@3.25.76) @@ -11893,6 +12030,65 @@ snapshots: '@babel/helper-string-parser': 7.27.1 '@babel/helper-validator-identifier': 7.27.1 + '@basetenlabs/performance-client-android-arm-eabi@0.0.10': + optional: true + + '@basetenlabs/performance-client-android-arm64@0.0.10': + optional: true + + '@basetenlabs/performance-client-darwin-arm64@0.0.10': + optional: true + + '@basetenlabs/performance-client-darwin-universal@0.0.10': + optional: true + + '@basetenlabs/performance-client-darwin-x64@0.0.10': + optional: true + + '@basetenlabs/performance-client-linux-arm-gnueabihf@0.0.10': + optional: true + + '@basetenlabs/performance-client-linux-arm-musleabihf@0.0.10': + optional: true + + '@basetenlabs/performance-client-linux-arm64-gnu@0.0.10': + optional: true + + '@basetenlabs/performance-client-linux-riscv64-gnu@0.0.10': + optional: true + + '@basetenlabs/performance-client-linux-x64-gnu@0.0.10': + optional: true + + '@basetenlabs/performance-client-linux-x64-musl@0.0.10': + optional: true + + '@basetenlabs/performance-client-win32-arm64-msvc@0.0.10': + optional: true + + '@basetenlabs/performance-client-win32-ia32-msvc@0.0.10': + optional: true + + '@basetenlabs/performance-client-win32-x64-msvc@0.0.10': + optional: true + + '@basetenlabs/performance-client@0.0.10': + optionalDependencies: + '@basetenlabs/performance-client-android-arm-eabi': 0.0.10 + '@basetenlabs/performance-client-android-arm64': 0.0.10 + '@basetenlabs/performance-client-darwin-arm64': 0.0.10 + '@basetenlabs/performance-client-darwin-universal': 0.0.10 + '@basetenlabs/performance-client-darwin-x64': 0.0.10 + '@basetenlabs/performance-client-linux-arm-gnueabihf': 0.0.10 + '@basetenlabs/performance-client-linux-arm-musleabihf': 0.0.10 + '@basetenlabs/performance-client-linux-arm64-gnu': 0.0.10 + '@basetenlabs/performance-client-linux-riscv64-gnu': 0.0.10 + '@basetenlabs/performance-client-linux-x64-gnu': 0.0.10 + '@basetenlabs/performance-client-linux-x64-musl': 0.0.10 + '@basetenlabs/performance-client-win32-arm64-msvc': 0.0.10 + '@basetenlabs/performance-client-win32-ia32-msvc': 0.0.10 + '@basetenlabs/performance-client-win32-x64-msvc': 0.0.10 + '@bcoe/v8-coverage@0.2.3': {} '@braintree/sanitize-url@7.1.1': {} diff --git a/src/api/providers/__tests__/baseten.spec.ts b/src/api/providers/__tests__/baseten.spec.ts new file mode 100644 index 00000000000..e44b201f291 --- /dev/null +++ b/src/api/providers/__tests__/baseten.spec.ts @@ -0,0 +1,446 @@ +// npx vitest run src/api/providers/__tests__/baseten.spec.ts + +// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) + +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, + } +}) + +vi.mock("@ai-sdk/baseten", () => ({ + createBaseten: vi.fn(() => { + return vi.fn(() => ({ + modelId: "zai-org/GLM-4.6", + provider: "baseten", + })) + }), +})) + +import type { Anthropic } from "@anthropic-ai/sdk" + +import { basetenDefaultModelId, basetenModels, type BasetenModelId } from "@roo-code/types" + +import type { ApiHandlerOptions } from "../../../shared/api" + +import { BasetenHandler } from "../baseten" + +describe("BasetenHandler", () => { + let handler: BasetenHandler + let mockOptions: ApiHandlerOptions + + beforeEach(() => { + mockOptions = { + basetenApiKey: "test-baseten-api-key", + apiModelId: "zai-org/GLM-4.6", + } + handler = new BasetenHandler(mockOptions) + vi.clearAllMocks() + }) + + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(BasetenHandler) + expect(handler.getModel().id).toBe(mockOptions.apiModelId) + }) + + it("should use default model ID if not provided", () => { + const handlerWithoutModel = new BasetenHandler({ + ...mockOptions, + apiModelId: undefined, + }) + expect(handlerWithoutModel.getModel().id).toBe(basetenDefaultModelId) + }) + }) + + describe("getModel", () => { + it("should return default model when no model is specified", () => { + const handlerWithoutModel = new BasetenHandler({ + basetenApiKey: "test-baseten-api-key", + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe(basetenDefaultModelId) + expect(model.info).toEqual(basetenModels[basetenDefaultModelId]) + }) + + it("should return specified model when valid model is provided", () => { + const testModelId: BasetenModelId = "deepseek-ai/DeepSeek-R1" + const handlerWithModel = new BasetenHandler({ + apiModelId: testModelId, + basetenApiKey: "test-baseten-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + expect(model.info).toEqual(basetenModels[testModelId]) + }) + + it("should return provided model ID with default model info if model does not exist", () => { + const handlerWithInvalidModel = new BasetenHandler({ + ...mockOptions, + apiModelId: "invalid-model", + }) + const model = handlerWithInvalidModel.getModel() + expect(model.id).toBe("invalid-model") + expect(model.info).toBeDefined() + expect(model.info).toBe(basetenModels[basetenDefaultModelId]) + }) + + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") + }) + }) + + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] + + it("should handle streaming responses", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response from Baseten" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response from Baseten") + }) + + it("should include usage information", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 20, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(20) + }) + + it("should pass correct temperature (0.5 default) to streamText", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + }) + + const handlerWithDefaultTemp = new BasetenHandler({ + basetenApiKey: "test-key", + apiModelId: "zai-org/GLM-4.6", + }) + + const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.5, + }), + ) + }) + + it("should use user-specified temperature over default", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + }) + + const handlerWithCustomTemp = new BasetenHandler({ + basetenApiKey: "test-key", + apiModelId: "zai-org/GLM-4.6", + modelTemperature: 0.9, + }) + + const stream = handlerWithCustomTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.9, + }), + ) + }) + + it("should handle stream with multiple chunks", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Hello" } + yield { type: "text-delta", text: " world" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 5, outputTokens: 10 }), + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const textChunks = chunks.filter((c) => c.type === "text") + expect(textChunks[0]).toEqual({ type: "text", text: "Hello" }) + expect(textChunks[1]).toEqual({ type: "text", text: " world" }) + + const usageChunks = chunks.filter((c) => c.type === "usage") + expect(usageChunks[0]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 }) + }) + }) + + describe("completePrompt", () => { + it("should complete a prompt using generateText", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion from Baseten", + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Test completion from Baseten") + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "Test prompt", + }), + ) + }) + + it("should use default temperature in completePrompt", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion", + }) + + await handler.completePrompt("Test prompt") + + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.5, + }), + ) + }) + }) + + describe("tool handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "Hello!" }], + }, + ] + + it("should handle tool calls in streaming", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "tool-call-1", + toolName: "read_file", + } + yield { + type: "tool-input-delta", + id: "tool-call-1", + delta: '{"path":"test.ts"}', + } + yield { + type: "tool-input-end", + id: "tool-call-1", + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], + }, + }, + }, + ], + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start") + const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end") + + expect(toolCallStartChunks.length).toBe(1) + expect(toolCallStartChunks[0].id).toBe("tool-call-1") + expect(toolCallStartChunks[0].name).toBe("read_file") + + expect(toolCallDeltaChunks.length).toBe(1) + expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}') + + expect(toolCallEndChunks.length).toBe(1) + expect(toolCallEndChunks[0].id).toBe("tool-call-1") + }) + + it("should ignore tool-call events to prevent duplicate tools in UI", async () => { + async function* mockFullStream() { + yield { + type: "tool-call", + toolCallId: "tool-call-1", + toolName: "read_file", + input: { path: "test.ts" }, + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallChunks = chunks.filter( + (c) => c.type === "tool_call_start" || c.type === "tool_call_delta" || c.type === "tool_call_end", + ) + expect(toolCallChunks.length).toBe(0) + }) + }) + + describe("error handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "Hello!" }], + }, + ] + + it("should handle AI SDK errors with handleAiSdkError", async () => { + // eslint-disable-next-line require-yield + async function* mockFullStream(): AsyncGenerator { + throw new Error("API Error") + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + }) + + const stream = handler.createMessage(systemPrompt, messages) + + await expect(async () => { + for await (const _ of stream) { + // consume stream + } + }).rejects.toThrow("Baseten: API Error") + }) + + it("should preserve status codes in error handling", async () => { + const apiError = new Error("Rate limit exceeded") + ;(apiError as any).status = 429 + + // eslint-disable-next-line require-yield + async function* mockFullStream(): AsyncGenerator { + throw apiError + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + }) + + const stream = handler.createMessage(systemPrompt, messages) + + try { + for await (const _ of stream) { + // consume stream + } + expect.fail("Should have thrown an error") + } catch (error: any) { + expect(error.message).toContain("Baseten") + expect(error.status).toBe(429) + } + }) + }) +}) diff --git a/src/api/providers/baseten.ts b/src/api/providers/baseten.ts index ca0c2867756..2e63f3d52c1 100644 --- a/src/api/providers/baseten.ts +++ b/src/api/providers/baseten.ts @@ -1,18 +1,156 @@ -import { type BasetenModelId, basetenDefaultModelId, basetenModels } from "@roo-code/types" +import { Anthropic } from "@anthropic-ai/sdk" +import { createBaseten } from "@ai-sdk/baseten" +import { streamText, generateText, ToolSet } from "ai" + +import { basetenModels, basetenDefaultModelId, type ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" -export class BasetenHandler extends BaseOpenAiCompatibleProvider { +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + +import { DEFAULT_HEADERS } from "./constants" +import { BaseProvider } from "./base-provider" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" + +const BASETEN_DEFAULT_TEMPERATURE = 0.5 + +/** + * Baseten provider using the dedicated @ai-sdk/baseten package. + * Provides native support for Baseten's inference API. + */ +export class BasetenHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + protected provider: ReturnType + constructor(options: ApiHandlerOptions) { - super({ - ...options, - providerName: "Baseten", + super() + this.options = options + + this.provider = createBaseten({ baseURL: "https://inference.baseten.co/v1", - apiKey: options.basetenApiKey, - defaultProviderModelId: basetenDefaultModelId, - providerModels: basetenModels, - defaultTemperature: 0.5, + apiKey: options.basetenApiKey ?? "not-provided", + headers: DEFAULT_HEADERS, + }) + } + + override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } { + const id = this.options.apiModelId ?? basetenDefaultModelId + const info = basetenModels[id as keyof typeof basetenModels] || basetenModels[basetenDefaultModelId] + const params = getModelParams({ + format: "openai", + modelId: id, + model: info, + settings: this.options, + defaultTemperature: BASETEN_DEFAULT_TEMPERATURE, + }) + return { id, info, ...params } + } + + /** + * Get the language model for the configured model ID. + */ + protected getLanguageModel() { + const { id } = this.getModel() + return this.provider(id) + } + + /** + * Process usage metrics from the AI SDK response. + */ + protected processUsageMetrics(usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }): ApiStreamUsageChunk { + return { + type: "usage", + inputTokens: usage.inputTokens || 0, + outputTokens: usage.outputTokens || 0, + reasoningTokens: usage.details?.reasoningTokens, + } + } + + /** + * Get the max tokens parameter to include in the request. + */ + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined + } + + /** + * Create a message stream using the AI SDK. + */ + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + const aiSdkMessages = convertToAiSdkMessages(messages) + + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature: this.options.modelTemperature ?? temperature ?? BASETEN_DEFAULT_TEMPERATURE, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: mapToolChoice(metadata?.tool_choice), + } + + const result = streamText(requestOptions) + + try { + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } + } + + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage) + } + } catch (error) { + throw handleAiSdkError(error, "Baseten") + } + } + + /** + * Complete a prompt using the AI SDK generateText. + */ + async completePrompt(prompt: string): Promise { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: this.options.modelTemperature ?? temperature ?? BASETEN_DEFAULT_TEMPERATURE, }) + + return text + } + + override isAiSdkProvider(): boolean { + return true } } diff --git a/src/esbuild.mjs b/src/esbuild.mjs index aabacfcee99..fb7b1866797 100644 --- a/src/esbuild.mjs +++ b/src/esbuild.mjs @@ -43,6 +43,22 @@ async function main() { * @type {import('esbuild').Plugin[]} */ const plugins = [ + { + // Stub out @basetenlabs/performance-client which contains native .node + // binaries that esbuild cannot bundle. This module is only used by + // @ai-sdk/baseten for embedding models, not for chat completions. + name: "stub-baseten-native", + setup(build) { + build.onResolve({ filter: /^@basetenlabs\/performance-client/ }, (args) => ({ + path: args.path, + namespace: "stub-baseten-native", + })) + build.onLoad({ filter: /.*/, namespace: "stub-baseten-native" }, () => ({ + contents: "module.exports = { PerformanceClient: class PerformanceClient {} };", + loader: "js", + })) + }, + }, { name: "copyFiles", setup(build) { diff --git a/src/package.json b/src/package.json index f2b574c2626..eb67c0d1d7d 100644 --- a/src/package.json +++ b/src/package.json @@ -451,6 +451,7 @@ }, "dependencies": { "@ai-sdk/amazon-bedrock": "^4.0.50", + "@ai-sdk/baseten": "^1.0.31", "@ai-sdk/cerebras": "^1.0.0", "@ai-sdk/deepseek": "^2.0.14", "@ai-sdk/fireworks": "^2.0.26",