Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor query engine tool code and use auto_routed mode for LlamaCloudIndex #450

Merged
merged 15 commits into from
Nov 27, 2024
33 changes: 11 additions & 22 deletions templates/components/agents/python/blog/app/agents/researcher.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,28 @@
import os
from textwrap import dedent
from typing import List
from typing import List, Optional, Dict, Any

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.workflows.single import FunctionCallingAgent
from llama_index.core.chat_engine.types import ChatMessage
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from llama_index.core.tools import QueryEngineTool
from app.engine.tools.query_engine import get_query_engine_tool


def _create_query_engine_tool(params=None) -> QueryEngineTool:
"""
Provide an agent worker that can be used to query the index.
"""
def _create_query_engine_tool(
params: Optional[Dict[str, Any]] = None, **kwargs
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
) -> QueryEngineTool:
if params is None:
params = {}
# Add query tool if index exists
index_config = IndexConfig(**(params or {}))
index_config = IndexConfig(**params)
index = get_index(index_config)
if index is None:
return None
top_k = int(os.getenv("TOP_K", 0))
query_engine = index.as_query_engine(
**({"similarity_top_k": top_k} if top_k != 0 else {})
)
return QueryEngineTool(
query_engine=query_engine,
metadata=ToolMetadata(
name="query_index",
description="""
Use this tool to retrieve information about the text corpus from the index.
""",
),
)
return get_query_engine_tool(index=index, **kwargs)


