|
| 1 | +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +use regex::Regex; |
| 5 | +use serde_json::Value; |
| 6 | +use std::sync::OnceLock; |
| 7 | + |
| 8 | +use super::config::JsonParserConfig; |
| 9 | +use super::response::{CalledFunction, ToolCallResponse, ToolCallType}; |
| 10 | + |
| 11 | +static DEEPSEEK_V3_1_OUTER_REGEX: OnceLock<Regex> = OnceLock::new(); |
| 12 | +static DEEPSEEK_V3_1_INNER_REGEX: OnceLock<Regex> = OnceLock::new(); |
| 13 | + |
| 14 | +pub fn get_deepseek_v3_1_outer_regex() -> &'static Regex { |
| 15 | + DEEPSEEK_V3_1_OUTER_REGEX.get_or_init(|| { |
| 16 | + // Outer regex: matches the entire tool call block |
| 17 | + Regex::new(r"(?s)<|tool▁call▁begin|>.*?<|tool▁call▁end|>") |
| 18 | + .expect("Failed to compile deepseek v3.1 outer regex pattern") |
| 19 | + }) |
| 20 | +} |
| 21 | + |
| 22 | +pub fn get_deepseek_v3_1_inner_regex() -> &'static Regex { |
| 23 | + DEEPSEEK_V3_1_INNER_REGEX.get_or_init(|| { |
| 24 | + // Inner regex: captures function name and arguments between sep tokens |
| 25 | + Regex::new(r"(?s)<|tool▁call▁begin|>(.*?)<|tool▁sep|>(.*?)<|tool▁call▁end|>") |
| 26 | + .expect("Failed to compile deepseek v3.1 inner regex pattern") |
| 27 | + }) |
| 28 | +} |
| 29 | + |
| 30 | +pub fn parse_tool_calls_deepseek_v3_1( |
| 31 | + message: &str, |
| 32 | + config: &JsonParserConfig, |
| 33 | +) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> { |
| 34 | + // Format Structure: |
| 35 | + // <|tool▁calls▁begin|><|tool▁call▁begin|>{function_name}<|tool▁sep|>{json_arguments}<|tool▁calls▁end|><|end▁of▁sentence|> |
| 36 | + let trimmed = message.trim(); |
| 37 | + |
| 38 | + let tool_call_start_tokens = &config.tool_call_start_tokens; |
| 39 | + |
| 40 | + // Early exit if no content or tool_call_start_tokens is empty |
| 41 | + if trimmed.is_empty() || tool_call_start_tokens.is_empty() { |
| 42 | + return Ok((vec![], Some(trimmed.to_string()))); |
| 43 | + } |
| 44 | + |
| 45 | + // If tool call start token is not present then, no tool calls are there, return empty tool calls and the original trimmed string |
| 46 | + if let Some(start_token) = tool_call_start_tokens.first() { |
| 47 | + if !trimmed.contains(start_token) { |
| 48 | + return Ok((vec![], Some(trimmed.to_string()))); |
| 49 | + } |
| 50 | + } else { |
| 51 | + // Invalid start token |
| 52 | + return Ok((vec![], Some(trimmed.to_string()))); |
| 53 | + } |
| 54 | + |
| 55 | + let outer_re = get_deepseek_v3_1_outer_regex(); |
| 56 | + let inner_re = get_deepseek_v3_1_inner_regex(); |
| 57 | + |
| 58 | + let outer_matches = outer_re.find_iter(trimmed); |
| 59 | + |
| 60 | + let mut tool_calls: Vec<ToolCallResponse> = Vec::new(); |
| 61 | + let mut call_idx = 0usize; |
| 62 | + // Two matches are there, first one using outer regex to extract multiple tool calls |
| 63 | + // Second one using inner regex to extract the structure of the tool call |
| 64 | + for outer_match in outer_matches { |
| 65 | + for grp in inner_re.captures_iter(outer_match.as_str()) { |
| 66 | + let Some(function_name) = grp.get(1).map(|x| x.as_str()) else { |
| 67 | + continue; // Skip if function name is not found |
| 68 | + }; |
| 69 | + |
| 70 | + let Some(arg_match) = grp.get(2) else { |
| 71 | + continue; // Skip if arguments Match is not found. |
| 72 | + }; |
| 73 | + |
| 74 | + let arguments = match serde_json::from_str::<Value>(arg_match.as_str()) { |
| 75 | + Ok(args) => args, |
| 76 | + Err(_) => { |
| 77 | + continue; // Skip if arguments are not valid JSON |
| 78 | + } |
| 79 | + }; |
| 80 | + |
| 81 | + call_idx += 1; |
| 82 | + tool_calls.push(ToolCallResponse { |
| 83 | + id: format!("call-{}", call_idx), |
| 84 | + tp: ToolCallType::Function, |
| 85 | + function: CalledFunction { |
| 86 | + name: function_name.to_string(), |
| 87 | + arguments: serde_json::to_string(&arguments)?, |
| 88 | + }, |
| 89 | + }); |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | + // Fast path: if no tool calls, just return early |
| 94 | + // This may happen due to invalid json or any other parsing error reasons |
| 95 | + if tool_calls.is_empty() { |
| 96 | + return Ok((vec![], Some(trimmed.to_string()))); |
| 97 | + } |
| 98 | + |
| 99 | + // Safety: We already checked above that tool_call_start_tokens.first() is Some |
| 100 | + let start_token = tool_call_start_tokens.first().unwrap(); |
| 101 | + let normal_text = trimmed |
| 102 | + .split_once(start_token) |
| 103 | + .map(|(before, _)| before.to_string()) |
| 104 | + .unwrap_or_else(|| trimmed.to_string()); |
| 105 | + |
| 106 | + Ok((tool_calls, Some(normal_text))) |
| 107 | +} |
| 108 | + |
| 109 | +#[cfg(test)] |
| 110 | +mod tests { |
| 111 | + use super::*; |
| 112 | + |
| 113 | + fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) { |
| 114 | + let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap(); |
| 115 | + (call.function.name, args) |
| 116 | + } |
| 117 | + |
| 118 | + #[test] |
| 119 | + fn test_parse_tool_calls_deepseek_v3_1_basic() { |
| 120 | + let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#; |
| 121 | + let config = JsonParserConfig { |
| 122 | + tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()], |
| 123 | + tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()], |
| 124 | + ..Default::default() |
| 125 | + }; |
| 126 | + let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); |
| 127 | + assert_eq!(content, Some("".to_string())); |
| 128 | + assert_eq!(result.len(), 2); |
| 129 | + let (name, args) = extract_name_and_args(result[0].clone()); |
| 130 | + assert_eq!(name, "get_current_weather"); |
| 131 | + assert_eq!(args["location"], "Tokyo"); |
| 132 | + let (name, args) = extract_name_and_args(result[1].clone()); |
| 133 | + assert_eq!(name, "get_current_weather"); |
| 134 | + assert_eq!(args["location"], "Paris"); |
| 135 | + } |
| 136 | + |
| 137 | + #[test] |
| 138 | + fn test_parse_tool_calls_deepseek_v3_1_with_normal_text() { |
| 139 | + let text = r#"The following tool call retrieves weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "New York"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#; |
| 140 | + let config = JsonParserConfig { |
| 141 | + tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()], |
| 142 | + tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()], |
| 143 | + ..Default::default() |
| 144 | + }; |
| 145 | + let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); |
| 146 | + assert_eq!( |
| 147 | + content, |
| 148 | + Some("The following tool call retrieves weather information: ".to_string()) |
| 149 | + ); |
| 150 | + assert_eq!(result.len(), 1); |
| 151 | + let (name, args) = extract_name_and_args(result[0].clone()); |
| 152 | + assert_eq!(name, "get_current_weather"); |
| 153 | + assert_eq!(args["location"], "New York"); |
| 154 | + } |
| 155 | + |
| 156 | + #[test] |
| 157 | + fn test_parse_tool_calls_deepseek_v3_1_without_tool_call_start_token() { |
| 158 | + let text = r#"<|tool▁call▁begin|>get_current_weather宽带}{location": "Tokyo"}<|tool▁call▁end|><|tool▁calls▁end|>"#; |
| 159 | + let config = JsonParserConfig { |
| 160 | + tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()], |
| 161 | + tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()], |
| 162 | + ..Default::default() |
| 163 | + }; |
| 164 | + let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); |
| 165 | + assert_eq!(content, Some(text.to_string())); |
| 166 | + assert_eq!(result.len(), 0); |
| 167 | + } |
| 168 | + |
| 169 | + #[test] |
| 170 | + fn test_parse_tool_calls_deepseek_v3_1_with_multi_tool_calls_with_multiple_args() { |
| 171 | + let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Berlin", "units": "metric"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather_forecast<|tool▁sep|>{"location": "Berlin", "days": 7, "units": "imperial"}<|tool▁call▁end|><|tool▁call▁begin|>get_air_quality<|tool▁sep|>{"location": "Berlin", "radius": 50}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#; |
| 172 | + let config = JsonParserConfig { |
| 173 | + tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()], |
| 174 | + tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()], |
| 175 | + ..Default::default() |
| 176 | + }; |
| 177 | + let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); |
| 178 | + assert_eq!(content, Some("".to_string())); |
| 179 | + assert_eq!(result.len(), 3); |
| 180 | + let (name, args) = extract_name_and_args(result[0].clone()); |
| 181 | + assert_eq!(name, "get_current_weather"); |
| 182 | + assert_eq!(args["location"], "Berlin"); |
| 183 | + assert_eq!(args["units"], "metric"); |
| 184 | + let (name, args) = extract_name_and_args(result[1].clone()); |
| 185 | + assert_eq!(name, "get_weather_forecast"); |
| 186 | + assert_eq!(args["location"], "Berlin"); |
| 187 | + assert_eq!(args["days"], 7); |
| 188 | + assert_eq!(args["units"], "imperial"); |
| 189 | + let (name, args) = extract_name_and_args(result[2].clone()); |
| 190 | + assert_eq!(name, "get_air_quality"); |
| 191 | + assert_eq!(args["location"], "Berlin"); |
| 192 | + assert_eq!(args["radius"], 50); |
| 193 | + } |
| 194 | + |
| 195 | + #[test] |
| 196 | + fn test_parse_tool_calls_deepseek_v3_1_with_invalid_json() { |
| 197 | + // Everything is normal text in case of invalid json |
| 198 | + let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather}{location": "Tokyo"}<|tool▁call▁end|><|tool▁calls▁end|>"#; |
| 199 | + let config = JsonParserConfig { |
| 200 | + tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()], |
| 201 | + tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()], |
| 202 | + ..Default::default() |
| 203 | + }; |
| 204 | + let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); |
| 205 | + assert_eq!(content, Some(text.trim().to_string())); |
| 206 | + assert_eq!(result.len(), 0); |
| 207 | + } |
| 208 | + |
| 209 | + #[test] |
| 210 | + fn test_parse_tool_calls_deepseek_v3_1_with_multi_tool_calls_with_normal_text() { |
| 211 | + // Everything is normal text in case of invalid json |
| 212 | + let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}{location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather_forecast宽带}{location": "Berlin", "days": 7, "units": "imperial"}<|tool▁call▁end|><|tool▁call▁begin|>get_air_quality宽带}{location": "Berlin", "radius": 50}<|tool▁call▁end|><|tool▁calls▁end|>"#; |
| 213 | + let config = JsonParserConfig { |
| 214 | + tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()], |
| 215 | + tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()], |
| 216 | + ..Default::default() |
| 217 | + }; |
| 218 | + let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); |
| 219 | + assert_eq!(content, Some(text.trim().to_string())); |
| 220 | + assert_eq!(result.len(), 0); |
| 221 | + } |
| 222 | +} |
0 commit comments