diff --git a/lib/llm/src/protocols/common.rs b/lib/llm/src/protocols/common.rs index f3c8a894b3..7bf3467df7 100644 --- a/lib/llm/src/protocols/common.rs +++ b/lib/llm/src/protocols/common.rs @@ -257,6 +257,10 @@ pub struct StopConditions { /// tokens after the EOS token is generated. // TODO(ignore_eos) - improve this my masking the EOS token with logit bias pub ignore_eos: Option, + + /// Maximum number of thinking tokens allowed + /// NOTE: Currently a passthrough - no enforcement logic implemented + pub max_thinking_tokens: Option, } impl StopConditions { diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index 6d0839200a..c816c56a48 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -67,6 +67,12 @@ trait OpenAIStopConditionsProvider { self.nvext().and_then(|nv| nv.ignore_eos.as_ref()), ) } + + /// Get max_thinking_tokens from nvext + /// NOTE: This is currently a passthrough for future thinking budget implementation + fn get_max_thinking_tokens(&self) -> Option { + self.nvext().and_then(|nv| nv.max_thinking_tokens) + } } trait OpenAIOutputOptionsProvider { @@ -152,6 +158,7 @@ impl StopConditionsProvider for T { let max_tokens = self.get_max_tokens(); let min_tokens = self.get_min_tokens(); let stop = self.get_stop(); + let max_thinking_tokens = self.get_max_thinking_tokens(); if let Some(stop) = &stop && stop.len() > 4 @@ -168,6 +175,7 @@ impl StopConditionsProvider for T { stop, stop_token_ids_hidden: None, ignore_eos, + max_thinking_tokens, }) } } diff --git a/lib/llm/src/protocols/openai/nvext.rs b/lib/llm/src/protocols/openai/nvext.rs index 68f6253900..5335f83245 100644 --- a/lib/llm/src/protocols/openai/nvext.rs +++ b/lib/llm/src/protocols/openai/nvext.rs @@ -100,6 +100,12 @@ pub struct NvExt { #[serde(default, skip_serializing_if = "Option::is_none")] #[builder(default, setter(strip_option))] pub guided_decoding_backend: Option, + + /// Maximum number of thinking tokens allowed + /// NOTE: Currently passed through to backends as a no-op for future implementation + #[serde(default, skip_serializing_if = "Option::is_none")] + #[builder(default, setter(strip_option))] + pub max_thinking_tokens: Option, } impl Default for NvExt { @@ -157,6 +163,7 @@ mod tests { assert_eq!(nv_ext.guided_regex, None); assert_eq!(nv_ext.guided_grammar, None); assert_eq!(nv_ext.guided_choice, None); + assert_eq!(nv_ext.max_thinking_tokens, None); } // Test valid builder configurations @@ -172,6 +179,7 @@ mod tests { .guided_grammar("S -> 'a' S 'b' | 'c'".to_string()) .guided_choice(vec!["choice1".to_string(), "choice2".to_string()]) .guided_decoding_backend("xgrammar".to_string()) + .max_thinking_tokens(1024) .build() .unwrap(); @@ -193,6 +201,7 @@ mod tests { Some(vec!["choice1".to_string(), "choice2".to_string()]) ); assert_eq!(nv_ext.guided_decoding_backend, Some("xgrammar".to_string())); + assert_eq!(nv_ext.max_thinking_tokens, Some(1024)); // Validate the built struct assert!(nv_ext.validate().is_ok()); } diff --git a/lib/llm/tests/test_common_ext.rs b/lib/llm/tests/test_common_ext.rs index f62cb74e27..c49fea7855 100644 --- a/lib/llm/tests/test_common_ext.rs +++ b/lib/llm/tests/test_common_ext.rs @@ -184,6 +184,40 @@ fn test_chat_completions_common_overrides_nvext() { assert_eq!(stop_conditions.min_tokens, Some(50)); } +#[test] +fn test_max_thinking_tokens_extraction() { + // Test that max_thinking_tokens is extracted from nvext to StopConditions + let json_str = r#"{ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}], + "nvext": { + "max_thinking_tokens": 1024 + } + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + + // Verify nvext parsing + assert_eq!( + request.nvext.as_ref().unwrap().max_thinking_tokens, + Some(1024) + ); + + // Verify extraction to StopConditions + let stop_conditions = request.extract_stop_conditions().unwrap(); + assert_eq!(stop_conditions.max_thinking_tokens, Some(1024)); + + // Test with None value + let json_str_none = r#"{ + "model": "test-model", + "messages": [{"role": "user", "content": "Hello"}] + }"#; + + let request_none: NvCreateChatCompletionRequest = serde_json::from_str(json_str_none).unwrap(); + let stop_conditions_none = request_none.extract_stop_conditions().unwrap(); + assert_eq!(stop_conditions_none.max_thinking_tokens, None); +} + #[test] fn test_chat_completions_backward_compatibility() { // Test backward compatibility - ignore_eos and guided_json only in nvext