Skip to content

Commit 41d6f95

Browse files
committed
Pass mcp config and auth headers in streaming_query too
1 parent 54b4149 commit 41d6f95

File tree

4 files changed

+235
-45
lines changed

4 files changed

+235
-45
lines changed

src/app/endpoints/query.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -194,20 +194,23 @@ def retrieve_response(
194194
"Authorization": f"Bearer {token}",
195195
}
196196

197+
extra_headers = {
198+
"X-LlamaStack-Provider-Data": json.dumps(
199+
{
200+
"mcp_headers": mcp_headers,
201+
}
202+
),
203+
}
204+
205+
agent.extra_headers = extra_headers
206+
197207
vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
198208
response = agent.create_turn(
199209
messages=[UserMessage(role="user", content=query_request.query)],
200210
session_id=conversation_id,
201211
documents=query_request.get_documents(),
202212
stream=False,
203213
toolgroups=get_rag_toolgroups(vector_db_ids),
204-
extra_headers={
205-
"X-LlamaStack-Provider-Data": json.dumps(
206-
{
207-
"mcp_headers": mcp_headers,
208-
}
209-
),
210-
},
211214
)
212215

213216
return str(response.output_message.content), conversation_id # type: ignore[union-attr]

src/app/endpoints/streaming_query.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
validate_attachments_metadata,
3131
)
3232

33+
3334
logger = logging.getLogger("app.endpoints.handlers")
3435
router = APIRouter(tags=["streaming_query"])
3536

@@ -57,7 +58,7 @@ async def get_agent(
5758
model=model_id,
5859
instructions=system_prompt,
5960
input_shields=available_shields if available_shields else [],
60-
tools=[], # mcp config ?
61+
tools=[mcp.name for mcp in configuration.mcp_servers],
6162
enable_session_persistence=True,
6263
)
6364
conversation_id = await agent.create_session(get_suid())
@@ -165,7 +166,9 @@ async def streaming_query_endpoint_handler(
165166
logger.info("LLama stack config: %s", llama_stack_config)
166167
client = await get_async_llama_stack_client(llama_stack_config)
167168
model_id = select_model_id(await client.models.list(), query_request)
168-
response, conversation_id = await retrieve_response(client, model_id, query_request)
169+
response, conversation_id = await retrieve_response(
170+
client, model_id, query_request, auth
171+
)
169172

170173
async def response_generator(turn_response: Any) -> AsyncIterator[str]:
171174
"""Generate SSE formatted streaming response."""
@@ -204,7 +207,10 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
204207

205208

206209
async def retrieve_response(
207-
client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest
210+
client: AsyncLlamaStackClient,
211+
model_id: str,
212+
query_request: QueryRequest,
213+
token: str,
208214
) -> tuple[Any, str]:
209215
"""Retrieve response from LLMs and agents."""
210216
available_shields = [shield.identifier for shield in await client.shields.list()]
@@ -234,6 +240,21 @@ async def retrieve_response(
234240
query_request.conversation_id,
235241
)
236242

243+
mcp_headers = {}
244+
if token:
245+
for mcp_server in configuration.mcp_servers:
246+
mcp_headers[mcp_server.url] = {
247+
"Authorization": f"Bearer {token}",
248+
}
249+
250+
agent.extra_headers = {
251+
"X-LlamaStack-Provider-Data": json.dumps(
252+
{
253+
"mcp_headers": mcp_headers,
254+
}
255+
),
256+
}
257+
237258
logger.debug("Session ID: %s", conversation_id)
238259
vector_db_ids = [
239260
vector_db.identifier for vector_db in await client.vector_dbs.list()

tests/unit/app/endpoints/test_query.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ def test_retrieve_response_vector_db_available(mocker):
295295
documents=[],
296296
stream=False,
297297
toolgroups=get_rag_toolgroups(["VectorDB-1"]),
298-
extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'},
299298
)
300299

301300

@@ -331,7 +330,6 @@ def test_retrieve_response_no_available_shields(mocker):
331330
documents=[],
332331
stream=False,
333332
toolgroups=None,
334-
extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'},
335333
)
336334

