Skip to content

Commit f4d49b0

Browse files
bhuvan002dillon-cullinan
authored andcommitted
chore: frontend API changes for thinking budget (#2848)
Signed-off-by: Bhuvan Agrawal <11240550+bhuvan002@users.noreply.github.com>
1 parent d0959d0 commit f4d49b0

File tree

4 files changed

+55
-0
lines changed

4 files changed

+55
-0
lines changed

lib/llm/src/protocols/common.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,10 @@ pub struct StopConditions {
257257
/// tokens after the EOS token is generated.
258258
// TODO(ignore_eos) - improve this my masking the EOS token with logit bias
259259
pub ignore_eos: Option<bool>,
260+
261+
/// Maximum number of thinking tokens allowed
262+
/// NOTE: Currently a passthrough - no enforcement logic implemented
263+
pub max_thinking_tokens: Option<u32>,
260264
}
261265

262266
impl StopConditions {

lib/llm/src/protocols/openai.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ trait OpenAIStopConditionsProvider {
6767
self.nvext().and_then(|nv| nv.ignore_eos.as_ref()),
6868
)
6969
}
70+
71+
/// Get max_thinking_tokens from nvext
72+
/// NOTE: This is currently a passthrough for future thinking budget implementation
73+
fn get_max_thinking_tokens(&self) -> Option<u32> {
74+
self.nvext().and_then(|nv| nv.max_thinking_tokens)
75+
}
7076
}
7177

7278
trait OpenAIOutputOptionsProvider {
@@ -152,6 +158,7 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
152158
let max_tokens = self.get_max_tokens();
153159
let min_tokens = self.get_min_tokens();
154160
let stop = self.get_stop();
161+
let max_thinking_tokens = self.get_max_thinking_tokens();
155162

156163
if let Some(stop) = &stop
157164
&& stop.len() > 4
@@ -168,6 +175,7 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
168175
stop,
169176
stop_token_ids_hidden: None,
170177
ignore_eos,
178+
max_thinking_tokens,
171179
})
172180
}
173181
}

lib/llm/src/protocols/openai/nvext.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ pub struct NvExt {
100100
#[serde(default, skip_serializing_if = "Option::is_none")]
101101
#[builder(default, setter(strip_option))]
102102
pub guided_decoding_backend: Option<String>,
103+
104+
/// Maximum number of thinking tokens allowed
105+
/// NOTE: Currently passed through to backends as a no-op for future implementation
106+
#[serde(default, skip_serializing_if = "Option::is_none")]
107+
#[builder(default, setter(strip_option))]
108+
pub max_thinking_tokens: Option<u32>,
103109
}
104110

105111
impl Default for NvExt {
@@ -157,6 +163,7 @@ mod tests {
157163
assert_eq!(nv_ext.guided_regex, None);
158164
assert_eq!(nv_ext.guided_grammar, None);
159165
assert_eq!(nv_ext.guided_choice, None);
166+
assert_eq!(nv_ext.max_thinking_tokens, None);
160167
}
161168

162169
// Test valid builder configurations
@@ -172,6 +179,7 @@ mod tests {
172179
.guided_grammar("S -> 'a' S 'b' | 'c'".to_string())
173180
.guided_choice(vec!["choice1".to_string(), "choice2".to_string()])
174181
.guided_decoding_backend("xgrammar".to_string())
182+
.max_thinking_tokens(1024)
175183
.build()
176184
.unwrap();
177185

@@ -193,6 +201,7 @@ mod tests {
193201
Some(vec!["choice1".to_string(), "choice2".to_string()])
194202
);
195203
assert_eq!(nv_ext.guided_decoding_backend, Some("xgrammar".to_string()));
204+
assert_eq!(nv_ext.max_thinking_tokens, Some(1024));
196205
// Validate the built struct
197206
assert!(nv_ext.validate().is_ok());
198207
}

lib/llm/tests/test_common_ext.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,40 @@ fn test_chat_completions_common_overrides_nvext() {
184184
assert_eq!(stop_conditions.min_tokens, Some(50));
185185
}
186186

187+
#[test]
188+
fn test_max_thinking_tokens_extraction() {
189+
// Test that max_thinking_tokens is extracted from nvext to StopConditions
190+
let json_str = r#"{
191+
"model": "test-model",
192+
"messages": [{"role": "user", "content": "Hello"}],
193+
"nvext": {
194+
"max_thinking_tokens": 1024
195+
}
196+
}"#;
197+
198+
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
199+
200+
// Verify nvext parsing
201+
assert_eq!(
202+
request.nvext.as_ref().unwrap().max_thinking_tokens,
203+
Some(1024)
204+
);
205+
206+
// Verify extraction to StopConditions
207+
let stop_conditions = request.extract_stop_conditions().unwrap();
208+
assert_eq!(stop_conditions.max_thinking_tokens, Some(1024));
209+
210+
// Test with None value
211+
let json_str_none = r#"{
212+
"model": "test-model",
213+
"messages": [{"role": "user", "content": "Hello"}]
214+
}"#;
215+
216+
let request_none: NvCreateChatCompletionRequest = serde_json::from_str(json_str_none).unwrap();
217+
let stop_conditions_none = request_none.extract_stop_conditions().unwrap();
218+
assert_eq!(stop_conditions_none.max_thinking_tokens, None);
219+
}
220+
187221
#[test]
188222
fn test_chat_completions_backward_compatibility() {
189223
// Test backward compatibility - ignore_eos and guided_json only in nvext

0 commit comments

Comments
 (0)