Skip to content

Commit 41c7ef8

Browse files
committed
chore: move tool end detection to parser lib
Signed-off-by: ayushag <ayushag@nvidia.com>
1 parent 583ff5b commit 41c7ef8

File tree

6 files changed

+89
-68
lines changed

6 files changed

+89
-68
lines changed

lib/llm/src/protocols/openai/chat_completions/jail.rs

Lines changed: 9 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
use async_stream::stream;
55
use dynamo_async_openai::types::{
6-
ChatChoiceStream, ChatCompletionMessageToolCallChunk, ChatCompletionStreamResponseDelta,
7-
FinishReason, FunctionCallStream, Role,
6+
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionMessageToolCallChunk,
7+
ChatCompletionStreamResponseDelta, FinishReason, FunctionCallStream, Role,
88
};
99

1010
use dynamo_parsers::tool_calling::parsers::get_tool_parser_map;
11-
use dynamo_parsers::tool_calling::{detect_tool_call_start, try_tool_call_parse_aggregate};
11+
use dynamo_parsers::tool_calling::{
12+
detect_tool_call_start, find_tool_call_end_position, try_tool_call_parse_aggregate,
13+
};
1214
use dynamo_runtime::protocols::annotated::Annotated;
1315
use futures::{Stream, StreamExt};
1416

