forked from run-llama/LlamaIndexTS
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: use images in context chat engine (run-llama#886)
- Loading branch information
1 parent
0b51995
commit 6e156ed
Showing
7 changed files
with
171 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
--- | ||
"llamaindex": patch | ||
--- | ||
|
||
Use images in context chat engine |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// call pnpm tsx multimodal/load.ts first to init the storage | ||
import { | ||
ContextChatEngine, | ||
NodeWithScore, | ||
ObjectType, | ||
OpenAI, | ||
RetrievalEndEvent, | ||
Settings, | ||
VectorStoreIndex, | ||
} from "llamaindex"; | ||
import { getStorageContext } from "./storage"; | ||
|
||
// Update chunk size and overlap | ||
Settings.chunkSize = 512; | ||
Settings.chunkOverlap = 20; | ||
|
||
// Update llm | ||
Settings.llm = new OpenAI({ model: "gpt-4-turbo", maxTokens: 512 }); | ||
|
||
// Update callbackManager | ||
Settings.callbackManager.on("retrieve-end", (event: RetrievalEndEvent) => { | ||
const { nodes, query } = event.detail.payload; | ||
const imageNodes = nodes.filter( | ||
(node: NodeWithScore) => node.node.type === ObjectType.IMAGE_DOCUMENT, | ||
); | ||
const textNodes = nodes.filter( | ||
(node: NodeWithScore) => node.node.type === ObjectType.TEXT, | ||
); | ||
console.log( | ||
`Retrieved ${textNodes.length} text nodes and ${imageNodes.length} image nodes for query: ${query}`, | ||
); | ||
}); | ||
|
||
async function main() { | ||
const storageContext = await getStorageContext(); | ||
const index = await VectorStoreIndex.init({ | ||
storageContext, | ||
}); | ||
// topK for text is 0 and for image 1 => we only retrieve one image and no text based on the query | ||
const retriever = index.asRetriever({ topK: { TEXT: 0, IMAGE: 1 } }); | ||
// NOTE: we set the contextRole to "user" (default is "system"). The reason is that GPT-4 does not support | ||
// images in a system message | ||
const chatEngine = new ContextChatEngine({ retriever, contextRole: "user" }); | ||
|
||
// the ContextChatEngine will use the Clip embedding to retrieve the closest image | ||
// (the lady in the chair) and use it in the context for the query | ||
const response = await chatEngine.chat({ | ||
message: "What is the name of the painting with the lady in the chair?", | ||
}); | ||
|
||
console.log(response.response, "\n"); | ||
} | ||
|
||
main().catch(console.error); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import { | ||
ImageNode, | ||
MetadataMode, | ||
ModalityType, | ||
splitNodesByType, | ||
type BaseNode, | ||
} from "../Node.js"; | ||
import type { SimplePrompt } from "../Prompt.js"; | ||
import { imageToDataUrl } from "../embeddings/utils.js"; | ||
import type { MessageContentDetail } from "../llm/types.js"; | ||
|
||
export async function createMessageContent( | ||
prompt: SimplePrompt, | ||
nodes: BaseNode[], | ||
extraParams: Record<string, string | undefined> = {}, | ||
metadataMode: MetadataMode = MetadataMode.NONE, | ||
): Promise<MessageContentDetail[]> { | ||
const content: MessageContentDetail[] = []; | ||
const nodeMap = splitNodesByType(nodes); | ||
for (const type in nodeMap) { | ||
// for each retrieved modality type, create message content | ||
const nodes = nodeMap[type as ModalityType]; | ||
if (nodes) { | ||
content.push( | ||
...(await createContentPerModality( | ||
prompt, | ||
type as ModalityType, | ||
nodes, | ||
extraParams, | ||
metadataMode, | ||
)), | ||
); | ||
} | ||
} | ||
return content; | ||
} | ||
|
||
// eslint-disable-next-line max-params | ||
async function createContentPerModality( | ||
prompt: SimplePrompt, | ||
type: ModalityType, | ||
nodes: BaseNode[], | ||
extraParams: Record<string, string | undefined>, | ||
metadataMode: MetadataMode, | ||
): Promise<MessageContentDetail[]> { | ||
switch (type) { | ||
case ModalityType.TEXT: | ||
return [ | ||
{ | ||
type: "text", | ||
text: prompt({ | ||
...extraParams, | ||
context: nodes.map((r) => r.getContent(metadataMode)).join("\n\n"), | ||
}), | ||
}, | ||
]; | ||
case ModalityType.IMAGE: | ||
const images: MessageContentDetail[] = await Promise.all( | ||
(nodes as ImageNode[]).map(async (node) => { | ||
return { | ||
type: "image_url", | ||
image_url: { | ||
url: await imageToDataUrl(node.image), | ||
}, | ||
} satisfies MessageContentDetail; | ||
}), | ||
); | ||
return images; | ||
default: | ||
return []; | ||
} | ||
} |