4747from metrics .utils import update_llm_token_count_from_turn
4848from models .cache_entry import CacheEntry
4949from models .config import Action
50+ from models .context import ResponseGeneratorContext
5051from models .database .conversations import UserConversation
5152from models .requests import QueryRequest
5253from models .responses import ForbiddenResponse , UnauthorizedResponse
@@ -696,17 +697,8 @@ def _handle_heartbeat_event(
696697 )
697698
698699
699- def create_agent_response_generator ( # pylint: disable=too-many-arguments,too-many-locals
700- conversation_id : str ,
701- user_id : str ,
702- model_id : str ,
703- provider_id : str ,
704- query_request : QueryRequest ,
705- metadata_map : dict [str , dict [str , Any ]],
706- client : AsyncLlamaStackClient ,
707- llama_stack_model_id : str ,
708- started_at : str ,
709- _skip_userid_check : bool ,
700+ def create_agent_response_generator ( # pylint: disable=too-many-locals
701+ context : ResponseGeneratorContext ,
710702) -> Any :
711703 """
712704 Create a response generator function for Agent API streaming.
@@ -715,16 +707,7 @@ def create_agent_response_generator( # pylint: disable=too-many-arguments,too-m
715707 responses from the Agent API and yields Server-Sent Events (SSE).
716708
717709 Args:
718- conversation_id: The conversation identifier
719- user_id: The user identifier
720- model_id: The model identifier
721- provider_id: The provider identifier
722- query_request: The query request object
723- metadata_map: Dictionary for storing metadata from tool responses
724- client: The Llama Stack client
725- llama_stack_model_id: The full llama stack model ID
726- started_at: Timestamp when the request started
727- _skip_userid_check: Whether to skip user ID validation
710+ context: Context object containing all necessary parameters for response generation
728711
729712 Returns:
730713 An async generator function that yields SSE-formatted strings
@@ -748,10 +731,10 @@ async def response_generator(
748731 summary = TurnSummary (llm_response = "No response from the model" , tool_calls = [])
749732
750733 # Determine media type for response formatting
751- media_type = query_request .media_type or MEDIA_TYPE_JSON
734+ media_type = context . query_request .media_type or MEDIA_TYPE_JSON
752735
753736 # Send start event at the beginning of the stream
754- yield stream_start_event (conversation_id )
737+ yield stream_start_event (context . conversation_id )
755738
756739 latest_turn : Any | None = None
757740
@@ -764,10 +747,10 @@ async def response_generator(
764747 p .turn .output_message .content
765748 )
766749 latest_turn = p .turn
767- system_prompt = get_system_prompt (query_request , configuration )
750+ system_prompt = get_system_prompt (context . query_request , configuration )
768751 try :
769752 update_llm_token_count_from_turn (
770- p .turn , model_id , provider_id , system_prompt
753+ p .turn , context . model_id , context . provider_id , system_prompt
771754 )
772755 except Exception : # pylint: disable=broad-except
773756 logger .exception ("Failed to update token usage metrics" )
@@ -776,7 +759,11 @@ async def response_generator(
776759 summary .append_tool_calls_from_llama (p .step_details )
777760
778761 for event in stream_build_event (
779- chunk , chunk_id , metadata_map , media_type , conversation_id
762+ chunk ,
763+ chunk_id ,
764+ context .metadata_map ,
765+ media_type ,
766+ context .conversation_id ,
780767 ):
781768 chunk_id += 1
782769 yield event
@@ -788,49 +775,53 @@ async def response_generator(
788775 else TokenCounter ()
789776 )
790777
791- yield stream_end_event (metadata_map , summary , token_usage , media_type )
778+ yield stream_end_event (context . metadata_map , summary , token_usage , media_type )
792779
793780 if not is_transcripts_enabled ():
794781 logger .debug ("Transcript collection is disabled in the configuration" )
795782 else :
796783 store_transcript (
797- user_id = user_id ,
798- conversation_id = conversation_id ,
799- model_id = model_id ,
800- provider_id = provider_id ,
784+ user_id = context . user_id ,
785+ conversation_id = context . conversation_id ,
786+ model_id = context . model_id ,
787+ provider_id = context . provider_id ,
801788 query_is_valid = True , # TODO(lucasagomes): implement as part of query validation
802- query = query_request .query ,
803- query_request = query_request ,
789+ query = context . query_request .query ,
790+ query_request = context . query_request ,
804791 summary = summary ,
805792 rag_chunks = create_rag_chunks_dict (summary ),
806793 truncated = False , # TODO(lucasagomes): implement truncation as part
807794 # of quota work
808- attachments = query_request .attachments or [],
795+ attachments = context . query_request .attachments or [],
809796 )
810797
811798 # Get the initial topic summary for the conversation
812799 topic_summary = None
813800 with get_session () as session :
814801 existing_conversation = (
815- session .query (UserConversation ).filter_by (id = conversation_id ).first ()
802+ session .query (UserConversation )
803+ .filter_by (id = context .conversation_id )
804+ .first ()
816805 )
817806 if not existing_conversation :
818807 topic_summary = await get_topic_summary (
819- query_request .query , client , llama_stack_model_id
808+ context .query_request .query ,
809+ context .client ,
810+ context .llama_stack_model_id ,
820811 )
821812
822813 completed_at = datetime .now (UTC ).strftime ("%Y-%m-%dT%H:%M:%SZ" )
823814
824815 referenced_documents = create_referenced_documents_with_metadata (
825- summary , metadata_map
816+ summary , context . metadata_map
826817 )
827818
828819 cache_entry = CacheEntry (
829- query = query_request .query ,
820+ query = context . query_request .query ,
830821 response = summary .llm_response ,
831- provider = provider_id ,
832- model = model_id ,
833- started_at = started_at ,
822+ provider = context . provider_id ,
823+ model = context . model_id ,
824+ started_at = context . started_at ,
834825 completed_at = completed_at ,
835826 referenced_documents = (
836827 referenced_documents if referenced_documents else None
@@ -839,25 +830,25 @@ async def response_generator(
839830
840831 store_conversation_into_cache (
841832 configuration ,
842- user_id ,
843- conversation_id ,
833+ context . user_id ,
834+ context . conversation_id ,
844835 cache_entry ,
845- _skip_userid_check ,
836+ context . skip_userid_check ,
846837 topic_summary ,
847838 )
848839
849840 persist_user_conversation_details (
850- user_id = user_id ,
851- conversation_id = conversation_id ,
852- model = model_id ,
853- provider_id = provider_id ,
841+ user_id = context . user_id ,
842+ conversation_id = context . conversation_id ,
843+ model = context . model_id ,
844+ provider_id = context . provider_id ,
854845 topic_summary = topic_summary ,
855846 )
856847
857848 return response_generator
858849
859850
860- async def streaming_query_endpoint_handler_base ( # pylint: disable=too-many-locals,too-many-statements,too-many-arguments
851+ async def streaming_query_endpoint_handler_base ( # pylint: disable=too-many-locals,too-many-statements,too-many-arguments,too-many-positional-arguments
861852 request : Request ,
862853 query_request : QueryRequest ,
863854 auth : AuthTuple ,
@@ -866,7 +857,7 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc
866857 create_response_generator_func : Callable [..., Any ],
867858) -> StreamingResponse :
868859 """
869- Base handler for streaming query endpoints.
860+ Handle streaming query endpoints with common logic .
870861
871862 This base handler contains all the common logic for streaming query endpoints
872863 and accepts functions for API-specific behavior (Agent API vs Responses API).
@@ -937,20 +928,23 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc
937928 )
938929 metadata_map : dict [str , dict [str , Any ]] = {}
939930
940- # Create the response generator using the provided factory function
941- response_generator = create_response_generator_func (
931+ # Create context object for response generator
932+ context = ResponseGeneratorContext (
942933 conversation_id = conversation_id ,
943934 user_id = user_id ,
935+ skip_userid_check = _skip_userid_check ,
944936 model_id = model_id ,
945937 provider_id = provider_id ,
946- query_request = query_request ,
947- metadata_map = metadata_map ,
948- client = client ,
949938 llama_stack_model_id = llama_stack_model_id ,
939+ query_request = query_request ,
950940 started_at = started_at ,
951- _skip_userid_check = _skip_userid_check ,
941+ client = client ,
942+ metadata_map = metadata_map ,
952943 )
953944
945+ # Create the response generator using the provided factory function
946+ response_generator = create_response_generator_func (context )
947+
954948 # Update metrics for the LLM call
955949 metrics .llm_calls_total .labels (provider_id , model_id ).inc ()
956950
0 commit comments