Skip to content

Commit 125be8c

Browse files
authored
chore: add support for multi-tool within nested tags (#2501)
1 parent 56e9923 commit 125be8c

File tree

2 files changed

+89
-14
lines changed

2 files changed

+89
-14
lines changed

lib/llm/src/postprocessor/tool_calling/json_parser.rs

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ pub struct CalledFunctionArguments {
2424
pub arguments: HashMap<String, Value>,
2525
}
2626

27-
fn extract_tool_call_content<'a>(
28-
input: &'a str,
29-
start_token: &str,
30-
end_token: &str,
31-
) -> Option<&'a str> {
27+
// Extract the contents between start and end tokens using regex parsing.
28+
// Returns a JSON array string if there are multiple matches, otherwise returns the last match directly.
29+
fn extract_tool_call_content(input: &str, start_token: &str, end_token: &str) -> Option<String> {
3230
let escaped_start = regex::escape(start_token);
3331
let escaped_end = regex::escape(end_token);
3432
let pattern = format!(r"{}(.*?){}", escaped_start, escaped_end);
@@ -38,19 +36,62 @@ fn extract_tool_call_content<'a>(
3836
.build()
3937
{
4038
Ok(regex) => {
41-
// Get all matches and take the last one for now. TODO : Handle multiple tool calls
39+
// Get all matches and take the last one for now. TODO: Handle multiple tool calls
4240
let matches: Vec<_> = regex
4341
.captures_iter(input)
4442
.filter_map(|captures| captures.get(1))
45-
.map(|m| m.as_str().trim())
43+
.map(|m| m.as_str().trim().to_string())
4644
.collect();
47-
48-
matches.last().copied()
45+
if !matches.is_empty() {
46+
// If only one match, return it directly, otherwise return as a JSON array string
47+
if matches.len() == 1 {
48+
// Return the last match directly
49+
return Some(matches.last().unwrap().clone());
50+
} else {
51+
// Join the matches into a JSON array string
52+
return Some(format!("[{}]", matches.join(",")));
53+
}
54+
}
55+
None
4956
}
5057
Err(_) => None,
5158
}
5259
}
5360

61+
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
62+
// Handles single tool and multiple tool call cases for single start_token like <|python_tag|>
63+
fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String {
64+
// Return the input if it doesn't contain the start token
65+
if !input.contains(start_token) {
66+
return input.to_string();
67+
}
68+
69+
// Split on the start token and keep only JSON-looking segments
70+
let mut items: Vec<String> = Vec::new();
71+
for seg in input.split(start_token) {
72+
let s = seg.trim();
73+
if s.is_empty() {
74+
continue;
75+
}
76+
// Only consider segments that start like JSON
77+
if s.starts_with('{') || s.starts_with('[') {
78+
// Trim trailing non-JSON by cutting at the last closing brace/bracket
79+
if let Some(pos) = s.rfind(['}', ']']) {
80+
let candidate = &s[..=pos];
81+
// Keep only valid JSON candidates
82+
if serde_json::from_str::<serde_json::Value>(candidate).is_ok() {
83+
items.push(candidate.to_string());
84+
}
85+
}
86+
}
87+
}
88+
89+
if items.is_empty() {
90+
return input.to_string();
91+
}
92+
format!("[{}]", items.join(","))
93+
}
94+
5495
/// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format.
5596
///
5697
/// This is a flexible helper that handles a variety of potential formats emitted by LLMs for function/tool calls,
@@ -110,21 +151,25 @@ pub fn try_tool_call_parse_json(
110151
);
111152

112153
// Iterate over all start and end tokens and try to extract the content between them
113-
let mut json = trimmed;
154+
// Assumption : One message will not contain different tags for tool calls. Iteration over tags is to support different tags by default for multiple models
155+
let mut json = trimmed.to_string();
114156
for (start_token, end_token) in tool_call_start_tokens
115157
.iter()
116158
.zip(tool_call_end_tokens.iter())
117159
{
118160
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
119161
json = if !start_token.is_empty() && end_token.is_empty() {
120-
json.strip_prefix(start_token).unwrap_or(json)
121-
} else if let Some(content) = extract_tool_call_content(json, start_token, end_token) {
162+
handle_single_token_tool_calls(&json, start_token)
163+
} else if let Some(content) = extract_tool_call_content(&json, start_token, end_token) {
122164
content
123165
} else {
124166
json
125167
};
126168
}
127169

170+
// Convert json to &str if it's a String, otherwise keep as &str
171+
let json = json.as_str();
172+
128173
// Anonymous function to attempt deserialization into a known representation
129174
let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> {
130175
Ok(ToolCallResponse {

lib/llm/src/postprocessor/tool_calling/parsers.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,6 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
334334
}
335335

336336
#[test]
337-
#[ignore]
338-
// TODO : Implement this
339337
fn test_qwen_qwq_32b_multiple_tool_calls() {
340338
let input = r#"<tool_call>
341339
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
@@ -426,6 +424,22 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
426424
assert_eq!(args["unit"], "fahrenheit");
427425
}
428426

427+
#[test]
428+
fn test_meta_llama_llama31_8b_instruct_with_python_tag_multiple() {
429+
let input = r#"<|python_tag|>{ "name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" } }<|python_tag|>{ "name": "get_weather", "parameters": {"location": "New York, NY", "unit": "fahrenheit" } }"#;
430+
let result = detect_and_parse_tool_call(input, Some("llama3_json")).unwrap();
431+
assert!(!result.is_empty());
432+
assert_eq!(result.len(), 2);
433+
let (name, args) = extract_name_and_args(result[0].clone());
434+
assert_eq!(name, "get_weather");
435+
assert_eq!(args["location"], "San Francisco, CA");
436+
assert_eq!(args["unit"], "fahrenheit");
437+
let (name, args) = extract_name_and_args(result[1].clone());
438+
assert_eq!(name, "get_weather");
439+
assert_eq!(args["location"], "New York, NY");
440+
assert_eq!(args["unit"], "fahrenheit");
441+
}
442+
429443
#[test]
430444
fn test_detect_and_parse_tool_call_error_handling() {
431445
// Unknown parser string should return an error
@@ -522,6 +536,22 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
522536
assert_eq!(args["unit"], "fahrenheit");
523537
}
524538

539+
#[test]
540+
fn test_detect_and_parse_tool_call_default_parser_nemotron_deci_multiple() {
541+
let input = r#"<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]</TOOLCALL>"#;
542+
let result = detect_and_parse_tool_call(input, None).unwrap();
543+
assert!(!result.is_empty());
544+
assert_eq!(result.len(), 2);
545+
let (name, args) = extract_name_and_args(result[0].clone());
546+
assert_eq!(name, "get_weather");
547+
assert_eq!(args["location"], "San Francisco, CA");
548+
assert_eq!(args["unit"], "fahrenheit");
549+
let (name, args) = extract_name_and_args(result[1].clone());
550+
assert_eq!(name, "get_weather");
551+
assert_eq!(args["location"], "New York, NY");
552+
assert_eq!(args["unit"], "fahrenheit");
553+
}
554+
525555
#[test]
526556
fn test_detect_and_parse_tool_call_default_parser_llama3_json_with_python_tag() {
527557
let input = r#"<|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;

0 commit comments

Comments
 (0)