Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor query engine tool code and use auto_routed mode for LlamaCloudIndex #450

Merged
merged 15 commits into from
Nov 27, 2024
5 changes: 5 additions & 0 deletions .changeset/twenty-snakes-play.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"create-llama": patch
---

Use auto_routed retriever mode for LlamaCloudIndex
2 changes: 1 addition & 1 deletion helpers/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ const getAdditionalDependencies = (
case "llamacloud":
dependencies.push({
name: "llama-index-indices-managed-llama-cloud",
version: "^0.3.1",
version: "^0.6.0",
});
break;
}
Expand Down
41 changes: 11 additions & 30 deletions templates/components/agents/python/blog/app/agents/researcher.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,28 @@
import os
from textwrap import dedent
from typing import List

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.workflows.single import FunctionCallingAgent
from llama_index.core.chat_engine.types import ChatMessage
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from app.engine.tools.query_engine import get_query_engine_tool


def _create_query_engine_tool(params=None) -> QueryEngineTool:
"""
Provide an agent worker that can be used to query the index.
"""
# Add query tool if index exists
index_config = IndexConfig(**(params or {}))
index = get_index(index_config)
if index is None:
return None
top_k = int(os.getenv("TOP_K", 0))
query_engine = index.as_query_engine(
**({"similarity_top_k": top_k} if top_k != 0 else {})
)
return QueryEngineTool(
query_engine=query_engine,
metadata=ToolMetadata(
name="query_index",
description="""
Use this tool to retrieve information about the text corpus from the index.
""",
),
)


def _get_research_tools(**kwargs) -> QueryEngineTool:
def _get_research_tools(**kwargs):
"""
Researcher take responsibility for retrieving information.
Try init wikipedia or duckduckgo tool if available.
"""
tools = []
query_engine_tool = _create_query_engine_tool(**kwargs)
if query_engine_tool is not None:
tools.append(query_engine_tool)
# Create query engine tool
index_config = IndexConfig(**kwargs)
index = get_index(index_config)
if index is not None:
query_engine_tool = get_query_engine_tool(index=index)
if query_engine_tool is not None:
tools.append(query_engine_tool)

# Create duckduckgo tool
researcher_tool_names = [
"duckduckgo_search",
"duckduckgo_image_search",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import os
from typing import Any, Dict, List, Optional

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection
Expand All @@ -27,16 +26,16 @@
def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
filters: Optional[List[Any]] = None,
**kwargs,
) -> Workflow:
# Create query engine tool
index_config = IndexConfig(**params)
index: VectorStoreIndex = get_index(config=index_config)
index = get_index(index_config)
if index is None:
query_engine_tool = None
else:
top_k = int(os.getenv("TOP_K", 10))
query_engine = index.as_query_engine(similarity_top_k=top_k, filters=filters)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
raise ValueError(
"Index is not found. Try run generation script to create the index first."
)
query_engine_tool = get_query_engine_tool(index=index)

configured_tools: Dict[str, FunctionTool] = ToolFactory.from_env(map_result=True) # type: ignore
code_interpreter_tool = configured_tools.get("interpret")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
import os
from typing import Any, Dict, List, Optional

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)
from llama_index.core import Settings
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.llms.function_calling import FunctionCallingLLM
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.tools import FunctionTool, QueryEngineTool, ToolSelection
Expand All @@ -23,24 +14,28 @@
step,
)

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from app.engine.tools.query_engine import get_query_engine_tool
from app.workflows.events import AgentRunEvent
from app.workflows.tools import (
call_tools,
chat_with_tools,
)


def create_workflow(
chat_history: Optional[List[ChatMessage]] = None,
params: Optional[Dict[str, Any]] = None,
filters: Optional[List[Any]] = None,
**kwargs,
) -> Workflow:
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
if params is None:
params = {}
if filters is None:
filters = []
# Create query engine tool
index_config = IndexConfig(**params)
index: VectorStoreIndex = get_index(config=index_config)
index = get_index(index_config)
if index is None:
query_engine_tool = None
else:
top_k = int(os.getenv("TOP_K", 10))
query_engine = index.as_query_engine(similarity_top_k=top_k, filters=filters)
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
query_engine_tool = get_query_engine_tool(index=index)

