Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for multimodal #453

Merged
merged 19 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 125 additions & 7 deletions templates/components/engines/python/agent/tools/query_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,121 @@
import os
from typing import Optional
from typing import List, Optional, Sequence

from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
from llama_index.core.query_engine import SimpleMultiModalQueryEngine
from llama_index.core.query_engine.multi_modal import _get_image_and_text_nodes
from llama_index.core.response_synthesizers import (
BaseSynthesizer,
get_response_synthesizer,
)
from llama_index.core.response_synthesizers.type import ResponseMode
from llama_index.core.schema import ImageNode, NodeWithScore, QueryBundle
from llama_index.core.tools.query_engine import QueryEngineTool

from app.settings import multi_modal_llm

def create_query_engine(index, **kwargs):

class MultiModalQueryEngine(SimpleMultiModalQueryEngine):
"""
A multi-modal query engine that splits the retrieval results into chunks then summarizes each chunk to reduce the number of tokens in the response.
"""

def __init__(
self,
text_synthesizer: Optional[BaseSynthesizer] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
# Use a response synthesizer for text nodes summarization
self._text_synthesizer = text_synthesizer or get_response_synthesizer(
streaming=False,
response_mode=ResponseMode.TREE_SUMMARIZE,
)

def _summarize_text_nodes(
self, query_bundle: QueryBundle, nodes: List[NodeWithScore]
) -> str:
"""
Synthesize a response for the query using the retrieved nodes.
"""
return str(
self._text_synthesizer.synthesize(
query=query_bundle,
nodes=nodes,
streaming=False,
)
)

def synthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
) -> RESPONSE_TYPE:
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)
# Summarize the text nodes
text_response = self._summarize_text_nodes(
query_bundle=query_bundle,
nodes=text_nodes,
)

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query_bundle.query_str,
)
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved

llm_response = self._multi_modal_llm.complete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)

return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)

async def asynthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)
# Summarize the text nodes to avoid exceeding the token limit
text_response = self._summarize_text_nodes(
query_bundle=query_bundle,
nodes=text_nodes,
)

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query_bundle.query_str,
)

llm_response = await self._multi_modal_llm.acomplete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved

return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)


def create_query_engine(index, **kwargs) -> BaseQueryEngine:
"""
Create a query engine for the given index.

Expand All @@ -17,11 +128,18 @@ def create_query_engine(index, **kwargs):
kwargs["similarity_top_k"] = top_k
# If index is index is LlamaCloudIndex
# use auto_routed mode for better query results
if (
index.__class__.__name__ == "LlamaCloudIndex"
and kwargs.get("auto_routed") is None
):
kwargs["auto_routed"] = True
if index.__class__.__name__ == "LlamaCloudIndex":
retrieval_mode = kwargs.get("retrieval_mode")
if retrieval_mode is None:
kwargs["retrieval_mode"] = "auto_routed"
if multi_modal_llm:
# Note: image nodes are not supported for auto_routed or chunk retrieval mode
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
kwargs["retrieve_image_nodes"] = True
return MultiModalQueryEngine(
retriever=index.as_retriever(**kwargs),
multi_modal_llm=multi_modal_llm,
)

return index.as_query_engine(**kwargs)


Expand Down
15 changes: 13 additions & 2 deletions templates/components/settings/python/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
from typing import Dict
from typing import Dict, Optional

from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.settings import Settings

# Singleton for multi-modal LLM
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
multi_modal_llm: Optional[MultiModalLLM] = None


def init_settings():
model_provider = os.getenv("MODEL_PROVIDER")
Expand Down Expand Up @@ -60,14 +64,21 @@ def init_openai():
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.multi_modal_llms.openai.utils import GPT4V_MODELS

max_tokens = os.getenv("LLM_MAX_TOKENS")
model_name = os.getenv("MODEL", "gpt-4o-mini")
marcusschiesser marked this conversation as resolved.
Show resolved Hide resolved
Settings.llm = OpenAI(
model=os.getenv("MODEL", "gpt-4o-mini"),
model=model_name,
temperature=float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
max_tokens=int(max_tokens) if max_tokens is not None else None,
)

if model_name in GPT4V_MODELS:
global multi_modal_llm
multi_modal_llm = OpenAIMultiModal(model=model_name)

dimensions = os.getenv("EMBEDDING_DIM")
Settings.embed_model = OpenAIEmbedding(
model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"),
Expand Down