Skip to content

Commit 9a6b850

Browse files
xuanyang15copybara-github
authored andcommitted
feat: Add bypass_multi_tools_limit option to GoogleSearchTool and VertexAiSearchTool
PiperOrigin-RevId: 817493869
1 parent 64646e0 commit 9a6b850

File tree

4 files changed

+74
-15
lines changed

4 files changed

+74
-15
lines changed

src/google/adk/agents/llm_agent.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,19 @@ async def _convert_tool_union_to_tools(
118118
model: Union[str, BaseLlm],
119119
multiple_tools: bool = False,
120120
) -> list[BaseTool]:
121-
from ..tools.google_search_tool import google_search
121+
from ..tools.google_search_tool import GoogleSearchTool
122122
from ..tools.vertex_ai_search_tool import VertexAiSearchTool
123123

124124
# Wrap google_search tool with AgentTool if there are multiple tools because
125125
# the built-in tools cannot be used together with other tools.
126126
# TODO(b/448114567): Remove once the workaround is no longer needed.
127-
if multiple_tools and tool_union is google_search:
127+
if multiple_tools and isinstance(tool_union, GoogleSearchTool):
128128
from ..tools.google_search_agent_tool import create_google_search_agent
129129
from ..tools.google_search_agent_tool import GoogleSearchAgentTool
130130

131-
return [GoogleSearchAgentTool(create_google_search_agent(model))]
131+
search_tool = cast(GoogleSearchTool, tool_union)
132+
if search_tool.bypass_multi_tools_limit:
133+
return [GoogleSearchAgentTool(create_google_search_agent(model))]
132134