def _get_research_tools(**kwargs) -> QueryEngineTool:
def _get_research_tools(**kwargs):
"""
Researcher take responsibility for retrieving information.
Try init wikipedia or duckduckgo tool if available.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import os
from typing import Any, Dict, List, Optional

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection
Expand All @@ -24,19 +23,23 @@
)


def _create_query_engine_tool(params=None, **kwargs) -> QueryEngineTool:
if params is None:
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
params = {}
# Add query tool if index exists
index_config = IndexConfig(**params)
index = get_index(index_config)
if index is None:
return None
return get_query_engine_tool(index=index, **kwargs)


def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
filters: Optional[List[Any]] = None,
**kwargs,
) -> Workflow:
index_config = IndexConfig(**params)
index: VectorStoreIndex = get_index(config=index_config)
if index is None:
query_engine_tool = None
else:
top_k = int(os.getenv("TOP_K", 10))
query_engine = index.as_query_engine(similarity_top_k=top_k, filters=filters)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
query_engine_tool = _create_query_engine_tool(params, **kwargs)

configured_tools: Dict[str, FunctionTool] = ToolFactory.from_env(map_result=True) # type: ignore
code_interpreter_tool = configured_tools.get("interpret")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
import os
from typing import Any, Dict, List, Optional

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection
Expand All @@ -23,25 +14,35 @@
step,
)

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)


def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
filters: Optional[List[Any]] = None,
) -> Workflow:
def _create_query_engine_tool(
params: Optional[Dict[str, Any]] = None, **kwargs
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
) -> QueryEngineTool:
if params is None:
params = {}
if filters is None:
filters = []
# Add query tool if index exists
index_config = IndexConfig(**params)
index: VectorStoreIndex = get_index(config=index_config)
index = get_index(index_config)
if index is None:
query_engine_tool = None
else:
top_k = int(os.getenv("TOP_K", 10))
query_engine = index.as_query_engine(similarity_top_k=top_k, filters=filters)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
return None
return get_query_engine_tool(index=index, **kwargs)


def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Workflow:
query_engine_tool = _create_query_engine_tool(params, **kwargs)
configured_tools = ToolFactory.from_env(map_result=True)
extractor_tool = configured_tools.get("extract_questions") # type: ignore
filling_tool = configured_tools.get("fill_form") # type: ignore
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import { ChatMessage } from "llamaindex";
import { getTool } from "../engine/tools";
import { FunctionCallingAgent } from "./single-agent";
import { getQueryEngineTools } from "./tools";
import { getQueryEngineTool } from "./tools";

export const createResearcher = async (chatHistory: ChatMessage[]) => {
const queryEngineTools = await getQueryEngineTools();
const queryEngineTool = await getQueryEngineTool();
const tools = [
await getTool("wikipedia_tool"),
await getTool("duckduckgo_search"),
await getTool("image_generator"),
...(queryEngineTools ? queryEngineTools : []),
queryEngineTool,
].filter((tool) => tool !== undefined);

return new FunctionCallingAgent({
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ChatMessage, ToolCallLLM } from "llamaindex";
import { getTool } from "../engine/tools";
import { FinancialReportWorkflow } from "./fin-report";
import { getQueryEngineTools } from "./tools";
import { getQueryEngineTool } from "./tools";

const TIMEOUT = 360 * 1000;

Expand All @@ -11,7 +11,7 @@ export async function createWorkflow(options: {
}) {
return new FinancialReportWorkflow({
chatHistory: options.chatHistory,
queryEngineTools: (await getQueryEngineTools()) || [],
queryEngineTool: (await getQueryEngineTool())!,
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
codeInterpreterTool: (await getTool("interpreter"))!,
documentGeneratorTool: (await getTool("document_generator"))!,
llm: options.llm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ export class FinancialReportWorkflow extends Workflow<
> {
llm: ToolCallLLM;
memory: ChatMemoryBuffer;
queryEngineTools: BaseToolWithCall[];
queryEngineTool: BaseToolWithCall[];
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
codeInterpreterTool: BaseToolWithCall;
documentGeneratorTool: BaseToolWithCall;
systemPrompt?: string;

constructor(options: {
llm?: ToolCallLLM;
chatHistory: ChatMessage[];
queryEngineTools: BaseToolWithCall[];
queryEngineTool: BaseToolWithCall;
codeInterpreterTool: BaseToolWithCall;
documentGeneratorTool: BaseToolWithCall;
systemPrompt?: string;
Expand All @@ -70,7 +70,7 @@ export class FinancialReportWorkflow extends Workflow<
throw new Error("LLM is not a ToolCallLLM");
}
this.systemPrompt = options.systemPrompt ?? DEFAULT_SYSTEM_PROMPT;
this.queryEngineTools = options.queryEngineTools;
this.queryEngineTool = options.queryEngineTool;
this.codeInterpreterTool = options.codeInterpreterTool;

this.documentGeneratorTool = options.documentGeneratorTool;
Expand Down Expand Up @@ -154,8 +154,8 @@ export class FinancialReportWorkflow extends Workflow<
const chatHistory = ev.data.input;

const tools = [this.codeInterpreterTool, this.documentGeneratorTool];
if (this.queryEngineTools) {
tools.push(...this.queryEngineTools);
if (this.queryEngineTool) {
tools.push(this.queryEngineTool);
}

const toolCallResponse = await chatWithTools(this.llm, tools, chatHistory);
Expand Down Expand Up @@ -190,8 +190,8 @@ export class FinancialReportWorkflow extends Workflow<
});
default:
if (
this.queryEngineTools &&
this.queryEngineTools.some((tool) => tool.metadata.name === toolName)
this.queryEngineTool &&
this.queryEngineTool.metadata.name === toolName
) {
return new ResearchEvent({
toolCalls: toolCallResponse.toolCalls,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ChatMessage, ToolCallLLM } from "llamaindex";
import { getTool } from "../engine/tools";
import { FormFillingWorkflow } from "./form-filling";
import { getQueryEngineTools } from "./tools";
import { getQueryEngineTool } from "./tools";

const TIMEOUT = 360 * 1000;

Expand All @@ -11,7 +11,7 @@ export async function createWorkflow(options: {
}) {
return new FormFillingWorkflow({
chatHistory: options.chatHistory,
queryEngineTools: (await getQueryEngineTools()) || [],
queryEngineTool: (await getQueryEngineTool())!,
extractorTool: (await getTool("extract_missing_cells"))!,
fillMissingCellsTool: (await getTool("fill_missing_cells"))!,
llm: options.llm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ export class FormFillingWorkflow extends Workflow<
llm: ToolCallLLM;
memory: ChatMemoryBuffer;
extractorTool: BaseToolWithCall;
queryEngineTools?: BaseToolWithCall[];
queryEngineTool?: BaseToolWithCall;
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
fillMissingCellsTool: BaseToolWithCall;
systemPrompt?: string;

constructor(options: {
llm?: ToolCallLLM;
chatHistory: ChatMessage[];
extractorTool: BaseToolWithCall;
queryEngineTools?: BaseToolWithCall[];
queryEngineTool: BaseToolWithCall;
fillMissingCellsTool: BaseToolWithCall;
systemPrompt?: string;
verbose?: boolean;
Expand All @@ -73,7 +73,7 @@ export class FormFillingWorkflow extends Workflow<
}
this.systemPrompt = options.systemPrompt ?? DEFAULT_SYSTEM_PROMPT;
this.extractorTool = options.extractorTool;
this.queryEngineTools = options.queryEngineTools;
this.queryEngineTool = options.queryEngineTool;
this.fillMissingCellsTool = options.fillMissingCellsTool;

this.memory = new ChatMemoryBuffer({
Expand Down Expand Up @@ -156,8 +156,8 @@ export class FormFillingWorkflow extends Workflow<
const chatHistory = ev.data.input;

const tools = [this.extractorTool, this.fillMissingCellsTool];
if (this.queryEngineTools) {
tools.push(...this.queryEngineTools);
if (this.queryEngineTool) {
tools.push(this.queryEngineTool);
}

const toolCallResponse = await chatWithTools(this.llm, tools, chatHistory);
Expand Down Expand Up @@ -192,8 +192,8 @@ export class FormFillingWorkflow extends Workflow<
});
default:
if (
this.queryEngineTools &&
this.queryEngineTools.some((tool) => tool.metadata.name === toolName)
this.queryEngineTool &&
this.queryEngineTool.metadata.name === toolName
) {
return new FindAnswersEvent({
toolCalls: toolCallResponse.toolCalls,
Expand Down Expand Up @@ -232,7 +232,7 @@ export class FormFillingWorkflow extends Workflow<
ev: FindAnswersEvent,
): Promise<InputEvent> => {
const { toolCalls } = ev.data;
if (!this.queryEngineTools) {
if (!this.queryEngineTool) {
throw new Error("Query engine tool is not available");
}
ctx.sendEvent(
Expand All @@ -243,7 +243,7 @@ export class FormFillingWorkflow extends Workflow<
}),
);
const toolMsgs = await callTools({
tools: this.queryEngineTools,
tools: [this.queryEngineTool],
toolCalls,
ctx,
agentName: "Researcher",
Expand Down
15 changes: 6 additions & 9 deletions templates/components/engines/python/agent/engine.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
import os
from typing import List

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from llama_index.core.agent import AgentRunner
from llama_index.core.callbacks import CallbackManager
from llama_index.core.settings import Settings
from llama_index.core.tools import BaseTool
from llama_index.core.tools.query_engine import QueryEngineTool

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool


def get_chat_engine(filters=None, params=None, event_handlers=None, **kwargs):
def get_chat_engine(params=None, event_handlers=None, **kwargs):
system_prompt = os.getenv("SYSTEM_PROMPT")
top_k = int(os.getenv("TOP_K", 0))
tools: List[BaseTool] = []
callback_manager = CallbackManager(handlers=event_handlers or [])

# Add query tool if index exists
index_config = IndexConfig(callback_manager=callback_manager, **(params or {}))
index = get_index(index_config)
if index is not None:
query_engine = index.as_query_engine(
filters=filters, **({"similarity_top_k": top_k} if top_k != 0 else {})
)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
query_engine_tool = get_query_engine_tool(index, **kwargs)
tools.append(query_engine_tool)

# Add additional tools
Expand Down
Loading
Loading