Skip to content

Commit 2ae2010

Browse files
authored
chore: jail stream optimizations (v1) (#3195)
Signed-off-by: ayushag <ayushag@nvidia.com>
1 parent 6ba64c3 commit 2ae2010

File tree

6 files changed

+189
-171
lines changed

6 files changed

+189
-171
lines changed

lib/llm/src/protocols/openai/chat_completions/jail.rs

Lines changed: 94 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
use async_stream::stream;
55
use dynamo_async_openai::types::{
6-
ChatChoiceStream, ChatCompletionMessageToolCallChunk, ChatCompletionStreamResponseDelta,
7-
FinishReason, FunctionCallStream, Role,
6+
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionMessageToolCallChunk,
7+
ChatCompletionStreamResponseDelta, FinishReason, FunctionCallStream, Role,
88
};
99

1010
use dynamo_parsers::tool_calling::parsers::get_tool_parser_map;
11-
use dynamo_parsers::tool_calling::{detect_tool_call_start, try_tool_call_parse_aggregate};
11+
use dynamo_parsers::tool_calling::{
12+
detect_tool_call_start, find_tool_call_end_position, try_tool_call_parse_aggregate,
13+
};
1214
use dynamo_runtime::protocols::annotated::Annotated;
1315
use futures::{Stream, StreamExt};
1416

@@ -72,6 +74,30 @@ struct ChoiceJailState {
7274
partial_match_buffer: String,
7375
}
7476

77+
fn create_choice_stream(
78+
index: u32,
79+
role: Option<Role>,
80+
content: &str,
81+
tool_calls: Option<Vec<ChatCompletionMessageToolCallChunk>>,
82+
finish_reason: Option<FinishReason>,
83+
logprobs: Option<ChatChoiceLogprobs>,
84+
) -> ChatChoiceStream {
85+
#[allow(deprecated)]
86+
ChatChoiceStream {
87+
index,
88+
delta: ChatCompletionStreamResponseDelta {
89+
role,
90+
content: Some(content.to_string()),
91+
tool_calls,
92+
function_call: None,
93+
refusal: None,
94+
reasoning_content: None,
95+
},
96+
finish_reason,
97+
logprobs,
98+
}
99+
}
100+
75101
impl ChoiceJailState {
76102
/// Create a new jail state for a choice
77103
fn new(index: u32) -> Self {
@@ -120,19 +146,14 @@ impl ChoiceJailState {
120146
// Emit prefix if any
121147
if !prefix.is_empty() {
122148
#[allow(deprecated)]
123-
let prefix_choice = ChatChoiceStream {
124-
index: choice.index,
125-
delta: ChatCompletionStreamResponseDelta {
126-
role: choice.delta.role,
127-
content: Some(prefix),
128-
tool_calls: None,
129-
function_call: None,
130-
refusal: None,
131-
reasoning_content: None,
132-
},
133-
finish_reason: None,
134-
logprobs: choice.logprobs.clone(),
135-
};
149+
let prefix_choice = create_choice_stream(
150+
choice.index,
151+
choice.delta.role,
152+
&prefix,
153+
None,
154+
None,
155+
choice.logprobs.clone(),
156+
);
136157
emissions.push(ChoiceEmission::PassThrough(prefix_choice));
137158
}
138159

@@ -165,19 +186,14 @@ impl ChoiceJailState {
165186
// Handle trailing content if any
166187
if !trailing_part.is_empty() {
167188
#[allow(deprecated)]
168-
let trailing_choice = ChatChoiceStream {
169-
index: choice.index,
170-
delta: ChatCompletionStreamResponseDelta {
171-
role: choice.delta.role,
172-
content: Some(trailing_part.to_string()),
173-
tool_calls: None,
174-
function_call: None,
175-
refusal: None,
176-
reasoning_content: None,
177-
},
178-
finish_reason: None,
179-
logprobs: choice.logprobs.clone(),
180-
};
189+
let trailing_choice = create_choice_stream(
190+
choice.index,
191+
choice.delta.role,
192+
trailing_part,
193+
None,
194+
None,
195+
choice.logprobs.clone(),
196+
);
181197
emissions.push(ChoiceEmission::Trailing(trailing_choice));
182198
}
183199
} else {
@@ -202,19 +218,14 @@ impl ChoiceJailState {
202218
// Emit the safe prefix
203219
if !prefix.is_empty() {
204220
#[allow(deprecated)]
205-
let prefix_choice = ChatChoiceStream {
206-
index: choice.index,
207-
delta: ChatCompletionStreamResponseDelta {
208-
role: choice.delta.role,
209-
content: Some(prefix),
210-
tool_calls: None,
211-
function_call: None,
212-
refusal: None,
213-
reasoning_content: None,
214-
},
215-
finish_reason: None,
216-
logprobs: choice.logprobs.clone(),
217-
};
221+
let prefix_choice = create_choice_stream(
222+
choice.index,
223+
choice.delta.role,
224+
&prefix,
225+
None,
226+
None,
227+
choice.logprobs.clone(),
228+
);
218229
emissions.push(ChoiceEmission::PassThrough(prefix_choice));
219230
}
220231

@@ -250,19 +261,14 @@ impl ChoiceJailState {
250261
// No markers - emit everything
251262
if !content.is_empty() {
252263
#[allow(deprecated)]
253-
let pass_through_choice = ChatChoiceStream {
254-
index: choice.index,
255-
delta: ChatCompletionStreamResponseDelta {
256-
role: choice.delta.role,
257-
content: Some(content),
258-
tool_calls: None,
259-
function_call: None,
260-
refusal: None,
261-
reasoning_content: None,
262-
},
263-
finish_reason: None,
264-
logprobs: choice.logprobs.clone(),
265-
};
264+
let pass_through_choice = create_choice_stream(
265+
choice.index,
266+
choice.delta.role,
267+
&content,
268+
None,
269+
None,
270+
choice.logprobs.clone(),
271+
);
266272
emissions.push(ChoiceEmission::PassThrough(pass_through_choice));
267273
}
268274
self.partial_match_buffer.clear();
@@ -300,19 +306,14 @@ impl ChoiceJailState {
300306
// Handle trailing content if any
301307
if !trailing_part.is_empty() {
302308
#[allow(deprecated)]
303-
let trailing_choice = ChatChoiceStream {
304-
index: choice.index,
305-
delta: ChatCompletionStreamResponseDelta {
306-
role: choice.delta.role,
307-
content: Some(trailing_part.to_string()),
308-
tool_calls: None,
309-
function_call: None,
310-
refusal: None,
311-
reasoning_content: None,
312-
},
313-
finish_reason: None,
314-
logprobs: choice.logprobs.clone(),
315-
};
309+
let trailing_choice = create_choice_stream(
310+
choice.index,
311+
choice.delta.role,
312+
trailing_part,
313+
None,
314+
None,
315+
choice.logprobs.clone(),
316+
);
316317
emissions.push(ChoiceEmission::Trailing(trailing_choice));
317318
}
318319

@@ -335,19 +336,14 @@ impl ChoiceJailState {
335336

336337
// Create a dummy choice for the method call
337338
#[allow(deprecated)]
338-
let dummy_choice = ChatChoiceStream {
339-
index: self.index,
340-
delta: ChatCompletionStreamResponseDelta {
341-
role: Some(Role::Assistant),
342-
content: None,
343-
tool_calls: None,
344-
function_call: None,
345-
refusal: None,
346-
reasoning_content: None,
347-
},
348-
finish_reason: None,
349-
logprobs: None,
350-
};
339+
let dummy_choice = create_choice_stream(
340+
self.index,
341+
Some(Role::Assistant),
342+
&self.accumulated_content,
343+
None,
344+
None,
345+
None,
346+
);
351347

352348
let final_choice = jail_stream
353349
.create_tool_call_choice(self.index, &self.accumulated_content, &dummy_choice)
@@ -663,7 +659,7 @@ impl JailedStream {
663659
if let Ok((_, _)) =
664660
try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await
665661
{
666-
let split_pos = self.find_tool_call_end_position(accumulated_content, parser);
662+
let split_pos = find_tool_call_end_position(accumulated_content, Some(parser));
667663
(true, split_pos)
668664
} else {
669665
(false, accumulated_content.len())
@@ -704,37 +700,25 @@ impl JailedStream {
704700
.collect();
705701

706702
// Create choice with tool calls
707-
#[allow(deprecated)]
708-
return ChatChoiceStream {
709-
index: choice_index,
710-
delta: ChatCompletionStreamResponseDelta {
711-
role: Some(Role::Assistant),
712-
content: normal_text.filter(|t| !t.is_empty()),
713-
tool_calls: Some(tool_call_chunks),
714-
function_call: None,
715-
refusal: None,
716-
reasoning_content: None,
717-
},
718-
finish_reason: Some(FinishReason::ToolCalls),
719-
logprobs: None,
720-
};
703+
return create_choice_stream(
704+
choice_index,
705+
Some(Role::Assistant),
706+
normal_text.as_deref().unwrap_or(""),
707+
Some(tool_call_chunks),
708+
Some(FinishReason::ToolCalls),
709+
None,
710+
);
721711
}
722712

723713
// No tool calls found or parsing failed, return content choice
724-
#[allow(deprecated)]
725-
ChatChoiceStream {
726-
index: choice_index,
727-
delta: ChatCompletionStreamResponseDelta {
728-
role: Some(Role::Assistant),
729-
content: Some(accumulated_content.to_string()),
730-
tool_calls: None,
731-
function_call: None,
732-
refusal: None,
733-
reasoning_content: None,
734-
},
735-
finish_reason: None,
736-
logprobs: base_choice.logprobs.clone(),
737-
}
714+
create_choice_stream(
715+
choice_index,
716+
Some(Role::Assistant),
717+
accumulated_content,
718+
None,
719+
None,
720+
base_choice.logprobs.clone(),
721+
)
738722
}
739723

740724
/// Check if accumulated content contains complete tool calls that can be parsed
@@ -750,63 +734,6 @@ impl JailedStream {
750734
}
751735
false
752736
}
753-
754-
/// Find the exact position where the tool call ends for splitting content
755-
/// This handles the early exit case where we have trailing content after the tool call
756-
fn find_tool_call_end_position(&self, content: &str, parser: &str) -> usize {
757-
match parser {
758-
"hermes" => {
759-
// For Hermes, look for </tool_call> marker
760-
if let Some(pos) = content.find("</tool_call>") {
761-
pos + "</tool_call>".len()
762-
} else {
763-
content.len()
764-
}
765-
}
766-
"nemotron_deci" => {
767-
// For Nemotron, look for </TOOLCALL> marker
768-
if let Some(pos) = content.find("</TOOLCALL>") {
769-
pos + "</TOOLCALL>".len()
770-
} else {
771-
content.len()
772-
}
773-
}
774-
"mistral" => {
775-
// For Mistral, look for [/TOOL_CALLS] marker or end of JSON array
776-
if let Some(pos) = content.find("[/TOOL_CALLS]") {
777-
pos + "[/TOOL_CALLS]".len()
778-
} else if let Some(pos) = content.rfind(']') {
779-
// Find the last ] which should be the end of the tool calls array
780-
pos + 1
781-
} else {
782-
content.len()
783-
}
784-
}
785-
"phi4" => {
786-
// For Phi4, look for <|tool_call|> end marker
787-
if let Some(pos) = content.rfind("<|tool_call|>") {
788-
// Look for the next occurrence after this position
789-
if let Some(end_pos) = content[pos..].find(">") {
790-
pos + end_pos + 1
791-
} else {
792-
content.len()
793-
}
794-
} else {
795-
content.len()
796-
}
797-
}
798-
"llama3_json" => {
799-
// For Llama3 JSON, there's no explicit end marker
800-
// The end is determined by complete JSON parsing
801-
// Return full content length to avoid early splitting
802-
content.len()
803-
}
804-
_ => {
805-
// Unknown parser, default to full content
806-
content.len()
807-
}
808-
}
809-
}
810737
}
811738

812739
/// Builder for configuring a JailedStream

lib/parsers/src/tool_calling/harmony/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,20 @@
33

44
pub mod harmony_parser;
55

6+
pub use super::config::JsonParserConfig;
67
pub use super::{config, response};
78
pub use harmony_parser::{
89
detect_tool_call_start_harmony, parse_tool_calls_harmony, parse_tool_calls_harmony_complete,
910
};
11+
12+
pub fn find_tool_call_end_position_harmony(chunk: &str, config: &JsonParserConfig) -> usize {
13+
let end_token = config
14+
.tool_call_end_tokens
15+
.first()
16+
.map_or("<|call|>", |v| v);
17+
if let Some(pos) = chunk.rfind(end_token) {
18+
pos + end_token.len()
19+
} else {
20+
chunk.len()
21+
}
22+
}

lib/parsers/src/tool_calling/json/mod.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,31 @@ pub fn detect_tool_call_start_json(chunk: &str, config: &JsonParserConfig) -> bo
4141
JsonParserType::DeepseekV31 => detect_tool_call_start_deepseek_v3_1(chunk, config),
4242
}
4343
}
44+
45+
pub fn find_tool_call_end_position_json(
46+
chunk: &str,
47+
parser: &str,
48+
config: &JsonParserConfig,
49+
) -> usize {
50+
match parser {
51+
"hermes" | "nemotron_deci" => {
52+
if let Some(end_token) = config.tool_call_end_tokens.first() {
53+
if let Some(pos) = chunk.find(end_token) {
54+
pos + end_token.len()
55+
} else {
56+
chunk.len()
57+
}
58+
} else {
59+
chunk.len()
60+
}
61+
}
62+
"mistral" | "phi4" => {
63+
if let Some(pos) = chunk.rfind(']') {
64+
pos + 1
65+
} else {
66+
chunk.len()
67+
}
68+
}
69+
_ => chunk.len(),
70+
}
71+
}

0 commit comments

Comments
 (0)