diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 417e69a07ad..0a47ae416ae 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -755,6 +755,12 @@ importers: '@ai-sdk/fireworks': specifier: ^2.0.26 version: 2.0.26(zod@3.25.76) + '@ai-sdk/google': + specifier: ^3.0.20 + version: 3.0.20(zod@3.25.76) + '@ai-sdk/google-vertex': + specifier: ^3.0.20 + version: 3.0.98(zod@3.25.76) '@ai-sdk/groq': specifier: ^3.0.19 version: 3.0.19(zod@3.25.76) @@ -1411,6 +1417,12 @@ packages: '@adobe/css-tools@4.4.2': resolution: {integrity: sha512-baYZExFpsdkBNuvGKTKWCwKH57HRZLVtycZS05WTQNVOiXVSeAki3nU35zlRbToeMW8aHlJfyS+1C4BOv27q0A==} + '@ai-sdk/anthropic@2.0.58': + resolution: {integrity: sha512-CkNW5L1Arv8gPtPlEmKd+yf/SG9ucJf0XQdpMG8OiYEtEMc2smuCA+tyCp8zI7IBVg/FE7nUfFHntQFaOjRwJQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/cerebras@1.0.35': resolution: {integrity: sha512-JrNdMYptrOUjNthibgBeAcBjZ/H+fXb49sSrWhOx5Aq8eUcrYvwQ2DtSAi8VraHssZu78NAnBMrgFWSUOTXFxw==} engines: {node: '>=18'} @@ -1435,6 +1447,24 @@ packages: peerDependencies: zod: 3.25.76 + '@ai-sdk/google-vertex@3.0.98': + resolution: {integrity: sha512-uuv0RHkdJ5vTzeH1+iuBlv7GAjRcOPd2jiqtGLz6IKOUDH+PRQoE3ExrvOysVnKuhhTBMqvawkktDhMDQE6sVQ==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + + '@ai-sdk/google@2.0.52': + resolution: {integrity: sha512-2XUnGi3f7TV4ujoAhA+Fg3idUoG/+Y2xjCRg70a1/m0DH1KSQqYaCboJ1C19y6ZHGdf5KNT20eJdswP6TvrY2g==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + + '@ai-sdk/google@3.0.20': + resolution: {integrity: sha512-bVGsulEr6JiipAFlclo9bjL5WaUV0iCSiiekLt+PY6pwmtJeuU2GaD9DoE3OqR8LN2W779mU13IhVEzlTupf8g==} + engines: {node: '>=18'} + peerDependencies: + zod: 3.25.76 + '@ai-sdk/groq@3.0.19': resolution: {integrity: sha512-WAeGVnp9rvU3RUvu6S1HiD8hAjKgNlhq+z3m4j5Z1fIKRXqcKjOscVZGwL36If8qxsqXNVCtG3ltXawM5UAa8w==} engines: {node: '>=18'} @@ -11100,6 +11130,12 @@ snapshots: '@adobe/css-tools@4.4.2': {} + '@ai-sdk/anthropic@2.0.58(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 2.0.1 + '@ai-sdk/provider-utils': 3.0.20(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/cerebras@1.0.35(zod@3.25.76)': dependencies: '@ai-sdk/openai-compatible': 1.0.31(zod@3.25.76) @@ -11127,6 +11163,29 @@ snapshots: '@vercel/oidc': 3.1.0 zod: 3.25.76 + '@ai-sdk/google-vertex@3.0.98(zod@3.25.76)': + dependencies: + '@ai-sdk/anthropic': 2.0.58(zod@3.25.76) + '@ai-sdk/google': 2.0.52(zod@3.25.76) + '@ai-sdk/provider': 2.0.1 + '@ai-sdk/provider-utils': 3.0.20(zod@3.25.76) + google-auth-library: 10.5.0 + zod: 3.25.76 + transitivePeerDependencies: + - supports-color + + '@ai-sdk/google@2.0.52(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 2.0.1 + '@ai-sdk/provider-utils': 3.0.20(zod@3.25.76) + zod: 3.25.76 + + '@ai-sdk/google@3.0.20(zod@3.25.76)': + dependencies: + '@ai-sdk/provider': 3.0.7 + '@ai-sdk/provider-utils': 4.0.13(zod@3.25.76) + zod: 3.25.76 + '@ai-sdk/groq@3.0.19(zod@3.25.76)': dependencies: '@ai-sdk/provider': 3.0.6 diff --git a/src/api/providers/__tests__/gemini-handler.spec.ts b/src/api/providers/__tests__/gemini-handler.spec.ts index a9544a0b97f..2753c1ad516 100644 --- a/src/api/providers/__tests__/gemini-handler.spec.ts +++ b/src/api/providers/__tests__/gemini-handler.spec.ts @@ -1,28 +1,78 @@ +// npx vitest run src/api/providers/__tests__/gemini-handler.spec.ts + +// Mock the AI SDK functions +const mockStreamText = vi.fn() +const mockGenerateText = vi.fn() + +vi.mock("ai", async (importOriginal) => { + const original = await importOriginal() + return { + ...original, + streamText: (...args: unknown[]) => mockStreamText(...args), + generateText: (...args: unknown[]) => mockGenerateText(...args), + } +}) + import { t } from "i18next" -import { FunctionCallingConfigMode } from "@google/genai" import { GeminiHandler } from "../gemini" import type { ApiHandlerOptions } from "../../../shared/api" describe("GeminiHandler backend support", () => { - it("createMessage uses function declarations (URL context and grounding are only for completePrompt)", async () => { - // URL context and grounding are mutually exclusive with function declarations - // in Gemini API, so createMessage only uses function declarations. - // URL context/grounding are only added in completePrompt. + beforeEach(() => { + mockStreamText.mockClear() + mockGenerateText.mockClear() + }) + + it("createMessage uses AI SDK tools format", async () => { const options = { apiProvider: "gemini", enableUrlContext: true, enableGrounding: true, } as ApiHandlerOptions const handler = new GeminiHandler(options) - const stub = vi.fn().mockReturnValue((async function* () {})()) - // @ts-ignore access private client - handler["client"].models.generateContentStream = stub + + const mockFullStream = (async function* () {})() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) + await handler.createMessage("instr", [] as any).next() - const config = stub.mock.calls[0][0].config - // createMessage always uses function declarations only - // (tools are always present from ALWAYS_AVAILABLE_TOOLS) - expect(config.tools).toEqual([{ functionDeclarations: expect.any(Array) }]) + + // Verify streamText was called + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + system: "instr", + }), + ) + }) + + it("completePrompt passes tools when URL context and grounding enabled", async () => { + const options = { + apiProvider: "gemini", + enableUrlContext: true, + enableGrounding: true, + } as ApiHandlerOptions + const handler = new GeminiHandler(options) + + mockGenerateText.mockResolvedValue({ + text: "ok", + providerMetadata: {}, + }) + + const res = await handler.completePrompt("hi") + expect(res).toBe("ok") + + // Verify generateText was called with tools + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "hi", + tools: expect.any(Object), + }), + ) }) it("completePrompt passes config overrides without tools when URL context and grounding disabled", async () => { @@ -32,13 +82,18 @@ describe("GeminiHandler backend support", () => { enableGrounding: false, } as ApiHandlerOptions const handler = new GeminiHandler(options) - const stub = vi.fn().mockResolvedValue({ text: "ok" }) - // @ts-ignore access private client - handler["client"].models.generateContent = stub + + mockGenerateText.mockResolvedValue({ + text: "ok", + providerMetadata: {}, + }) + const res = await handler.completePrompt("hi") expect(res).toBe("ok") - const promptConfig = stub.mock.calls[0][0].config - expect(promptConfig.tools).toBeUndefined() + + // Verify generateText was called without tools + const callArgs = mockGenerateText.mock.calls[0][0] + expect(callArgs.tools).toBeUndefined() }) describe("error scenarios", () => { @@ -49,23 +104,22 @@ describe("GeminiHandler backend support", () => { } as ApiHandlerOptions const handler = new GeminiHandler(options) - const mockStream = async function* () { - yield { - candidates: [ - { - groundingMetadata: { - // Invalid structure - missing groundingChunks - }, - content: { parts: [{ text: "test response" }] }, + // AI SDK text-delta events have a 'text' property (processAiSdkStreamPart casts to this) + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "test response" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({ + google: { + groundingMetadata: { + // Invalid structure - missing groundingChunks }, - ], - usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 }, - } - } - - const stub = vi.fn().mockReturnValue(mockStream()) - // @ts-ignore access private client - handler["client"].models.generateContentStream = stub + }, + }), + }) const messages = [] for await (const chunk of handler.createMessage("test", [] as any)) { @@ -74,7 +128,7 @@ describe("GeminiHandler backend support", () => { // Should still return the main content without sources expect(messages.some((m) => m.type === "text" && m.text === "test response")).toBe(true) - expect(messages.some((m) => m.type === "text" && m.text?.includes("Sources:"))).toBe(false) + expect(messages.some((m) => m.type === "grounding")).toBe(false) }) it("should handle malformed grounding metadata", async () => { @@ -84,27 +138,26 @@ describe("GeminiHandler backend support", () => { } as ApiHandlerOptions const handler = new GeminiHandler(options) - const mockStream = async function* () { - yield { - candidates: [ - { - groundingMetadata: { - groundingChunks: [ - { web: null }, // Missing URI - { web: { uri: "https://example.com", title: "Example Site" } }, // Valid - {}, // Missing web property entirely - ], - }, - content: { parts: [{ text: "test response" }] }, + // AI SDK text-delta events have a 'text' property (processAiSdkStreamPart casts to this) + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "test response" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({ + google: { + groundingMetadata: { + groundingChunks: [ + { web: null }, // Missing URI + { web: { uri: "https://example.com", title: "Example Site" } }, // Valid + {}, // Missing web property entirely + ], }, - ], - usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 }, - } - } - - const stub = vi.fn().mockReturnValue(mockStream()) - // @ts-ignore access private client - handler["client"].models.generateContentStream = stub + }, + }), + }) const messages = [] for await (const chunk of handler.createMessage("test", [] as any)) { @@ -137,9 +190,16 @@ describe("GeminiHandler backend support", () => { const handler = new GeminiHandler(options) const mockError = new Error("API rate limit exceeded") - const stub = vi.fn().mockRejectedValue(mockError) - // @ts-ignore access private client - handler["client"].models.generateContentStream = stub + // eslint-disable-next-line require-yield + const mockFullStream = (async function* () { + throw mockError + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({}), + providerMetadata: Promise.resolve({}), + }) await expect(async () => { const generator = handler.createMessage("test", [] as any) @@ -148,7 +208,7 @@ describe("GeminiHandler backend support", () => { }) }) - describe("allowedFunctionNames support", () => { + describe("toolChoice support", () => { const testTools = [ { type: "function" as const, @@ -176,123 +236,120 @@ describe("GeminiHandler backend support", () => { }, ] - it("should pass allowedFunctionNames to toolConfig when provided", async () => { + it("should pass tools to streamText", async () => { const options = { apiProvider: "gemini", } as ApiHandlerOptions const handler = new GeminiHandler(options) - const stub = vi.fn().mockReturnValue((async function* () {})()) - // @ts-ignore access private client - handler["client"].models.generateContentStream = stub - await handler - .createMessage("test", [] as any, { - taskId: "test-task", - tools: testTools, - allowedFunctionNames: ["read_file", "write_to_file"], - }) - .next() + const mockFullStream = (async function* () {})() - const config = stub.mock.calls[0][0].config - expect(config.toolConfig).toEqual({ - functionCallingConfig: { - mode: FunctionCallingConfigMode.ANY, - allowedFunctionNames: ["read_file", "write_to_file"], - }, + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({}), + providerMetadata: Promise.resolve({}), }) - }) - - it("should include all tools but restrict callable functions via allowedFunctionNames", async () => { - const options = { - apiProvider: "gemini", - } as ApiHandlerOptions - const handler = new GeminiHandler(options) - const stub = vi.fn().mockReturnValue((async function* () {})()) - // @ts-ignore access private client - handler["client"].models.generateContentStream = stub await handler .createMessage("test", [] as any, { taskId: "test-task", tools: testTools, - allowedFunctionNames: ["read_file"], }) .next() - const config = stub.mock.calls[0][0].config - // All tools should be passed to the model - expect(config.tools[0].functionDeclarations).toHaveLength(3) - // But only read_file should be allowed to be called - expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(["read_file"]) + // Verify streamText was called with tools + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + tools: expect.any(Object), + }), + ) }) - it("should take precedence over tool_choice when allowedFunctionNames is provided", async () => { + it("should pass toolChoice when allowedFunctionNames is provided", async () => { const options = { apiProvider: "gemini", } as ApiHandlerOptions const handler = new GeminiHandler(options) - const stub = vi.fn().mockReturnValue((async function* () {})()) - // @ts-ignore access private client - handler["client"].models.generateContentStream = stub + + const mockFullStream = (async function* () {})() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({}), + providerMetadata: Promise.resolve({}), + }) await handler .createMessage("test", [] as any, { taskId: "test-task", tools: testTools, - tool_choice: "auto", - allowedFunctionNames: ["read_file"], + allowedFunctionNames: ["read_file", "write_to_file"], }) .next() - const config = stub.mock.calls[0][0].config - // allowedFunctionNames should take precedence - mode should be ANY, not AUTO - expect(config.toolConfig.functionCallingConfig.mode).toBe(FunctionCallingConfigMode.ANY) - expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(["read_file"]) + // Verify toolChoice is 'required' when allowedFunctionNames is provided + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + toolChoice: "required", + }), + ) }) - it("should fall back to tool_choice when allowedFunctionNames is empty", async () => { + it("should use tool_choice when allowedFunctionNames is not provided", async () => { const options = { apiProvider: "gemini", } as ApiHandlerOptions const handler = new GeminiHandler(options) - const stub = vi.fn().mockReturnValue((async function* () {})()) - // @ts-ignore access private client - handler["client"].models.generateContentStream = stub + + const mockFullStream = (async function* () {})() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({}), + providerMetadata: Promise.resolve({}), + }) await handler .createMessage("test", [] as any, { taskId: "test-task", tools: testTools, tool_choice: "auto", - allowedFunctionNames: [], }) .next() - const config = stub.mock.calls[0][0].config - // Empty allowedFunctionNames should fall back to tool_choice behavior - expect(config.toolConfig.functionCallingConfig.mode).toBe(FunctionCallingConfigMode.AUTO) - expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toBeUndefined() + // Verify toolChoice follows tool_choice + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + toolChoice: "auto", + }), + ) }) - it("should not set toolConfig when allowedFunctionNames is undefined and no tool_choice", async () => { + it("should not set toolChoice when allowedFunctionNames is empty and no tool_choice", async () => { const options = { apiProvider: "gemini", } as ApiHandlerOptions const handler = new GeminiHandler(options) - const stub = vi.fn().mockReturnValue((async function* () {})()) - // @ts-ignore access private client - handler["client"].models.generateContentStream = stub + + const mockFullStream = (async function* () {})() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({}), + providerMetadata: Promise.resolve({}), + }) await handler .createMessage("test", [] as any, { taskId: "test-task", tools: testTools, + allowedFunctionNames: [], }) .next() - const config = stub.mock.calls[0][0].config - // No toolConfig should be set when neither allowedFunctionNames nor tool_choice is provided - expect(config.toolConfig).toBeUndefined() + // With empty allowedFunctionNames, toolChoice should be undefined + const callArgs = mockStreamText.mock.calls[0][0] + expect(callArgs.toolChoice).toBeUndefined() }) }) }) diff --git a/src/api/providers/__tests__/gemini.spec.ts b/src/api/providers/__tests__/gemini.spec.ts index 8c2ee87a787..8019e3e4360 100644 --- a/src/api/providers/__tests__/gemini.spec.ts +++ b/src/api/providers/__tests__/gemini.spec.ts @@ -10,6 +10,19 @@ vitest.mock("@roo-code/telemetry", () => ({ }, })) +// Mock the AI SDK functions +const mockStreamText = vitest.fn() +const mockGenerateText = vitest.fn() + +vitest.mock("ai", async (importOriginal) => { + const original = await importOriginal() + return { + ...original, + streamText: (...args: unknown[]) => mockStreamText(...args), + generateText: (...args: unknown[]) => mockGenerateText(...args), + } +}) + import { Anthropic } from "@anthropic-ai/sdk" import { type ModelInfo, geminiDefaultModelId, ApiProviderError } from "@roo-code/types" @@ -25,26 +38,14 @@ describe("GeminiHandler", () => { beforeEach(() => { // Reset mocks mockCaptureException.mockClear() - - // Create mock functions - const mockGenerateContentStream = vitest.fn() - const mockGenerateContent = vitest.fn() - const mockGetGenerativeModel = vitest.fn() + mockStreamText.mockClear() + mockGenerateText.mockClear() handler = new GeminiHandler({ apiKey: "test-key", apiModelId: GEMINI_MODEL_NAME, geminiApiKey: "test-key", }) - - // Replace the client with our mock - handler["client"] = { - models: { - generateContentStream: mockGenerateContentStream, - generateContent: mockGenerateContent, - getGenerativeModel: mockGetGenerativeModel, - }, - } as any }) describe("constructor", () => { @@ -69,13 +70,17 @@ describe("GeminiHandler", () => { const systemPrompt = "You are a helpful assistant" it("should handle text messages correctly", async () => { - // Setup the mock implementation to return an async generator - ;(handler["client"].models.generateContentStream as any).mockResolvedValue({ - [Symbol.asyncIterator]: async function* () { - yield { text: "Hello" } - yield { text: " world!" } - yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } - }, + // Setup the mock implementation to return an async generator for fullStream + // AI SDK text-delta events have a 'text' property (processAiSdkStreamPart casts to this) + const mockFullStream = (async function* () { + yield { type: "text-delta", text: "Hello" } + yield { type: "text-delta", text: " world!" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), }) const stream = handler.createMessage(systemPrompt, mockMessages) @@ -91,21 +96,27 @@ describe("GeminiHandler", () => { expect(chunks[1]).toEqual({ type: "text", text: " world!" }) expect(chunks[2]).toMatchObject({ type: "usage", inputTokens: 10, outputTokens: 5 }) - // Verify the call to generateContentStream - expect(handler["client"].models.generateContentStream).toHaveBeenCalledWith( + // Verify the call to streamText + expect(mockStreamText).toHaveBeenCalledWith( expect.objectContaining({ - model: GEMINI_MODEL_NAME, - config: expect.objectContaining({ - temperature: 1, - systemInstruction: systemPrompt, - }), + system: systemPrompt, + temperature: 1, }), ) }) it("should handle API errors", async () => { const mockError = new Error("Gemini API error") - ;(handler["client"].models.generateContentStream as any).mockRejectedValue(mockError) + // eslint-disable-next-line require-yield + const mockFullStream = (async function* () { + throw mockError + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({}), + providerMetadata: Promise.resolve({}), + }) const stream = handler.createMessage(systemPrompt, mockMessages) @@ -119,28 +130,26 @@ describe("GeminiHandler", () => { describe("completePrompt", () => { it("should complete prompt successfully", async () => { - // Mock the response with text property - ;(handler["client"].models.generateContent as any).mockResolvedValue({ + mockGenerateText.mockResolvedValue({ text: "Test response", + providerMetadata: {}, }) const result = await handler.completePrompt("Test prompt") expect(result).toBe("Test response") - // Verify the call to generateContent - expect(handler["client"].models.generateContent).toHaveBeenCalledWith({ - model: GEMINI_MODEL_NAME, - contents: [{ role: "user", parts: [{ text: "Test prompt" }] }], - config: { - httpOptions: undefined, + // Verify the call to generateText + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "Test prompt", temperature: 1, - }, - }) + }), + ) }) it("should handle API errors", async () => { const mockError = new Error("Gemini API error") - ;(handler["client"].models.generateContent as any).mockRejectedValue(mockError) + mockGenerateText.mockRejectedValue(mockError) await expect(handler.completePrompt("Test prompt")).rejects.toThrow( t("common:errors.gemini.generate_complete_prompt", { error: "Gemini API error" }), @@ -148,9 +157,9 @@ describe("GeminiHandler", () => { }) it("should handle empty response", async () => { - // Mock the response with empty text - ;(handler["client"].models.generateContent as any).mockResolvedValue({ + mockGenerateText.mockResolvedValue({ text: "", + providerMetadata: {}, }) const result = await handler.completePrompt("Test prompt") @@ -255,7 +264,16 @@ describe("GeminiHandler", () => { it("should capture telemetry on createMessage error", async () => { const mockError = new Error("Gemini API error") - ;(handler["client"].models.generateContentStream as any).mockRejectedValue(mockError) + // eslint-disable-next-line require-yield + const mockFullStream = (async function* () { + throw mockError + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({}), + providerMetadata: Promise.resolve({}), + }) const stream = handler.createMessage(systemPrompt, mockMessages) @@ -283,7 +301,7 @@ describe("GeminiHandler", () => { it("should capture telemetry on completePrompt error", async () => { const mockError = new Error("Gemini completion error") - ;(handler["client"].models.generateContent as any).mockRejectedValue(mockError) + mockGenerateText.mockRejectedValue(mockError) await expect(handler.completePrompt("Test prompt")).rejects.toThrow() @@ -305,7 +323,16 @@ describe("GeminiHandler", () => { it("should still throw the error after capturing telemetry", async () => { const mockError = new Error("Gemini API error") - ;(handler["client"].models.generateContentStream as any).mockRejectedValue(mockError) + // eslint-disable-next-line require-yield + const mockFullStream = (async function* () { + throw mockError + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({}), + providerMetadata: Promise.resolve({}), + }) const stream = handler.createMessage(systemPrompt, mockMessages) diff --git a/src/api/providers/__tests__/vertex.spec.ts b/src/api/providers/__tests__/vertex.spec.ts index 1420b05c7a0..1a23ba16c3a 100644 --- a/src/api/providers/__tests__/vertex.spec.ts +++ b/src/api/providers/__tests__/vertex.spec.ts @@ -3,6 +3,37 @@ // Mock vscode first to avoid import errors vitest.mock("vscode", () => ({})) +// Mock the createVertex function from @ai-sdk/google-vertex +const mockCreateVertex = vitest.fn() +const mockGoogleSearchTool = vitest.fn() +const mockUrlContextTool = vitest.fn() + +vitest.mock("@ai-sdk/google-vertex", () => ({ + createVertex: (...args: unknown[]) => { + mockCreateVertex(...args) + const provider = Object.assign((modelId: string) => ({ modelId }), { + tools: { + googleSearch: mockGoogleSearchTool, + urlContext: mockUrlContextTool, + }, + }) + return provider + }, +})) + +// Mock the AI SDK functions +const mockStreamText = vitest.fn() +const mockGenerateText = vitest.fn() + +vitest.mock("ai", async (importOriginal) => { + const original = await importOriginal() + return { + ...original, + streamText: (...args: unknown[]) => mockStreamText(...args), + generateText: (...args: unknown[]) => mockGenerateText(...args), + } +}) + import { Anthropic } from "@anthropic-ai/sdk" import { ApiStreamChunk } from "../../transform/stream" @@ -14,25 +45,105 @@ describe("VertexHandler", () => { let handler: VertexHandler beforeEach(() => { - // Create mock functions - const mockGenerateContentStream = vitest.fn() - const mockGenerateContent = vitest.fn() - const mockGetGenerativeModel = vitest.fn() + mockStreamText.mockClear() + mockGenerateText.mockClear() + mockCreateVertex.mockClear() + mockGoogleSearchTool.mockClear() + mockUrlContextTool.mockClear() handler = new VertexHandler({ apiModelId: "gemini-1.5-pro-001", vertexProjectId: "test-project", vertexRegion: "us-central1", }) + }) + + describe("constructor", () => { + it("should create provider with project and location", () => { + new VertexHandler({ + apiModelId: "gemini-1.5-pro-001", + vertexProjectId: "my-project", + vertexRegion: "europe-west1", + }) + + expect(mockCreateVertex).toHaveBeenCalledWith( + expect.objectContaining({ + project: "my-project", + location: "europe-west1", + }), + ) + }) - // Replace the client with our mock - handler["client"] = { - models: { - generateContentStream: mockGenerateContentStream, - generateContent: mockGenerateContent, - getGenerativeModel: mockGetGenerativeModel, - }, - } as any + it("should create provider with JSON credentials", () => { + const credentials = { type: "service_account", project_id: "test" } + + new VertexHandler({ + apiModelId: "gemini-1.5-pro-001", + vertexProjectId: "my-project", + vertexRegion: "us-central1", + vertexJsonCredentials: JSON.stringify(credentials), + }) + + expect(mockCreateVertex).toHaveBeenCalledWith( + expect.objectContaining({ + project: "my-project", + location: "us-central1", + googleAuthOptions: { credentials }, + }), + ) + }) + + it("should create provider with key file", () => { + new VertexHandler({ + apiModelId: "gemini-1.5-pro-001", + vertexProjectId: "my-project", + vertexRegion: "us-central1", + vertexKeyFile: "/path/to/keyfile.json", + }) + + expect(mockCreateVertex).toHaveBeenCalledWith( + expect.objectContaining({ + project: "my-project", + location: "us-central1", + googleAuthOptions: { keyFile: "/path/to/keyfile.json" }, + }), + ) + }) + + it("should prefer JSON credentials over key file", () => { + const credentials = { type: "service_account", project_id: "test" } + + new VertexHandler({ + apiModelId: "gemini-1.5-pro-001", + vertexProjectId: "my-project", + vertexRegion: "us-central1", + vertexJsonCredentials: JSON.stringify(credentials), + vertexKeyFile: "/path/to/keyfile.json", + }) + + expect(mockCreateVertex).toHaveBeenCalledWith( + expect.objectContaining({ + googleAuthOptions: { credentials }, + }), + ) + }) + + it("should handle invalid JSON credentials gracefully", () => { + new VertexHandler({ + apiModelId: "gemini-1.5-pro-001", + vertexProjectId: "my-project", + vertexRegion: "us-central1", + vertexJsonCredentials: "invalid-json", + }) + + // Should not throw and should create provider without credentials + expect(mockCreateVertex).toHaveBeenCalledWith( + expect.objectContaining({ + project: "my-project", + googleAuthOptions: undefined, + }), + ) + }) }) describe("createMessage", () => { @@ -43,19 +154,11 @@ describe("VertexHandler", () => { const systemPrompt = "You are a helpful assistant" - it("should handle streaming responses correctly for Gemini", async () => { - // Let's examine the test expectations and adjust our mock accordingly - // The test expects 4 chunks: - // 1. Usage chunk with input tokens - // 2. Text chunk with "Gemini response part 1" - // 3. Text chunk with " part 2" - // 4. Usage chunk with output tokens - - // Let's modify our approach and directly mock the createMessage method - // instead of mocking the client + it("should handle streaming responses correctly", async () => { + // Mock the createMessage method to test the streaming behavior vitest.spyOn(handler, "createMessage").mockImplementation(async function* () { yield { type: "usage", inputTokens: 10, outputTokens: 0 } - yield { type: "text", text: "Gemini response part 1" } + yield { type: "text", text: "Vertex response part 1" } yield { type: "text", text: " part 2" } yield { type: "usage", inputTokens: 0, outputTokens: 5 } }) @@ -70,59 +173,130 @@ describe("VertexHandler", () => { expect(chunks.length).toBe(4) expect(chunks[0]).toEqual({ type: "usage", inputTokens: 10, outputTokens: 0 }) - expect(chunks[1]).toEqual({ type: "text", text: "Gemini response part 1" }) + expect(chunks[1]).toEqual({ type: "text", text: "Vertex response part 1" }) expect(chunks[2]).toEqual({ type: "text", text: " part 2" }) expect(chunks[3]).toEqual({ type: "usage", inputTokens: 0, outputTokens: 5 }) + }) + + it("should call streamText with correct options", async () => { + const mockFullStream = (async function* () { + yield { type: "text-delta", textDelta: "Hello" } + })() + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream, + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5 }), + providerMetadata: Promise.resolve({}), + }) + + const stream = handler.createMessage(systemPrompt, mockMessages) + const chunks: ApiStreamChunk[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } - // Since we're directly mocking createMessage, we don't need to verify - // that generateContentStream was called + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + system: systemPrompt, + temperature: 1, + }), + ) }) }) describe("completePrompt", () => { - it("should complete prompt successfully for Gemini", async () => { - // Mock the response with text property - ;(handler["client"].models.generateContent as any).mockResolvedValue({ - text: "Test Gemini response", + it("should complete prompt successfully", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test Vertex response", + providerMetadata: {}, }) const result = await handler.completePrompt("Test prompt") - expect(result).toBe("Test Gemini response") + expect(result).toBe("Test Vertex response") - // Verify the call to generateContent - expect(handler["client"].models.generateContent).toHaveBeenCalledWith( + // Verify generateText was called with the prompt + expect(mockGenerateText).toHaveBeenCalledWith( expect.objectContaining({ - model: expect.any(String), - contents: [{ role: "user", parts: [{ text: "Test prompt" }] }], - config: expect.objectContaining({ - temperature: 1, - }), + prompt: "Test prompt", + temperature: 1, }), ) }) - it("should handle API errors for Gemini", async () => { + it("should handle API errors", async () => { const mockError = new Error("Vertex API error") - ;(handler["client"].models.generateContent as any).mockRejectedValue(mockError) + mockGenerateText.mockRejectedValue(mockError) await expect(handler.completePrompt("Test prompt")).rejects.toThrow( t("common:errors.gemini.generate_complete_prompt", { error: "Vertex API error" }), ) }) - it("should handle empty response for Gemini", async () => { - // Mock the response with empty text - ;(handler["client"].models.generateContent as any).mockResolvedValue({ + it("should handle empty response", async () => { + mockGenerateText.mockResolvedValue({ text: "", + providerMetadata: {}, }) const result = await handler.completePrompt("Test prompt") expect(result).toBe("") }) + + it("should add Google Search tool when grounding is enabled", async () => { + const handlerWithGrounding = new VertexHandler({ + apiModelId: "gemini-1.5-pro-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + enableGrounding: true, + }) + + mockGenerateText.mockResolvedValue({ + text: "Search result", + providerMetadata: {}, + }) + mockGoogleSearchTool.mockReturnValue({ type: "googleSearch" }) + + await handlerWithGrounding.completePrompt("Search query") + + expect(mockGoogleSearchTool).toHaveBeenCalledWith({}) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + tools: expect.objectContaining({ + google_search: { type: "googleSearch" }, + }), + }), + ) + }) + + it("should add URL Context tool when enabled", async () => { + const handlerWithUrlContext = new VertexHandler({ + apiModelId: "gemini-1.5-pro-001", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + enableUrlContext: true, + }) + + mockGenerateText.mockResolvedValue({ + text: "URL context result", + providerMetadata: {}, + }) + mockUrlContextTool.mockReturnValue({ type: "urlContext" }) + + await handlerWithUrlContext.completePrompt("Fetch URL") + + expect(mockUrlContextTool).toHaveBeenCalledWith({}) + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + tools: expect.objectContaining({ + url_context: { type: "urlContext" }, + }), + }), + ) + }) }) describe("getModel", () => { - it("should return correct model info for Gemini", () => { + it("should return correct model info", () => { // Create a new instance with specific model ID const testHandler = new VertexHandler({ apiModelId: "gemini-2.0-flash-001", @@ -130,12 +304,135 @@ describe("VertexHandler", () => { vertexRegion: "us-central1", }) - // Don't mock getModel here as we want to test the actual implementation const modelInfo = testHandler.getModel() expect(modelInfo.id).toBe("gemini-2.0-flash-001") expect(modelInfo.info).toBeDefined() expect(modelInfo.info.maxTokens).toBe(8192) expect(modelInfo.info.contextWindow).toBe(1048576) }) + + it("should return default model when invalid ID provided", () => { + const testHandler = new VertexHandler({ + apiModelId: "invalid-model-id", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const modelInfo = testHandler.getModel() + // Should fall back to default model + expect(modelInfo.info).toBeDefined() + }) + + it("should strip :thinking suffix from model ID", () => { + const testHandler = new VertexHandler({ + apiModelId: "gemini-2.5-flash-preview-05-20:thinking", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const modelInfo = testHandler.getModel() + expect(modelInfo.id).toBe("gemini-2.5-flash-preview-05-20") + }) + }) + + describe("calculateCost", () => { + it("should calculate cost correctly", () => { + const result = handler.calculateCost({ + info: { + maxTokens: 8192, + contextWindow: 1048576, + supportsPromptCache: false, + inputPrice: 1.25, + outputPrice: 5.0, + }, + inputTokens: 1000, + outputTokens: 500, + }) + + // Input: 1.25 * (1000 / 1_000_000) = 0.00125 + // Output: 5.0 * (500 / 1_000_000) = 0.0025 + // Total: 0.00375 + expect(result).toBeCloseTo(0.00375, 5) + }) + + it("should handle cache read tokens", () => { + const result = handler.calculateCost({ + info: { + maxTokens: 8192, + contextWindow: 1048576, + supportsPromptCache: true, + inputPrice: 1.25, + outputPrice: 5.0, + cacheReadsPrice: 0.3125, + }, + inputTokens: 1000, + outputTokens: 500, + cacheReadTokens: 400, + }) + + // Uncached input: 600 tokens at 1.25/M = 0.00075 + // Cache read: 400 tokens at 0.3125/M = 0.000125 + // Output: 500 tokens at 5.0/M = 0.0025 + // Total: 0.003375 + expect(result).toBeCloseTo(0.003375, 5) + }) + + it("should handle reasoning tokens", () => { + const result = handler.calculateCost({ + info: { + maxTokens: 8192, + contextWindow: 1048576, + supportsPromptCache: false, + inputPrice: 1.25, + outputPrice: 5.0, + }, + inputTokens: 1000, + outputTokens: 500, + reasoningTokens: 200, + }) + + // Input: 1.25 * (1000 / 1_000_000) = 0.00125 + // Output + Reasoning: 5.0 * (700 / 1_000_000) = 0.0035 + // Total: 0.00475 + expect(result).toBeCloseTo(0.00475, 5) + }) + + it("should return undefined when prices are missing", () => { + const result = handler.calculateCost({ + info: { + maxTokens: 8192, + contextWindow: 1048576, + supportsPromptCache: false, + }, + inputTokens: 1000, + outputTokens: 500, + }) + + expect(result).toBeUndefined() + }) + + it("should use tiered pricing when available", () => { + const result = handler.calculateCost({ + info: { + maxTokens: 8192, + contextWindow: 1048576, + supportsPromptCache: false, + inputPrice: 1.25, + outputPrice: 5.0, + tiers: [ + { contextWindow: 128000, inputPrice: 0.5, outputPrice: 2.0 }, + { contextWindow: 1048576, inputPrice: 1.0, outputPrice: 4.0 }, + ], + }, + inputTokens: 50000, // Falls into first tier + outputTokens: 500, + }) + + // Uses tier 1 pricing: inputPrice=0.5, outputPrice=2.0 + // Input: 0.5 * (50000 / 1_000_000) = 0.025 + // Output: 2.0 * (500 / 1_000_000) = 0.001 + // Total: 0.026 + expect(result).toBeCloseTo(0.026, 5) + }) }) }) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 823ed0ac8b0..ed291db74e7 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -1,13 +1,6 @@ import type { Anthropic } from "@anthropic-ai/sdk" -import { - GoogleGenAI, - type GenerateContentResponseUsageMetadata, - type GenerateContentParameters, - type GenerateContentConfig, - type GroundingMetadata, - FunctionCallingConfigMode, -} from "@google/genai" -import type { JWTInput } from "google-auth-library" +import { createGoogleGenerativeAI, type GoogleGenerativeAIProvider } from "@ai-sdk/google" +import { streamText, generateText, ToolSet } from "ai" import { type ModelInfo, @@ -16,59 +9,42 @@ import { geminiModels, ApiProviderError, } from "@roo-code/types" -import { safeJsonParse } from "@roo-code/core" import { TelemetryService } from "@roo-code/telemetry" import type { ApiHandlerOptions } from "../../shared/api" -import { convertAnthropicMessageToGemini } from "../transform/gemini-format" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, +} from "../transform/ai-sdk" import { t } from "i18next" -import type { ApiStream, GroundingSource } from "../transform/stream" +import type { ApiStream, ApiStreamUsageChunk, GroundingSource } from "../transform/stream" import { getModelParams } from "../transform/model-params" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" import { BaseProvider } from "./base-provider" - -type GeminiHandlerOptions = ApiHandlerOptions & { - isVertex?: boolean -} +import { DEFAULT_HEADERS } from "./constants" export class GeminiHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions - - private client: GoogleGenAI - private lastThoughtSignature?: string - private lastResponseId?: string + protected provider: GoogleGenerativeAIProvider private readonly providerName = "Gemini" - constructor({ isVertex, ...options }: GeminiHandlerOptions) { + constructor(options: ApiHandlerOptions) { super() this.options = options - const project = this.options.vertexProjectId ?? "not-provided" - const location = this.options.vertexRegion ?? "not-provided" - const apiKey = this.options.geminiApiKey ?? "not-provided" - - this.client = this.options.vertexJsonCredentials - ? new GoogleGenAI({ - vertexai: true, - project, - location, - googleAuthOptions: { - credentials: safeJsonParse(this.options.vertexJsonCredentials, undefined), - }, - }) - : this.options.vertexKeyFile - ? new GoogleGenAI({ - vertexai: true, - project, - location, - googleAuthOptions: { keyFile: this.options.vertexKeyFile }, - }) - : isVertex - ? new GoogleGenAI({ vertexai: true, project, location }) - : new GoogleGenAI({ apiKey }) + // Create the Google Generative AI provider using AI SDK + // For Vertex AI, we still use this provider but with different authentication + // (Vertex authentication happens separately) + this.provider = createGoogleGenerativeAI({ + apiKey: this.options.geminiApiKey ?? "not-provided", + baseURL: this.options.googleGeminiBaseUrl, + headers: DEFAULT_HEADERS, + }) } async *createMessage( @@ -76,10 +52,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { id: model, info, reasoning: thinkingConfig, maxTokens } = this.getModel() - // Reset per-request metadata that we persist into apiConversationHistory. - this.lastThoughtSignature = undefined - this.lastResponseId = undefined + const { id: modelId, info, reasoning: thinkingConfig, maxTokens } = this.getModel() // For hybrid/budget reasoning models (e.g. Gemini 2.5 Pro), respect user-configured // modelMaxTokens so the ThinkingBudget slider can control the cap. For effort-only or @@ -90,19 +63,23 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl ? (this.options.modelMaxTokens ?? maxTokens ?? undefined) : (maxTokens ?? undefined) - // Gemini 3 validates thought signatures for tool/function calling steps. - // We must round-trip the signature when tools are in use, even if the user chose - // a minimal thinking level (or thinkingConfig is otherwise absent). - const includeThoughtSignatures = Boolean(thinkingConfig) || Boolean(metadata?.tools?.length) + // Determine temperature respecting model capabilities and defaults: + // - If supportsTemperature is explicitly false, ignore user overrides + // and pin to the model's defaultTemperature (or omit if undefined). + // - Otherwise, allow the user setting to override, falling back to model default, + // then to 1 for Gemini provider default. + const supportsTemperature = info.supportsTemperature !== false + const temperatureConfig: number | undefined = supportsTemperature + ? (this.options.modelTemperature ?? info.defaultTemperature ?? 1) + : info.defaultTemperature // The message list can include provider-specific meta entries such as // `{ type: "reasoning", ... }` that are intended only for providers like // openai-native. Gemini should never see those; they are not valid - // Anthropic.MessageParam values and will cause failures (e.g. missing - // `content` for the converter). Filter them out here. + // Anthropic.MessageParam values and will cause failures. type ReasoningMetaLike = { type?: string } - const geminiMessages = messages.filter((message): message is Anthropic.Messages.MessageParam => { + const filteredMessages = messages.filter((message): message is Anthropic.Messages.MessageParam => { const meta = message as ReasoningMetaLike if (meta.type === "reasoning") { return false @@ -110,232 +87,80 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl return true }) - // Build a map of tool IDs to names from previous messages - // This is needed because Anthropic's tool_result blocks only contain the ID, - // but Gemini requires the name in functionResponse - const toolIdToName = new Map() - for (const message of messages) { - if (Array.isArray(message.content)) { - for (const block of message.content) { - if (block.type === "tool_use") { - toolIdToName.set(block.id, block.name) - } - } - } - } - - const contents = geminiMessages - .map((message) => convertAnthropicMessageToGemini(message, { includeThoughtSignatures, toolIdToName })) - .flat() - - // Tools are always present (minimum ALWAYS_AVAILABLE_TOOLS). - // Google built-in tools (Grounding, URL Context) are mutually exclusive - // with function declarations in the Gemini API, so we always use - // function declarations when tools are provided. - const tools: GenerateContentConfig["tools"] = [ - { - functionDeclarations: (metadata?.tools ?? []).map((tool) => ({ - name: (tool as any).function.name, - description: (tool as any).function.description, - parametersJsonSchema: (tool as any).function.parameters, - })), - }, - ] + // Convert messages to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(filteredMessages) - // Determine temperature respecting model capabilities and defaults: - // - If supportsTemperature is explicitly false, ignore user overrides - // and pin to the model's defaultTemperature (or omit if undefined). - // - Otherwise, allow the user setting to override, falling back to model default, - // then to 1 for Gemini provider default. - const supportsTemperature = info.supportsTemperature !== false - const temperatureConfig: number | undefined = supportsTemperature - ? (this.options.modelTemperature ?? info.defaultTemperature ?? 1) - : info.defaultTemperature + // Convert tools to OpenAI format first, then to AI SDK format + let openAiTools = this.convertToolsForOpenAI(metadata?.tools) - const config: GenerateContentConfig = { - systemInstruction, - httpOptions: this.options.googleGeminiBaseUrl ? { baseUrl: this.options.googleGeminiBaseUrl } : undefined, - thinkingConfig, - maxOutputTokens, - temperature: temperatureConfig, - ...(tools.length > 0 ? { tools } : {}), + // Filter tools based on allowedFunctionNames for mode-restricted tool access + if (metadata?.allowedFunctionNames && metadata.allowedFunctionNames.length > 0 && openAiTools) { + const allowedSet = new Set(metadata.allowedFunctionNames) + openAiTools = openAiTools.filter((tool) => tool.type === "function" && allowedSet.has(tool.function.name)) } - // Handle allowedFunctionNames for mode-restricted tool access. - // When provided, all tool definitions are passed to the model (so it can reference - // historical tool calls in conversation), but only the specified tools can be invoked. - // This takes precedence over tool_choice to ensure mode restrictions are honored. - if (metadata?.allowedFunctionNames && metadata.allowedFunctionNames.length > 0) { - config.toolConfig = { - functionCallingConfig: { - // Use ANY mode to allow calling any of the allowed functions - mode: FunctionCallingConfigMode.ANY, - allowedFunctionNames: metadata.allowedFunctionNames, - }, - } - } else if (metadata?.tool_choice) { - const choice = metadata.tool_choice - let mode: FunctionCallingConfigMode - let allowedFunctionNames: string[] | undefined - - if (choice === "auto") { - mode = FunctionCallingConfigMode.AUTO - } else if (choice === "none") { - mode = FunctionCallingConfigMode.NONE - } else if (choice === "required") { - // "required" means the model must call at least one tool; Gemini uses ANY for this. - mode = FunctionCallingConfigMode.ANY - } else if (typeof choice === "object" && "function" in choice && choice.type === "function") { - mode = FunctionCallingConfigMode.ANY - allowedFunctionNames = [choice.function.name] - } else { - // Fall back to AUTO for unknown values to avoid unintentionally broadening tool access. - mode = FunctionCallingConfigMode.AUTO - } + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined - config.toolConfig = { - functionCallingConfig: { - mode, - ...(allowedFunctionNames ? { allowedFunctionNames } : {}), - }, - } - } + // Build tool choice - use 'required' when allowedFunctionNames restricts available tools + const toolChoice = + metadata?.allowedFunctionNames && metadata.allowedFunctionNames.length > 0 + ? "required" + : mapToolChoice(metadata?.tool_choice) - const params: GenerateContentParameters = { model, contents, config } + // Build the request options + const requestOptions: Parameters[0] = { + model: this.provider(modelId), + system: systemInstruction, + messages: aiSdkMessages, + temperature: temperatureConfig, + maxOutputTokens, + tools: aiSdkTools, + toolChoice, + // Add thinking/reasoning configuration if present + // Cast to any to bypass strict JSONObject typing - the AI SDK accepts the correct runtime values + ...(thinkingConfig && { + providerOptions: { google: { thinkingConfig } } as any, + }), + } try { - const result = await this.client.models.generateContentStream(params) - - let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined - let pendingGroundingMetadata: GroundingMetadata | undefined - let finalResponse: { responseId?: string } | undefined - let finishReason: string | undefined - - let toolCallCounter = 0 - let hasContent = false - let hasReasoning = false - - for await (const chunk of result) { - // Track the final structured response (per SDK pattern: candidate.finishReason) - if (chunk.candidates && chunk.candidates[0]?.finishReason) { - finalResponse = chunk as { responseId?: string } - finishReason = chunk.candidates[0].finishReason - } - // Process candidates and their parts to separate thoughts from content - if (chunk.candidates && chunk.candidates.length > 0) { - const candidate = chunk.candidates[0] + // Use streamText for streaming responses + const result = streamText(requestOptions) - if (candidate.groundingMetadata) { - pendingGroundingMetadata = candidate.groundingMetadata - } - - if (candidate.content && candidate.content.parts) { - for (const part of candidate.content.parts as Array<{ - thought?: boolean - text?: string - thoughtSignature?: string - functionCall?: { name: string; args: Record } - }>) { - // Capture thought signatures so they can be persisted into API history. - const thoughtSignature = part.thoughtSignature - // Persist thought signatures so they can be round-tripped in the next step. - // Gemini 3 requires this during tool calling; other Gemini thinking models - // benefit from it for continuity. - if (includeThoughtSignatures && thoughtSignature) { - this.lastThoughtSignature = thoughtSignature - } - - if (part.thought) { - // This is a thinking/reasoning part - if (part.text) { - hasReasoning = true - yield { type: "reasoning", text: part.text } - } - } else if (part.functionCall) { - hasContent = true - // Gemini sends complete function calls in a single chunk - // Emit as partial chunks for consistent handling with NativeToolCallParser - const callId = `${part.functionCall.name}-${toolCallCounter}` - const args = JSON.stringify(part.functionCall.args) - - // Emit name first - yield { - type: "tool_call_partial", - index: toolCallCounter, - id: callId, - name: part.functionCall.name, - arguments: undefined, - } - - // Then emit arguments - yield { - type: "tool_call_partial", - index: toolCallCounter, - id: callId, - name: undefined, - arguments: args, - } - - toolCallCounter++ - } else { - // This is regular content - if (part.text) { - hasContent = true - yield { type: "text", text: part.text } - } - } - } - } - } - - // Fallback to the original text property if no candidates structure - else if (chunk.text) { - hasContent = true - yield { type: "text", text: chunk.text } - } - - if (chunk.usageMetadata) { - lastUsageMetadata = chunk.usageMetadata + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk } } - if (finalResponse?.responseId) { - // Capture responseId so Task.addToApiConversationHistory can store it - // alongside the assistant message in api_history.json. - this.lastResponseId = finalResponse.responseId - } + // Extract grounding sources from providerMetadata if available + const providerMetadata = await result.providerMetadata + const groundingMetadata = providerMetadata?.google as + | { + groundingMetadata?: { + groundingChunks?: Array<{ + web?: { uri?: string; title?: string } + }> + } + } + | undefined - if (pendingGroundingMetadata) { - const sources = this.extractGroundingSources(pendingGroundingMetadata) + if (groundingMetadata?.groundingMetadata) { + const sources = this.extractGroundingSources(groundingMetadata.groundingMetadata) if (sources.length > 0) { yield { type: "grounding", sources } } } - if (lastUsageMetadata) { - const inputTokens = lastUsageMetadata.promptTokenCount ?? 0 - const outputTokens = lastUsageMetadata.candidatesTokenCount ?? 0 - const cacheReadTokens = lastUsageMetadata.cachedContentTokenCount - const reasoningTokens = lastUsageMetadata.thoughtsTokenCount - - yield { - type: "usage", - inputTokens, - outputTokens, - cacheReadTokens, - reasoningTokens, - totalCost: this.calculateCost({ - info, - inputTokens, - outputTokens, - cacheReadTokens, - reasoningTokens, - }), - } + // Yield usage metrics at the end + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage, info, providerMetadata) } } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) - const apiError = new ApiProviderError(errorMessage, this.providerName, model, "createMessage") + const apiError = new ApiProviderError(errorMessage, this.providerName, modelId, "createMessage") TelemetryService.instance.captureException(apiError) if (error instanceof Error) { @@ -366,7 +191,47 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params } } - private extractGroundingSources(groundingMetadata?: GroundingMetadata): GroundingSource[] { + /** + * Process usage metrics from the AI SDK response. + */ + protected processUsageMetrics( + usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }, + info: ModelInfo, + providerMetadata?: Record, + ): ApiStreamUsageChunk { + const inputTokens = usage.inputTokens || 0 + const outputTokens = usage.outputTokens || 0 + const cacheReadTokens = usage.details?.cachedInputTokens + const reasoningTokens = usage.details?.reasoningTokens + + return { + type: "usage", + inputTokens, + outputTokens, + cacheReadTokens, + reasoningTokens, + totalCost: this.calculateCost({ + info, + inputTokens, + outputTokens, + cacheReadTokens, + reasoningTokens, + }), + } + } + + private extractGroundingSources(groundingMetadata?: { + groundingChunks?: Array<{ + web?: { uri?: string; title?: string } + }> + }): GroundingSource[] { const chunks = groundingMetadata?.groundingChunks if (!chunks) { @@ -389,7 +254,11 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl .filter((source): source is GroundingSource => source !== null) } - private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null { + private extractCitationsOnly(groundingMetadata?: { + groundingChunks?: Array<{ + web?: { uri?: string; title?: string } + }> + }): string | null { const sources = this.extractGroundingSources(groundingMetadata) if (sources.length === 0) { @@ -401,15 +270,21 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl } async completePrompt(prompt: string): Promise { - const { id: model, info } = this.getModel() + const { id: modelId, info } = this.getModel() try { - const tools: GenerateContentConfig["tools"] = [] + // Build tools for grounding - cast to any to bypass strict typing + // Google provider tools have a different shape than standard ToolSet + const tools: Record = {} + + // Add URL context tool if enabled if (this.options.enableUrlContext) { - tools.push({ urlContext: {} }) + tools.url_context = this.provider.tools.urlContext({}) } + + // Add Google Search grounding tool if enabled if (this.options.enableGrounding) { - tools.push({ googleSearch: {} }) + tools.google_search = this.provider.tools.googleSearch({}) } const supportsTemperature = info.supportsTemperature !== false @@ -417,27 +292,29 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl ? (this.options.modelTemperature ?? info.defaultTemperature ?? 1) : info.defaultTemperature - const promptConfig: GenerateContentConfig = { - httpOptions: this.options.googleGeminiBaseUrl - ? { baseUrl: this.options.googleGeminiBaseUrl } - : undefined, + const result = await generateText({ + model: this.provider(modelId), + prompt, temperature: temperatureConfig, - ...(tools.length > 0 ? { tools } : {}), - } - - const request = { - model, - contents: [{ role: "user", parts: [{ text: prompt }] }], - config: promptConfig, - } - - const result = await this.client.models.generateContent(request) + ...(Object.keys(tools).length > 0 && { tools: tools as ToolSet }), + }) let text = result.text ?? "" - const candidate = result.candidates?.[0] - if (candidate?.groundingMetadata) { - const citations = this.extractCitationsOnly(candidate.groundingMetadata) + // Extract grounding citations from providerMetadata if available + const providerMetadata = result.providerMetadata + const groundingMetadata = providerMetadata?.google as + | { + groundingMetadata?: { + groundingChunks?: Array<{ + web?: { uri?: string; title?: string } + }> + } + } + | undefined + + if (groundingMetadata?.groundingMetadata) { + const citations = this.extractCitationsOnly(groundingMetadata.groundingMetadata) if (citations) { text += `\n\n${t("common:errors.gemini.sources")} ${citations}` } @@ -446,7 +323,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl return text } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) - const apiError = new ApiProviderError(errorMessage, this.providerName, model, "completePrompt") + const apiError = new ApiProviderError(errorMessage, this.providerName, modelId, "completePrompt") TelemetryService.instance.captureException(apiError) if (error instanceof Error) { @@ -457,14 +334,6 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl } } - public getThoughtSignature(): string | undefined { - return this.lastThoughtSignature - } - - public getResponseId(): string | undefined { - return this.lastResponseId - } - public calculateCost({ info, inputTokens, diff --git a/src/api/providers/vertex.ts b/src/api/providers/vertex.ts index 2c077d97b7e..62aec505c92 100644 --- a/src/api/providers/vertex.ts +++ b/src/api/providers/vertex.ts @@ -1,22 +1,202 @@ -import { type ModelInfo, type VertexModelId, vertexDefaultModelId, vertexModels } from "@roo-code/types" +import type { Anthropic } from "@anthropic-ai/sdk" +import { createVertex, type GoogleVertexProvider } from "@ai-sdk/google-vertex" +import { streamText, generateText, ToolSet } from "ai" + +import { + type ModelInfo, + type VertexModelId, + vertexDefaultModelId, + vertexModels, + ApiProviderError, +} from "@roo-code/types" +import { TelemetryService } from "@roo-code/telemetry" import type { ApiHandlerOptions } from "../../shared/api" +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, +} from "../transform/ai-sdk" +import { t } from "i18next" +import type { ApiStream, ApiStreamUsageChunk, GroundingSource } from "../transform/stream" import { getModelParams } from "../transform/model-params" -import { GeminiHandler } from "./gemini" -import { SingleCompletionHandler } from "../index" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { BaseProvider } from "./base-provider" +import { DEFAULT_HEADERS } from "./constants" + +/** + * Vertex AI provider using the dedicated @ai-sdk/google-vertex package. + * Provides native support for Google's Vertex AI with proper authentication. + */ +export class VertexHandler extends BaseProvider implements SingleCompletionHandler { + protected options: ApiHandlerOptions + protected provider: GoogleVertexProvider + private readonly providerName = "Vertex" -export class VertexHandler extends GeminiHandler implements SingleCompletionHandler { constructor(options: ApiHandlerOptions) { - super({ ...options, isVertex: true }) + super() + this.options = options + + // Build googleAuthOptions based on provided credentials + let googleAuthOptions: { credentials?: object; keyFile?: string } | undefined + if (options.vertexJsonCredentials) { + try { + googleAuthOptions = { credentials: JSON.parse(options.vertexJsonCredentials) } + } catch { + // If JSON parsing fails, ignore and try other auth methods + } + } else if (options.vertexKeyFile) { + googleAuthOptions = { keyFile: options.vertexKeyFile } + } + + // Create the Vertex AI provider using AI SDK + this.provider = createVertex({ + project: options.vertexProjectId, + location: options.vertexRegion, + googleAuthOptions, + headers: DEFAULT_HEADERS, + }) + } + + async *createMessage( + systemInstruction: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { id: modelId, info, reasoning: thinkingConfig, maxTokens } = this.getModel() + + // For hybrid/budget reasoning models (e.g. Gemini 2.5 Pro), respect user-configured + // modelMaxTokens so the ThinkingBudget slider can control the cap. For effort-only or + // standard models (like gemini-3-pro-preview), ignore any stale modelMaxTokens and + // default to the model's computed maxTokens from getModelMaxOutputTokens. + const isHybridReasoningModel = info.supportsReasoningBudget || info.requiredReasoningBudget + const maxOutputTokens = isHybridReasoningModel + ? (this.options.modelMaxTokens ?? maxTokens ?? undefined) + : (maxTokens ?? undefined) + + // Determine temperature respecting model capabilities and defaults: + // - If supportsTemperature is explicitly false, ignore user overrides + // and pin to the model's defaultTemperature (or omit if undefined). + // - Otherwise, allow the user setting to override, falling back to model default, + // then to 1 for Gemini provider default. + const supportsTemperature = info.supportsTemperature !== false + const temperatureConfig: number | undefined = supportsTemperature + ? (this.options.modelTemperature ?? info.defaultTemperature ?? 1) + : info.defaultTemperature + + // The message list can include provider-specific meta entries such as + // `{ type: "reasoning", ... }` that are intended only for providers like + // openai-native. Vertex should never see those; they are not valid + // Anthropic.MessageParam values and will cause failures. + type ReasoningMetaLike = { type?: string } + + const filteredMessages = messages.filter((message): message is Anthropic.Messages.MessageParam => { + const meta = message as ReasoningMetaLike + if (meta.type === "reasoning") { + return false + } + return true + }) + + // Convert messages to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(filteredMessages) + + // Convert tools to OpenAI format first, then to AI SDK format + let openAiTools = this.convertToolsForOpenAI(metadata?.tools) + + // Filter tools based on allowedFunctionNames for mode-restricted tool access + if (metadata?.allowedFunctionNames && metadata.allowedFunctionNames.length > 0 && openAiTools) { + const allowedSet = new Set(metadata.allowedFunctionNames) + openAiTools = openAiTools.filter((tool) => tool.type === "function" && allowedSet.has(tool.function.name)) + } + + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + // Build tool choice - use 'required' when allowedFunctionNames restricts available tools + const toolChoice = + metadata?.allowedFunctionNames && metadata.allowedFunctionNames.length > 0 + ? "required" + : mapToolChoice(metadata?.tool_choice) + + // Build the request options + const requestOptions: Parameters[0] = { + model: this.provider(modelId), + system: systemInstruction, + messages: aiSdkMessages, + temperature: temperatureConfig, + maxOutputTokens, + tools: aiSdkTools, + toolChoice, + // Add thinking/reasoning configuration if present + // Cast to any to bypass strict JSONObject typing - the AI SDK accepts the correct runtime values + ...(thinkingConfig && { + providerOptions: { google: { thinkingConfig } } as any, + }), + } + + try { + // Use streamText for streaming responses + const result = streamText(requestOptions) + + // Process the full stream to get all events including reasoning + for await (const part of result.fullStream) { + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } + } + + // Extract grounding sources from providerMetadata if available + const providerMetadata = await result.providerMetadata + const groundingMetadata = providerMetadata?.google as + | { + groundingMetadata?: { + groundingChunks?: Array<{ + web?: { uri?: string; title?: string } + }> + } + } + | undefined + + if (groundingMetadata?.groundingMetadata) { + const sources = this.extractGroundingSources(groundingMetadata.groundingMetadata) + if (sources.length > 0) { + yield { type: "grounding", sources } + } + } + + // Yield usage metrics at the end + const usage = await result.usage + if (usage) { + yield this.processUsageMetrics(usage, info, providerMetadata) + } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + const apiError = new ApiProviderError(errorMessage, this.providerName, modelId, "createMessage") + TelemetryService.instance.captureException(apiError) + + if (error instanceof Error) { + throw new Error(t("common:errors.gemini.generate_stream", { error: error.message })) + } + + throw error + } } override getModel() { const modelId = this.options.apiModelId let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId - const info: ModelInfo = vertexModels[id] - const params = getModelParams({ format: "gemini", modelId: id, model: info, settings: this.options }) + let info: ModelInfo = vertexModels[id] + + const params = getModelParams({ + format: "gemini", + modelId: id, + model: info, + settings: this.options, + defaultTemperature: info.defaultTemperature ?? 1, + }) // The `:thinking` suffix indicates that the model is a "Hybrid" // reasoning model and that reasoning is required to be enabled. @@ -24,4 +204,202 @@ export class VertexHandler extends GeminiHandler implements SingleCompletionHand // suffix. return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params } } + + /** + * Process usage metrics from the AI SDK response. + */ + protected processUsageMetrics( + usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }, + info: ModelInfo, + providerMetadata?: Record, + ): ApiStreamUsageChunk { + const inputTokens = usage.inputTokens || 0 + const outputTokens = usage.outputTokens || 0 + const cacheReadTokens = usage.details?.cachedInputTokens + const reasoningTokens = usage.details?.reasoningTokens + + return { + type: "usage", + inputTokens, + outputTokens, + cacheReadTokens, + reasoningTokens, + totalCost: this.calculateCost({ + info, + inputTokens, + outputTokens, + cacheReadTokens, + reasoningTokens, + }), + } + } + + private extractGroundingSources(groundingMetadata?: { + groundingChunks?: Array<{ + web?: { uri?: string; title?: string } + }> + }): GroundingSource[] { + const chunks = groundingMetadata?.groundingChunks + + if (!chunks) { + return [] + } + + return chunks + .map((chunk): GroundingSource | null => { + const uri = chunk.web?.uri + const title = chunk.web?.title || uri || "Unknown Source" + + if (uri) { + return { + title, + url: uri, + } + } + return null + }) + .filter((source): source is GroundingSource => source !== null) + } + + private extractCitationsOnly(groundingMetadata?: { + groundingChunks?: Array<{ + web?: { uri?: string; title?: string } + }> + }): string | null { + const sources = this.extractGroundingSources(groundingMetadata) + + if (sources.length === 0) { + return null + } + + const citationLinks = sources.map((source, i) => `[${i + 1}](${source.url})`) + return citationLinks.join(", ") + } + + async completePrompt(prompt: string): Promise { + const { id: modelId, info } = this.getModel() + + try { + // Build tools for grounding - cast to any to bypass strict typing + // Google provider tools have a different shape than standard ToolSet + const tools: Record = {} + + // Add URL context tool if enabled + if (this.options.enableUrlContext) { + tools.url_context = this.provider.tools.urlContext({}) + } + + // Add Google Search grounding tool if enabled + if (this.options.enableGrounding) { + tools.google_search = this.provider.tools.googleSearch({}) + } + + const supportsTemperature = info.supportsTemperature !== false + const temperatureConfig: number | undefined = supportsTemperature + ? (this.options.modelTemperature ?? info.defaultTemperature ?? 1) + : info.defaultTemperature + + const result = await generateText({ + model: this.provider(modelId), + prompt, + temperature: temperatureConfig, + ...(Object.keys(tools).length > 0 && { tools: tools as ToolSet }), + }) + + let text = result.text ?? "" + + // Extract grounding citations from providerMetadata if available + const providerMetadata = result.providerMetadata + const groundingMetadata = providerMetadata?.google as + | { + groundingMetadata?: { + groundingChunks?: Array<{ + web?: { uri?: string; title?: string } + }> + } + } + | undefined + + if (groundingMetadata?.groundingMetadata) { + const citations = this.extractCitationsOnly(groundingMetadata.groundingMetadata) + if (citations) { + text += `\n\n${t("common:errors.gemini.sources")} ${citations}` + } + } + + return text + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + const apiError = new ApiProviderError(errorMessage, this.providerName, modelId, "completePrompt") + TelemetryService.instance.captureException(apiError) + + if (error instanceof Error) { + throw new Error(t("common:errors.gemini.generate_complete_prompt", { error: error.message })) + } + + throw error + } + } + + public calculateCost({ + info, + inputTokens, + outputTokens, + cacheReadTokens = 0, + reasoningTokens = 0, + }: { + info: ModelInfo + inputTokens: number + outputTokens: number + cacheReadTokens?: number + reasoningTokens?: number + }) { + // For models with tiered pricing, prices might only be defined in tiers + let inputPrice = info.inputPrice + let outputPrice = info.outputPrice + let cacheReadsPrice = info.cacheReadsPrice + + // If there's tiered pricing then adjust the input and output token prices + // based on the input tokens used. + if (info.tiers) { + const tier = info.tiers.find((tier) => inputTokens <= tier.contextWindow) + + if (tier) { + inputPrice = tier.inputPrice ?? inputPrice + outputPrice = tier.outputPrice ?? outputPrice + cacheReadsPrice = tier.cacheReadsPrice ?? cacheReadsPrice + } + } + + // Check if we have the required prices after considering tiers + if (!inputPrice || !outputPrice) { + return undefined + } + + // cacheReadsPrice is optional - if not defined, treat as 0 + if (!cacheReadsPrice) { + cacheReadsPrice = 0 + } + + // Subtract the cached input tokens from the total input tokens. + const uncachedInputTokens = inputTokens - cacheReadTokens + + // Bill both completion and reasoning ("thoughts") tokens as output. + const billedOutputTokens = outputTokens + reasoningTokens + + let cacheReadCost = cacheReadTokens > 0 ? cacheReadsPrice * (cacheReadTokens / 1_000_000) : 0 + + const inputTokensCost = inputPrice * (uncachedInputTokens / 1_000_000) + const outputTokensCost = outputPrice * (billedOutputTokens / 1_000_000) + const totalCost = inputTokensCost + outputTokensCost + cacheReadCost + + return totalCost + } } diff --git a/src/api/transform/__tests__/gemini-format.spec.ts b/src/api/transform/__tests__/gemini-format.spec.ts deleted file mode 100644 index 23f752e207f..00000000000 --- a/src/api/transform/__tests__/gemini-format.spec.ts +++ /dev/null @@ -1,487 +0,0 @@ -// npx vitest run src/api/transform/__tests__/gemini-format.spec.ts - -import { Anthropic } from "@anthropic-ai/sdk" - -import { convertAnthropicMessageToGemini } from "../gemini-format" - -describe("convertAnthropicMessageToGemini", () => { - it("should convert a simple text message", () => { - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: "Hello, world!", - } - - const result = convertAnthropicMessageToGemini(anthropicMessage) - - expect(result).toEqual([ - { - role: "user", - parts: [{ text: "Hello, world!" }], - }, - ]) - }) - - it("should convert assistant role to model role", () => { - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "assistant", - content: "I'm an assistant", - } - - const result = convertAnthropicMessageToGemini(anthropicMessage) - - expect(result).toEqual([ - { - role: "model", - parts: [{ text: "I'm an assistant" }], - }, - ]) - }) - - it("should convert a message with text blocks", () => { - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: [ - { type: "text", text: "First paragraph" }, - { type: "text", text: "Second paragraph" }, - ], - } - - const result = convertAnthropicMessageToGemini(anthropicMessage) - - expect(result).toEqual([ - { - role: "user", - parts: [{ text: "First paragraph" }, { text: "Second paragraph" }], - }, - ]) - }) - - it("should convert a message with an image", () => { - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: [ - { type: "text", text: "Check out this image:" }, - { - type: "image", - source: { - type: "base64", - media_type: "image/jpeg", - data: "base64encodeddata", - }, - }, - ], - } - - const result = convertAnthropicMessageToGemini(anthropicMessage) - - expect(result).toEqual([ - { - role: "user", - parts: [ - { text: "Check out this image:" }, - { - inlineData: { - data: "base64encodeddata", - mimeType: "image/jpeg", - }, - }, - ], - }, - ]) - }) - - it("should throw an error for unsupported image source type", () => { - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: [ - { - type: "image", - source: { - type: "url", // Not supported - url: "https://example.com/image.jpg", - } as any, - }, - ], - } - - expect(() => convertAnthropicMessageToGemini(anthropicMessage)).toThrow("Unsupported image source type") - }) - - it("should convert a message with tool use", () => { - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "assistant", - content: [ - { type: "text", text: "Let me calculate that for you." }, - { - type: "tool_use", - id: "calc-123", - name: "calculator", - input: { operation: "add", numbers: [2, 3] }, - }, - ], - } - - const result = convertAnthropicMessageToGemini(anthropicMessage) - - expect(result).toEqual([ - { - role: "model", - parts: [ - { text: "Let me calculate that for you." }, - { - functionCall: { - name: "calculator", - args: { operation: "add", numbers: [2, 3] }, - }, - thoughtSignature: "skip_thought_signature_validator", - }, - ], - }, - ]) - }) - - it("should only attach thoughtSignature to the first functionCall in the message", () => { - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "assistant", - content: [ - { type: "thoughtSignature", thoughtSignature: "sig-123" } as any, - { type: "tool_use", id: "call-1", name: "toolA", input: { a: 1 } }, - { type: "tool_use", id: "call-2", name: "toolB", input: { b: 2 } }, - ], - } - - const result = convertAnthropicMessageToGemini(anthropicMessage) - expect(result).toHaveLength(1) - - const parts = result[0]!.parts as any[] - const functionCallParts = parts.filter((p) => p.functionCall) - expect(functionCallParts).toHaveLength(2) - - expect(functionCallParts[0].thoughtSignature).toBe("sig-123") - expect(functionCallParts[1].thoughtSignature).toBeUndefined() - }) - - it("should convert a message with tool result as string", () => { - const toolIdToName = new Map() - toolIdToName.set("calculator-123", "calculator") - - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: [ - { type: "text", text: "Here's the result:" }, - { - type: "tool_result", - tool_use_id: "calculator-123", - content: "The result is 5", - }, - ], - } - - const result = convertAnthropicMessageToGemini(anthropicMessage, { toolIdToName }) - - expect(result).toEqual([ - { - role: "user", - parts: [ - { text: "Here's the result:" }, - { - functionResponse: { - name: "calculator", - response: { - name: "calculator", - content: "The result is 5", - }, - }, - }, - ], - }, - ]) - }) - - it("should handle empty tool result content", () => { - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: "calculator-123", - content: null as any, // Empty content - }, - ], - } - - const result = convertAnthropicMessageToGemini(anthropicMessage) - - // Should skip the empty tool result - expect(result).toEqual([]) - }) - - it("should convert a message with tool result as array with text only", () => { - const toolIdToName = new Map() - toolIdToName.set("search-123", "search") - - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: "search-123", - content: [ - { type: "text", text: "First result" }, - { type: "text", text: "Second result" }, - ], - }, - ], - } - - const result = convertAnthropicMessageToGemini(anthropicMessage, { toolIdToName }) - - expect(result).toEqual([ - { - role: "user", - parts: [ - { - functionResponse: { - name: "search", - response: { - name: "search", - content: "First result\n\nSecond result", - }, - }, - }, - ], - }, - ]) - }) - - it("should convert a message with tool result as array with text and images", () => { - const toolIdToName = new Map() - toolIdToName.set("search-123", "search") - - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: "search-123", - content: [ - { type: "text", text: "Search results:" }, - { - type: "image", - source: { - type: "base64", - media_type: "image/png", - data: "image1data", - }, - }, - { - type: "image", - source: { - type: "base64", - media_type: "image/jpeg", - data: "image2data", - }, - }, - ], - }, - ], - } - - const result = convertAnthropicMessageToGemini(anthropicMessage, { toolIdToName }) - - expect(result).toEqual([ - { - role: "user", - parts: [ - { - functionResponse: { - name: "search", - response: { - name: "search", - content: "Search results:\n\n(See next part for image)", - }, - }, - }, - { - inlineData: { - data: "image1data", - mimeType: "image/png", - }, - }, - { - inlineData: { - data: "image2data", - mimeType: "image/jpeg", - }, - }, - ], - }, - ]) - }) - - it("should convert a message with tool result containing only images", () => { - const toolIdToName = new Map() - toolIdToName.set("imagesearch-123", "imagesearch") - - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: "imagesearch-123", - content: [ - { - type: "image", - source: { - type: "base64", - media_type: "image/png", - data: "onlyimagedata", - }, - }, - ], - }, - ], - } - - const result = convertAnthropicMessageToGemini(anthropicMessage, { toolIdToName }) - - expect(result).toEqual([ - { - role: "user", - parts: [ - { - functionResponse: { - name: "imagesearch", - response: { - name: "imagesearch", - content: "\n\n(See next part for image)", - }, - }, - }, - { - inlineData: { - data: "onlyimagedata", - mimeType: "image/png", - }, - }, - ], - }, - ]) - }) - - it("should handle tool names with hyphens using toolIdToName map", () => { - const toolIdToName = new Map() - toolIdToName.set("search-files-123", "search-files") - - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: "search-files-123", - content: "found files", - }, - ], - } - - const result = convertAnthropicMessageToGemini(anthropicMessage, { toolIdToName }) - - expect(result).toEqual([ - { - role: "user", - parts: [ - { - functionResponse: { - name: "search-files", - response: { - name: "search-files", - content: "found files", - }, - }, - }, - ], - }, - ]) - }) - - it("should throw error when toolIdToName map is not provided", () => { - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: "calculator-123", - content: "result is 5", - }, - ], - } - - expect(() => convertAnthropicMessageToGemini(anthropicMessage)).toThrow( - 'Unable to find tool name for tool_use_id "calculator-123"', - ) - }) - - it("should throw error when tool_use_id is not in the map", () => { - const toolIdToName = new Map() - toolIdToName.set("other-tool-456", "other-tool") - - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: "calculator-123", - content: "result is 5", - }, - ], - } - - expect(() => convertAnthropicMessageToGemini(anthropicMessage, { toolIdToName })).toThrow( - 'Unable to find tool name for tool_use_id "calculator-123"', - ) - }) - - it("should skip unsupported content block types", () => { - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "user", - content: [ - { - type: "unknown_type", // Unsupported type - data: "some data", - } as any, - { type: "text", text: "Valid content" }, - ], - } - - const result = convertAnthropicMessageToGemini(anthropicMessage) - - expect(result).toEqual([ - { - role: "user", - parts: [{ text: "Valid content" }], - }, - ]) - }) - - it("should skip reasoning content blocks", () => { - const anthropicMessage: Anthropic.Messages.MessageParam = { - role: "assistant", - content: [ - { - type: "reasoning" as any, - text: "Let me think about this...", - }, - { type: "text", text: "Here's my answer" }, - ], - } - - const result = convertAnthropicMessageToGemini(anthropicMessage) - - expect(result).toEqual([ - { - role: "model", - parts: [{ text: "Here's my answer" }], - }, - ]) - }) -}) diff --git a/src/api/transform/gemini-format.ts b/src/api/transform/gemini-format.ts deleted file mode 100644 index 6f240362960..00000000000 --- a/src/api/transform/gemini-format.ts +++ /dev/null @@ -1,199 +0,0 @@ -import { Anthropic } from "@anthropic-ai/sdk" -import { Content, Part } from "@google/genai" - -type ThoughtSignatureContentBlock = { - type: "thoughtSignature" - thoughtSignature?: string -} - -type ReasoningContentBlock = { - type: "reasoning" - text: string -} - -type ExtendedContentBlockParam = Anthropic.ContentBlockParam | ThoughtSignatureContentBlock | ReasoningContentBlock -type ExtendedAnthropicContent = string | ExtendedContentBlockParam[] - -// Extension type to safely add thoughtSignature to Part -type PartWithThoughtSignature = Part & { - thoughtSignature?: string -} - -function isThoughtSignatureContentBlock(block: ExtendedContentBlockParam): block is ThoughtSignatureContentBlock { - return block.type === "thoughtSignature" -} - -export function convertAnthropicContentToGemini( - content: ExtendedAnthropicContent, - options?: { includeThoughtSignatures?: boolean; toolIdToName?: Map }, -): Part[] { - const includeThoughtSignatures = options?.includeThoughtSignatures ?? true - const toolIdToName = options?.toolIdToName - - // First pass: find thoughtSignature if it exists in the content blocks - let activeThoughtSignature: string | undefined - if (Array.isArray(content)) { - const sigBlock = content.find((block) => isThoughtSignatureContentBlock(block)) as ThoughtSignatureContentBlock - if (sigBlock?.thoughtSignature) { - activeThoughtSignature = sigBlock.thoughtSignature - } - } - - // Determine the signature to attach to function calls. - // If we're in a mode that expects signatures (includeThoughtSignatures is true): - // 1. Use the actual signature if we found one in the history/content. - // 2. Fallback to "skip_thought_signature_validator" if missing (e.g. cross-model history). - let functionCallSignature: string | undefined - if (includeThoughtSignatures) { - functionCallSignature = activeThoughtSignature || "skip_thought_signature_validator" - } - - if (typeof content === "string") { - return [{ text: content }] - } - - const parts = content.flatMap((block): Part | Part[] => { - // Handle thoughtSignature blocks first - if (isThoughtSignatureContentBlock(block)) { - // We process thought signatures globally and attach them to the relevant parts - // or create a placeholder part if no other content exists. - return [] - } - - switch (block.type) { - case "text": - return { text: block.text } - case "image": - if (block.source.type !== "base64") { - throw new Error("Unsupported image source type") - } - - return { inlineData: { data: block.source.data, mimeType: block.source.media_type } } - case "tool_use": - // Gemini 3 validation rules: - // - In a parallel function calling response, only the FIRST functionCall part has a signature. - // - In sequential steps, each step's first functionCall must include its signature. - // When converting from our history, we don't always have enough information to perfectly - // recreate the original per-part distribution, but we can and should avoid attaching the - // signature to every parallel call in a single assistant message. - return { - functionCall: { - name: block.name, - args: block.input as Record, - }, - // Inject the thoughtSignature into the functionCall part if required. - // This is necessary for Gemini 3+ thinking models to validate the tool call. - ...(functionCallSignature ? { thoughtSignature: functionCallSignature } : {}), - } as Part - case "tool_result": { - if (!block.content) { - return [] - } - - // Get tool name from the map (built from tool_use blocks in message history). - // The map must contain the tool name - if it doesn't, this indicates a bug - // where the conversation history is incomplete or tool_use blocks are missing. - const toolName = toolIdToName?.get(block.tool_use_id) - if (!toolName) { - throw new Error( - `Unable to find tool name for tool_use_id "${block.tool_use_id}". ` + - `This indicates the conversation history is missing the corresponding tool_use block. ` + - `Available tool IDs: ${Array.from(toolIdToName?.keys() ?? []).join(", ") || "none"}`, - ) - } - - if (typeof block.content === "string") { - return { - functionResponse: { name: toolName, response: { name: toolName, content: block.content } }, - } - } - - if (!Array.isArray(block.content)) { - return [] - } - - const textParts: string[] = [] - const imageParts: Part[] = [] - - for (const item of block.content) { - if (item.type === "text") { - textParts.push(item.text) - } else if (item.type === "image" && item.source.type === "base64") { - const { data, media_type } = item.source - imageParts.push({ inlineData: { data, mimeType: media_type } }) - } - } - - // Create content text with a note about images if present - const contentText = - textParts.join("\n\n") + (imageParts.length > 0 ? "\n\n(See next part for image)" : "") - - // Return function response followed by any images - return [ - { functionResponse: { name: toolName, response: { name: toolName, content: contentText } } }, - ...imageParts, - ] - } - default: - // Skip unsupported content block types (e.g., "reasoning", "thinking", "redacted_thinking", "document") - // These are typically metadata from other providers that don't need to be sent to Gemini - console.warn(`Skipping unsupported content block type: ${block.type}`) - return [] - } - }) - - // Post-processing: - // 1) Ensure thought signature is attached if required - // 2) For multiple function calls in a single message, keep the signature only on the first - // functionCall part to match Gemini 3 parallel-calling behavior. - if (includeThoughtSignatures && activeThoughtSignature) { - const hasSignature = parts.some((p) => "thoughtSignature" in p) - - if (!hasSignature) { - if (parts.length > 0) { - // Attach to the first part (usually text) - // We use the intersection type to allow adding the property safely - ;(parts[0] as PartWithThoughtSignature).thoughtSignature = activeThoughtSignature - } else { - // Create a placeholder part if no other content exists - const placeholder: PartWithThoughtSignature = { text: "", thoughtSignature: activeThoughtSignature } - parts.push(placeholder) - } - } - } - - if (includeThoughtSignatures) { - let seenFirstFunctionCall = false - for (const part of parts) { - if (part && typeof part === "object" && "functionCall" in part && (part as any).functionCall) { - const partWithSig = part as PartWithThoughtSignature - if (!seenFirstFunctionCall) { - seenFirstFunctionCall = true - } else { - // Remove signature from subsequent function calls in this message. - delete partWithSig.thoughtSignature - } - } - } - } - - return parts -} - -export function convertAnthropicMessageToGemini( - message: Anthropic.Messages.MessageParam, - options?: { includeThoughtSignatures?: boolean; toolIdToName?: Map }, -): Content[] { - const parts = convertAnthropicContentToGemini(message.content, options) - - if (parts.length === 0) { - return [] - } - - return [ - { - role: message.role === "assistant" ? "model" : "user", - parts, - }, - ] -} diff --git a/src/package.json b/src/package.json index 423463bf0ba..042119134ba 100644 --- a/src/package.json +++ b/src/package.json @@ -453,6 +453,8 @@ "@ai-sdk/cerebras": "^1.0.0", "@ai-sdk/deepseek": "^2.0.14", "@ai-sdk/fireworks": "^2.0.26", + "@ai-sdk/google": "^3.0.20", + "@ai-sdk/google-vertex": "^3.0.20", "@ai-sdk/groq": "^3.0.19", "@ai-sdk/mistral": "^3.0.0", "@ai-sdk/xai": "^3.0.46",