diff --git a/packages/types/src/image-generation.ts b/packages/types/src/image-generation.ts index 2317a11b802..bc7fa32794b 100644 --- a/packages/types/src/image-generation.ts +++ b/packages/types/src/image-generation.ts @@ -2,10 +2,16 @@ * Image generation model constants */ +/** + * API method used for image generation + */ +export type ImageGenerationApiMethod = "chat_completions" | "images_api" + export interface ImageGenerationModel { value: string label: string provider: ImageGenerationProvider + apiMethod?: ImageGenerationApiMethod } export const IMAGE_GENERATION_MODELS: ImageGenerationModel[] = [ @@ -17,6 +23,7 @@ export const IMAGE_GENERATION_MODELS: ImageGenerationModel[] = [ // Roo Code Cloud models { value: "google/gemini-2.5-flash-image", label: "Gemini 2.5 Flash Image", provider: "roo" }, { value: "google/gemini-3-pro-image", label: "Gemini 3 Pro Image", provider: "roo" }, + { value: "bfl/flux-2-pro", label: "BFL Flux 2 Pro", provider: "roo", apiMethod: "images_api" }, ] /** diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index 0641b5a45d2..665fa9f9e8a 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -435,7 +435,8 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } /** - * Generate an image using OpenRouter's image generation API + * Generate an image using OpenRouter's image generation API (chat completions with modalities) + * Note: OpenRouter only supports the chat completions approach, not the /images/generations endpoint * @param prompt The text prompt for image generation * @param model The model to use for generation * @param apiKey The OpenRouter API key (must be explicitly provided) @@ -456,6 +457,8 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH } const baseURL = this.options.openRouterBaseUrl || "https://openrouter.ai/api/v1" + + // OpenRouter only supports chat completions approach for image generation return generateImageWithProvider({ baseURL, authToken: apiKey, diff --git a/src/api/providers/roo.ts b/src/api/providers/roo.ts index 0d5590e74c7..3546d8cc760 100644 --- a/src/api/providers/roo.ts +++ b/src/api/providers/roo.ts @@ -1,7 +1,7 @@ import { Anthropic } from "@anthropic-ai/sdk" import OpenAI from "openai" -import { rooDefaultModelId, getApiProtocol } from "@roo-code/types" +import { rooDefaultModelId, getApiProtocol, type ImageGenerationApiMethod } from "@roo-code/types" import { CloudService } from "@roo-code/cloud" import type { ApiHandlerOptions, ModelRecord } from "../../shared/api" @@ -15,7 +15,7 @@ import type { ApiHandlerCreateMessageMetadata } from "../index" import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider" import { getModels, getModelsFromCache } from "../providers/fetchers/modelCache" import { handleOpenAIError } from "./utils/openai-error-handler" -import { generateImageWithProvider, ImageGenerationResult } from "./utils/image-generation" +import { generateImageWithProvider, generateImageWithImagesApi, ImageGenerationResult } from "./utils/image-generation" import { t } from "../../i18n" // Extend OpenAI's CompletionUsage to include Roo specific fields @@ -273,9 +273,15 @@ export class RooHandler extends BaseOpenAiCompatibleProvider { * @param prompt The text prompt for image generation * @param model The model to use for generation * @param inputImage Optional base64 encoded input image data URL + * @param apiMethod The API method to use (chat_completions or images_api) * @returns The generated image data and format, or an error */ - async generateImage(prompt: string, model: string, inputImage?: string): Promise { + async generateImage( + prompt: string, + model: string, + inputImage?: string, + apiMethod?: ImageGenerationApiMethod, + ): Promise { const sessionToken = getSessionToken() if (!sessionToken || sessionToken === "unauthenticated") { @@ -285,8 +291,23 @@ export class RooHandler extends BaseOpenAiCompatibleProvider { } } + const baseURL = `${this.fetcherBaseURL}/v1` + + // Use the specified API method, defaulting to chat_completions for backward compatibility + if (apiMethod === "images_api") { + return generateImageWithImagesApi({ + baseURL, + authToken: sessionToken, + model, + prompt, + inputImage, + outputFormat: "png", + }) + } + + // Default to chat completions approach return generateImageWithProvider({ - baseURL: `${this.fetcherBaseURL}/v1`, + baseURL, authToken: sessionToken, model, prompt, diff --git a/src/api/providers/utils/__tests__/image-generation.spec.ts b/src/api/providers/utils/__tests__/image-generation.spec.ts new file mode 100644 index 00000000000..413402aed3d --- /dev/null +++ b/src/api/providers/utils/__tests__/image-generation.spec.ts @@ -0,0 +1,417 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest" +import { generateImageWithImagesApi, generateImageWithProvider } from "../image-generation" + +// Mock the i18n module +vi.mock("../../../i18n", () => ({ + t: (key: string, options?: any) => { + // Return a sensible mock for i18n + if (key === "tools:generateImage.failedWithMessage" && options?.message) { + return options.message + } + return key + }, +})) + +// Mock fetch globally +global.fetch = vi.fn() +global.FormData = vi.fn(() => ({ + append: vi.fn(), +})) as any +global.Blob = vi.fn() as any +global.atob = vi.fn((str: string) => { + return Buffer.from(str, "base64").toString("binary") +}) + +describe("generateImageWithImagesApi", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + describe("image generation (text-to-image)", () => { + it("should successfully generate an image", async () => { + const mockBase64 = Buffer.from("fake image data").toString("base64") + const mockResponse = { + ok: true, + json: vi.fn().mockResolvedValue({ + data: [{ b64_json: mockBase64 }], + }), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "A cute cat", + outputFormat: "png", + }) + + expect(result.success).toBe(true) + expect(result.imageData).toContain("data:image/png;base64,") + expect(result.imageFormat).toBe("png") + + // Verify fetch was called with correct parameters + expect(global.fetch).toHaveBeenCalledWith( + "https://api.example.com/v1/images/generations", + expect.objectContaining({ + method: "POST", + headers: expect.objectContaining({ + Authorization: "Bearer test-token", + "Content-Type": "application/json", + }), + }), + ) + }) + + it("should handle API errors gracefully", async () => { + const mockResponse = { + ok: false, + status: 400, + statusText: "Bad Request", + text: vi.fn().mockResolvedValue("{}"), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "A cute cat", + }) + + expect(result.success).toBe(false) + expect(result.error).toBeDefined() + }) + + it("should handle missing image data in response", async () => { + const mockResponse = { + ok: true, + json: vi.fn().mockResolvedValue({ + data: [{}], // Missing b64_json and url + }), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "A cute cat", + }) + + expect(result.success).toBe(false) + expect(result.error).toBeDefined() + }) + + it("should handle URL response instead of b64_json", async () => { + const mockResponse = { + ok: true, + json: vi.fn().mockResolvedValue({ + data: [{ url: "" }], + }), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "A cute cat", + }) + + expect(result.success).toBe(true) + expect(result.imageData).toBe("") + expect(result.imageFormat).toBe("png") + }) + + it("should handle external URL response", async () => { + const mockResponse = { + ok: true, + json: vi.fn().mockResolvedValue({ + data: [{ url: "https://example.com/generated-image.png" }], + }), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "A cute cat", + outputFormat: "png", + }) + + expect(result.success).toBe(true) + expect(result.imageData).toBe("https://example.com/generated-image.png") + expect(result.imageFormat).toBe("png") + }) + + it("should handle empty data array in response", async () => { + const mockResponse = { + ok: true, + json: vi.fn().mockResolvedValue({ + data: [], + }), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "A cute cat", + }) + + expect(result.success).toBe(false) + expect(result.error).toBeDefined() + }) + + it("should handle API error response", async () => { + const mockResponse = { + ok: true, + json: vi.fn().mockResolvedValue({ + error: { + message: "Rate limit exceeded", + type: "rate_limit_error", + }, + }), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "A cute cat", + }) + + expect(result.success).toBe(false) + expect(result.error).toBeDefined() + }) + + it("should include optional parameters when provided", async () => { + const mockBase64 = Buffer.from("fake image data").toString("base64") + const mockResponse = { + ok: true, + json: vi.fn().mockResolvedValue({ + data: [{ b64_json: mockBase64 }], + }), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "A cute cat", + size: "1024x1024", + quality: "hd", + outputFormat: "png", + }) + + expect(result.success).toBe(true) + + // Verify fetch was called with optional parameters + const callArgs = vi.mocked(global.fetch).mock.calls[0] + const body = JSON.parse(callArgs[1]?.body as string) + expect(body.size).toBe("1024x1024") + expect(body.quality).toBe("hd") + }) + + it("should handle network errors", async () => { + vi.mocked(global.fetch).mockRejectedValue(new Error("Network error")) + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "A cute cat", + }) + + expect(result.success).toBe(false) + expect(result.error).toContain("Network error") + }) + }) + + describe("image editing", () => { + it("should use /images/generations endpoint with inputImage in request body", async () => { + const mockBase64 = Buffer.from("fake image data").toString("base64") + const mockResponse = { + ok: true, + json: vi.fn().mockResolvedValue({ + data: [{ b64_json: mockBase64 }], + }), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const inputImageDataUrl = `data:image/png;base64,${mockBase64}` + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "Make it blue", + inputImage: inputImageDataUrl, + outputFormat: "png", + }) + + expect(result.success).toBe(true) + + // Verify /images/generations endpoint was used (not /images/edits) + const callUrl = vi.mocked(global.fetch).mock.calls[0][0] + expect(callUrl).toContain("/images/generations") + }) + + it("should handle edit operation errors", async () => { + const mockResponse = { + ok: false, + status: 400, + statusText: "Bad Request", + text: vi.fn().mockResolvedValue("{}"), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const inputImageDataUrl = + "" + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "Make it blue", + inputImage: inputImageDataUrl, + }) + + expect(result.success).toBe(false) + expect(result.error).toBeDefined() + }) + }) + + describe("output format handling", () => { + it("should use png format by default", async () => { + const mockBase64 = Buffer.from("fake image data").toString("base64") + const mockResponse = { + ok: true, + json: vi.fn().mockResolvedValue({ + data: [{ b64_json: mockBase64 }], + }), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "A cute cat", + }) + + expect(result.imageFormat).toBe("png") + expect(result.imageData).toContain("data:image/png;base64,") + }) + + it("should use specified output format", async () => { + const mockBase64 = Buffer.from("fake image data").toString("base64") + const mockResponse = { + ok: true, + json: vi.fn().mockResolvedValue({ + data: [{ b64_json: mockBase64 }], + }), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const result = await generateImageWithImagesApi({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-image-1", + prompt: "A cute cat", + outputFormat: "jpeg", + }) + + expect(result.imageFormat).toBe("jpeg") + expect(result.imageData).toContain("data:image/jpeg;base64,") + }) + }) +}) + +describe("generateImageWithProvider (chat completions)", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + afterEach(() => { + vi.clearAllMocks() + }) + + it("should use /chat/completions endpoint", async () => { + const mockResponse = { + ok: true, + json: vi.fn().mockResolvedValue({ + choices: [ + { + message: { + images: [ + { + image_url: { + url: "", + }, + }, + ], + }, + }, + ], + }), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const result = await generateImageWithProvider({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-4-vision", + prompt: "A cute cat", + }) + + expect(result.success).toBe(true) + + // Verify /chat/completions endpoint was used + const callUrl = vi.mocked(global.fetch).mock.calls[0][0] + expect(callUrl).toContain("/chat/completions") + }) + + it("should handle missing images in response", async () => { + const mockResponse = { + ok: true, + json: vi.fn().mockResolvedValue({ + choices: [{ message: { content: "No images" } }], + }), + } + + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const result = await generateImageWithProvider({ + baseURL: "https://api.example.com/v1", + authToken: "test-token", + model: "gpt-4-vision", + prompt: "A cute cat", + }) + + expect(result.success).toBe(false) + expect(result.error).toBeDefined() + }) +}) diff --git a/src/api/providers/utils/image-generation.ts b/src/api/providers/utils/image-generation.ts index 113dc3383e0..16ddb9c815b 100644 --- a/src/api/providers/utils/image-generation.ts +++ b/src/api/providers/utils/image-generation.ts @@ -20,6 +20,18 @@ interface ImageGenerationResponse { } } +interface ImagesApiResponse { + data?: Array<{ + b64_json?: string + url?: string + }> + error?: { + message?: string + type?: string + code?: string + } +} + export interface ImageGenerationResult { success: boolean imageData?: string @@ -35,6 +47,17 @@ interface ImageGenerationOptions { inputImage?: string } +interface ImagesApiOptions { + baseURL: string + authToken: string + model: string + prompt: string + inputImage?: string + size?: string + quality?: string + outputFormat?: string +} + /** * Shared image generation implementation for OpenRouter and Roo Code Cloud providers */ @@ -147,3 +170,144 @@ export async function generateImageWithProvider(options: ImageGenerationOptions) } } } + +/** + * Generate an image using OpenAI's Images API (/v1/images/generations) + * Supports BFL models (Flux) with provider-specific options for image editing + */ +export async function generateImageWithImagesApi(options: ImagesApiOptions): Promise { + const { baseURL, authToken, model, prompt, inputImage, outputFormat = "png" } = options + + try { + const url = `${baseURL}/images/generations` + + // Build the request body + // For BFL models, inputImage is passed via providerOptions.blackForestLabs.inputImage + const requestBody: Record = { + model, + prompt, + n: 1, + } + + // Add optional parameters + if (options.size) { + requestBody.size = options.size + } + if (options.quality) { + requestBody.quality = options.quality + } + + // For BFL (Black Forest Labs) models like flux-pro-1.1, use providerOptions + if (model.startsWith("bfl/")) { + requestBody.providerOptions = { + blackForestLabs: { + outputFormat: outputFormat, + // inputImage: Base64 encoded image or URL of image to use as reference + ...(inputImage && { inputImage }), + }, + } + } else { + // For other models, use standard output_format parameter + requestBody.output_format = outputFormat + } + + const fetchOptions: RequestInit = { + method: "POST", + headers: { + Authorization: `Bearer ${authToken}`, + "Content-Type": "application/json", + "HTTP-Referer": "https://github.com/RooVetGit/Roo-Code", + "X-Title": "Roo Code", + }, + body: JSON.stringify(requestBody), + } + + const response = await fetch(url, fetchOptions) + + if (!response.ok) { + const errorText = await response.text() + let errorMessage = t("tools:generateImage.failedWithStatus", { + status: response.status, + statusText: response.statusText, + }) + + try { + const errorJson = JSON.parse(errorText) + if (errorJson.error?.message) { + errorMessage = t("tools:generateImage.failedWithMessage", { + message: errorJson.error.message, + }) + } + } catch { + // Use default error message + } + return { + success: false, + error: errorMessage, + } + } + + const result: ImagesApiResponse = await response.json() + + if (result.error) { + return { + success: false, + error: t("tools:generateImage.failedWithMessage", { + message: result.error.message, + }), + } + } + + // Extract the generated image from the response + const images = result.data + if (!images || images.length === 0) { + return { + success: false, + error: t("tools:generateImage.noImageGenerated"), + } + } + + const imageItem = images[0] + + // Handle b64_json response (most common) + if (imageItem?.b64_json) { + // Convert base64 to data URL + const dataUrl = `data:image/${outputFormat};base64,${imageItem.b64_json}` + return { + success: true, + imageData: dataUrl, + imageFormat: outputFormat, + } + } + + // Handle URL response (fallback) + if (imageItem?.url) { + // If it's already a data URL, use it directly + if (imageItem.url.startsWith("data:image/")) { + const formatMatch = imageItem.url.match(/^data:image\/(\w+);/) + const format = formatMatch?.[1] || outputFormat + return { + success: true, + imageData: imageItem.url, + imageFormat: format, + } + } + // For external URLs, return as-is (the caller will need to handle fetching) + return { + success: true, + imageData: imageItem.url, + imageFormat: outputFormat, + } + } + + return { + success: false, + error: t("tools:generateImage.invalidImageData"), + } + } catch (error) { + return { + success: false, + error: error instanceof Error ? error.message : t("tools:generateImage.unknownError"), + } + } +} diff --git a/src/core/tools/GenerateImageTool.ts b/src/core/tools/GenerateImageTool.ts index 4c4f6819155..5f82ac44551 100644 --- a/src/core/tools/GenerateImageTool.ts +++ b/src/core/tools/GenerateImageTool.ts @@ -135,24 +135,28 @@ export class GenerateImageTool extends BaseTool<"generate_image"> { // Get the selected model let selectedModel = state?.openRouterImageGenerationSelectedModel + let modelInfo = undefined - // Verify the selected model matches the selected provider - // If not, default to first model of the selected provider + // Find the model info matching both value AND provider + // (since the same model value can exist for multiple providers) if (selectedModel) { - const modelInfo = IMAGE_GENERATION_MODELS.find((m) => m.value === selectedModel) - if (!modelInfo || modelInfo.provider !== imageProvider) { - // Model doesn't match provider, use first model for selected provider + modelInfo = IMAGE_GENERATION_MODELS.find((m) => m.value === selectedModel && m.provider === imageProvider) + if (!modelInfo) { + // Model doesn't exist for this provider, use first model for selected provider const providerModels = IMAGE_GENERATION_MODELS.filter((m) => m.provider === imageProvider) - selectedModel = providerModels[0]?.value || IMAGE_GENERATION_MODEL_IDS[0] + modelInfo = providerModels[0] + selectedModel = modelInfo?.value || IMAGE_GENERATION_MODEL_IDS[0] } } else { // No model selected, use first model for selected provider const providerModels = IMAGE_GENERATION_MODELS.filter((m) => m.provider === imageProvider) - selectedModel = providerModels[0]?.value || IMAGE_GENERATION_MODEL_IDS[0] + modelInfo = providerModels[0] + selectedModel = modelInfo?.value || IMAGE_GENERATION_MODEL_IDS[0] } // Use the provider selection const modelProvider = imageProvider + const apiMethod = modelInfo?.apiMethod // Validate API key for OpenRouter const openRouterApiKey = state?.openRouterImageApiKey @@ -192,11 +196,11 @@ export class GenerateImageTool extends BaseTool<"generate_image"> { let result if (modelProvider === "roo") { - // Use Roo Code Cloud provider + // Use Roo Code Cloud provider (supports both chat completions and images API) const rooHandler = new RooHandler({} as any) - result = await rooHandler.generateImage(prompt, selectedModel, inputImageData) + result = await rooHandler.generateImage(prompt, selectedModel, inputImageData, apiMethod) } else { - // Use OpenRouter provider + // Use OpenRouter provider (only supports chat completions API) const openRouterHandler = new OpenRouterHandler({} as any) result = await openRouterHandler.generateImage(prompt, selectedModel, openRouterApiKey!, inputImageData) }