Skip to content

Commit 6dcbb5a

Browse files
wukathcopybara-github
authored andcommitted
feat: Support dynamic per-request headers in MCPToolset
Add a header_provider param which is a callable[ReadonlyContext, Dict[str, Any]] for users to build headers in MCPToolset fix: #3156 PiperOrigin-RevId: 820412372
1 parent 2a8fdd9 commit 6dcbb5a

File tree

8 files changed

+243
-5
lines changed

8 files changed

+243
-5
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
This agent connects to a local MCP server via Streamable HTTP and provides
2+
custom per-request headers to the MCP server.
3+
4+
To run this agent, start the local MCP server first by running:
5+
6+
```bash
7+
uv run header_server.py
8+
```
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from google.adk.agents.llm_agent import LlmAgent
17+
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
18+
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
19+
20+
root_agent = LlmAgent(
21+
model='gemini-2.0-flash',
22+
name='tenant_agent',
23+
instruction="""You are a helpful assistant that helps users get tenant
24+
information. Call the get_tenant_data tool when the user asks for tenant data.""",
25+
tools=[
26+
McpToolset(
27+
connection_params=StreamableHTTPConnectionParams(
28+
url='http://localhost:3000/mcp',
29+
),
30+
tool_filter=['get_tenant_data'],
31+
header_provider=lambda ctx: {'X-Tenant-ID': 'tenant1'},
32+
)
33+
],
34+
)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from fastapi import Request
18+
from mcp.server.fastmcp import Context
19+
from mcp.server.fastmcp import FastMCP
20+
21+
mcp = FastMCP('Header Check Server', host='localhost', port=3000)
22+
23+
TENANT_DATA = {
24+
'tenant1': {'name': 'Tenant 1', 'data': 'Data for tenant 1'},
25+
'tenant2': {'name': 'Tenant 2', 'data': 'Data for tenant 2'},
26+
}
27+
28+
29+
@mcp.tool(
30+
description='Returns tenant specific data based on X-Tenant-ID header.'
31+
)
32+
def get_tenant_data(context: Context) -> dict:
33+
"""Return tenant specific data."""
34+
if context.request_context and context.request_context.request:
35+
headers = context.request_context.request.headers
36+
tenant_id = headers.get('x-tenant-id')
37+
if tenant_id in TENANT_DATA:
38+
return TENANT_DATA[tenant_id]
39+
else:
40+
return {'error': f'Tenant {tenant_id} not found'}
41+
else:
42+
return {'error': 'Could not get request context'}
43+
44+
45+
if __name__ == '__main__':
46+
try:
47+
print('Starting Header Check MCP server on http://localhost:3000')
48+
mcp.run(transport='streamable-http')
49+
except KeyboardInterrupt:
50+
print('\nServer stopped.')

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import base64
1818
import inspect
1919
import logging
20+
import sys
2021
from typing import Any
2122
from typing import Callable
23+
from typing import Dict
2224
from typing import Optional
2325
from typing import Union
2426
import warnings
@@ -27,6 +29,7 @@
2729
from google.genai.types import FunctionDeclaration
2830
from typing_extensions import override
2931

