Skip to content

Commit 2a2325f

Browse files
KrishnanPrashrmccorm4ayushag-nv
authored andcommitted
feat: Add frontend support for min_tokens and ignore_eos (outside of nvext) and Structured Output / Guided Decoding (#2380)
Signed-off-by: KrishnanPrash <140860868+KrishnanPrash@users.noreply.github.com> Co-authored-by: Ryan McCormick <rmccormick@nvidia.com> Co-authored-by: Ayush Agarwal <ayushag@nvidia.com> Signed-off-by: Hannah Zhang <hannahz@nvidia.com>
1 parent ab6c7cb commit 2a2325f

File tree

12 files changed

+625
-42
lines changed

12 files changed

+625
-42
lines changed

lib/llm/src/entrypoint/input/batch.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,11 @@ async fn evaluate(
222222
)
223223
.temperature(template.as_ref().map_or(0.7, |t| t.temperature))
224224
.build()?;
225-
let req = NvCreateChatCompletionRequest { inner, nvext: None };
225+
let req = NvCreateChatCompletionRequest {
226+
inner,
227+
common: Default::default(),
228+
nvext: None,
229+
};
226230
let mut stream = engine.generate(Context::new(req)).await?;
227231
let mut output = String::new();
228232
while let Some(item) = stream.next().await {

lib/llm/src/entrypoint/input/text.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ async fn main_loop(
118118

119119
let req = NvCreateChatCompletionRequest {
120120
inner,
121+
common: Default::default(),
121122
nvext: Some(nvext),
122123
};
123124

lib/llm/src/http/service/openai.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,7 @@ mod tests {
12371237
messages: vec![],
12381238
..Default::default()
12391239
},
1240+
common: Default::default(),
12401241
nvext: None,
12411242
};
12421243
let result = validate_chat_completion_required_fields(&request);
@@ -1263,6 +1264,7 @@ mod tests {
12631264
)],
12641265
..Default::default()
12651266
},
1267+
common: Default::default(),
12661268
nvext: None,
12671269
};
12681270
let result = validate_chat_completion_required_fields(&request);

lib/llm/src/protocols/openai.rs

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ use super::{
2020
common::{self, SamplingOptionsProvider, StopConditionsProvider},
2121
ContentProvider,
2222
};
23+
use crate::protocols::openai::common_ext::CommonExtProvider;
2324

2425
pub mod chat_completions;
26+
pub mod common_ext;
2527
pub mod completions;
2628
pub mod embeddings;
2729
pub mod models;
@@ -61,9 +63,23 @@ trait OpenAIStopConditionsProvider {
6163
fn get_stop(&self) -> Option<Vec<String>>;
6264

6365
fn nvext(&self) -> Option<&nvext::NvExt>;
66+
67+
/// Get ignore_eos from CommonExt if the type supports it.
68+
/// Default returns None for types without CommonExt support.
69+
fn get_common_ignore_eos(&self) -> Option<bool> {
70+
None
71+
}
72+
73+
/// Get the effective ignore_eos value, considering both CommonExt and NvExt.
74+
/// CommonExt (root-level) takes precedence over NvExt.
75+
fn get_ignore_eos(&self) -> Option<bool> {
76+
// Check common first (takes precedence), then fall back to nvext
77+
self.get_common_ignore_eos()
78+
.or_else(|| self.nvext().and_then(|nv| nv.ignore_eos))
79+
}
6480
}
6581

