33
44use async_stream:: stream;
55use dynamo_async_openai:: types:: {
6- ChatChoiceStream , ChatCompletionMessageToolCallChunk , ChatCompletionStreamResponseDelta ,
7- FinishReason , FunctionCallStream , Role ,
6+ ChatChoiceLogprobs , ChatChoiceStream , ChatCompletionMessageToolCallChunk ,
7+ ChatCompletionStreamResponseDelta , FinishReason , FunctionCallStream , Role ,
88} ;
99
1010use dynamo_parsers:: tool_calling:: parsers:: get_tool_parser_map;
11- use dynamo_parsers:: tool_calling:: { detect_tool_call_start, try_tool_call_parse_aggregate} ;
11+ use dynamo_parsers:: tool_calling:: {
12+ detect_tool_call_start, find_tool_call_end_position, try_tool_call_parse_aggregate,
13+ } ;
1214use dynamo_runtime:: protocols:: annotated:: Annotated ;
1315use futures:: { Stream , StreamExt } ;
1416
@@ -74,7 +76,7 @@ struct ChoiceJailState {
7476
7577fn create_choice_stream (
7678 content : & str ,
77- role : Role ,
79+ role : Option < Role > ,
7880 index : u32 ,
7981 finish_reason : Option < FinishReason > ,
8082 logprobs : Option < ChatChoiceLogprobs > ,
@@ -83,7 +85,7 @@ fn create_choice_stream(
8385 ChatChoiceStream {
8486 index,
8587 delta : ChatCompletionStreamResponseDelta {
86- role : Some ( role ) ,
88+ role,
8789 content : Some ( content. to_string ( ) ) ,
8890 tool_calls : None ,
8991 function_call : None ,
@@ -327,7 +329,7 @@ impl ChoiceJailState {
327329 #[ allow( deprecated) ]
328330 let dummy_choice = create_choice_stream (
329331 & self . accumulated_content ,
330- Role :: Assistant ,
332+ Some ( Role :: Assistant ) ,
331333 self . index ,
332334 None ,
333335 None ,
@@ -648,7 +650,7 @@ impl JailedStream {
648650 if let Some ( parser) = & self . tool_call_parser {
649651 if let Ok ( ( _, _) ) = try_tool_call_parse_aggregate ( accumulated_content, Some ( parser) )
650652 {
651- let split_pos = self . find_tool_call_end_position ( accumulated_content, parser) ;
653+ let split_pos = find_tool_call_end_position ( accumulated_content, Some ( parser) ) ;
652654 ( true , split_pos)
653655 } else {
654656 ( false , accumulated_content. len ( ) )
@@ -732,63 +734,6 @@ impl JailedStream {
732734 }
733735 false
734736 }
735-
736- /// Find the exact position where the tool call ends for splitting content
737- /// This handles the early exit case where we have trailing content after the tool call
738- fn find_tool_call_end_position ( & self , content : & str , parser : & str ) -> usize {
739- match parser {
740- "hermes" => {
741- // For Hermes, look for </tool_call> marker
742- if let Some ( pos) = content. find ( "</tool_call>" ) {
743- pos + "</tool_call>" . len ( )
744- } else {
745- content. len ( )
746- }
747- }
748- "nemotron_deci" => {
749- // For Nemotron, look for </TOOLCALL> marker
750- if let Some ( pos) = content. find ( "</TOOLCALL>" ) {
751- pos + "</TOOLCALL>" . len ( )
752- } else {
753- content. len ( )
754- }
755- }
756- "mistral" => {
757- // For Mistral, look for [/TOOL_CALLS] marker or end of JSON array
758- if let Some ( pos) = content. find ( "[/TOOL_CALLS]" ) {
759- pos + "[/TOOL_CALLS]" . len ( )
760- } else if let Some ( pos) = content. rfind ( ']' ) {
761- // Find the last ] which should be the end of the tool calls array
762- pos + 1
763- } else {
764- content. len ( )
765- }
766- }
767- "phi4" => {
768- // For Phi4, look for <|tool_call|> end marker
769- if let Some ( pos) = content. rfind ( "<|tool_call|>" ) {
770- // Look for the next occurrence after this position
771- if let Some ( end_pos) = content[ pos..] . find ( ">" ) {
772- pos + end_pos + 1
773- } else {
774- content. len ( )
775- }
776- } else {
777- content. len ( )
778- }
779- }
780- "llama3_json" => {
781- // For Llama3 JSON, there's no explicit end marker
782- // The end is determined by complete JSON parsing
783- // Return full content length to avoid early splitting
784- content. len ( )
785- }
786- _ => {
787- // Unknown parser, default to full content
788- content. len ( )
789- }
790- }
791- }
792737}
793738
794739/// Builder for configuring a JailedStream
0 commit comments