32+
from ...agents.readonly_context import ReadonlyContext
3033
from .._gemini_schema_util import _to_gemini_schema
3134
from .mcp_session_manager import MCPSessionManager
3235
from .mcp_session_manager import retry_on_closed_resource
@@ -36,8 +39,6 @@
3639
try:
3740
from mcp.types import Tool as McpBaseTool
3841
except ImportError as e:
39-
import sys
40-
4142
if sys.version_info < (3, 10):
4243
raise ImportError(
4344
"MCP Tool requires Python 3.10 or above. Please upgrade your Python"
@@ -75,6 +76,9 @@ def __init__(
7576
auth_scheme: Optional[AuthScheme] = None,
7677
auth_credential: Optional[AuthCredential] = None,
7778
require_confirmation: Union[bool, Callable[..., bool]] = False,
79+
header_provider: Optional[
80+
Callable[[ReadonlyContext], Dict[str, str]]
81+
] = None,
7882
):
7983
"""Initializes an MCPTool.
8084
@@ -106,6 +110,7 @@ def __init__(
106110
self._mcp_tool = mcp_tool
107111
self._mcp_session_manager = mcp_session_manager
108112
self._require_confirmation = require_confirmation
113+
self._header_provider = header_provider
109114

110115
@override
111116
def _get_declaration(self) -> FunctionDeclaration:
@@ -192,10 +197,24 @@ async def _run_async_impl(
192197
Any: The response from the tool.
193198
"""
194199
# Extract headers from credential for session pooling
195-
headers = await self._get_headers(tool_context, credential)
200+
auth_headers = await self._get_headers(tool_context, credential)
201+
dynamic_headers = None
202+
if self._header_provider:
203+
dynamic_headers = self._header_provider(
204+
ReadonlyContext(tool_context._invocation_context)
205+
)
206+
207+
headers: Dict[str, str] = {}
208+
if auth_headers:
209+
headers.update(auth_headers)
210+
if dynamic_headers:
211+
headers.update(dynamic_headers)
212+
final_headers = headers if headers else None
196213

197214
# Get the session from the session manager
198-
session = await self._mcp_session_manager.create_session(headers=headers)
215+
session = await self._mcp_session_manager.create_session(
216+
headers=final_headers
217+
)
199218

200219
response = await session.call_tool(self._mcp_tool.name, arguments=args)
201220
return response

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import logging
1818
import sys
19+
from typing import Any
20+
from typing import AsyncIterator
1921
from typing import Callable
2022
from typing import Dict
2123
from typing import List
@@ -107,6 +109,9 @@ def __init__(
107109
auth_scheme: Optional[AuthScheme] = None,
108110
auth_credential: Optional[AuthCredential] = None,
109111
require_confirmation: Union[bool, Callable[..., bool]] = False,
112+
header_provider: Optional[
113+
Callable[[ReadonlyContext], Dict[str, str]]
114+
] = None,
110115
):
111116
"""Initializes the MCPToolset.
112117
@@ -130,6 +135,8 @@ def __init__(
130135
require_confirmation: Whether tools in this toolset require
131136
confirmation. Can be a single boolean or a callable to apply to all
132137
tools.
138+
header_provider: A callable that takes a ReadonlyContext and returns a
139+
dictionary of headers to be used for the MCP session.
133140
"""
134141
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
135142

@@ -138,6 +145,7 @@ def __init__(
138145

139146
self._connection_params = connection_params
140147
self._errlog = errlog
148+
self._header_provider = header_provider
141149

142150
# Create the session manager that will handle the MCP connection
143151
self._mcp_session_manager = MCPSessionManager(
@@ -162,8 +170,13 @@ async def get_tools(
162170
Returns:
163171
List[BaseTool]: A list of tools available under the specified context.
164172
"""
173+
headers = (
174+
self._header_provider(readonly_context)
175+
if self._header_provider and readonly_context
176+
else None
177+
)
165178
# Get session from session manager
166-
session = await self._mcp_session_manager.create_session()
179+
session = await self._mcp_session_manager.create_session(headers=headers)
167180

168181
# Fetch available tools from the MCP server
169182
tools_response: ListToolsResult = await session.list_tools()
@@ -177,6 +190,7 @@ async def get_tools(
177190
auth_scheme=self._auth_scheme,
178191
auth_credential=self._auth_credential,
179192
require_confirmation=self._require_confirmation,
193+
header_provider=self._header_provider,
180194
)
181195

182196
if self._is_tool_selected(mcp_tool, readonly_context):

tests/unittests/tools/mcp_tool/test_mcp_tool.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,3 +640,74 @@ def test_init_validation(self):
640640

641641
with pytest.raises(TypeError):
642642
MCPTool(mcp_tool=self.mock_mcp_tool) # Missing session manager
643+
644+
@pytest.mark.asyncio
645+
async def test_run_async_impl_with_header_provider_no_auth(self):
646+
"""Test running tool with header_provider but no auth."""
647+
expected_headers = {"X-Tenant-ID": "test-tenant"}
648+
header_provider = Mock(return_value=expected_headers)
649+
tool = MCPTool(
650+
mcp_tool=self.mock_mcp_tool,
651+
mcp_session_manager=self.mock_session_manager,
652+
header_provider=header_provider,
653+
)
654+
655+
expected_response = {"result": "success"}
656+
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
657+
658+
tool_context = Mock(spec=ToolContext)
659+
tool_context._invocation_context = Mock()
660+
args = {"param1": "test_value"}
661+
662+
result = await tool._run_async_impl(
663+
args=args, tool_context=tool_context, credential=None
664+
)
665+
666+
assert result == expected_response
667+
header_provider.assert_called_once()
668+
self.mock_session_manager.create_session.assert_called_once_with(
669+
headers=expected_headers
670+
)
671+
self.mock_session.call_tool.assert_called_once_with(
672+
"test_tool", arguments=args
673+
)
674+
675+
@pytest.mark.asyncio
676+
async def test_run_async_impl_with_header_provider_and_oauth2(self):
677+
"""Test running tool with header_provider and OAuth2 auth."""
678+
dynamic_headers = {"X-Tenant-ID": "test-tenant"}
679+
header_provider = Mock(return_value=dynamic_headers)
680+
tool = MCPTool(
681+
mcp_tool=self.mock_mcp_tool,
682+
mcp_session_manager=self.mock_session_manager,
683+
header_provider=header_provider,
684+
)
685+
686+
oauth2_auth = OAuth2Auth(access_token="test_access_token")
687+
credential = AuthCredential(
688+
auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth
689+
)
690+
691+
expected_response = {"result": "success"}
692+
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
693+
694+
tool_context = Mock(spec=ToolContext)
695+
tool_context._invocation_context = Mock()
696+
args = {"param1": "test_value"}
697+
698+
result = await tool._run_async_impl(
699+
args=args, tool_context=tool_context, credential=credential
700+
)
701+
702+
assert result == expected_response
703+
header_provider.assert_called_once()
704+
self.mock_session_manager.create_session.assert_called_once()
705+
call_args = self.mock_session_manager.create_session.call_args
706+
headers = call_args[1]["headers"]
707+
assert headers == {
708+
"Authorization": "Bearer test_access_token",
709+
"X-Tenant-ID": "test-tenant",
710+
}
711+
self.mock_session.call_tool.assert_called_once_with(
712+
"test_tool", arguments=args
713+
)

tests/unittests/tools/mcp_tool/test_mcp_toolset.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
# Import dependencies with version checking
3131
try:
32+
from google.adk.agents.readonly_context import ReadonlyContext
3233
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
3334
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
3435
from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams
@@ -55,6 +56,7 @@ def __init__(self, command="test_command", args=None):
5556
StreamableHTTPConnectionParams = DummyClass
5657
MCPTool = DummyClass
5758
MCPToolset = DummyClass
59+
ReadonlyContext = DummyClass
5860
else:
5961
raise e
6062

@@ -245,6 +247,31 @@ def file_tools_filter(tool, context):
245247
assert tools[0].name == "read_file"
246248
assert tools[1].name == "write_file"
247249

250+
@pytest.mark.asyncio
251+
async def test_get_tools_with_header_provider(self):
252+
"""Test get_tools with a header_provider."""
253+
mock_tools = [MockMCPTool("tool1"), MockMCPTool("tool2")]
254+
self.mock_session.list_tools = AsyncMock(
255+
return_value=MockListToolsResult(mock_tools)
256+
)
257+
mock_readonly_context = Mock(spec=ReadonlyContext)
258+
expected_headers = {"X-Tenant-ID": "test-tenant"}
259+
header_provider = Mock(return_value=expected_headers)
260+
261+
toolset = MCPToolset(
262+
connection_params=self.mock_stdio_params,
263+
header_provider=header_provider,
264+
)
265+
toolset._mcp_session_manager = self.mock_session_manager
266+
267+
tools = await toolset.get_tools(readonly_context=mock_readonly_context)
268+
269+
assert len(tools) == 2
270+
header_provider.assert_called_once_with(mock_readonly_context)
271+
self.mock_session_manager.create_session.assert_called_once_with(
272+
headers=expected_headers
273+
)
274+
248275
@pytest.mark.asyncio
249276
async def test_close_success(self):
250277
"""Test successful cleanup."""

0 commit comments

Comments
 (0)