Skip to content

Commit

Permalink
openai[minor]: Allow for multiple images to be returned with dalle (#…
Browse files Browse the repository at this point in the history
…6154)

* openai[minor]: Allow for multiple images to be returned with dalle

* skip dalle tests
  • Loading branch information
bracesproul authored Jul 19, 2024
1 parent 9ca0a3a commit 0ff34f9
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 5 deletions.
83 changes: 78 additions & 5 deletions libs/langchain-openai/src/tools/dalle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { OpenAI as OpenAIClient } from "openai";
import { Tool, ToolParams } from "@langchain/core/tools";
import {
MessageContentComplex,
MessageContentImageUrl,
} from "@langchain/core/messages";

/**
* An interface for the Dall-E API Wrapper.
Expand Down Expand Up @@ -149,18 +153,88 @@ export class DallEAPIWrapper extends Tool {
this.user = fields?.user;
}

/**
* Processes the API response if multiple images are generated.
* Returns a list of MessageContentImageUrl objects. If the response
* format is `url`, then the `image_url` field will contain the URL.
* If it is `b64_json`, then the `image_url` field will contain an object
* with a `url` field with the base64 encoded image.
*
* @param {OpenAIClient.Images.ImagesResponse[]} response The API response
* @returns {MessageContentImageUrl[]}
*/
private processMultipleGeneratedUrls(
response: OpenAIClient.Images.ImagesResponse[]
): MessageContentImageUrl[] {
if (this.dallEResponseFormat === "url") {
return response.flatMap((res) => {
const imageUrlContent = res.data
.flatMap((item) => {
if (!item.url) return [];
return {
type: "image_url" as const,
image_url: item.url,
};
})
.filter(
(item) =>
item !== undefined &&
item.type === "image_url" &&
typeof item.image_url === "string" &&
item.image_url !== undefined
);
return imageUrlContent;
});
} else {
return response.flatMap((res) => {
const b64Content = res.data
.flatMap((item) => {
if (!item.b64_json) return [];
return {
type: "image_url" as const,
image_url: {
url: item.b64_json,
},
};
})
.filter(
(item) =>
item !== undefined &&
item.type === "image_url" &&
typeof item.image_url === "object" &&
"url" in item.image_url &&
typeof item.image_url.url === "string" &&
item.image_url.url !== undefined
);
return b64Content;
});
}
}

/** @ignore */
async _call(input: string): Promise<string> {
const response = await this.client.images.generate({
async _call(input: string): Promise<string | MessageContentComplex[]> {
const generateImageFields = {
model: this.model,
prompt: input,
n: this.n,
n: 1,
size: this.size,
response_format: this.dallEResponseFormat,
style: this.style,
quality: this.quality,
user: this.user,
});
};

if (this.n > 1) {
const results = await Promise.all(
Array.from({ length: this.n }).map(() =>
this.client.images.generate(generateImageFields)
)
);

return this.processMultipleGeneratedUrls(results);
}

const response = await this.client.images.generate(generateImageFields);

let data = "";
if (this.dallEResponseFormat === "url") {
Expand All @@ -172,7 +246,6 @@ export class DallEAPIWrapper extends Tool {
.map((item) => item.b64_json)
.filter((b64_json): b64_json is string => b64_json !== "undefined");
}

return data;
}
}
49 changes: 49 additions & 0 deletions libs/langchain-openai/src/tools/tests/dalle.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,52 @@ test.skip("Dalle can generate images with base 64 response format", async () =>
expect(res).toBeDefined();
expect(res).not.toContain("https://");
});

test.skip("Dalle returns multiple image URLs if n > 1", async () => {
const dalle = new DallEAPIWrapper({
n: 2,
});
const res = await dalle.invoke("A painting of a cat");
expect(res).toBeDefined();
expect(res).toBeInstanceOf(Array);
if (!Array.isArray(res)) return;
expect(res).toHaveLength(2);

// The types for each should be `image_url` with an `image_url` field containing the URL
expect(res[0].type).toBe("image_url");
expect(res[1].type).toBe("image_url");

expect(res[0]).toHaveProperty("image_url");
expect(res[1]).toHaveProperty("image_url");

expect(res[0].image_url.startsWith("https://")).toBe(true);
expect(res[1].image_url.startsWith("https://")).toBe(true);
});

test.skip("Dalle returns multiple base64 image strings if n > 1", async () => {
const dalle = new DallEAPIWrapper({
n: 2,
dallEResponseFormat: "b64_json",
});
const res = await dalle.invoke("A painting of a cat");
expect(res).toBeDefined();
expect(res).toBeInstanceOf(Array);
if (!Array.isArray(res)) return;
expect(res).toHaveLength(2);

// The types for each should be `b64_json` with an `b64_json` field containing the URL
expect(res[0].type).toBe("image_url");
expect(res[1].type).toBe("image_url");

expect(res[0]).toHaveProperty("image_url");
expect(res[1]).toHaveProperty("image_url");

expect(res[0].image_url).toHaveProperty("url");
expect(res[1].image_url).toHaveProperty("url");

expect(res[0].image_url.url).toBeDefined();
expect(res[1].image_url.url).toBeDefined();

expect(res[0].image_url.url).not.toBe("");
expect(res[1].image_url.url).not.toBe("");
});

0 comments on commit 0ff34f9

Please sign in to comment.