66import os
77from pathlib import Path
88from typing import Any
9+ from expiringdict import ExpiringDict
910from llama_stack_client .lib .agents .agent import Agent
10-
1111from llama_stack_client import LlamaStackClient # type: ignore
1212from llama_stack_client .types import UserMessage # type: ignore
1313from llama_stack_client .types .agents .turn_create_params import (
3030logger = logging .getLogger ("app.endpoints.handlers" )
3131router = APIRouter (tags = ["query" ])
3232
33+ # Global agent registry to persist agents across requests
34+ _agent_cache : dict [str , Agent ] = ExpiringDict (max_len = 1000 , max_age_seconds = 3600 )
3335
3436query_response : dict [int | str , dict [str , Any ]] = {
3537 200 : {
@@ -48,16 +50,33 @@ def is_transcripts_enabled() -> bool:
4850 return not configuration .user_data_collection_configuration .transcripts_disabled
4951
5052
51- def retrieve_conversation_id (query_request : QueryRequest ) -> str :
52- """Retrieve conversation ID based on existing ID or on newly generated one."""
53- conversation_id = query_request .conversation_id
54-
55- # Generate a new conversation ID if not provided
56- if not conversation_id :
57- conversation_id = get_suid ()
58- logger .info ("Generated new conversation ID: %s" , conversation_id )
59-
60- return conversation_id
53+ def get_agent (
54+ client : LlamaStackClient ,
55+ model_id : str ,
56+ system_prompt : str ,
57+ available_shields : list [str ],
58+ conversation_id : str | None ,
59+ ) -> tuple [Agent , str ]:
60+ """Get existing agent or create a new one with session persistence."""
61+ if conversation_id is not None :
62+ agent = _agent_cache .get (conversation_id )
63+ if agent :
64+ logger .debug ("Reusing existing agent with key: %s" , conversation_id )
65+ return agent , conversation_id
66+
67+ logger .debug ("Creating new agent" )
68+ # TODO(lucasagomes): move to ReActAgent
69+ agent = Agent (
70+ client ,
71+ model = model_id ,
72+ instructions = system_prompt ,
73+ input_shields = available_shields if available_shields else [],
74+ tools = [mcp .name for mcp in configuration .mcp_servers ],
75+ enable_session_persistence = True ,
76+ )
77+ conversation_id = agent .create_session (get_suid ())
78+ _agent_cache [conversation_id ] = agent
79+ return agent , conversation_id
6180
6281
6382@router .post ("/query" , responses = query_response )
@@ -70,8 +89,7 @@ def query_endpoint_handler(
7089 logger .info ("LLama stack config: %s" , llama_stack_config )
7190 client = get_llama_stack_client (llama_stack_config )
7291 model_id = select_model_id (client .models .list (), query_request )
73- conversation_id = retrieve_conversation_id (query_request )
74- response = retrieve_response (client , model_id , query_request , auth )
92+ response , conversation_id = retrieve_response (client , model_id , query_request , auth )
7593
7694 if not is_transcripts_enabled ():
7795 logger .debug ("Transcript collection is disabled in the configuration" )
@@ -140,7 +158,7 @@ def retrieve_response(
140158 model_id : str ,
141159 query_request : QueryRequest ,
142160 token : str ,
143- ) -> str :
161+ ) -> tuple [ str , str ] :
144162 """Retrieve response from LLMs and agents."""
145163 available_shields = [shield .identifier for shield in client .shields .list ()]
146164 if not available_shields :
@@ -161,21 +179,28 @@ def retrieve_response(
161179 if query_request .attachments :
162180 validate_attachments_metadata (query_request .attachments )
163181
164- # Build mcp_headers config dynamically for all MCP servers
165- # this will allow the agent to pass the user token to the MCP servers
182+ agent , conversation_id = get_agent (
183+ client ,
184+ model_id ,
185+ system_prompt ,
186+ available_shields ,
187+ query_request .conversation_id ,
188+ )
189+
166190 mcp_headers = {}
167191 if token :
168192 for mcp_server in configuration .mcp_servers :
169193 mcp_headers [mcp_server .url ] = {
170194 "Authorization" : f"Bearer { token } " ,
171195 }
172- # TODO(lucasagomes): move to ReActAgent
173- agent = Agent (
174- client ,
175- model = model_id ,
176- instructions = system_prompt ,
177- input_shields = available_shields if available_shields else [],
178- tools = [mcp .name for mcp in configuration .mcp_servers ],
196+
197+ vector_db_ids = [vector_db .identifier for vector_db in client .vector_dbs .list ()]
198+ response = agent .create_turn (
199+ messages = [UserMessage (role = "user" , content = query_request .query )],
200+ session_id = conversation_id ,
201+ documents = query_request .get_documents (),
202+ stream = False ,
203+ toolgroups = get_rag_toolgroups (vector_db_ids ),
179204 extra_headers = {
180205 "X-LlamaStack-Provider-Data" : json .dumps (
181206 {
@@ -184,17 +209,8 @@ def retrieve_response(
184209 ),
185210 },
186211 )
187- session_id = agent .create_session ("chat_session" )
188- logger .debug ("Session ID: %s" , session_id )
189- vector_db_ids = [vector_db .identifier for vector_db in client .vector_dbs .list ()]
190- response = agent .create_turn (
191- messages = [UserMessage (role = "user" , content = query_request .query )],
192- session_id = session_id ,
193- documents = query_request .get_documents (),
194- stream = False ,
195- toolgroups = get_rag_toolgroups (vector_db_ids ),
196- )
197- return str (response .output_message .content ) # type: ignore[union-attr]
212+
213+ return str (response .output_message .content ), conversation_id # type: ignore[union-attr]
198214
199215
200216def validate_attachments_metadata (attachments : list [Attachment ]) -> None :
0 commit comments