Skip to content

Commit dc74af7

Browse files
committed
Fixes from PR review
1 parent 7e2c1b7 commit dc74af7

File tree

3 files changed

+138
-109
lines changed

3 files changed

+138
-109
lines changed

src/app/endpoints/streaming_query.py

Lines changed: 50 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from metrics.utils import update_llm_token_count_from_turn
4848
from models.cache_entry import CacheEntry
4949
from models.config import Action
50+
from models.context import ResponseGeneratorContext
5051
from models.database.conversations import UserConversation
5152
from models.requests import QueryRequest
5253
from models.responses import ForbiddenResponse, UnauthorizedResponse
@@ -696,17 +697,8 @@ def _handle_heartbeat_event(
696697
)
697698

698699

699-
def create_agent_response_generator( # pylint: disable=too-many-arguments,too-many-locals
700-
conversation_id: str,
701-
user_id: str,
702-
model_id: str,
703-
provider_id: str,
704-
query_request: QueryRequest,
705-
metadata_map: dict[str, dict[str, Any]],
706-
client: AsyncLlamaStackClient,
707-
llama_stack_model_id: str,
708-
started_at: str,
709-
_skip_userid_check: bool,
700+
def create_agent_response_generator( # pylint: disable=too-many-locals
701+
context: ResponseGeneratorContext,
710702
) -> Any:
711703
"""
712704
Create a response generator function for Agent API streaming.
@@ -715,16 +707,7 @@ def create_agent_response_generator( # pylint: disable=too-many-arguments,too-m
715707
responses from the Agent API and yields Server-Sent Events (SSE).
716708
717709
Args:
718-
conversation_id: The conversation identifier
719-
user_id: The user identifier
720-
model_id: The model identifier
721-
provider_id: The provider identifier
722-
query_request: The query request object
723-
metadata_map: Dictionary for storing metadata from tool responses
724-
client: The Llama Stack client
725-
llama_stack_model_id: The full llama stack model ID
726-
started_at: Timestamp when the request started
727-
_skip_userid_check: Whether to skip user ID validation
710+
context: Context object containing all necessary parameters for response generation
728711
729712
Returns:
730713
An async generator function that yields SSE-formatted strings
@@ -748,10 +731,10 @@ async def response_generator(
748731
summary = TurnSummary(llm_response="No response from the model", tool_calls=[])
749732

750733
# Determine media type for response formatting
751-
media_type = query_request.media_type or MEDIA_TYPE_JSON
734+
media_type = context.query_request.media_type or MEDIA_TYPE_JSON
752735

753736
# Send start event at the beginning of the stream
754-
yield stream_start_event(conversation_id)
737+
yield stream_start_event(context.conversation_id)
755738

756739
latest_turn: Any | None = None
757740

@@ -764,10 +747,10 @@ async def response_generator(
764747
p.turn.output_message.content
765748
)
766749
latest_turn = p.turn
767-
system_prompt = get_system_prompt(query_request, configuration)
750+
system_prompt = get_system_prompt(context.query_request, configuration)
768751
try:
769752
update_llm_token_count_from_turn(
770-
p.turn, model_id, provider_id, system_prompt
753+
p.turn, context.model_id, context.provider_id, system_prompt
771754
)
772755
except Exception: # pylint: disable=broad-except
773756
logger.exception("Failed to update token usage metrics")
@@ -776,7 +759,11 @@ async def response_generator(
776759
summary.append_tool_calls_from_llama(p.step_details)
777760

778761
for event in stream_build_event(
779-
chunk, chunk_id, metadata_map, media_type, conversation_id
762+
chunk,
763+
chunk_id,
764+
context.metadata_map,
765+
media_type,
766+
context.conversation_id,
780767
):
781768
chunk_id += 1
782769
yield event
@@ -788,49 +775,53 @@ async def response_generator(
788775
else TokenCounter()
789776
)
790777

791-
yield stream_end_event(metadata_map, summary, token_usage, media_type)
778+
yield stream_end_event(context.metadata_map, summary, token_usage, media_type)
792779

793780
if not is_transcripts_enabled():
794781
logger.debug("Transcript collection is disabled in the configuration")
795782
else:
796783
store_transcript(
797-
user_id=user_id,
798-
conversation_id=conversation_id,
799-
model_id=model_id,
800-
provider_id=provider_id,
784+
user_id=context.user_id,
785+
conversation_id=context.conversation_id,
786+
model_id=context.model_id,
787+
provider_id=context.provider_id,
801788
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
802-
query=query_request.query,
803-
query_request=query_request,
789+
query=context.query_request.query,
790+
query_request=context.query_request,
804791
summary=summary,
805792
rag_chunks=create_rag_chunks_dict(summary),
806793
truncated=False, # TODO(lucasagomes): implement truncation as part
807794
# of quota work
808-
attachments=query_request.attachments or [],
795+
attachments=context.query_request.attachments or [],
809796
)
810797

811798
# Get the initial topic summary for the conversation
812799
topic_summary = None
813800
with get_session() as session:
814801
existing_conversation = (
815-
session.query(UserConversation).filter_by(id=conversation_id).first()
802+
session.query(UserConversation)
803+
.filter_by(id=context.conversation_id)
804+
.first()
816805
)
817806
if not existing_conversation:
818807
topic_summary = await get_topic_summary(
819-
query_request.query, client, llama_stack_model_id
808+
context.query_request.query,
809+
context.client,
810+
context.llama_stack_model_id,
820811
)
821812

822813
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
823814

824815
referenced_documents = create_referenced_documents_with_metadata(
825-
summary, metadata_map
816+
summary, context.metadata_map
826817
)
827818

828819
cache_entry = CacheEntry(
829-
query=query_request.query,
820+
query=context.query_request.query,
830821
response=summary.llm_response,
831-
provider=provider_id,
832-
model=model_id,
833-
started_at=started_at,
822+
provider=context.provider_id,
823+
model=context.model_id,
824+
started_at=context.started_at,
834825
completed_at=completed_at,
835826
referenced_documents=(
836827
referenced_documents if referenced_documents else None
@@ -839,25 +830,25 @@ async def response_generator(
839830

840831
store_conversation_into_cache(
841832
configuration,
842-
user_id,
843-
conversation_id,
833+
context.user_id,
834+
context.conversation_id,
844835
cache_entry,
845-
_skip_userid_check,
836+
context.skip_userid_check,
846837
topic_summary,
847838
)
848839

849840
persist_user_conversation_details(
850-
user_id=user_id,
851-
conversation_id=conversation_id,
852-
model=model_id,
853-
provider_id=provider_id,
841+
user_id=context.user_id,
842+
conversation_id=context.conversation_id,
843+
model=context.model_id,
844+
provider_id=context.provider_id,
854845
topic_summary=topic_summary,
855846
)
856847

857848
return response_generator
858849

859850

860-
async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-locals,too-many-statements,too-many-arguments
851+
async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-locals,too-many-statements,too-many-arguments,too-many-positional-arguments
861852
request: Request,
862853
query_request: QueryRequest,
863854
auth: AuthTuple,
@@ -866,7 +857,7 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc
866857
create_response_generator_func: Callable[..., Any],
867858
) -> StreamingResponse:
868859
"""
869-
Base handler for streaming query endpoints.
860+
Handle streaming query endpoints with common logic.
870861
871862
This base handler contains all the common logic for streaming query endpoints
872863
and accepts functions for API-specific behavior (Agent API vs Responses API).
@@ -937,20 +928,23 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc
937928
)
938929
metadata_map: dict[str, dict[str, Any]] = {}
939930

940-
# Create the response generator using the provided factory function
941-
response_generator = create_response_generator_func(
931+
# Create context object for response generator
932+
context = ResponseGeneratorContext(
942933
conversation_id=conversation_id,
943934
user_id=user_id,
935+
skip_userid_check=_skip_userid_check,
944936
model_id=model_id,
945937
provider_id=provider_id,
946-
query_request=query_request,
947-
metadata_map=metadata_map,
948-
client=client,
949938
llama_stack_model_id=llama_stack_model_id,
939+
query_request=query_request,
950940
started_at=started_at,
951-
_skip_userid_check=_skip_userid_check,
941+
client=client,
942+
metadata_map=metadata_map,
952943
)
953944

945+
# Create the response generator using the provided factory function
946+
response_generator = create_response_generator_func(context)
947+
954948
# Update metrics for the LLM call
955949
metrics.llm_calls_total.labels(provider_id, model_id).inc()
956950

0 commit comments

Comments
 (0)