Skip to content

Commit 72bd750

Browse files
committed
preprocessor: warn and proceed when no parser configured for tool_choice; jail: gate common markers when parser is set and clarify local naming; tests: tighten dual-entry path assertions
1 parent 2b4c77e commit 72bd750

File tree

3 files changed

+57
-210
lines changed

3 files changed

+57
-210
lines changed

lib/llm/src/preprocessor.rs

Lines changed: 19 additions & 188 deletions
Original file line numberDiff line numberDiff line change
@@ -624,15 +624,24 @@ impl OpenAIPreprocessor {
624624
) -> std::result::Result<bool, Error> {
625625
match (tool_call_parser, tool_choice, has_tools) {
626626
// No parser but tools requested - error cases
627-
(None, Some(ChatCompletionToolChoiceOption::Required), true) => Err(anyhow::anyhow!(
628-
"Tool choice 'required' specified but no tool parser configured"
629-
)),
630-
(None, Some(ChatCompletionToolChoiceOption::Auto), true) => Err(anyhow::anyhow!(
631-
"Tool choice 'auto' specified but no tool parser configured"
632-
)),
633-
(None, Some(ChatCompletionToolChoiceOption::Named(_)), _) => Err(anyhow::anyhow!(
634-
"Named tool choice specified but no tool parser configured"
635-
)),
627+
(None, Some(ChatCompletionToolChoiceOption::Required), true) => {
628+
tracing::warn!(
629+
"Tool choice 'required' specified but no tool parser configured; proceeding without jailing"
630+
);
631+
Ok(false)
632+
}
633+
(None, Some(ChatCompletionToolChoiceOption::Auto), true) => {
634+
tracing::warn!(
635+
"Tool choice 'auto' specified but no tool parser configured; proceeding without jailing"
636+
);
637+
Ok(false)
638+
}
639+
(None, Some(ChatCompletionToolChoiceOption::Named(_)), _) => {
640+
tracing::warn!(
641+
"Named tool choice specified but no tool parser configured; proceeding without jailing"
642+
);
643+
Ok(false)
644+
}
636645

637646
// Parser exists and tools might be called
638647
(Some(_), Some(ChatCompletionToolChoiceOption::None), _) => {
@@ -864,182 +873,4 @@ impl
864873
}
865874
}
866875

867-
#[allow(deprecated, dead_code)]
868-
#[cfg(test)]
869-
mod tests {
870-
use super::*;
871-
use dynamo_async_openai::types::{
872-
ChatChoiceStream, ChatCompletionStreamResponseDelta, FinishReason as OAIFinishReason, Role,
873-
};
874-
875-
use dynamo_runtime::protocols::annotated::Annotated;
876-
877-
use std::sync::Arc;
878-
879-
// Helper function to create a mock chat response chunk
880-
fn create_mock_response_chunk(
881-
content: String,
882-
index: u32,
883-
) -> Annotated<NvCreateChatCompletionStreamResponse> {
884-
let choice = ChatChoiceStream {
885-
index,
886-
delta: ChatCompletionStreamResponseDelta {
887-
role: Some(Role::Assistant),
888-
content: Some(content),
889-
tool_calls: None,
890-
function_call: None,
891-
refusal: None,
892-
reasoning_content: None,
893-
},
894-
finish_reason: None,
895-
logprobs: None,
896-
};
897-
898-
let response = NvCreateChatCompletionStreamResponse {
899-
id: "test-id".to_string(),
900-
choices: vec![choice],
901-
created: 1234567890,
902-
model: "test-model".to_string(),
903-
system_fingerprint: Some("test-fingerprint".to_string()),
904-
object: "chat.completion.chunk".to_string(),
905-
usage: None,
906-
service_tier: None,
907-
};
908-
909-
Annotated {
910-
data: Some(response),
911-
id: None,
912-
event: None,
913-
comment: None,
914-
}
915-
}
916-
917-
// Helper function to create a final response chunk with finish reason
918-
fn create_final_response_chunk(index: u32) -> Annotated<NvCreateChatCompletionStreamResponse> {
919-
let choice = ChatChoiceStream {
920-
index,
921-
delta: ChatCompletionStreamResponseDelta {
922-
role: None,
923-
content: None,
924-
tool_calls: None,
925-
function_call: None,
926-
refusal: None,
927-
reasoning_content: None,
928-
},
929-
finish_reason: Some(OAIFinishReason::Stop),
930-
logprobs: None,
931-
};
932-
933-
let response = NvCreateChatCompletionStreamResponse {
934-
id: "test-id".to_string(),
935-
choices: vec![choice],
936-
created: 1234567890,
937-
model: "test-model".to_string(),
938-
system_fingerprint: Some("test-fingerprint".to_string()),
939-
object: "chat.completion.chunk".to_string(),
940-
usage: None,
941-
service_tier: None,
942-
};
943-
944-
Annotated {
945-
data: Some(response),
946-
id: None,
947-
event: None,
948-
comment: None,
949-
}
950-
}
951-
952-
// Mock async engine context for testing
953-
#[derive(Debug)]
954-
struct MockAsyncEngineContext {
955-
id: String,
956-
stopped: std::sync::atomic::AtomicBool,
957-
}
958-
959-
impl MockAsyncEngineContext {
960-
fn new(id: String) -> Self {
961-
Self {
962-
id,
963-
stopped: std::sync::atomic::AtomicBool::new(false),
964-
}
965-
}
966-
}
967-
968-
#[async_trait]
969-
impl dynamo_runtime::pipeline::AsyncEngineContext for MockAsyncEngineContext {
970-
fn id(&self) -> &str {
971-
&self.id
972-
}
973-
974-
fn stop(&self) {
975-
self.stopped
976-
.store(true, std::sync::atomic::Ordering::Relaxed);
977-
}
978-
979-
fn stop_generating(&self) {
980-
self.stopped
981-
.store(true, std::sync::atomic::Ordering::Relaxed);
982-
}
983-
984-
fn kill(&self) {
985-
self.stopped
986-
.store(true, std::sync::atomic::Ordering::Relaxed);
987-
}
988-
989-
fn is_stopped(&self) -> bool {
990-
self.stopped.load(std::sync::atomic::Ordering::Relaxed)
991-
}
992-
993-
fn is_killed(&self) -> bool {
994-
self.stopped.load(std::sync::atomic::Ordering::Relaxed)
995-
}
996-
997-
async fn stopped(&self) {
998-
// No-op for testing
999-
}
1000-
1001-
async fn killed(&self) {
1002-
// No-op for testing
1003-
}
1004-
1005-
fn link_child(&self, _: Arc<dyn dynamo_runtime::pipeline::AsyncEngineContext>) {
1006-
// No-op for testing
1007-
}
1008-
}
1009-
1010-
// Test for tool call detection with different parsers - still valuable to keep
1011-
#[tokio::test]
1012-
async fn test_detect_tool_call_start_different_parsers() {
1013-
use dynamo_parsers::tool_calling::detect_tool_call_start;
1014-
1015-
// Test nemotron_deci parser
1016-
assert!(detect_tool_call_start("<TOOLCALL>", Some("nemotron_deci")).unwrap());
1017-
assert!(!detect_tool_call_start("Hello world", Some("nemotron_deci")).unwrap());
1018-
assert!(!detect_tool_call_start("<tool_call>", Some("nemotron_deci")).unwrap()); // Wrong format
1019-
1020-
// Test hermes parser - now also detects JSON patterns
1021-
assert!(detect_tool_call_start("<tool_call>", Some("hermes")).unwrap());
1022-
assert!(detect_tool_call_start("{\"name\": \"test\"}", Some("hermes")).unwrap()); // JSON detection
1023-
assert!(!detect_tool_call_start("Hello world", Some("hermes")).unwrap());
1024-
assert!(!detect_tool_call_start("<TOOLCALL>", Some("hermes")).unwrap()); // Wrong format
1025-
1026-
// Test phi4 parser
1027-
assert!(detect_tool_call_start("functools[", Some("phi4")).unwrap());
1028-
assert!(detect_tool_call_start("{\"name\": \"test\"}", Some("phi4")).unwrap()); // JSON detection
1029-
assert!(!detect_tool_call_start("Hello world", Some("phi4")).unwrap());
1030-
1031-
// Test mistral parser
1032-
assert!(detect_tool_call_start("[{", Some("mistral")).unwrap());
1033-
assert!(detect_tool_call_start("[TOOL_CALLS]", Some("mistral")).unwrap());
1034-
assert!(!detect_tool_call_start("Hello world", Some("mistral")).unwrap());
1035-
1036-
// Test llama3_json parser
1037-
assert!(detect_tool_call_start("<|python_tag|>", Some("llama3_json")).unwrap());
1038-
assert!(detect_tool_call_start("{\"name\": \"test\"}", Some("llama3_json")).unwrap()); // JSON detection
1039-
1040-
// Test default parser (should behave like nemotron_deci)
1041-
assert!(detect_tool_call_start("<TOOLCALL>", None).unwrap());
1042-
assert!(detect_tool_call_start("{\"name\": \"test\"}", None).unwrap()); // JSON detection
1043-
assert!(!detect_tool_call_start("Hello world", None).unwrap());
1044-
}
1045-
}
876+
// Note: tests for jailing and parser detection live in `lib/llm/tests/test_jail.rs`

lib/llm/src/protocols/openai/chat_completions/jail.rs

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,9 @@ impl ChoiceJailState {
140140
let full_content = format!("{}{}", marker, suffix);
141141

142142
// Check if this already contains the end marker
143-
let (should_unjail, split_pos) = jail_stream.should_end_jail(&full_content);
143+
let (should_end, split_pos) = jail_stream.should_end_jail(&full_content);
144144

145-
if should_unjail {
145+
if should_end {
146146
// Complete tool call found in this chunk
147147
tracing::debug!(
148148
"Choice {} complete tool call detected in single chunk",
@@ -272,9 +272,9 @@ impl ChoiceJailState {
272272
// Already jailed - accumulate and check for unjail
273273
self.accumulate(content);
274274

275-
let (should_unjail, split_pos) = jail_stream.should_end_jail(&self.accumulated_content);
275+
let (should_end, split_pos) = jail_stream.should_end_jail(&self.accumulated_content);
276276

277-
if should_unjail {
277+
if should_end {
278278
tracing::debug!(
279279
"Choice {} jail exit detected, releasing accumulated content",
280280
choice.index
@@ -919,22 +919,24 @@ impl JailedStreamBuilder {
919919
}
920920

921921
// Add common tool call markers to ensure we detect all formats
922-
// These are always included even when a specific parser is configured
923-
// to provide broad compatibility and prevent missed tool calls
924-
let common_markers = vec![
925-
"<TOOLCALL>".to_string(), // nemotron_deci format
926-
"<tool_call>".to_string(), // hermes format
927-
"[TOOL_CALLS]".to_string(), // mistral format
928-
"<|python_tag|>".to_string(), // llama3_json format
929-
"functools[".to_string(), // phi4 format
930-
// Add JSON start patterns for Mistral-style tool calls
931-
"[{".to_string(),
932-
"{".to_string(),
933-
// Note: Harmony parser uses JSON patterns, covered by "{" above
934-
];
935-
for marker in common_markers {
936-
if !all_patterns.contains(&marker) {
937-
all_patterns.push(marker);
922+
// Only include these when a specific parser is NOT configured,
923+
// to avoid unexpected false positives for explicit formats
924+
if self.tool_call_parser.is_none() {
925+
let common_markers = vec![
926+
"<TOOLCALL>".to_string(), // nemotron_deci format
927+
"<tool_call>".to_string(), // hermes format
928+
"[TOOL_CALLS]".to_string(), // mistral format
929+
"<|python_tag|>".to_string(), // llama3_json format
930+
"functools[".to_string(), // phi4 format
931+
// Add JSON start patterns for Mistral-style tool calls
932+
"[{".to_string(),
933+
"{".to_string(),
934+
// Note: Harmony parser uses JSON patterns, covered by "{" above
935+
];
936+
for marker in common_markers {
937+
if !all_patterns.contains(&marker) {
938+
all_patterns.push(marker);
939+
}
938940
}
939941
}
940942

lib/llm/tests/test_jail.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,11 @@ mod tests {
434434
let jailed_stream = jail.apply(input_stream);
435435
let results: Vec<_> = jailed_stream.collect().await;
436436

437+
// We should get 2 chunks:
438+
// 1. "Normal text " (before jail)
439+
// 2. Accumulated jailed content when jail ends via </jail>
440+
assert_eq!(results.len(), 2);
441+
437442
// First chunk should pass through
438443
assert_eq!(
439444
results[0].data.as_ref().unwrap().choices[0]
@@ -443,8 +448,17 @@ mod tests {
443448
Some("Normal text ")
444449
);
445450

446-
// Jail should trigger and accumulate
447-
assert!(results.len() >= 2);
451+
// Second chunk should contain the accumulated jailed content
452+
let jailed = results[1]
453+
.data
454+
.as_ref()
455+
.unwrap()
456+
.choices[0]
457+
.delta
458+
.content
459+
.as_ref()
460+
.expect("Expected accumulated jailed content");
461+
assert!(jailed.contains("<jail><TOOLCALL>Jailed content</jail>"));
448462
}
449463

450464
#[tokio::test]

0 commit comments

Comments
 (0)