Skip to content

Commit 5737dbf

Browse files
authored
Merge pull request #179 from rawagner/streaming_mcp
Pass mcp config and auth headers in streaming_query too
2 parents 9859c0f + 34a5224 commit 5737dbf

File tree

6 files changed

+504
-178
lines changed

6 files changed

+504
-178
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ dependencies = [
1111
"uvicorn>=0.34.3",
1212
"llama-stack>=0.2.13",
1313
"rich>=14.0.0",
14+
"expiringdict>=1.2.2",
15+
"cachetools>=6.1.0",
1416
]
1517

1618
[tool.pdm]

src/app/endpoints/query.py

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import os
77
from pathlib import Path
88
from typing import Any
9-
from llama_stack_client.lib.agents.agent import Agent
109

10+
from cachetools import TTLCache # type: ignore
11+
12+
from llama_stack_client.lib.agents.agent import Agent
1113
from llama_stack_client import APIConnectionError
1214
from llama_stack_client import LlamaStackClient # type: ignore
1315
from llama_stack_client.types import UserMessage # type: ignore
@@ -32,6 +34,8 @@
3234
logger = logging.getLogger("app.endpoints.handlers")
3335
router = APIRouter(tags=["query"])
3436

37+
# Global agent registry to persist agents across requests
38+
_agent_cache: TTLCache[str, Agent] = TTLCache(maxsize=1000, ttl=3600)
3539

