Skip to content

Commit b779979

Browse files
authored
fix: react agent (#200)
Summary: 1. update prompt to match the json response object 2. call tools with tool_args when present (e.g. vector_db_ids for rag tool) Test Plan: Run react agent with rag tool ``` react_agent_with_vector_db = ReActAgent( client=client, model=model_id, tools=[ { "name": "builtin::rag", "args": {"vector_db_ids": [vector_db_id]}, } ], ) react_agent_with_vector_db_config = react_agent_with_vector_db.agent_config react_tool_parser = react_agent_with_vector_db.tool_parser react_agent_client_tools = react_agent_with_vector_db.client_tools results_vector_db_agent = evaluator_no_context.run( react_agent_with_vector_db_config, client, num_examples=100, client_tools=[], tool_parser=react_tool_parser, json_response_format=True, # answer_parsing_fn=lambda x: json.loads(x)["answer"], ) ```
1 parent 6fa8095 commit b779979

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

src/llama_stack_client/lib/agents/agent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ def initialize(self) -> None:
188188
)
189189
self.agent_id = agentic_system_create_response.agent_id
190190
for tg in self.agent_config["toolgroups"]:
191-
for tool in self.client.tools.list(toolgroup_id=tg):
192-
self.builtin_tools[tool.identifier] = tool
191+
for tool in self.client.tools.list(toolgroup_id=tg if isinstance(tg, str) else tg.get("name")):
192+
self.builtin_tools[tool.identifier] = tg.get("args", {}) if isinstance(tg, dict) else {}
193193

194194
def create_session(self, session_name: str) -> str:
195195
agentic_system_create_session_response = self.client.agents.session.create(
@@ -225,7 +225,7 @@ def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
225225
if tool_call.tool_name in self.builtin_tools:
226226
tool_result = self.client.tool_runtime.invoke_tool(
227227
tool_name=tool_call.tool_name,
228-
kwargs=tool_call.arguments,
228+
kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]},
229229
)
230230
tool_response = ToolResponseParam(
231231
call_id=tool_call.call_id,
@@ -411,7 +411,7 @@ async def initialize(self) -> None:
411411
self._agent_id = agentic_system_create_response.agent_id
412412
for tg in self.agent_config["toolgroups"]:
413413
for tool in await self.client.tools.list(toolgroup_id=tg):
414-
self.builtin_tools[tool.identifier] = tool
414+
self.builtin_tools[tool.identifier] = tg.get("args", {}) if isinstance(tg, dict) else {}
415415

416416
async def create_session(self, session_name: str) -> str:
417417
await self.initialize()
@@ -462,7 +462,7 @@ async def _run_tool(self, tool_calls: List[ToolCall]) -> ToolResponseParam:
462462
if tool_call.tool_name in self.builtin_tools:
463463
tool_result = await self.client.tool_runtime.invoke_tool(
464464
tool_name=tool_call.tool_name,
465-
kwargs=tool_call.arguments,
465+
kwargs={**tool_call.arguments, **self.builtin_tools[tool_call.tool_name]},
466466
)
467467
tool_response = ToolResponseParam(
468468
call_id=tool_call.call_id,

src/llama_stack_client/lib/agents/react/prompts.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
2323
The `action` key should specify the $TOOL_NAME the name of the tool to use and the `tool_params` key should specify the parameters key as input to the tool.
2424
25-
Make sure to have the $TOOL_PARAMS as a dictionary in the right format for the tool you are using, and do not put variable names as input if you can find the right values.
25+
Make sure to have the $TOOL_PARAMS as a list of dictionaries in the right format for the tool you are using, and do not put variable names as input if you can find the right values.
2626
2727
You should always think about one action to take, and have the `thought` key contain your thought process about this action.
2828
If the tool responds, the tool will return an observation containing result of the action.
@@ -37,7 +37,7 @@
3737
"thought": "I need to transform the image that I received in the previous observation to make it green.",
3838
"action": {
3939
"tool_name": "image_transformer",
40-
"tool_params": {"image": "image_1.jpg"}
40+
"tool_params": [{"name": "image"}, {"value": "image_1.jpg"}]
4141
},
4242
"answer": null
4343
}
@@ -61,7 +61,7 @@
6161
"thought": "I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.",
6262
"action": {
6363
"tool_name": "document_qa",
64-
"tool_params": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
64+
"tool_params": [{"name": "document"}, {"value": "document.pdf"}, {"name": "question"}, {"value": "Who is the oldest person mentioned?"}]
6565
},
6666
"answer": null
6767
}
@@ -73,7 +73,7 @@
7373
"thought": "I will now generate an image showcasing the oldest person.",
7474
"action": {
7575
"tool_name": "image_generator",
76-
"tool_params": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}
76+
"tool_params": [{"name": "prompt"}, {"value": "A portrait of John Doe, a 55-year-old man living in Canada."}]
7777
},
7878
"answer": null
7979
}
@@ -93,7 +93,7 @@
9393
"thought": "I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool",
9494
"action": {
9595
"tool_name": "python_interpreter",
96-
"tool_params": {"code": "5 + 3 + 1294.678"}
96+
"tool_params": [{"name": "code"}, {"value": "5 + 3 + 1294.678"}]
9797
},
9898
"answer": null
9999
}
@@ -113,7 +113,7 @@
113113
"thought": "I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.",
114114
"action": {
115115
"tool_name": "search",
116-
"tool_params": {"query": "Population Guangzhou"}
116+
"tool_params": [{"name": "query"}, {"value": "Population Guangzhou"}]
117117
},
118118
"answer": null
119119
}
@@ -124,7 +124,7 @@
124124
"thought": "Now let's get the population of Shanghai using the tool 'search'.",
125125
"action": {
126126
"tool_name": "search",
127-
"tool_params": {"query": "Population Shanghai"}
127+
"tool_params": [{"name": "query"}, {"value": "Population Shanghai"}]
128128
},
129129
"answer": null
130130
}

0 commit comments

Comments
 (0)