66-
impl<T: OpenAISamplingOptionsProvider> SamplingOptionsProvider for T {
82+
impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvider for T {
6783
fn extract_sampling_options(&self) -> Result<common::SamplingOptions> {
6884
// let result = self.validate();
6985
// if let Err(e) = result {
@@ -88,29 +104,26 @@ impl<T: OpenAISamplingOptionsProvider> SamplingOptionsProvider for T {
88104
}
89105
}
90106

91-
let mut guided_decoding = None;
92-
if let Some(nvext) = self.nvext() {
93-
let guided_decoding_backend = nvext.guided_decoding_backend.clone();
94-
let guided_json = nvext.guided_json.clone();
95-
let guided_regex = nvext.guided_regex.clone();
96-
let guided_grammar = nvext.guided_grammar.clone();
97-
let guided_choice = nvext.guided_choice.clone();
98-
99-
match common::GuidedDecodingOptions::from_optional(
100-
guided_json,
101-
guided_regex,
102-
guided_choice,
103-
guided_grammar,
104-
guided_decoding_backend,
105-
) {
106-
Ok(options) => guided_decoding = options,
107-
Err(e) => {
108-
// Handle the validation error (log, return error, etc.)
109-
tracing::error!("Invalid guided decoding options: {}", e);
110-
return Err(e);
111-
}
107+
let guided_decoding_backend = self.get_guided_decoding_backend();
108+
let guided_json = self.get_guided_json();
109+
let guided_regex = self.get_guided_regex();
110+
let guided_grammar = self.get_guided_grammar();
111+
let guided_choice = self.get_guided_choice();
112+
113+
let guided_decoding = match common::GuidedDecodingOptions::from_optional(
114+
guided_json.cloned(),
115+
guided_regex,
116+
guided_choice,
117+
guided_grammar,
118+
guided_decoding_backend,
119+
) {
120+
Ok(options) => options,
121+
Err(e) => {
122+
// Handle the validation error (log, return error, etc.)
123+
tracing::error!("Invalid guided decoding options: {:?}", e);
124+
return Err(e);
112125
}
113-
}
126+
};
114127

115128
Ok(common::SamplingOptions {
116129
n: None,
@@ -142,11 +155,8 @@ impl<T: OpenAIStopConditionsProvider> StopConditionsProvider for T {
142155
}
143156
}
144157

145-
let mut ignore_eos = None;
146-
147-
if let Some(nvext) = self.nvext() {
148-
ignore_eos = nvext.ignore_eos;
149-
}
158+
// Use the trait method to get ignore_eos, which handles precedence
159+
let ignore_eos = self.get_ignore_eos();
150160

151161
Ok(common::StopConditions {
152162
max_tokens,

lib/llm/src/protocols/openai/chat_completions.rs

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ use validator::Validate;
2020
use crate::engines::ValidateRequest;
2121

2222
use super::{
23-
nvext::NvExt, nvext::NvExtProvider, validate, OpenAISamplingOptionsProvider,
24-
OpenAIStopConditionsProvider,
23+
common_ext::{CommonExt, CommonExtProvider},
24+
nvext::NvExt,
25+
nvext::NvExtProvider,
26+
validate, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
2527
};
2628

2729
mod aggregator;
@@ -31,17 +33,21 @@ pub use aggregator::DeltaAggregator;
3133
pub use delta::DeltaGenerator;
3234

3335
/// A request structure for creating a chat completion, extending OpenAI's
34-
/// `CreateChatCompletionRequest` with [`NvExt`] extensions.
36+
/// `CreateChatCompletionRequest` with [`NvExt`] extensions and common fields.
3537
///
3638
/// # Fields
3739
/// - `inner`: The base OpenAI chat completion request, embedded using `serde(flatten)`.
38-
/// - `nvext`: The optional NVIDIA extension field. See [`NvExt`] for
39-
/// more details.
40+
/// - `common`: Common extension fields (ignore_eos, min_tokens) at root level, embedded using `serde(flatten)`.
41+
/// - `nvext`: The optional NVIDIA extension field. See [`NvExt`] for more details.
42+
/// Note: If ignore_eos is specified in both common and nvext, the common (root-level) value takes precedence.
4043
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
4144
pub struct NvCreateChatCompletionRequest {
4245
#[serde(flatten)]
4346
pub inner: async_openai::types::CreateChatCompletionRequest,
4447

48+
#[serde(flatten, default)]
49+
pub common: CommonExt,
50+
4551
#[serde(skip_serializing_if = "Option::is_none")]
4652
pub nvext: Option<NvExt>,
4753
}
@@ -139,6 +145,52 @@ impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
139145
}
140146
}
141147

148+
/// Implements `CommonExtProvider` for `NvCreateChatCompletionRequest`,
149+
/// providing access to common extension fields.
150+
impl CommonExtProvider for NvCreateChatCompletionRequest {
151+
/// Returns a reference to the CommonExt struct.
152+
fn common_ext(&self) -> Option<&CommonExt> {
153+
Some(&self.common)
154+
}
155+
156+
/// Guided Decoding Options
157+
fn get_guided_json(&self) -> Option<&serde_json::Value> {
158+
self.common
159+
.guided_json
160+
.as_ref()
161+
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_json.as_ref()))
162+
}
163+
164+
fn get_guided_regex(&self) -> Option<String> {
165+
self.common
166+
.guided_regex
167+
.clone()
168+
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_regex.clone()))
169+
}
170+
171+
fn get_guided_grammar(&self) -> Option<String> {
172+
self.common
173+
.guided_grammar
174+
.clone()
175+
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_grammar.clone()))
176+
}
177+
178+
fn get_guided_choice(&self) -> Option<Vec<String>> {
179+
self.common
180+
.guided_choice
181+
.clone()
182+
.or_else(|| self.nvext.as_ref().and_then(|nv| nv.guided_choice.clone()))
183+
}
184+
185+
fn get_guided_decoding_backend(&self) -> Option<String> {
186+
self.common.guided_decoding_backend.clone().or_else(|| {
187+
self.nvext
188+
.as_ref()
189+
.and_then(|nv| nv.guided_decoding_backend.clone())
190+
})
191+
}
192+
}
193+
142194
/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
143195
/// providing access to stop conditions that control chat completion behavior.
144196
impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
@@ -149,12 +201,10 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
149201
}
150202

151203
/// Retrieves the minimum number of tokens required in the response.
152-
///
153-
/// # Note
154-
/// This method is currently a placeholder and always returns `None`
155-
/// since `min_tokens` is not an OpenAI-supported parameter.
204+
/// Returns `min_tokens` Value
205+
/// `min_tokens` is not an OpenAI-supported parameter.
156206
fn get_min_tokens(&self) -> Option<u32> {
157-
None
207+
self.common.min_tokens
158208
}
159209

160210
/// Retrieves the stop conditions that terminate the chat completion response.
@@ -175,6 +225,11 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
175225
fn nvext(&self) -> Option<&NvExt> {
176226
self.nvext.as_ref()
177227
}
228+
229+
/// Get ignore_eos from CommonExt.
230+
fn get_common_ignore_eos(&self) -> Option<bool> {
231+
self.common.ignore_eos
232+
}
178233
}
179234

180235
/// Implements `ValidateRequest` for `NvCreateChatCompletionRequest`,

0 commit comments

Comments
 (0)