Skip to content

Commit 54b4149

Browse files
committed
Do not create new session if conversation_id is provided
1 parent 9c2950f commit 54b4149

File tree

7 files changed

+298
-162
lines changed

7 files changed

+298
-162
lines changed

pdm.lock

Lines changed: 14 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies = [
1111
"uvicorn>=0.34.3",
1212
"llama-stack>=0.2.13",
1313
"rich>=14.0.0",
14+
"expiringdict>=1.2.2",
1415
]
1516

1617
[tool.pdm]

src/app/endpoints/query.py

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import os
77
from pathlib import Path
88
from typing import Any
9+
from expiringdict import ExpiringDict
910
from llama_stack_client.lib.agents.agent import Agent
10-
1111
from llama_stack_client import LlamaStackClient # type: ignore
1212
from llama_stack_client.types import UserMessage # type: ignore
1313
from llama_stack_client.types.agents.turn_create_params import (
@@ -30,6 +30,8 @@
3030
logger = logging.getLogger("app.endpoints.handlers")
3131
router = APIRouter(tags=["query"])
3232

33+
# Global agent registry to persist agents across requests
34+
_agent_cache: dict[str, Agent] = ExpiringDict(max_len=1000, max_age_seconds=3600)
3335

3436
query_response: dict[int | str, dict[str, Any]] = {
3537
200: {
@@ -48,16 +50,33 @@ def is_transcripts_enabled() -> bool:
4850
return not configuration.user_data_collection_configuration.transcripts_disabled
4951

5052

51-
def retrieve_conversation_id(query_request: QueryRequest) -> str:
52-
"""Retrieve conversation ID based on existing ID or on newly generated one."""
53-
conversation_id = query_request.conversation_id
54-
55-
# Generate a new conversation ID if not provided
56-
if not conversation_id:
57-
conversation_id = get_suid()
58-
logger.info("Generated new conversation ID: %s", conversation_id)
59-
60-
return conversation_id
53+
def get_agent(
54+
client: LlamaStackClient,
55+
model_id: str,
56+
system_prompt: str,
57+
available_shields: list[str],
58+
conversation_id: str | None,
59+
) -> tuple[Agent, str]:
60+
"""Get existing agent or create a new one with session persistence."""
61+
if conversation_id is not None:
62+
agent = _agent_cache.get(conversation_id)
63+
if agent:
64+
logger.debug("Reusing existing agent with key: %s", conversation_id)
65+
return agent, conversation_id
66+
67+
logger.debug("Creating new agent")
68+
# TODO(lucasagomes): move to ReActAgent
69+
agent = Agent(
70+
client,
71+
model=model_id,
72+
instructions=system_prompt,
73+
input_shields=available_shields if available_shields else [],
74+
tools=[mcp.name for mcp in configuration.mcp_servers],
75+
enable_session_persistence=True,
76+
)
77+
conversation_id = agent.create_session(get_suid())
78+
_agent_cache[conversation_id] = agent
79+
return agent, conversation_id
6180

6281

6382
@router.post("/query", responses=query_response)
@@ -70,8 +89,7 @@ def query_endpoint_handler(
7089
logger.info("LLama stack config: %s", llama_stack_config)
7190
client = get_llama_stack_client(llama_stack_config)
7291
model_id = select_model_id(client.models.list(), query_request)
73-
conversation_id = retrieve_conversation_id(query_request)
74-
response = retrieve_response(client, model_id, query_request, auth)
92+
response, conversation_id = retrieve_response(client, model_id, query_request, auth)
7593

7694
if not is_transcripts_enabled():
7795
logger.debug("Transcript collection is disabled in the configuration")
@@ -140,7 +158,7 @@ def retrieve_response(
140158
model_id: str,
141159
query_request: QueryRequest,
142160
token: str,
143-
) -> str:
161+
) -> tuple[str, str]:
144162
"""Retrieve response from LLMs and agents."""
145163
available_shields = [shield.identifier for shield in client.shields.list()]
146164
if not available_shields:
@@ -161,21 +179,28 @@ def retrieve_response(
161179
if query_request.attachments:
162180
validate_attachments_metadata(query_request.attachments)
163181

164-
# Build mcp_headers config dynamically for all MCP servers
165-
# this will allow the agent to pass the user token to the MCP servers
182+
agent, conversation_id = get_agent(
183+
client,
184+
model_id,
185+
system_prompt,
186+
available_shields,
187+
query_request.conversation_id,
188+
)
189+
166190
mcp_headers = {}
167191
if token:
168192
for mcp_server in configuration.mcp_servers:
169193
mcp_headers[mcp_server.url] = {
170194
"Authorization": f"Bearer {token}",
171195
}
172-
# TODO(lucasagomes): move to ReActAgent
173-
agent = Agent(
174-
client,
175-
model=model_id,
176-
instructions=system_prompt,
177-
input_shields=available_shields if available_shields else [],
178-
tools=[mcp.name for mcp in configuration.mcp_servers],
196+
197+
vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
198+
response = agent.create_turn(
199+
messages=[UserMessage(role="user", content=query_request.query)],
200+
session_id=conversation_id,
201+
documents=query_request.get_documents(),
202+
stream=False,
203+
toolgroups=get_rag_toolgroups(vector_db_ids),
179204
extra_headers={
180205
"X-LlamaStack-Provider-Data": json.dumps(
181206
{
@@ -184,17 +209,8 @@ def retrieve_response(
184209
),
185210
},
186211
)
187-
session_id = agent.create_session("chat_session")
188-
logger.debug("Session ID: %s", session_id)
189-
vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
190-
response = agent.create_turn(
191-
messages=[UserMessage(role="user", content=query_request.query)],
192-
session_id=session_id,
193-
documents=query_request.get_documents(),
194-
stream=False,
195-
toolgroups=get_rag_toolgroups(vector_db_ids),
196-
)
197-
return str(response.output_message.content) # type: ignore[union-attr]
212+
213+
return str(response.output_message.content), conversation_id # type: ignore[union-attr]
198214

199215

200216
def validate_attachments_metadata(attachments: list[Attachment]) -> None:

src/app/endpoints/streaming_query.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import logging
55
from typing import Any, AsyncIterator
66

7+
from expiringdict import ExpiringDict
8+
79
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
810
from llama_stack_client import AsyncLlamaStackClient # type: ignore
911
from llama_stack_client.types import UserMessage # type: ignore
@@ -17,12 +19,12 @@
1719
import constants
1820
from utils.auth import auth_dependency
1921
from utils.common import retrieve_user_id
22+
from utils.suid import get_suid
2023

2124

2225
from app.endpoints.query import (
2326
get_rag_toolgroups,
2427
is_transcripts_enabled,
25-
retrieve_conversation_id,
2628
store_transcript,
2729
select_model_id,
2830
validate_attachments_metadata,
@@ -31,6 +33,37 @@
3133
logger = logging.getLogger("app.endpoints.handlers")
3234
router = APIRouter(tags=["streaming_query"])
3335

36+
# Global agent registry to persist agents across requests
37+
_agent_cache: dict[str, AsyncAgent] = ExpiringDict(max_len=1000, max_age_seconds=3600)
38+
39+
40+
async def get_agent(
41+
client: AsyncLlamaStackClient,
42+
model_id: str,
43+
system_prompt: str,
44+
available_shields: list[str],
45+
conversation_id: str | None,
46+
) -> tuple[AsyncAgent, str]:
47+
"""Get existing agent or create a new one with session persistence."""
48+
if conversation_id is not None:
49+
agent = _agent_cache.get(conversation_id)
50+
if agent:
51+
logger.debug("Reusing existing agent with key: %s", conversation_id)
52+
return agent, conversation_id
53+
54+
logger.debug("Creating new agent")
55+
agent = AsyncAgent(
56+
client, # type: ignore[arg-type]
57+
model=model_id,
58+
instructions=system_prompt,
59+
input_shields=available_shields if available_shields else [],
60+
tools=[], # mcp config ?
61+
enable_session_persistence=True,
62+
)
63+
conversation_id = await agent.create_session(get_suid())
64+
_agent_cache[conversation_id] = agent
65+
return agent, conversation_id
66+
3467

3568
def format_stream_data(d: dict) -> str:
3669
"""Format outbound data in the Event Stream Format."""
@@ -132,8 +165,7 @@ async def streaming_query_endpoint_handler(
132165
logger.info("LLama stack config: %s", llama_stack_config)
133166
client = await get_async_llama_stack_client(llama_stack_config)
134167
model_id = select_model_id(await client.models.list(), query_request)
135-
conversation_id = retrieve_conversation_id(query_request)
136-
response = await retrieve_response(client, model_id, query_request)
168+
response, conversation_id = await retrieve_response(client, model_id, query_request)
137169

138170
async def response_generator(turn_response: Any) -> AsyncIterator[str]:
139171
"""Generate SSE formatted streaming response."""
@@ -173,7 +205,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
173205

174206
async def retrieve_response(
175207
client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest
176-
) -> Any:
208+
) -> tuple[Any, str]:
177209
"""Retrieve response from LLMs and agents."""
178210
available_shields = [shield.identifier for shield in await client.shields.list()]
179211
if not available_shields:
@@ -194,24 +226,24 @@ async def retrieve_response(
194226
if query_request.attachments:
195227
validate_attachments_metadata(query_request.attachments)
196228

197-
agent = AsyncAgent(
198-
client, # type: ignore[arg-type]
199-
model=model_id,
200-
instructions=system_prompt,
201-
input_shields=available_shields if available_shields else [],
202-
tools=[],
229+
agent, conversation_id = await get_agent(
230+
client,
231+
model_id,
232+
system_prompt,
233+
available_shields,
234+
query_request.conversation_id,
203235
)
204-
session_id = await agent.create_session("chat_session")
205-
logger.debug("Session ID: %s", session_id)
236+
237+
logger.debug("Session ID: %s", conversation_id)
206238
vector_db_ids = [
207239
vector_db.identifier for vector_db in await client.vector_dbs.list()
208240
]
209241
response = await agent.create_turn(
210242
messages=[UserMessage(role="user", content=query_request.query)],
211-
session_id=session_id,
243+
session_id=conversation_id,
212244
documents=query_request.get_documents(),
213245
stream=True,
214246
toolgroups=get_rag_toolgroups(vector_db_ids),
215247
)
216248

217-
return response
249+
return response, conversation_id

0 commit comments

Comments
 (0)