3640
query_response: dict[int | str, dict[str, Any]] = {
3741
200: {
@@ -56,16 +60,33 @@ def is_transcripts_enabled() -> bool:
5660
return not configuration.user_data_collection_configuration.transcripts_disabled
5761

5862

59-
def retrieve_conversation_id(query_request: QueryRequest) -> str:
60-
"""Retrieve conversation ID based on existing ID or on newly generated one."""
61-
conversation_id = query_request.conversation_id
62-
63-
# Generate a new conversation ID if not provided
64-
if not conversation_id:
65-
conversation_id = get_suid()
66-
logger.info("Generated new conversation ID: %s", conversation_id)
67-
68-
return conversation_id
63+
def get_agent(
64+
client: LlamaStackClient,
65+
model_id: str,
66+
system_prompt: str,
67+
available_shields: list[str],
68+
conversation_id: str | None,
69+
) -> tuple[Agent, str]:
70+
"""Get existing agent or create a new one with session persistence."""
71+
if conversation_id is not None:
72+
agent = _agent_cache.get(conversation_id)
73+
if agent:
74+
logger.debug("Reusing existing agent with key: %s", conversation_id)
75+
return agent, conversation_id
76+
77+
logger.debug("Creating new agent")
78+
# TODO(lucasagomes): move to ReActAgent
79+
agent = Agent(
80+
client,
81+
model=model_id,
82+
instructions=system_prompt,
83+
input_shields=available_shields if available_shields else [],
84+
tools=[mcp.name for mcp in configuration.mcp_servers],
85+
enable_session_persistence=True,
86+
)
87+
conversation_id = agent.create_session(get_suid())
88+
_agent_cache[conversation_id] = agent
89+
return agent, conversation_id
6990

7091

7192
@router.post("/query", responses=query_response)
@@ -83,8 +104,9 @@ def query_endpoint_handler(
83104
# try to get Llama Stack client
84105
client = get_llama_stack_client(llama_stack_config)
85106
model_id = select_model_id(client.models.list(), query_request)
86-
conversation_id = retrieve_conversation_id(query_request)
87-
response = retrieve_response(client, model_id, query_request, auth)
107+
response, conversation_id = retrieve_response(
108+
client, model_id, query_request, auth
109+
)
88110

89111
if not is_transcripts_enabled():
90112
logger.debug("Transcript collection is disabled in the configuration")
@@ -163,7 +185,7 @@ def retrieve_response(
163185
model_id: str,
164186
query_request: QueryRequest,
165187
token: str,
166-
) -> str:
188+
) -> tuple[str, str]:
167189
"""Retrieve response from LLMs and agents."""
168190
available_shields = [shield.identifier for shield in client.shields.list()]
169191
if not available_shields:
@@ -184,40 +206,39 @@ def retrieve_response(
184206
if query_request.attachments:
185207
validate_attachments_metadata(query_request.attachments)
186208

187-
# Build mcp_headers config dynamically for all MCP servers
188-
# this will allow the agent to pass the user token to the MCP servers
209+
agent, conversation_id = get_agent(
210+
client,
211+
model_id,
212+
system_prompt,
213+
available_shields,
214+
query_request.conversation_id,
215+
)
216+
189217
mcp_headers = {}
190218
if token:
191219
for mcp_server in configuration.mcp_servers:
192220
mcp_headers[mcp_server.url] = {
193221
"Authorization": f"Bearer {token}",
194222
}
195-
# TODO(lucasagomes): move to ReActAgent
196-
agent = Agent(
197-
client,
198-
model=model_id,
199-
instructions=system_prompt,
200-
input_shields=available_shields if available_shields else [],
201-
tools=[mcp.name for mcp in configuration.mcp_servers],
202-
extra_headers={
203-
"X-LlamaStack-Provider-Data": json.dumps(
204-
{
205-
"mcp_headers": mcp_headers,
206-
}
207-
),
208-
},
209-
)
210-
session_id = agent.create_session("chat_session")
211-
logger.debug("Session ID: %s", session_id)
223+
224+
agent.extra_headers = {
225+
"X-LlamaStack-Provider-Data": json.dumps(
226+
{
227+
"mcp_headers": mcp_headers,
228+
}
229+
),
230+
}
231+
212232
vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
213233
response = agent.create_turn(
214234
messages=[UserMessage(role="user", content=query_request.query)],
215-
session_id=session_id,
235+
session_id=conversation_id,
216236
documents=query_request.get_documents(),
217237
stream=False,
218238
toolgroups=get_rag_toolgroups(vector_db_ids),
219239
)
220-
return str(response.output_message.content) # type: ignore[union-attr]
240+
241+
return str(response.output_message.content), conversation_id # type: ignore[union-attr]
221242

222243

223244
def validate_attachments_metadata(attachments: list[Attachment]) -> None:

src/app/endpoints/streaming_query.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import logging
55
from typing import Any, AsyncIterator
66

7+
from cachetools import TTLCache # type: ignore
8+
79
from llama_stack_client import APIConnectionError
810
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
911
from llama_stack_client import AsyncLlamaStackClient # type: ignore
@@ -19,12 +21,12 @@
1921
from utils.auth import auth_dependency
2022
from utils.endpoints import check_configuration_loaded
2123
from utils.common import retrieve_user_id
24+
from utils.suid import get_suid
2225

2326

2427
from app.endpoints.query import (
2528
get_rag_toolgroups,
2629
is_transcripts_enabled,
27-
retrieve_conversation_id,
2830
store_transcript,
2931
select_model_id,
3032
validate_attachments_metadata,
@@ -33,6 +35,37 @@
3335
logger = logging.getLogger("app.endpoints.handlers")
3436
router = APIRouter(tags=["streaming_query"])
3537

38+
# Global agent registry to persist agents across requests
39+
_agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600)
40+
41+
42+
async def get_agent(
43+
client: AsyncLlamaStackClient,
44+
model_id: str,
45+
system_prompt: str,
46+
available_shields: list[str],
47+
conversation_id: str | None,
48+
) -> tuple[AsyncAgent, str]:
49+
"""Get existing agent or create a new one with session persistence."""
50+
if conversation_id is not None:
51+
agent = _agent_cache.get(conversation_id)
52+
if agent:
53+
logger.debug("Reusing existing agent with key: %s", conversation_id)
54+
return agent, conversation_id
55+
56+
logger.debug("Creating new agent")
57+
agent = AsyncAgent(
58+
client, # type: ignore[arg-type]
59+
model=model_id,
60+
instructions=system_prompt,
61+
input_shields=available_shields if available_shields else [],
62+
tools=[mcp.name for mcp in configuration.mcp_servers],
63+
enable_session_persistence=True,
64+
)
65+
conversation_id = await agent.create_session(get_suid())
66+
_agent_cache[conversation_id] = agent
67+
return agent, conversation_id
68+
3669

3770
def format_stream_data(d: dict) -> str:
3871
"""Format outbound data in the Event Stream Format."""
@@ -139,8 +172,9 @@ async def streaming_query_endpoint_handler(
139172
# try to get Llama Stack client
140173
client = await get_async_llama_stack_client(llama_stack_config)
141174
model_id = select_model_id(await client.models.list(), query_request)
142-
conversation_id = retrieve_conversation_id(query_request)
143-
response = await retrieve_response(client, model_id, query_request)
175+
response, conversation_id = await retrieve_response(
176+
client, model_id, query_request, auth
177+
)
144178

145179
async def response_generator(turn_response: Any) -> AsyncIterator[str]:
146180
"""Generate SSE formatted streaming response."""
@@ -190,8 +224,11 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
190224

191225

192226
async def retrieve_response(
193-
client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest
194-
) -> Any:
227+
client: AsyncLlamaStackClient,
228+
model_id: str,
229+
query_request: QueryRequest,
230+
token: str,
231+
) -> tuple[Any, str]:
195232
"""Retrieve response from LLMs and agents."""
196233
available_shields = [shield.identifier for shield in await client.shields.list()]
197234
if not available_shields:
@@ -212,24 +249,39 @@ async def retrieve_response(
212249
if query_request.attachments:
213250
validate_attachments_metadata(query_request.attachments)
214251

215-
agent = AsyncAgent(
216-
client, # type: ignore[arg-type]
217-
model=model_id,
218-
instructions=system_prompt,
219-
input_shields=available_shields if available_shields else [],
220-
tools=[],
252+
agent, conversation_id = await get_agent(
253+
client,
254+
model_id,
255+
system_prompt,
256+
available_shields,
257+
query_request.conversation_id,
221258
)
222-
session_id = await agent.create_session("chat_session")
223-
logger.debug("Session ID: %s", session_id)
259+
260+
mcp_headers = {}
261+
if token:
262+
for mcp_server in configuration.mcp_servers:
263+
mcp_headers[mcp_server.url] = {
264+
"Authorization": f"Bearer {token}",
265+
}
266+
267+
agent.extra_headers = {
268+
"X-LlamaStack-Provider-Data": json.dumps(
269+
{
270+
"mcp_headers": mcp_headers,
271+
}
272+
),
273+
}
274+
275+
logger.debug("Session ID: %s", conversation_id)
224276
vector_db_ids = [
225277
vector_db.identifier for vector_db in await client.vector_dbs.list()
226278
]
227279
response = await agent.create_turn(
228280
messages=[UserMessage(role="user", content=query_request.query)],
229-
session_id=session_id,
281+
session_id=conversation_id,
230282
documents=query_request.get_documents(),
231283
stream=True,
232284
toolgroups=get_rag_toolgroups(vector_db_ids),
233285
)
234286

235-
return response
287+
return response, conversation_id

0 commit comments

Comments
 (0)