337335

@@ -372,7 +370,6 @@ def __init__(self, identifier):
372370
documents=[],
373371
stream=False,
374372
toolgroups=None,
375-
extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'},
376373
)
377374

378375

@@ -416,7 +413,6 @@ def __init__(self, identifier):
416413
documents=[],
417414
stream=False,
418415
toolgroups=None,
419-
extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'},
420416
)
421417

422418

@@ -465,7 +461,6 @@ def test_retrieve_response_with_one_attachment(mocker):
465461
},
466462
],
467463
toolgroups=None,
468-
extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'},
469464
)
470465

471466

@@ -523,7 +518,6 @@ def test_retrieve_response_with_two_attachments(mocker):
523518
},
524519
],
525520
toolgroups=None,
526-
extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'},
527521
)
528522

529523

@@ -573,23 +567,28 @@ def test_retrieve_response_with_mcp_servers(mocker):
573567
None, # conversation_id
574568
)
575569

576-
# Check that the agent's create_turn was called with MCP headers
577-
mock_agent.create_turn.assert_called_once()
578-
call_args = mock_agent.create_turn.call_args
579-
580-
extra_headers_data = json.loads(
581-
call_args[1]["extra_headers"]["X-LlamaStack-Provider-Data"]
582-
)
583-
mcp_headers = extra_headers_data["mcp_headers"]
570+
# Check that the agent's extra_headers property was set correctly
571+
expected_extra_headers = {
572+
"X-LlamaStack-Provider-Data": json.dumps(
573+
{
574+
"mcp_headers": {
575+
"http://localhost:3000": {"Authorization": "Bearer test_token_123"},
576+
"https://git.example.com/mcp": {
577+
"Authorization": "Bearer test_token_123"
578+
},
579+
}
580+
}
581+
)
582+
}
583+
assert mock_agent.extra_headers == expected_extra_headers
584584

585-
assert "http://localhost:3000" in mcp_headers
586-
assert (
587-
mcp_headers["http://localhost:3000"]["Authorization"] == "Bearer test_token_123"
588-
)
589-
assert "https://git.example.com/mcp" in mcp_headers
590-
assert (
591-
mcp_headers["https://git.example.com/mcp"]["Authorization"]
592-
== "Bearer test_token_123"
585+
# Check that create_turn was called with the correct parameters
586+
mock_agent.create_turn.assert_called_once_with(
587+
messages=[UserMessage(role="user", content="What is OpenStack?")],
588+
session_id="fake_session_id",
589+
documents=[],
590+
stream=False,
591+
toolgroups=None,
593592
)
594593

595594

@@ -632,15 +631,20 @@ def test_retrieve_response_with_mcp_servers_empty_token(mocker):
632631
None, # conversation_id
633632
)
634633

635-
# Check that the agent's create_turn was called with empty MCP headers
636-
mock_agent.create_turn.assert_called_once()
637-
call_args = mock_agent.create_turn.call_args
634+
# Check that the agent's extra_headers property was set correctly (empty mcp_headers)
635+
expected_extra_headers = {
636+
"X-LlamaStack-Provider-Data": json.dumps({"mcp_headers": {}})
637+
}
638+
assert mock_agent.extra_headers == expected_extra_headers
638639

639-
extra_headers_data = json.loads(
640-
call_args[1]["extra_headers"]["X-LlamaStack-Provider-Data"]
640+
# Check that create_turn was called with the correct parameters
641+
mock_agent.create_turn.assert_called_once_with(
642+
messages=[UserMessage(role="user", content="What is OpenStack?")],
643+
session_id="fake_session_id",
644+
documents=[],
645+
stream=False,
646+
toolgroups=None,
641647
)
642-
mcp_headers = extra_headers_data["mcp_headers"]
643-
assert len(mcp_headers) == 0
644648

645649

646650
def test_construct_transcripts_path(setup_configuration, mocker):

0 commit comments

Comments
 (0)