configured_tools = ToolFactory.from_env(map_result=True)
extractor_tool = configured_tools.get("extract_questions") # type: ignore
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import { ChatMessage } from "llamaindex";
import { getTool } from "../engine/tools";
import { FunctionCallingAgent } from "./single-agent";
import { getQueryEngineTools } from "./tools";
import { getQueryEngineTool } from "./tools";

export const createResearcher = async (chatHistory: ChatMessage[]) => {
const queryEngineTools = await getQueryEngineTools();
const queryEngineTool = await getQueryEngineTool();
const tools = [
await getTool("wikipedia_tool"),
await getTool("duckduckgo_search"),
await getTool("image_generator"),
...(queryEngineTools ? queryEngineTools : []),
queryEngineTool,
].filter((tool) => tool !== undefined);

return new FunctionCallingAgent({
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import { ChatMessage, ToolCallLLM } from "llamaindex";
import { getTool } from "../engine/tools";
import { FinancialReportWorkflow } from "./fin-report";
import { getQueryEngineTools } from "./tools";
import { getQueryEngineTool } from "./tools";

const TIMEOUT = 360 * 1000;

export async function createWorkflow(options: {
chatHistory: ChatMessage[];
llm?: ToolCallLLM;
}) {
const queryEngineTool = await getQueryEngineTool();
const codeInterpreterTool = await getTool("interpreter");
const documentGeneratorTool = await getTool("document_generator");

if (!queryEngineTool || !codeInterpreterTool || !documentGeneratorTool) {
throw new Error("One or more required tools are not defined");
}

return new FinancialReportWorkflow({
chatHistory: options.chatHistory,
queryEngineTools: (await getQueryEngineTools()) || [],
codeInterpreterTool: (await getTool("interpreter"))!,
documentGeneratorTool: (await getTool("document_generator"))!,
queryEngineTool,
codeInterpreterTool,
documentGeneratorTool,
llm: options.llm,
timeout: TIMEOUT,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ export class FinancialReportWorkflow extends Workflow<
> {
llm: ToolCallLLM;
memory: ChatMemoryBuffer;
queryEngineTools: BaseToolWithCall[];
queryEngineTool: BaseToolWithCall;
codeInterpreterTool: BaseToolWithCall;
documentGeneratorTool: BaseToolWithCall;
systemPrompt?: string;

constructor(options: {
llm?: ToolCallLLM;
chatHistory: ChatMessage[];
queryEngineTools: BaseToolWithCall[];
queryEngineTool: BaseToolWithCall;
codeInterpreterTool: BaseToolWithCall;
documentGeneratorTool: BaseToolWithCall;
systemPrompt?: string;
Expand All @@ -70,7 +70,7 @@ export class FinancialReportWorkflow extends Workflow<
throw new Error("LLM is not a ToolCallLLM");
}
this.systemPrompt = options.systemPrompt ?? DEFAULT_SYSTEM_PROMPT;
this.queryEngineTools = options.queryEngineTools;
this.queryEngineTool = options.queryEngineTool;
this.codeInterpreterTool = options.codeInterpreterTool;

this.documentGeneratorTool = options.documentGeneratorTool;
Expand Down Expand Up @@ -153,10 +153,11 @@ export class FinancialReportWorkflow extends Workflow<
> => {
const chatHistory = ev.data.input;

const tools = [this.codeInterpreterTool, this.documentGeneratorTool];
if (this.queryEngineTools) {
tools.push(...this.queryEngineTools);
}
const tools = [
this.codeInterpreterTool,
this.documentGeneratorTool,
this.queryEngineTool,
];

const toolCallResponse = await chatWithTools(this.llm, tools, chatHistory);

Expand Down Expand Up @@ -189,10 +190,7 @@ export class FinancialReportWorkflow extends Workflow<
toolCalls: toolCallResponse.toolCalls,
});
default:
if (
this.queryEngineTools &&
this.queryEngineTools.some((tool) => tool.metadata.name === toolName)
) {
if (this.queryEngineTool.metadata.name === toolName) {
return new ResearchEvent({
toolCalls: toolCallResponse.toolCalls,
});
Expand All @@ -216,7 +214,7 @@ export class FinancialReportWorkflow extends Workflow<
const { toolCalls } = ev.data;

const toolMsgs = await callTools({
tools: this.queryEngineTools,
tools: [this.queryEngineTool],
toolCalls,
ctx,
agentName: "Researcher",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import { ChatMessage, ToolCallLLM } from "llamaindex";
import { getTool } from "../engine/tools";
import { FormFillingWorkflow } from "./form-filling";
import { getQueryEngineTools } from "./tools";
import { getQueryEngineTool } from "./tools";

const TIMEOUT = 360 * 1000;

export async function createWorkflow(options: {
chatHistory: ChatMessage[];
llm?: ToolCallLLM;
}) {
const extractorTool = await getTool("extract_missing_cells");
const fillMissingCellsTool = await getTool("fill_missing_cells");

if (!extractorTool || !fillMissingCellsTool) {
throw new Error("One or more required tools are not defined");
}

return new FormFillingWorkflow({
chatHistory: options.chatHistory,
queryEngineTools: (await getQueryEngineTools()) || [],
extractorTool: (await getTool("extract_missing_cells"))!,
fillMissingCellsTool: (await getTool("fill_missing_cells"))!,
queryEngineTool: (await getQueryEngineTool()) || undefined,
extractorTool,
fillMissingCellsTool,
llm: options.llm,
timeout: TIMEOUT,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ export class FormFillingWorkflow extends Workflow<
llm: ToolCallLLM;
memory: ChatMemoryBuffer;
extractorTool: BaseToolWithCall;
queryEngineTools?: BaseToolWithCall[];
queryEngineTool?: BaseToolWithCall;
leehuwuj marked this conversation as resolved.
Show resolved Hide resolved
fillMissingCellsTool: BaseToolWithCall;
systemPrompt?: string;

constructor(options: {
llm?: ToolCallLLM;
chatHistory: ChatMessage[];
extractorTool: BaseToolWithCall;
queryEngineTools?: BaseToolWithCall[];
queryEngineTool?: BaseToolWithCall;
fillMissingCellsTool: BaseToolWithCall;
systemPrompt?: string;
verbose?: boolean;
Expand All @@ -73,7 +73,7 @@ export class FormFillingWorkflow extends Workflow<
}
this.systemPrompt = options.systemPrompt ?? DEFAULT_SYSTEM_PROMPT;
this.extractorTool = options.extractorTool;
this.queryEngineTools = options.queryEngineTools;
this.queryEngineTool = options.queryEngineTool;
this.fillMissingCellsTool = options.fillMissingCellsTool;

this.memory = new ChatMemoryBuffer({
Expand Down Expand Up @@ -156,8 +156,8 @@ export class FormFillingWorkflow extends Workflow<
const chatHistory = ev.data.input;

const tools = [this.extractorTool, this.fillMissingCellsTool];
if (this.queryEngineTools) {
tools.push(...this.queryEngineTools);
if (this.queryEngineTool) {
tools.push(this.queryEngineTool);
}

const toolCallResponse = await chatWithTools(this.llm, tools, chatHistory);
Expand Down Expand Up @@ -192,8 +192,8 @@ export class FormFillingWorkflow extends Workflow<
});
default:
if (
this.queryEngineTools &&
this.queryEngineTools.some((tool) => tool.metadata.name === toolName)
this.queryEngineTool &&
this.queryEngineTool.metadata.name === toolName
) {
return new FindAnswersEvent({
toolCalls: toolCallResponse.toolCalls,
Expand Down Expand Up @@ -232,7 +232,7 @@ export class FormFillingWorkflow extends Workflow<
ev: FindAnswersEvent,
): Promise<InputEvent> => {
const { toolCalls } = ev.data;
if (!this.queryEngineTools) {
if (!this.queryEngineTool) {
throw new Error("Query engine tool is not available");
}
ctx.sendEvent(
Expand All @@ -243,7 +243,7 @@ export class FormFillingWorkflow extends Workflow<
}),
);
const toolMsgs = await callTools({
tools: this.queryEngineTools,
tools: [this.queryEngineTool],
toolCalls,
ctx,
agentName: "Researcher",
Expand Down
Loading