Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/llm/src/protocols/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ pub struct StopConditions {
/// tokens after the EOS token is generated.
// TODO(ignore_eos) - improve this my masking the EOS token with logit bias
pub ignore_eos: Option<bool>,

/// Maximum number of thinking tokens allowed
/// NOTE: Currently a passthrough - no enforcement logic implemented
pub max_thinking_tokens: Option<u32>,
}

impl StopConditions {
Expand Down
8 changes: 8 additions & 0 deletions lib/llm/src/protocols/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ trait OpenAIStopConditionsProvider {
self.nvext().and_then(|nv| nv.ignore_eos.as_ref()),
)
}

/// Get max_thinking_tokens from nvext
/// NOTE: This is currently a passthrough for future thinking budget implementation
fn get_max_thinking_tokens(&self) -> Option<u32> {
self.nvext().and_then(|nv| nv.max_thinking_tokens)
}
}

trait OpenAIOutputOptionsProvider {
Expand Down Expand Up @@ -152,6 +158,7 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
let max_tokens = self.get_max_tokens();
let min_tokens = self.get_min_tokens();
let stop = self.get_stop();
let max_thinking_tokens = self.get_max_thinking_tokens();

if let Some(stop) = &stop
&& stop.len() > 4
Expand All @@ -168,6 +175,7 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
stop,
stop_token_ids_hidden: None,
ignore_eos,
max_thinking_tokens,
})
}
}
Expand Down
9 changes: 9 additions & 0 deletions lib/llm/src/protocols/openai/nvext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ pub struct NvExt {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub guided_decoding_backend: Option<String>,

/// Maximum number of thinking tokens allowed
/// NOTE: Currently passed through to backends as a no-op for future implementation
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub max_thinking_tokens: Option<u32>,
}

impl Default for NvExt {
Expand Down Expand Up @@ -157,6 +163,7 @@ mod tests {
assert_eq!(nv_ext.guided_regex, None);
assert_eq!(nv_ext.guided_grammar, None);
assert_eq!(nv_ext.guided_choice, None);
assert_eq!(nv_ext.max_thinking_tokens, None);
}

// Test valid builder configurations
Expand All @@ -172,6 +179,7 @@ mod tests {
.guided_grammar("S -> 'a' S 'b' | 'c'".to_string())
.guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
.guided_decoding_backend("xgrammar".to_string())
.max_thinking_tokens(1024)
.build()
.unwrap();

Expand All @@ -193,6 +201,7 @@ mod tests {
Some(vec!["choice1".to_string(), "choice2".to_string()])
);
assert_eq!(nv_ext.guided_decoding_backend, Some("xgrammar".to_string()));
assert_eq!(nv_ext.max_thinking_tokens, Some(1024));
// Validate the built struct
assert!(nv_ext.validate().is_ok());
}
Expand Down
34 changes: 34 additions & 0 deletions lib/llm/tests/test_common_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,40 @@ fn test_chat_completions_common_overrides_nvext() {
assert_eq!(stop_conditions.min_tokens, Some(50));
}

#[test]
fn test_max_thinking_tokens_extraction() {
// Test that max_thinking_tokens is extracted from nvext to StopConditions
let json_str = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}],
"nvext": {
"max_thinking_tokens": 1024
}
}"#;

let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();

// Verify nvext parsing
assert_eq!(
request.nvext.as_ref().unwrap().max_thinking_tokens,
Some(1024)
);

// Verify extraction to StopConditions
let stop_conditions = request.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions.max_thinking_tokens, Some(1024));

// Test with None value
let json_str_none = r#"{
"model": "test-model",
"messages": [{"role": "user", "content": "Hello"}]
}"#;

let request_none: NvCreateChatCompletionRequest = serde_json::from_str(json_str_none).unwrap();
let stop_conditions_none = request_none.extract_stop_conditions().unwrap();
assert_eq!(stop_conditions_none.max_thinking_tokens, None);
}

#[test]
fn test_chat_completions_backward_compatibility() {
// Test backward compatibility - ignore_eos and guided_json only in nvext
Expand Down
Loading