@@ -74,7 +76,7 @@ struct ChoiceJailState {
7476

7577
fn create_choice_stream(
7678
content: &str,
77-
role: Role,
79+
role: Option<Role>,
7880
index: u32,
7981
finish_reason: Option<FinishReason>,
8082
logprobs: Option<ChatChoiceLogprobs>,
@@ -83,7 +85,7 @@ fn create_choice_stream(
8385
ChatChoiceStream {
8486
index,
8587
delta: ChatCompletionStreamResponseDelta {
86-
role: Some(role),
88+
role,
8789
content: Some(content.to_string()),
8890
tool_calls: None,
8991
function_call: None,
@@ -327,7 +329,7 @@ impl ChoiceJailState {
327329
#[allow(deprecated)]
328330
let dummy_choice = create_choice_stream(
329331
&self.accumulated_content,
330-
Role::Assistant,
332+
Some(Role::Assistant),
331333
self.index,
332334
None,
333335
None,
@@ -648,7 +650,7 @@ impl JailedStream {
648650
if let Some(parser) = &self.tool_call_parser {
649651
if let Ok((_, _)) = try_tool_call_parse_aggregate(accumulated_content, Some(parser))
650652
{
651-
let split_pos = self.find_tool_call_end_position(accumulated_content, parser);
653+
let split_pos = find_tool_call_end_position(accumulated_content, Some(parser));
652654
(true, split_pos)
653655
} else {
654656
(false, accumulated_content.len())
@@ -732,63 +734,6 @@ impl JailedStream {
732734
}
733735
false
734736
}
735-
736-
/// Find the exact position where the tool call ends for splitting content
737-
/// This handles the early exit case where we have trailing content after the tool call
738-
fn find_tool_call_end_position(&self, content: &str, parser: &str) -> usize {
739-
match parser {
740-
"hermes" => {
741-
// For Hermes, look for </tool_call> marker
742-
if let Some(pos) = content.find("</tool_call>") {
743-
pos + "</tool_call>".len()
744-
} else {
745-
content.len()
746-
}
747-
}
748-
"nemotron_deci" => {
749-
// For Nemotron, look for </TOOLCALL> marker
750-
if let Some(pos) = content.find("</TOOLCALL>") {
751-
pos + "</TOOLCALL>".len()
752-
} else {
753-
content.len()
754-
}
755-
}
756-
"mistral" => {
757-
// For Mistral, look for [/TOOL_CALLS] marker or end of JSON array
758-
if let Some(pos) = content.find("[/TOOL_CALLS]") {
759-
pos + "[/TOOL_CALLS]".len()
760-
} else if let Some(pos) = content.rfind(']') {
761-
// Find the last ] which should be the end of the tool calls array
762-
pos + 1
763-
} else {
764-
content.len()
765-
}
766-
}
767-
"phi4" => {
768-
// For Phi4, look for <|tool_call|> end marker
769-
if let Some(pos) = content.rfind("<|tool_call|>") {
770-
// Look for the next occurrence after this position
771-
if let Some(end_pos) = content[pos..].find(">") {
772-
pos + end_pos + 1
773-
} else {
774-
content.len()
775-
}
776-
} else {
777-
content.len()
778-
}
779-
}
780-
"llama3_json" => {
781-
// For Llama3 JSON, there's no explicit end marker
782-
// The end is determined by complete JSON parsing
783-
// Return full content length to avoid early splitting
784-
content.len()
785-
}
786-
_ => {
787-
// Unknown parser, default to full content
788-
content.len()
789-
}
790-
}
791-
}
792737
}
793738

794739
/// Builder for configuring a JailedStream

lib/parsers/src/tool_calling/harmony/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@ pub use super::{config, response};
77
pub use harmony_parser::{
88
detect_tool_call_start_harmony, parse_tool_calls_harmony, parse_tool_calls_harmony_complete,
99
};
10+
11+
pub fn find_tool_call_end_position_harmony(chunk: &str) -> usize {
12+
chunk.len()
13+
}

lib/parsers/src/tool_calling/json/mod.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,31 @@ pub fn detect_tool_call_start_json(chunk: &str, config: &JsonParserConfig) -> bo
4141
JsonParserType::DeepseekV31 => detect_tool_call_start_deepseek_v3_1(chunk, config),
4242
}
4343
}
44+
45+
pub fn find_tool_call_end_position_json(
46+
chunk: &str,
47+
parser: &str,
48+
config: &JsonParserConfig,
49+
) -> usize {
50+
match parser {
51+
"hermes" | "nemotron_deci" => {
52+
if let Some(end_token) = config.tool_call_end_tokens.first() {
53+
if let Some(pos) = chunk.find(end_token) {
54+
pos + end_token.len()
55+
} else {
56+
chunk.len()
57+
}
58+
} else {
59+
chunk.len()
60+
}
61+
}
62+
"mistral" | "phi4" => {
63+
if let Some(pos) = chunk.rfind(']') {
64+
pos + 1
65+
} else {
66+
chunk.len()
67+
}
68+
}
69+
_ => chunk.len(),
70+
}
71+
}

lib/parsers/src/tool_calling/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ pub mod tools;
1313
pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType};
1414
pub use harmony::{parse_tool_calls_harmony, parse_tool_calls_harmony_complete};
1515
pub use json::try_tool_call_parse_json;
16-
pub use parsers::{detect_and_parse_tool_call, detect_tool_call_start, try_tool_call_parse};
16+
pub use parsers::{
17+
detect_and_parse_tool_call, detect_tool_call_start, find_tool_call_end_position,
18+
try_tool_call_parse,
19+
};
1720
pub use pythonic::try_tool_call_parse_pythonic;
1821
pub use response::{CalledFunction, ToolCallResponse, ToolCallType};
1922
pub use tools::{try_tool_call_parse_aggregate, try_tool_call_parse_stream};

lib/parsers/src/tool_calling/parsers.rs

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,17 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
use super::config::{ToolCallConfig, ToolCallParserType};
5-
use super::harmony::{detect_tool_call_start_harmony, parse_tool_calls_harmony_complete};
6-
use super::json::{detect_tool_call_start_json, try_tool_call_parse_json};
7-
use super::pythonic::{detect_tool_call_start_pythonic, try_tool_call_parse_pythonic};
5+
use super::harmony::{
6+
detect_tool_call_start_harmony, find_tool_call_end_position_harmony,
7+
parse_tool_calls_harmony_complete,
8+
};
9+
use super::json::{
10+
detect_tool_call_start_json, find_tool_call_end_position_json, try_tool_call_parse_json,
11+
};
12+
use super::pythonic::{
13+
detect_tool_call_start_pythonic, find_tool_call_end_position_pythonic,
14+
try_tool_call_parse_pythonic,
15+
};
816
use super::response::ToolCallResponse;
917
use std::collections::HashMap;
1018
use std::sync::OnceLock;
@@ -116,6 +124,35 @@ pub fn detect_tool_call_start(chunk: &str, parser_str: Option<&str>) -> anyhow::
116124
}
117125
}
118126

127+
pub fn find_tool_call_end_position(chunk: &str, parser_str: Option<&str>) -> usize {
128+
let parser_map = get_tool_parser_map();
129+
let parser_key = match parser_str {
130+
Some(s) if !s.is_empty() => s,
131+
_ => "default",
132+
};
133+
134+
match parser_map.get(parser_key) {
135+
Some(config) => match config.format {
136+
ToolCallParserType::Json => {
137+
find_tool_call_end_position_json(chunk, parser_key, &config.json)
138+
}
139+
ToolCallParserType::Harmony => find_tool_call_end_position_harmony(chunk),
140+
ToolCallParserType::Pythonic => find_tool_call_end_position_pythonic(chunk),
141+
ToolCallParserType::Typescript => {
142+
// Typescript parser not implemented
143+
chunk.len()
144+
}
145+
ToolCallParserType::Xml => {
146+
// Xml parser not implemented
147+
chunk.len()
148+
}
149+
},
150+
None => {
151+
// Unknown parser, return full content length
152+
chunk.len()
153+
}
154+
}
155+
}
119156
// Tests
120157
// cargo test postprocessor::tool_calling::parsers
121158
#[cfg(test)]

lib/parsers/src/tool_calling/pythonic/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,7 @@ pub mod pythonic_parser;
55

66
pub use super::{config, response};
77
pub use pythonic_parser::{detect_tool_call_start_pythonic, try_tool_call_parse_pythonic};
8+
9+
pub fn find_tool_call_end_position_pythonic(chunk: &str) -> usize {
10+
chunk.len()
11+
}

0 commit comments

Comments
 (0)