Skip to content

Commit 34a5224

Browse files
committed
Pass mcp config and auth headers in streaming_query too
1 parent 48248a0 commit 34a5224

File tree

4 files changed

+230
-45
lines changed

4 files changed

+230
-45
lines changed

src/app/endpoints/query.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -221,20 +221,21 @@ def retrieve_response(
221221
"Authorization": f"Bearer {token}",
222222
}
223223

224+
agent.extra_headers = {
225+
"X-LlamaStack-Provider-Data": json.dumps(
226+
{
227+
"mcp_headers": mcp_headers,
228+
}
229+
),
230+
}
231+
224232
vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
225233
response = agent.create_turn(
226234
messages=[UserMessage(role="user", content=query_request.query)],
227235
session_id=conversation_id,
228236
documents=query_request.get_documents(),
229237
stream=False,
230238
toolgroups=get_rag_toolgroups(vector_db_ids),
231-
extra_headers={
232-
"X-LlamaStack-Provider-Data": json.dumps(
233-
{
234-
"mcp_headers": mcp_headers,
235-
}
236-
),
237-
},
238239
)
239240

240241
return str(response.output_message.content), conversation_id # type: ignore[union-attr]

src/app/endpoints/streaming_query.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def get_agent(
5959
model=model_id,
6060
instructions=system_prompt,
6161
input_shields=available_shields if available_shields else [],
62-
tools=[], # mcp config ?
62+
tools=[mcp.name for mcp in configuration.mcp_servers],
6363
enable_session_persistence=True,
6464
)
6565
conversation_id = await agent.create_session(get_suid())
@@ -173,7 +173,7 @@ async def streaming_query_endpoint_handler(
173173
client = await get_async_llama_stack_client(llama_stack_config)
174174
model_id = select_model_id(await client.models.list(), query_request)
175175
response, conversation_id = await retrieve_response(
176-
client, model_id, query_request
176+
client, model_id, query_request, auth
177177
)
178178

179179
async def response_generator(turn_response: Any) -> AsyncIterator[str]:
@@ -224,7 +224,10 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
224224

225225

226226
async def retrieve_response(
227-
client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest
227+
client: AsyncLlamaStackClient,
228+
model_id: str,
229+
query_request: QueryRequest,
230+
token: str,
228231
) -> tuple[Any, str]:
229232
"""Retrieve response from LLMs and agents."""
230233
available_shields = [shield.identifier for shield in await client.shields.list()]
@@ -254,6 +257,21 @@ async def retrieve_response(
254257
query_request.conversation_id,
255258
)
256259

260+
mcp_headers = {}
261+
if token:
262+
for mcp_server in configuration.mcp_servers:
263+
mcp_headers[mcp_server.url] = {
264+
"Authorization": f"Bearer {token}",
265+
}
266+
267+
agent.extra_headers = {
268+
"X-LlamaStack-Provider-Data": json.dumps(
269+
{
270+
"mcp_headers": mcp_headers,
271+
}
272+
),
273+
}
274+
257275
logger.debug("Session ID: %s", conversation_id)
258276
vector_db_ids = [
259277
vector_db.identifier for vector_db in await client.vector_dbs.list()

tests/unit/app/endpoints/test_query.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,6 @@ def test_retrieve_response_vector_db_available(mocker):
312312
documents=[],
313313
stream=False,
314314
toolgroups=get_rag_toolgroups(["VectorDB-1"]),
315-
extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'},
316315
)
317316

318317

@@ -348,7 +347,6 @@ def test_retrieve_response_no_available_shields(mocker):
348347
documents=[],
349348
stream=False,
350349
toolgroups=None,
351-
extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'},
352350
)
353351

354352

@@ -389,7 +387,6 @@ def __init__(self, identifier):
389387
documents=[],
390388
stream=False,
391389
toolgroups=None,
392-
extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'},
393390
)
394391

395392

@@ -433,7 +430,6 @@ def __init__(self, identifier):
433430
documents=[],
434431
stream=False,
435432
toolgroups=None,
436-
extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'},
437433
)
438434

439435

@@ -482,7 +478,6 @@ def test_retrieve_response_with_one_attachment(mocker):
482478
},
483479
],
484480
toolgroups=None,
485-
extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'},
486481
)
487482

488483

@@ -540,7 +535,6 @@ def test_retrieve_response_with_two_attachments(mocker):
540535
},
541536
],
542537
toolgroups=None,
543-
extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'},
544538
)
545539

546540

@@ -590,23 +584,28 @@ def test_retrieve_response_with_mcp_servers(mocker):
590584
None, # conversation_id
591585
)
592586

593-
# Check that the agent's create_turn was called with MCP headers
594-
mock_agent.create_turn.assert_called_once()
595-
call_args = mock_agent.create_turn.call_args
596-
597-
extra_headers_data = json.loads(
598-
call_args[1]["extra_headers"]["X-LlamaStack-Provider-Data"]
599-
)
600-
mcp_headers = extra_headers_data["mcp_headers"]
587+
# Check that the agent's extra_headers property was set correctly
588+
expected_extra_headers = {
589+
"X-LlamaStack-Provider-Data": json.dumps(
590+
{
591+
"mcp_headers": {
592+
"http://localhost:3000": {"Authorization": "Bearer test_token_123"},
593+
"https://git.example.com/mcp": {
594+
"Authorization": "Bearer test_token_123"
595+
},
596+
}
597+
}
598+
)
599+
}
600+
assert mock_agent.extra_headers == expected_extra_headers
601601

602-
assert "http://localhost:3000" in mcp_headers
603-
assert (
604-
mcp_headers["http://localhost:3000"]["Authorization"] == "Bearer test_token_123"
605-
)
606-
assert "https://git.example.com/mcp" in mcp_headers
607-
assert (
608-
mcp_headers["https://git.example.com/mcp"]["Authorization"]
609-
== "Bearer test_token_123"
602+
# Check that create_turn was called with the correct parameters
603+
mock_agent.create_turn.assert_called_once_with(
604+
messages=[UserMessage(role="user", content="What is OpenStack?")],
605+
session_id="fake_session_id",
606+
documents=[],
607+
stream=False,
608+
toolgroups=None,
610609
)
611610

612611

@@ -649,15 +648,20 @@ def test_retrieve_response_with_mcp_servers_empty_token(mocker):
649648
None, # conversation_id
650649
)
651650

652-
# Check that the agent's create_turn was called with empty MCP headers
653-
mock_agent.create_turn.assert_called_once()
654-
call_args = mock_agent.create_turn.call_args
651+
# Check that the agent's extra_headers property was set correctly (empty mcp_headers)
652+
expected_extra_headers = {
653+
"X-LlamaStack-Provider-Data": json.dumps({"mcp_headers": {}})
654+
}
655+
assert mock_agent.extra_headers == expected_extra_headers
655656

656-
extra_headers_data = json.loads(
657-
call_args[1]["extra_headers"]["X-LlamaStack-Provider-Data"]
657+
# Check that create_turn was called with the correct parameters
658+
mock_agent.create_turn.assert_called_once_with(
659+
messages=[UserMessage(role="user", content="What is OpenStack?")],
660+
session_id="fake_session_id",
661+
documents=[],
662+
stream=False,
663+
toolgroups=None,
658664
)
659-
mcp_headers = extra_headers_data["mcp_headers"]
660-
assert len(mcp_headers) == 0
661665

662666

663667
def test_construct_transcripts_path(setup_configuration, mocker):

0 commit comments

Comments
 (0)