Skip to content

Commit

Permalink
feat: add support for multimodal indexes (#453)
Browse files Browse the repository at this point in the history
---------
Co-authored-by: thucpn <thucsh2@gmail.com>
Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
  • Loading branch information
leehuwuj authored Nov 29, 2024
1 parent aedd73d commit f9a057d
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 9 deletions.
5 changes: 5 additions & 0 deletions .changeset/blue-hornets-boil.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"create-llama": patch
---

Add support multimodal indexes (e.g. from LlamaCloud)
148 changes: 141 additions & 7 deletions templates/components/engines/python/agent/tools/query_engine.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,51 @@
import os
from typing import Optional
from typing import Any, Dict, List, Optional, Sequence

from llama_index.core import get_response_synthesizer
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.prompts.base import BasePromptTemplate
from llama_index.core.prompts.default_prompt_selectors import (
DEFAULT_TEXT_QA_PROMPT_SEL,
)
from llama_index.core.query_engine.multi_modal import _get_image_and_text_nodes
from llama_index.core.response_synthesizers.base import BaseSynthesizer, QueryTextType
from llama_index.core.schema import (
ImageNode,
NodeWithScore,
)
from llama_index.core.tools.query_engine import QueryEngineTool
from llama_index.core.types import RESPONSE_TEXT_TYPE

from app.settings import get_multi_modal_llm

def create_query_engine(index, **kwargs):

def create_query_engine(index, **kwargs) -> BaseQueryEngine:
"""
Create a query engine for the given index.
Args:
index: The index to create a query engine for.
params (optional): Additional parameters for the query engine, e.g: similarity_top_k
"""

top_k = int(os.getenv("TOP_K", 0))
if top_k != 0 and kwargs.get("filters") is None:
kwargs["similarity_top_k"] = top_k
multimodal_llm = get_multi_modal_llm()
if multimodal_llm:
kwargs["response_synthesizer"] = MultiModalSynthesizer(
multimodal_model=multimodal_llm,
)

# If index is index is LlamaCloudIndex
# use auto_routed mode for better query results
if (
index.__class__.__name__ == "LlamaCloudIndex"
and kwargs.get("auto_routed") is None
):
kwargs["auto_routed"] = True
if index.__class__.__name__ == "LlamaCloudIndex":
if kwargs.get("retrieval_mode") is None:
kwargs["retrieval_mode"] = "auto_routed"
if multimodal_llm:
kwargs["retrieve_image_nodes"] = True
return index.as_query_engine(**kwargs)


Expand Down Expand Up @@ -51,3 +75,113 @@ def get_query_engine_tool(
name=name,
description=description,
)


class MultiModalSynthesizer(BaseSynthesizer):
"""
A synthesizer that summarizes text nodes and uses a multi-modal LLM to generate a response.
"""

def __init__(
self,
multimodal_model: MultiModalLLM,
response_synthesizer: Optional[BaseSynthesizer] = None,
text_qa_template: Optional[BasePromptTemplate] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._multi_modal_llm = multimodal_model
self._response_synthesizer = response_synthesizer or get_response_synthesizer()
self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL

def _get_prompts(self, **kwargs) -> Dict[str, Any]:
return {
"text_qa_template": self._text_qa_template,
}

def _update_prompts(self, prompts: Dict[str, Any]) -> None:
if "text_qa_template" in prompts:
self._text_qa_template = prompts["text_qa_template"]

async def aget_response(
self,
*args,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return await self._response_synthesizer.aget_response(*args, **response_kwargs)

def get_response(self, *args, **kwargs) -> RESPONSE_TEXT_TYPE:
return self._response_synthesizer.get_response(*args, **kwargs)

async def asynthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)

if len(image_nodes) == 0:
return await self._response_synthesizer.asynthesize(query, text_nodes)

# Summarize the text nodes to avoid exceeding the token limit
text_response = str(
await self._response_synthesizer.asynthesize(query, text_nodes)
)

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query.query_str, # type: ignore
)

llm_response = await self._multi_modal_llm.acomplete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)

return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)

def synthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)

if len(image_nodes) == 0:
return self._response_synthesizer.synthesize(query, text_nodes)

# Summarize the text nodes to avoid exceeding the token limit
text_response = str(self._response_synthesizer.synthesize(query, text_nodes))

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query.query_str, # type: ignore
)

llm_response = self._multi_modal_llm.complete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)

return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)
20 changes: 18 additions & 2 deletions templates/components/settings/python/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import os
from typing import Dict
from typing import Dict, Optional

from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.settings import Settings

# `Settings` does not support setting `MultiModalLLM`
# so we use a global variable to store it
_multi_modal_llm: Optional[MultiModalLLM] = None


def get_multi_modal_llm():
return _multi_modal_llm


def init_settings():
model_provider = os.getenv("MODEL_PROVIDER")
Expand Down Expand Up @@ -60,14 +69,21 @@ def init_openai():
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.multi_modal_llms.openai.utils import GPT4V_MODELS

max_tokens = os.getenv("LLM_MAX_TOKENS")
model_name = os.getenv("MODEL", "gpt-4o-mini")
Settings.llm = OpenAI(
model=os.getenv("MODEL", "gpt-4o-mini"),
model=model_name,
temperature=float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
max_tokens=int(max_tokens) if max_tokens is not None else None,
)

if model_name in GPT4V_MODELS:
global _multi_modal_llm
_multi_modal_llm = OpenAIMultiModal(model=model_name)

dimensions = os.getenv("EMBEDDING_DIM")
Settings.embed_model = OpenAIEmbedding(
model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"),
Expand Down

0 comments on commit f9a057d

Please sign in to comment.