Skip to content

Commit

Permalink
feat: use images in context chat engine (run-llama#886)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusschiesser authored Jun 3, 2024
1 parent 0b51995 commit 6e156ed
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 47 deletions.
5 changes: 5 additions & 0 deletions .changeset/afraid-seahorses-learn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"llamaindex": patch
---

Use images in context chat engine
54 changes: 54 additions & 0 deletions examples/multimodal/context.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// call pnpm tsx multimodal/load.ts first to init the storage
import {
ContextChatEngine,
NodeWithScore,
ObjectType,
OpenAI,
RetrievalEndEvent,
Settings,
VectorStoreIndex,
} from "llamaindex";
import { getStorageContext } from "./storage";

// Update chunk size and overlap
Settings.chunkSize = 512;
Settings.chunkOverlap = 20;

// Update llm
Settings.llm = new OpenAI({ model: "gpt-4-turbo", maxTokens: 512 });

// Update callbackManager
Settings.callbackManager.on("retrieve-end", (event: RetrievalEndEvent) => {
const { nodes, query } = event.detail.payload;
const imageNodes = nodes.filter(
(node: NodeWithScore) => node.node.type === ObjectType.IMAGE_DOCUMENT,
);
const textNodes = nodes.filter(
(node: NodeWithScore) => node.node.type === ObjectType.TEXT,
);
console.log(
`Retrieved ${textNodes.length} text nodes and ${imageNodes.length} image nodes for query: ${query}`,
);
});

async function main() {
const storageContext = await getStorageContext();
const index = await VectorStoreIndex.init({
storageContext,
});
// topK for text is 0 and for image 1 => we only retrieve one image and no text based on the query
const retriever = index.asRetriever({ topK: { TEXT: 0, IMAGE: 1 } });
// NOTE: we set the contextRole to "user" (default is "system"). The reason is that GPT-4 does not support
// images in a system message
const chatEngine = new ContextChatEngine({ retriever, contextRole: "user" });

// the ContextChatEngine will use the Clip embedding to retrieve the closest image
// (the lady in the chair) and use it in the context for the query
const response = await chatEngine.chat({
message: "What is the name of the painting with the lady in the chair?",
});

console.log(response.response, "\n");
}

main().catch(console.error);
14 changes: 6 additions & 8 deletions examples/multimodal/rag.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import {
ImageType,
MultiModalResponseSynthesizer,
OpenAI,
RetrievalEndEvent,
Expand All @@ -22,8 +21,6 @@ Settings.callbackManager.on("retrieve-end", (event: RetrievalEndEvent) => {
});

async function main() {
const images: ImageType[] = [];

const storageContext = await getStorageContext();
const index = await VectorStoreIndex.init({
nodes: [],
Expand All @@ -34,13 +31,14 @@ async function main() {
responseSynthesizer: new MultiModalResponseSynthesizer(),
retriever: index.asRetriever({ topK: { TEXT: 3, IMAGE: 1 } }),
});
const result = await queryEngine.query({
const stream = await queryEngine.query({
query: "Tell me more about Vincent van Gogh's famous paintings",
stream: true,
});
console.log(result.response, "\n");
images.forEach((image) =>
console.log(`Image retrieved and used in inference: ${image.toString()}`),
);
for await (const chunk of stream) {
process.stdout.write(chunk.response);
}
process.stdout.write("\n");
}

main().catch(console.error);
4 changes: 3 additions & 1 deletion packages/core/src/engines/chat/ContextChatEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import type { BaseRetriever } from "../../Retriever.js";
import { Settings } from "../../Settings.js";
import { wrapEventCaller } from "../../internal/context/EventCaller.js";
import type { ChatMessage, ChatResponseChunk, LLM } from "../../llm/index.js";
import type { MessageContent } from "../../llm/types.js";
import type { MessageContent, MessageType } from "../../llm/types.js";
import {
extractText,
streamConverter,
Expand Down Expand Up @@ -40,6 +40,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine {
contextSystemPrompt?: ContextSystemPrompt;
nodePostprocessors?: BaseNodePostprocessor[];
systemPrompt?: string;
contextRole?: MessageType;
}) {
super();
this.chatModel = init.chatModel ?? Settings.llm;
Expand All @@ -48,6 +49,7 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine {
retriever: init.retriever,
contextSystemPrompt: init?.contextSystemPrompt,
nodePostprocessors: init?.nodePostprocessors,
contextRole: init?.contextRole,
});
this.systemPrompt = init.systemPrompt;
}
Expand Down
20 changes: 13 additions & 7 deletions packages/core/src/engines/chat/DefaultContextGenerator.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import type { NodeWithScore, TextNode } from "../../Node.js";
import { type NodeWithScore } from "../../Node.js";
import type { ContextSystemPrompt } from "../../Prompt.js";
import { defaultContextSystemPrompt } from "../../Prompt.js";
import type { BaseRetriever } from "../../Retriever.js";
import type { MessageContent } from "../../llm/types.js";
import type { MessageContent, MessageType } from "../../llm/types.js";
import type { BaseNodePostprocessor } from "../../postprocessors/index.js";
import { PromptMixin } from "../../prompts/index.js";
import { createMessageContent } from "../../synthesizers/utils.js";
import type { Context, ContextGenerator } from "./types.js";

export class DefaultContextGenerator
Expand All @@ -14,18 +15,21 @@ export class DefaultContextGenerator
retriever: BaseRetriever;
contextSystemPrompt: ContextSystemPrompt;
nodePostprocessors: BaseNodePostprocessor[];
contextRole: MessageType;

constructor(init: {
retriever: BaseRetriever;
contextSystemPrompt?: ContextSystemPrompt;
nodePostprocessors?: BaseNodePostprocessor[];
contextRole?: MessageType;
}) {
super();

this.retriever = init.retriever;
this.contextSystemPrompt =
init?.contextSystemPrompt ?? defaultContextSystemPrompt;
this.nodePostprocessors = init.nodePostprocessors || [];
this.contextRole = init.contextRole ?? "system";
}

protected _getPrompts(): { contextSystemPrompt: ContextSystemPrompt } {
Expand Down Expand Up @@ -68,13 +72,15 @@ export class DefaultContextGenerator
message,
);

// TODO: also use retrieved image nodes in context
const content = await createMessageContent(
this.contextSystemPrompt,
nodes.map((r) => r.node),
);

return {
message: {
content: this.contextSystemPrompt({
context: nodes.map((r) => (r.node as TextNode).text).join("\n\n"),
}),
role: "system",
content,
role: this.contextRole,
},
nodes,
};
Expand Down
49 changes: 18 additions & 31 deletions packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import type { ImageNode } from "../Node.js";
import { MetadataMode, ModalityType, splitNodesByType } from "../Node.js";
import { MetadataMode } from "../Node.js";
import { Response } from "../Response.js";
import type { ServiceContext } from "../ServiceContext.js";
import { llmFromSettingsOrContext } from "../Settings.js";
import { imageToDataUrl } from "../embeddings/index.js";
import type { MessageContentDetail } from "../llm/types.js";
import { streamConverter } from "../llm/utils.js";
import { PromptMixin } from "../prompts/Mixin.js";
import type { TextQaPrompt } from "./../Prompt.js";
import { defaultTextQaPrompt } from "./../Prompt.js";
Expand All @@ -13,6 +11,7 @@ import type {
SynthesizeParamsNonStreaming,
SynthesizeParamsStreaming,
} from "./types.js";
import { createMessageContent } from "./utils.js";

export class MultiModalResponseSynthesizer
extends PromptMixin
Expand Down Expand Up @@ -59,41 +58,29 @@ export class MultiModalResponseSynthesizer
}: SynthesizeParamsStreaming | SynthesizeParamsNonStreaming): Promise<
AsyncIterable<Response> | Response
> {
if (stream) {
throw new Error("streaming not implemented");
}
const nodes = nodesWithScore.map(({ node }) => node);
const nodeMap = splitNodesByType(nodes);
const imageNodes: ImageNode[] =
(nodeMap[ModalityType.IMAGE] as ImageNode[]) ?? [];
const textNodes = nodeMap[ModalityType.TEXT] ?? [];
const textChunks = textNodes.map((node) =>
node.getContent(this.metadataMode),
);
// TODO: use builders to generate context
const context = textChunks.join("\n\n");
const textPrompt = this.textQATemplate({ context, query });
const images = await Promise.all(
imageNodes.map(async (node: ImageNode) => {
return {
type: "image_url",
image_url: {
url: await imageToDataUrl(node.image),
},
} as MessageContentDetail;
}),
const prompt = await createMessageContent(
this.textQATemplate,
nodes,
{ query },
this.metadataMode,
);
const prompt: MessageContentDetail[] = [
{ type: "text", text: textPrompt },
...images,
];

const llm = llmFromSettingsOrContext(this.serviceContext);

if (stream) {
const response = await llm.complete({
prompt,
stream,
});
return streamConverter(
response,
({ text }) => new Response(text, nodesWithScore),
);
}
const response = await llm.complete({
prompt,
});

return new Response(response.text, nodesWithScore);
}
}
72 changes: 72 additions & 0 deletions packages/core/src/synthesizers/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import {
ImageNode,
MetadataMode,
ModalityType,
splitNodesByType,
type BaseNode,
} from "../Node.js";
import type { SimplePrompt } from "../Prompt.js";
import { imageToDataUrl } from "../embeddings/utils.js";
import type { MessageContentDetail } from "../llm/types.js";

export async function createMessageContent(
prompt: SimplePrompt,
nodes: BaseNode[],
extraParams: Record<string, string | undefined> = {},
metadataMode: MetadataMode = MetadataMode.NONE,
): Promise<MessageContentDetail[]> {
const content: MessageContentDetail[] = [];
const nodeMap = splitNodesByType(nodes);
for (const type in nodeMap) {
// for each retrieved modality type, create message content
const nodes = nodeMap[type as ModalityType];
if (nodes) {
content.push(
...(await createContentPerModality(
prompt,
type as ModalityType,
nodes,
extraParams,
metadataMode,
)),
);
}
}
return content;
}

// eslint-disable-next-line max-params
async function createContentPerModality(
prompt: SimplePrompt,
type: ModalityType,
nodes: BaseNode[],
extraParams: Record<string, string | undefined>,
metadataMode: MetadataMode,
): Promise<MessageContentDetail[]> {
switch (type) {
case ModalityType.TEXT:
return [
{
type: "text",
text: prompt({
...extraParams,
context: nodes.map((r) => r.getContent(metadataMode)).join("\n\n"),
}),
},
];
case ModalityType.IMAGE:
const images: MessageContentDetail[] = await Promise.all(
(nodes as ImageNode[]).map(async (node) => {
return {
type: "image_url",
image_url: {
url: await imageToDataUrl(node.image),
},
} satisfies MessageContentDetail;
}),
);
return images;
default:
return [];
}
}

0 comments on commit 6e156ed

Please sign in to comment.