Skip to content

Commit

Permalink
add MultiModalSynthesizer
Browse files Browse the repository at this point in the history
  • Loading branch information
leehuwuj committed Nov 29, 2024
1 parent bc8df66 commit 88d2918
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 76 deletions.
104 changes: 32 additions & 72 deletions templates/components/engines/python/agent/tools/query_engine.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,59 @@
import os
from typing import List, Optional, Sequence
from typing import Any, List, Optional, Sequence

from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
from llama_index.core.query_engine import SimpleMultiModalQueryEngine
from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.prompts.base import BasePromptTemplate
from llama_index.core.prompts.default_prompt_selectors import (
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
)
from llama_index.core.query_engine import (
RetrieverQueryEngine,
)
from llama_index.core.query_engine.multi_modal import _get_image_and_text_nodes
from llama_index.core.response_synthesizers import (
BaseSynthesizer,
get_response_synthesizer,
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core.response_synthesizers.base import QueryTextType
from llama_index.core.schema import (
ImageNode,
NodeWithScore,
)
from llama_index.core.response_synthesizers.type import ResponseMode
from llama_index.core.schema import ImageNode, NodeWithScore, QueryBundle
from llama_index.core.tools.query_engine import QueryEngineTool

from app.settings import multi_modal_llm


class MultiModalQueryEngine(SimpleMultiModalQueryEngine):
class MultiModalSynthesizer(TreeSummarize):
"""
A multi-modal query engine that splits the retrieval results into chunks then summarizes each chunk to reduce the number of tokens in the response.
A synthesizer that summarizes text nodes and uses a multi-modal LLM to generate a response.
"""

def __init__(
self,
text_synthesizer: Optional[BaseSynthesizer] = None,
multimodal_model: Optional[MultiModalLLM] = None,
text_qa_template: Optional[BasePromptTemplate] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
# Use a response synthesizer for text nodes summarization
self._text_synthesizer = text_synthesizer or get_response_synthesizer(
streaming=False,
response_mode=ResponseMode.TREE_SUMMARIZE,
)

def _summarize_text_nodes(
self, query_bundle: QueryBundle, nodes: List[NodeWithScore]
) -> str:
"""
Synthesize a response for the query using the retrieved nodes.
"""
return str(
self._text_synthesizer.synthesize(
query=query_bundle,
nodes=nodes,
streaming=False,
)
)

def synthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)
# Summarize the text nodes
text_response = self._summarize_text_nodes(
query_bundle=query_bundle,
nodes=text_nodes,
)

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query_bundle.query_str,
)

llm_response = self._multi_modal_llm.complete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)

return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)
self._multi_modal_llm = multimodal_model
self._text_qa_template = text_qa_template or DEFAULT_TREE_SUMMARIZE_PROMPT_SEL

async def asynthesize(
self,
query_bundle: QueryBundle,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)

# Summarize the text nodes to avoid exceeding the token limit
text_response = self._summarize_text_nodes(
query_bundle=query_bundle,
nodes=text_nodes,
)
text_response = str(await super().asynthesize(query, nodes))

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query_bundle.query_str,
query_str=query.query_str, # type: ignore
)

llm_response = await self._multi_modal_llm.acomplete(
Expand All @@ -123,6 +80,7 @@ def create_query_engine(index, **kwargs) -> BaseQueryEngine:
index: The index to create a query engine for.
params (optional): Additional parameters for the query engine, e.g: similarity_top_k
"""

top_k = int(os.getenv("TOP_K", 0))
if top_k != 0 and kwargs.get("filters") is None:
kwargs["similarity_top_k"] = top_k
Expand All @@ -132,12 +90,14 @@ def create_query_engine(index, **kwargs) -> BaseQueryEngine:
retrieval_mode = kwargs.get("retrieval_mode")
if retrieval_mode is None:
kwargs["retrieval_mode"] = "auto_routed"
if multi_modal_llm:
# Note: image nodes are not supported for auto_routed or chunk retrieval mode
mm_model = multi_modal_llm.get()
if mm_model:
kwargs["retrieve_image_nodes"] = True
return MultiModalQueryEngine(
return RetrieverQueryEngine(
retriever=index.as_retriever(**kwargs),
multi_modal_llm=multi_modal_llm,
response_synthesizer=MultiModalSynthesizer(
multimodal_model=mm_model
),
)

return index.as_query_engine(**kwargs)
Expand Down
9 changes: 5 additions & 4 deletions templates/components/settings/python/settings.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from contextvars import ContextVar
from typing import Dict, Optional

from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.settings import Settings

# Singleton for multi-modal LLM
multi_modal_llm: Optional[MultiModalLLM] = None
multi_modal_llm: ContextVar[Optional[MultiModalLLM]] = ContextVar(
"multi_modal_llm", default=None
)


def init_settings():
Expand Down Expand Up @@ -76,8 +78,7 @@ def init_openai():
)

if model_name in GPT4V_MODELS:
global multi_modal_llm
multi_modal_llm = OpenAIMultiModal(model=model_name)
multi_modal_llm.set(OpenAIMultiModal(model=model_name))

dimensions = os.getenv("EMBEDDING_DIM")
Settings.embed_model = OpenAIEmbedding(
Expand Down

0 comments on commit 88d2918

Please sign in to comment.