133135
# Replace VertexAiSearchTool with DiscoveryEngineSearchTool if there are
134136
# multiple tools because the built-in tools cannot be used together with
@@ -138,15 +140,16 @@ async def _convert_tool_union_to_tools(
138140
from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool
139141

140142
vais_tool = cast(VertexAiSearchTool, tool_union)
141-
return [
142-
DiscoveryEngineSearchTool(
143-
data_store_id=vais_tool.data_store_id,
144-
data_store_specs=vais_tool.data_store_specs,
145-
search_engine_id=vais_tool.search_engine_id,
146-
filter=vais_tool.filter,
147-
max_results=vais_tool.max_results,
148-
)
149-
]
143+
if vais_tool.bypass_multi_tools_limit:
144+
return [
145+
DiscoveryEngineSearchTool(
146+
data_store_id=vais_tool.data_store_id,
147+
data_store_specs=vais_tool.data_store_specs,
148+
search_engine_id=vais_tool.search_engine_id,
149+
filter=vais_tool.filter,
150+
max_results=vais_tool.max_results,
151+
)
152+
]
150153

151154
if isinstance(tool_union, BaseTool):
152155
return [tool_union]

src/google/adk/tools/google_search_tool.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,17 @@ class GoogleSearchTool(BaseTool):
3535
local code execution.
3636
"""
3737

38-
def __init__(self):
38+
def __init__(self, *, bypass_multi_tools_limit: bool = True):
39+
"""Initializes the Google search tool.
40+
41+
Args:
42+
bypass_multi_tools_limit: Whether to bypass the multi tools limitation,
43+
so that the tool can be used with other tools in the same agent.
44+
"""
45+
3946
# Name and description are not used because this is a model built-in tool.
4047
super().__init__(name='google_search', description='google_search')
48+
self.bypass_multi_tools_limit = bypass_multi_tools_limit
4149

4250
@override
4351
async def process_llm_request(

src/google/adk/tools/vertex_ai_search_tool.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
search_engine_id: Optional[str] = None,
4848
filter: Optional[str] = None,
4949
max_results: Optional[int] = None,
50+
bypass_multi_tools_limit: bool = True,
5051
):
5152
"""Initializes the Vertex AI Search tool.
5253
@@ -58,6 +59,10 @@ def __init__(
5859
searched. It should only be set if engine is used.
5960
search_engine_id: The Vertex AI search engine resource ID in the format of
6061
"projects/{project}/locations/{location}/collections/{collection}/engines/{engine}".
62+
filter: The filter to apply to the search results.
63+
max_results: The maximum number of results to return.
64+
bypass_multi_tools_limit: Whether to bypass the multi tools limitation,
65+
so that the tool can be used with other tools in the same agent.
6166
6267
Raises:
6368
ValueError: If both data_store_id and search_engine_id are not specified
@@ -80,6 +85,7 @@ def __init__(
8085
self.search_engine_id = search_engine_id
8186
self.filter = filter
8287
self.max_results = max_results
88+
self.bypass_multi_tools_limit = bypass_multi_tools_limit
8389

8490
@override
8591
async def process_llm_request(

tests/unittests/agents/test_llm_agent_fields.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from google.adk.models.registry import LLMRegistry
2727
from google.adk.sessions.in_memory_session_service import InMemorySessionService
2828
from google.adk.tools.google_search_tool import google_search
29+
from google.adk.tools.google_search_tool import GoogleSearchTool
2930
from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool
3031
from google.genai import types
3132
from pydantic import BaseModel
@@ -310,6 +311,25 @@ async def test_handle_google_search_with_other_tools(self):
310311
assert tools[1].name == 'google_search_agent'
311312
assert tools[1].__class__.__name__ == 'GoogleSearchAgentTool'
312313

314+
async def test_handle_google_search_with_other_tools_no_bypass(self):
315+
"""Test that google_search is not wrapped into an agent."""
316+
agent = LlmAgent(
317+
name='test_agent',
318+
model='gemini-pro',
319+
tools=[
320+
self._my_tool,
321+
GoogleSearchTool(bypass_multi_tools_limit=False),
322+
],
323+
)
324+
ctx = await _create_readonly_context(agent)
325+
tools = await agent.canonical_tools(ctx)
326+
327+
assert len(tools) == 2
328+
assert tools[0].name == '_my_tool'
329+
assert tools[0].__class__.__name__ == 'FunctionTool'
330+
assert tools[1].name == 'google_search'
331+
assert tools[1].__class__.__name__ == 'GoogleSearchTool'
332+
313333
async def test_handle_google_search_only(self):
314334
"""Test that google_search is not wrapped into an agent."""
315335
agent = LlmAgent(
@@ -346,8 +366,8 @@ async def test_function_tool_only(self):
346366
'google.auth.default',
347367
mock.MagicMock(return_value=('credentials', 'project')),
348368
)
349-
async def test_handle_google_vais_with_other_tools(self):
350-
"""Test that VertexAiSearchTool is wrapped into an agent."""
369+
async def test_handle_vais_with_other_tools(self):
370+
"""Test that VertexAiSearchTool is replaced with Discovery Engine Search."""
351371
agent = LlmAgent(
352372
name='test_agent',
353373
model='gemini-pro',
@@ -365,6 +385,28 @@ async def test_handle_google_vais_with_other_tools(self):
365385
assert tools[1].name == 'discovery_engine_search'
366386
assert tools[1].__class__.__name__ == 'DiscoveryEngineSearchTool'
367387

388+
async def test_handle_vais_with_other_tools_no_bypass(self):
389+
"""Test that VertexAiSearchTool is not replaced."""
390+
agent = LlmAgent(
391+
name='test_agent',
392+
model='gemini-pro',
393+
tools=[
394+
self._my_tool,
395+
VertexAiSearchTool(
396+
data_store_id='test_data_store_id',
397+
bypass_multi_tools_limit=False,
398+
),
399+
],
400+
)
401+
ctx = await _create_readonly_context(agent)
402+
tools = await agent.canonical_tools(ctx)
403+
404+
assert len(tools) == 2
405+
assert tools[0].name == '_my_tool'
406+
assert tools[0].__class__.__name__ == 'FunctionTool'
407+
assert tools[1].name == 'vertex_ai_search'
408+
assert tools[1].__class__.__name__ == 'VertexAiSearchTool'
409+
368410
async def test_handle_vais_only(self):
369411
"""Test that VertexAiSearchTool is not wrapped into an agent."""
370412
agent = LlmAgent(

0 commit comments

Comments
 (0)