Skip to content

Commit

Permalink
use getter
Browse files Browse the repository at this point in the history
  • Loading branch information
leehuwuj committed Nov 29, 2024
1 parent fa17d39 commit 0ffb7ff
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 43 deletions.
43 changes: 5 additions & 38 deletions templates/components/engines/python/agent/tools/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from llama_index.core.tools.query_engine import QueryEngineTool

from app.settings import multi_modal_llm
from app.settings import get_multi_modal_llm


class MultiModalSynthesizer(TreeSummarize):
Expand All @@ -39,38 +39,6 @@ def __init__(
self._multi_modal_llm = multimodal_model
self._text_qa_template = text_qa_template or DEFAULT_TREE_SUMMARIZE_PROMPT_SEL

def synthesize(
self,
query: QueryTextType,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
**response_kwargs: Any,
) -> RESPONSE_TYPE:
image_nodes, text_nodes = _get_image_and_text_nodes(nodes)

# Summarize the text nodes to avoid exceeding the token limit
text_response = str(super().synthesize(query, nodes))

fmt_prompt = self._text_qa_template.format(
context_str=text_response,
query_str=query.query_str, # type: ignore
)

llm_response = self._multi_modal_llm.complete(
prompt=fmt_prompt,
image_documents=[
image_node.node
for image_node in image_nodes
if isinstance(image_node.node, ImageNode)
],
)

return Response(
response=str(llm_response),
source_nodes=nodes,
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
)

async def asynthesize(
self,
query: QueryTextType,
Expand Down Expand Up @@ -122,18 +90,17 @@ def create_query_engine(index, **kwargs) -> BaseQueryEngine:
retrieval_mode = kwargs.get("retrieval_mode")
if retrieval_mode is None:
kwargs["retrieval_mode"] = "auto_routed"
mm_model = multi_modal_llm.get()
if mm_model:
multi_modal_llm = get_multi_modal_llm()
if multi_modal_llm:
kwargs["retrieve_image_nodes"] = True
print("Using multi-modal model")
return RetrieverQueryEngine(
retriever=index.as_retriever(**kwargs),
response_synthesizer=MultiModalSynthesizer(
multimodal_model=mm_model
multimodal_model=multi_modal_llm
),
)

return index.as_query_engine(**kwargs)
raise ValueError("Multi-modal LLM is not set")


def get_query_engine_tool(
Expand Down
14 changes: 9 additions & 5 deletions templates/components/settings/python/settings.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
from contextvars import ContextVar
from typing import Dict, Optional

from llama_index.core.multi_modal_llms import MultiModalLLM
from llama_index.core.settings import Settings

multi_modal_llm: ContextVar[Optional[MultiModalLLM]] = ContextVar(
"multi_modal_llm", default=None
)
# `Settings` does not support setting `MultiModalLLM`
# so we use a global variable to store it
_multi_modal_llm: Optional[MultiModalLLM] = None


def get_multi_modal_llm():
return _multi_modal_llm


def init_settings():
Expand Down Expand Up @@ -78,7 +81,8 @@ def init_openai():
)

if model_name in GPT4V_MODELS:
multi_modal_llm.set(OpenAIMultiModal(model=model_name))
global _multi_modal_llm
_multi_modal_llm = OpenAIMultiModal(model=model_name)

dimensions = os.getenv("EMBEDDING_DIM")
Settings.embed_model = OpenAIEmbedding(
Expand Down

0 comments on commit 0ffb7ff

Please sign in to comment.