Skip to content

Commit 71d3667

Browse files
authored
Merge pull request #221 from ldjebran/allow-mcp-headers-with-mcp-servers-name
[LCORE-353] Allow mcp headers to contain mcp servers names
2 parents 2f71c65 + 8dd0435 commit 71d3667

File tree

5 files changed

+83
-8
lines changed

5 files changed

+83
-8
lines changed

src/app/endpoints/query.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from auth import get_auth_dependency
3030
from utils.common import retrieve_user_id
3131
from utils.endpoints import check_configuration_loaded, get_system_prompt
32-
from utils.mcp_headers import mcp_headers_dependency
32+
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3333
from utils.suid import get_suid
3434
from utils.types import GraniteToolParser
3535

@@ -231,6 +231,7 @@ def retrieve_response(
231231
# preserve compatibility when mcp_headers is not provided
232232
if mcp_headers is None:
233233
mcp_headers = {}
234+
mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration)
234235
if not mcp_headers and token:
235236
for mcp_server in configuration.mcp_servers:
236237
mcp_headers[mcp_server.url] = {

src/app/endpoints/streaming_query.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from models.requests import QueryRequest
2323
from utils.endpoints import check_configuration_loaded, get_system_prompt
2424
from utils.common import retrieve_user_id
25-
from utils.mcp_headers import mcp_headers_dependency
25+
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
2626
from utils.suid import get_suid
2727
from utils.types import GraniteToolParser
2828

@@ -290,6 +290,9 @@ async def retrieve_response(
290290
# preserve compatibility when mcp_headers is not provided
291291
if mcp_headers is None:
292292
mcp_headers = {}
293+
294+
mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration)
295+
293296
if not mcp_headers and token:
294297
for mcp_server in configuration.mcp_servers:
295298
mcp_headers[mcp_server.url] = {

src/utils/mcp_headers.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@
22

33
import json
44
import logging
5+
from urllib.parse import urlparse
6+
57
from fastapi import Request
68

9+
from configuration import AppConfig
10+
11+
712
logger = logging.getLogger("app.endpoints.dependencies")
813

914

@@ -46,3 +51,40 @@ def extract_mcp_headers(request: Request) -> dict[str, dict[str, str]]:
4651
)
4752
mcp_headers = {}
4853
return mcp_headers
54+
55+
56+
def handle_mcp_headers_with_toolgroups(
57+
mcp_headers: dict[str, dict[str, str]], config: AppConfig
58+
) -> dict[str, dict[str, str]]:
59+
"""Process MCP headers by converting toolgroup names to URLs.
60+
61+
This function takes MCP headers where keys can be either valid URLs or
62+
toolgroup names. For valid URLs (HTTP/HTTPS), it keeps them as-is. For
63+
toolgroup names, it looks up the corresponding MCP server URL in the
64+
configuration and replaces the key with the URL. Unknown toolgroup names
65+
are filtered out.
66+
67+
Args:
68+
mcp_headers: Dictionary with keys as URLs or toolgroup names
69+
config: Application configuration containing MCP server definitions
70+
71+
Returns:
72+
Dictionary with URLs as keys and their corresponding headers as values
73+
"""
74+
converted_mcp_headers = {}
75+
76+
for key, item in mcp_headers.items():
77+
key_url_parsed = urlparse(key)
78+
if key_url_parsed.scheme in ("http", "https") and key_url_parsed.netloc:
79+
# a valid url is supplied, deliver it as is
80+
converted_mcp_headers[key] = item
81+
else:
82+
# assume the key is a toolgroup name
83+
# look for toolgroups name in mcp_servers configuration
84+
# if the mcp server is not found, the mcp header gets ignored
85+
for mcp_server in config.mcp_servers:
86+
if mcp_server.name == key and mcp_server.url:
87+
converted_mcp_headers[mcp_server.url] = item
88+
break
89+
90+
return converted_mcp_headers

tests/unit/app/endpoints/test_query.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -694,8 +694,14 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
694694
model_id = "fake_model_id"
695695
access_token = ""
696696
mcp_headers = {
697-
"http://localhost:3000": {"Authorization": "Bearer test_token_123"},
698-
"https://git.example.com/mcp": {"Authorization": "Bearer test_token_123"},
697+
"filesystem-server": {"Authorization": "Bearer test_token_123"},
698+
"git-server": {"Authorization": "Bearer test_token_456"},
699+
"http://another-server-mcp-server:3000": {
700+
"Authorization": "Bearer test_token_789"
701+
},
702+
"unknown-mcp-server": {
703+
"Authorization": "Bearer test_token_for_unknown-mcp-server"
704+
},
699705
}
700706

701707
response, conversation_id = retrieve_response(
@@ -718,11 +724,20 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
718724
None, # conversation_id
719725
)
720726

727+
expected_mcp_headers = {
728+
"http://localhost:3000": {"Authorization": "Bearer test_token_123"},
729+
"https://git.example.com/mcp": {"Authorization": "Bearer test_token_456"},
730+
"http://another-server-mcp-server:3000": {
731+
"Authorization": "Bearer test_token_789"
732+
},
733+
# we do not put "unknown-mcp-server" url as it's unknown to lightspeed-stack
734+
}
735+
721736
# Check that the agent's extra_headers property was set correctly
722737
expected_extra_headers = {
723738
"X-LlamaStack-Provider-Data": json.dumps(
724739
{
725-
"mcp_headers": mcp_headers,
740+
"mcp_headers": expected_mcp_headers,
726741
}
727742
)
728743
}

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -762,8 +762,14 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
762762
model_id = "fake_model_id"
763763
access_token = ""
764764
mcp_headers = {
765-
"http://localhost:3000": {"Authorization": "Bearer test_token_123"},
766-
"https://git.example.com/mcp": {"Authorization": "Bearer test_token_456"},
765+
"filesystem-server": {"Authorization": "Bearer test_token_123"},
766+
"git-server": {"Authorization": "Bearer test_token_456"},
767+
"http://another-server-mcp-server:3000": {
768+
"Authorization": "Bearer test_token_789"
769+
},
770+
"unknown-mcp-server": {
771+
"Authorization": "Bearer test_token_for_unknown-mcp-server"
772+
},
767773
}
768774

769775
response, conversation_id = await retrieve_response(
@@ -786,9 +792,17 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
786792
None, # conversation_id
787793
)
788794

795+
expected_mcp_headers = {
796+
"http://localhost:3000": {"Authorization": "Bearer test_token_123"},
797+
"https://git.example.com/mcp": {"Authorization": "Bearer test_token_456"},
798+
"http://another-server-mcp-server:3000": {
799+
"Authorization": "Bearer test_token_789"
800+
},
801+
# we do not put "unknown-mcp-server" url as it's unknown to lightspeed-stack
802+
}
789803
# Check that the agent's extra_headers property was set correctly
790804
expected_extra_headers = {
791-
"X-LlamaStack-Provider-Data": json.dumps({"mcp_headers": mcp_headers})
805+
"X-LlamaStack-Provider-Data": json.dumps({"mcp_headers": expected_mcp_headers})
792806
}
793807
assert mock_agent.extra_headers == expected_extra_headers
794808

0 commit comments

Comments
 (0)