66import os
77from pathlib import Path
88from typing import Any
9- from llama_stack_client .lib .agents .agent import Agent
109
10+ from cachetools import TTLCache # type: ignore
11+
12+ from llama_stack_client .lib .agents .agent import Agent
1113from llama_stack_client import APIConnectionError
1214from llama_stack_client import LlamaStackClient # type: ignore
1315from llama_stack_client .types import UserMessage # type: ignore
3234logger = logging .getLogger ("app.endpoints.handlers" )
3335router = APIRouter (tags = ["query" ])
3436
37+ # Global agent registry to persist agents across requests
38+ _agent_cache : TTLCache [str , Agent ] = TTLCache (maxsize = 1000 , ttl = 3600 )
3539
3640query_response : dict [int | str , dict [str , Any ]] = {
3741 200 : {
@@ -56,16 +60,33 @@ def is_transcripts_enabled() -> bool:
5660 return not configuration .user_data_collection_configuration .transcripts_disabled
5761
5862
59- def retrieve_conversation_id (query_request : QueryRequest ) -> str :
60- """Retrieve conversation ID based on existing ID or on newly generated one."""
61- conversation_id = query_request .conversation_id
62-
63- # Generate a new conversation ID if not provided
64- if not conversation_id :
65- conversation_id = get_suid ()
66- logger .info ("Generated new conversation ID: %s" , conversation_id )
67-
68- return conversation_id
63+ def get_agent (
64+ client : LlamaStackClient ,
65+ model_id : str ,
66+ system_prompt : str ,
67+ available_shields : list [str ],
68+ conversation_id : str | None ,
69+ ) -> tuple [Agent , str ]:
70+ """Get existing agent or create a new one with session persistence."""
71+ if conversation_id is not None :
72+ agent = _agent_cache .get (conversation_id )
73+ if agent :
74+ logger .debug ("Reusing existing agent with key: %s" , conversation_id )
75+ return agent , conversation_id
76+
77+ logger .debug ("Creating new agent" )
78+ # TODO(lucasagomes): move to ReActAgent
79+ agent = Agent (
80+ client ,
81+ model = model_id ,
82+ instructions = system_prompt ,
83+ input_shields = available_shields if available_shields else [],
84+ tools = [mcp .name for mcp in configuration .mcp_servers ],
85+ enable_session_persistence = True ,
86+ )
87+ conversation_id = agent .create_session (get_suid ())
88+ _agent_cache [conversation_id ] = agent
89+ return agent , conversation_id
6990
7091
7192@router .post ("/query" , responses = query_response )
@@ -83,8 +104,9 @@ def query_endpoint_handler(
83104 # try to get Llama Stack client
84105 client = get_llama_stack_client (llama_stack_config )
85106 model_id = select_model_id (client .models .list (), query_request )
86- conversation_id = retrieve_conversation_id (query_request )
87- response = retrieve_response (client , model_id , query_request , auth )
107+ response , conversation_id = retrieve_response (
108+ client , model_id , query_request , auth
109+ )
88110
89111 if not is_transcripts_enabled ():
90112 logger .debug ("Transcript collection is disabled in the configuration" )
@@ -163,7 +185,7 @@ def retrieve_response(
163185 model_id : str ,
164186 query_request : QueryRequest ,
165187 token : str ,
166- ) -> str :
188+ ) -> tuple [ str , str ] :
167189 """Retrieve response from LLMs and agents."""
168190 available_shields = [shield .identifier for shield in client .shields .list ()]
169191 if not available_shields :
@@ -184,40 +206,39 @@ def retrieve_response(
184206 if query_request .attachments :
185207 validate_attachments_metadata (query_request .attachments )
186208
187- # Build mcp_headers config dynamically for all MCP servers
188- # this will allow the agent to pass the user token to the MCP servers
209+ agent , conversation_id = get_agent (
210+ client ,
211+ model_id ,
212+ system_prompt ,
213+ available_shields ,
214+ query_request .conversation_id ,
215+ )
216+
189217 mcp_headers = {}
190218 if token :
191219 for mcp_server in configuration .mcp_servers :
192220 mcp_headers [mcp_server .url ] = {
193221 "Authorization" : f"Bearer { token } " ,
194222 }
195- # TODO(lucasagomes): move to ReActAgent
196- agent = Agent (
197- client ,
198- model = model_id ,
199- instructions = system_prompt ,
200- input_shields = available_shields if available_shields else [],
201- tools = [mcp .name for mcp in configuration .mcp_servers ],
202- extra_headers = {
203- "X-LlamaStack-Provider-Data" : json .dumps (
204- {
205- "mcp_headers" : mcp_headers ,
206- }
207- ),
208- },
209- )
210- session_id = agent .create_session ("chat_session" )
211- logger .debug ("Session ID: %s" , session_id )
223+
224+ agent .extra_headers = {
225+ "X-LlamaStack-Provider-Data" : json .dumps (
226+ {
227+ "mcp_headers" : mcp_headers ,
228+ }
229+ ),
230+ }
231+
212232 vector_db_ids = [vector_db .identifier for vector_db in client .vector_dbs .list ()]
213233 response = agent .create_turn (
214234 messages = [UserMessage (role = "user" , content = query_request .query )],
215- session_id = session_id ,
235+ session_id = conversation_id ,
216236 documents = query_request .get_documents (),
217237 stream = False ,
218238 toolgroups = get_rag_toolgroups (vector_db_ids ),
219239 )
220- return str (response .output_message .content ) # type: ignore[union-attr]
240+
241+ return str (response .output_message .content ), conversation_id # type: ignore[union-attr]
221242
222243
223244def validate_attachments_metadata (attachments : list [Attachment ]) -> None :
0 commit comments