Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for multimodal #453

Merged
merged 19 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/blue-hornets-boil.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"create-llama": patch
---

Add support multimodal indexes (e.g. from LlamaCloud)
148 changes: 141 additions & 7 deletions templates/components/engines/python/agent/tools/query_engine.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,51 @@
import os
from typing import Optional
from typing import Any, Dict, List, Optional, Sequence

from llama_index.core import get_response_synthesizer
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.prompts.base import BasePromptTemplate
from llama_index.core.prompts.default_prompt_selectors import (
DEFAULT_TEXT_QA_PROMPT_SEL,
)
from llama_index.core.query_engine.multi_modal import _get_image_and_text_nodes
from llama_index.core.response_synthesizers.base import BaseSynthesizer, QueryTextType
from llama_index.core.schema import (
ImageNode,
NodeWithScore,
)
from llama_index.core.tools.query_engine import QueryEngineTool
from llama_index.core.types import RESPONSE_TEXT_TYPE

from app.settings import get_multi_modal_llm

def create_query_engine(index, **kwargs):

def create_query_engine(index, **kwargs) -> BaseQueryEngine:
"""
Create a query engine for the given index.

Args:
index: The index to create a query engine for.
params (optional): Additional parameters for the query engine, e.g: similarity_top_k
"""

top_k = int(os.getenv("TOP_K", 0))
if top_k != 0 and kwargs.get("filters") is None:
kwargs["similarity_top_k"] = top_k
multimodal_llm = get_multi_modal_llm()
if multimodal_llm:
kwargs["response_synthesizer"] = MultiModalSynthesizer(
multimodal_model=multimodal_llm,
)

# If index is index is LlamaCloudIndex
# use auto_routed mode for better query results
if (
index.__class__.__name__ == "LlamaCloudIndex"
and kwargs.get("auto_routed") is None
):
kwargs["auto_routed"] = True
if index.__class__.__name__ == "LlamaCloudIndex":
if kwargs.get("retrieval_mode") is None:
kwargs["retrieval_mode"] = "auto_routed"
if multimodal_llm:
kwargs["retrieve_image_nodes"] = True
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
return index.as_query_engine(**kwargs)


Expand Down Expand Up @@ -51,3 +75,113 @@ def get_query_engine_tool(
name=name,
description=description,
)


class MultiModalSynthesizer(BaseSynthesizer):
"""
A synthesizer that summarizes text nodes and uses a multi-modal LLM to generate a response.
"""

def __init__(
self,
multimodal_model: MultiModalLLM,
response_synthesizer: Optional[BaseSynthesizer] = None,
text_qa_template: Optional[BasePromptTemplate] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._multi_modal_llm = multimodal_model
self._response_synthesizer = response_synthesizer or get_response_synthesizer()
self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT_SEL

leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
def _get_prompts(self, **kwargs) -> Dict[str, Any]:
return {
"text_qa_template": self._text_qa_template,
}

def _update_prompts(self, prompts: Dict[str, Any]) -> None:
if "text_qa_template" in prompts:
self._text_qa_template = prompts["text_qa_template"]

async def aget_response(
self,
*args,
**response_kwargs: Any,
) -> RESPONSE_TEXT_TYPE:
return await self._response_synthesizer.aget_response(*args, **response_kwargs)

def get_response(self, *args, **kwargs) -> RESPONSE_TEXT_TYPE:
return self._response_synthesizer.get_response(*args, **kwargs)

async def asynthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
**response_kwargs: Any,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)

if len(image_nodes) == 0:
return await self._response_synthesizer.asynthesize(query, text_nodes)

# Summarize the text nodes to avoid exceeding the token limit
text_response = str(
await self._response_synthesizer.asynthesize(query, text_nodes)
)

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query.query_str, # type: ignore
)
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved

llm_response = await self._multi_modal_llm.acomplete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved

return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved

def synthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)

if len(image_nodes) == 0:
return self._response_synthesizer.synthesize(query, text_nodes)

# Summarize the text nodes to avoid exceeding the token limit
text_response = str(self._response_synthesizer.synthesize(query, text_nodes))

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query.query_str, # type: ignore
)

llm_response = self._multi_modal_llm.complete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)

return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)
20 changes: 18 additions & 2 deletions templates/components/settings/python/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import os
from typing import Dict
from typing import Dict, Optional

from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.settings import Settings

# `Settings` does not support setting `MultiModalLLM`
# so we use a global variable to store it
_multi_modal_llm: Optional[MultiModalLLM] = None


def get_multi_modal_llm():
return _multi_modal_llm


def init_settings():
model_provider = os.getenv("MODEL_PROVIDER")
Expand Down Expand Up @@ -60,14 +69,21 @@ def init_openai():
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index.multi_modal_llms.openai.utils import GPT4V_MODELS

max_tokens = os.getenv("LLM_MAX_TOKENS")
model_name = os.getenv("MODEL", "gpt-4o-mini")
marcusschiesser marked this conversation as resolved.
Show resolved Hide resolved
Settings.llm = OpenAI(
model=os.getenv("MODEL", "gpt-4o-mini"),
model=model_name,
temperature=float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
max_tokens=int(max_tokens) if max_tokens is not None else None,
)

if model_name in GPT4V_MODELS:
global _multi_modal_llm
_multi_modal_llm = OpenAIMultiModal(model=model_name)
marcusschiesser marked this conversation as resolved.
Show resolved Hide resolved

dimensions = os.getenv("EMBEDDING_DIM")
Settings.embed_model = OpenAIEmbedding(
model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"),
Expand Down