Skip to content

Commit 63f5bbc

Browse files
authored
chore: deprecate nvext.top_k and nvext.repetition_penalty and make available top level (#2767)
Signed-off-by: Ryan Lempka <rlempka@nvidia.com>
1 parent 79a9d69 commit 63f5bbc

File tree

6 files changed

+99
-8
lines changed

6 files changed

+99
-8
lines changed

lib/llm/src/protocols/openai.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
9595
.map_err(|e| anyhow::anyhow!("Error validating frequency_penalty: {}", e))?;
9696
let presence_penalty = validate_range(self.get_presence_penalty(), &PRESENCE_PENALTY_RANGE)
9797
.map_err(|e| anyhow::anyhow!("Error validating presence_penalty: {}", e))?;
98+
let top_k = CommonExtProvider::get_top_k(self);
99+
let repetition_penalty = CommonExtProvider::get_repetition_penalty(self);
98100

99101
if let Some(nvext) = self.nvext() {
100102
let greedy = nvext.greed_sampling.unwrap_or(false);
@@ -130,10 +132,10 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
130132
best_of: None,
131133
frequency_penalty,
132134
presence_penalty,
133-
repetition_penalty: None,
135+
repetition_penalty,
134136
temperature,
135137
top_p,
136-
top_k: None,
138+
top_k,
137139
min_p: None,
138140
seed: None,
139141
use_beam_search: None,

lib/llm/src/protocols/openai/chat_completions.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,24 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
198198
.and_then(|nv| nv.guided_decoding_backend.as_ref()),
199199
)
200200
}
201+
202+
fn get_top_k(&self) -> Option<i32> {
203+
choose_with_deprecation(
204+
"top_k",
205+
self.common.top_k.as_ref(),
206+
self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()),
207+
)
208+
}
209+
210+
fn get_repetition_penalty(&self) -> Option<f32> {
211+
choose_with_deprecation(
212+
"repetition_penalty",
213+
self.common.repetition_penalty.as_ref(),
214+
self.nvext
215+
.as_ref()
216+
.and_then(|nv| nv.repetition_penalty.as_ref()),
217+
)
218+
}
201219
}
202220

203221
/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,

lib/llm/src/protocols/openai/common_ext.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
// SPDX-License-Identifier: Apache-2.0
33

4+
use super::nvext::validate_top_k;
45
use derive_builder::Builder;
56
use serde::{Deserialize, Serialize};
67
use validator::Validate;
@@ -21,6 +22,19 @@ pub struct CommonExt {
2122
#[builder(default, setter(strip_option))]
2223
pub min_tokens: Option<u32>,
2324

25+
/// Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.
26+
#[serde(default, skip_serializing_if = "Option::is_none")]
27+
#[builder(default, setter(strip_option))]
28+
#[validate(custom(function = "validate_top_k"))]
29+
pub top_k: Option<i32>,
30+
31+
/// How much to penalize tokens based on how frequently they occur in the text.
32+
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
33+
#[serde(default, skip_serializing_if = "Option::is_none")]
34+
#[builder(default, setter(strip_option))]
35+
#[validate(range(exclusive_min = 0.0, max = 2.0))]
36+
pub repetition_penalty: Option<f32>,
37+
2438
/// Guided Decoding Options
2539
/// If specified, the output will be a JSON object. Can be a string, an object, or null.
2640
#[serde(default, skip_serializing_if = "Option::is_none")]
@@ -65,6 +79,10 @@ pub trait CommonExtProvider {
6579
fn get_guided_grammar(&self) -> Option<String>;
6680
fn get_guided_choice(&self) -> Option<Vec<String>>;
6781
fn get_guided_decoding_backend(&self) -> Option<String>;
82+
83+
/// Other sampling Options
84+
fn get_top_k(&self) -> Option<i32>;
85+
fn get_repetition_penalty(&self) -> Option<f32>;
6886
}
6987

7088
/// Helper function to emit deprecation warnings for nvext parameters
@@ -107,6 +125,8 @@ mod tests {
107125
let common_ext = CommonExt::builder().build().unwrap();
108126
assert_eq!(common_ext.ignore_eos, None);
109127
assert_eq!(common_ext.min_tokens, None);
128+
assert_eq!(common_ext.top_k, None);
129+
assert_eq!(common_ext.repetition_penalty, None);
110130
assert_eq!(common_ext.guided_json, None);
111131
assert_eq!(common_ext.guided_regex, None);
112132
assert_eq!(common_ext.guided_grammar, None);
@@ -119,6 +139,8 @@ mod tests {
119139
let common_ext = CommonExt::builder()
120140
.ignore_eos(true)
121141
.min_tokens(10)
142+
.top_k(50)
143+
.repetition_penalty(1.2)
122144
.guided_json(serde_json::json!({"key": "value"}))
123145
.guided_regex("regex".to_string())
124146
.guided_grammar("grammar".to_string())
@@ -129,6 +151,8 @@ mod tests {
129151

130152
assert_eq!(common_ext.ignore_eos, Some(true));
131153
assert_eq!(common_ext.min_tokens, Some(10));
154+
assert_eq!(common_ext.top_k, Some(50));
155+
assert_eq!(common_ext.repetition_penalty, Some(1.2));
132156
assert_eq!(
133157
common_ext.guided_json.as_ref(),
134158
Some(&serde_json::json!({"key": "value"}))
@@ -164,6 +188,8 @@ mod tests {
164188
let common_ext = CommonExt {
165189
ignore_eos: None,
166190
min_tokens: Some(0), // Should be valid (min = 0)
191+
top_k: None,
192+
repetition_penalty: None,
167193
guided_json: None,
168194
guided_regex: None,
169195
guided_grammar: None,
@@ -180,6 +206,8 @@ mod tests {
180206

181207
assert_eq!(common_ext.ignore_eos, None);
182208
assert_eq!(common_ext.min_tokens, None);
209+
assert_eq!(common_ext.top_k, None);
210+
assert_eq!(common_ext.repetition_penalty, None);
183211
assert!(common_ext.validate().is_ok());
184212
}
185213

@@ -190,6 +218,8 @@ mod tests {
190218

191219
assert_eq!(common_ext.ignore_eos, None);
192220
assert_eq!(common_ext.min_tokens, None);
221+
assert_eq!(common_ext.top_k, None);
222+
assert_eq!(common_ext.repetition_penalty, None);
193223
assert!(common_ext.validate().is_ok());
194224
}
195225

lib/llm/src/protocols/openai/completions.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,24 @@ impl CommonExtProvider for NvCreateCompletionRequest {
192192
.and_then(|nv| nv.guided_decoding_backend.as_ref()),
193193
)
194194
}
195+
196+
fn get_top_k(&self) -> Option<i32> {
197+
choose_with_deprecation(
198+
"top_k",
199+
self.common.top_k.as_ref(),
200+
self.nvext.as_ref().and_then(|nv| nv.top_k.as_ref()),
201+
)
202+
}
203+
204+
fn get_repetition_penalty(&self) -> Option<f32> {
205+
choose_with_deprecation(
206+
"repetition_penalty",
207+
self.common.repetition_penalty.as_ref(),
208+
self.nvext
209+
.as_ref()
210+
.and_then(|nv| nv.repetition_penalty.as_ref()),
211+
)
212+
}
195213
}
196214

197215
impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {

lib/llm/src/protocols/openai/nvext.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ pub struct NvExt {
3434
#[builder(default, setter(strip_option))] // NIM LLM might default to -1
3535
#[validate(custom(function = "validate_top_k"))]
3636
#[serde(default, skip_serializing_if = "Option::is_none")]
37-
pub top_k: Option<i64>,
37+
pub top_k: Option<i32>,
3838

3939
/// How much to penalize tokens based on how frequently they occur in the text.
4040
/// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
4141
#[builder(default, setter(strip_option))]
4242
#[validate(range(exclusive_min = 0.0, max = 2.0))]
43-
pub repetition_penalty: Option<f64>,
43+
pub repetition_penalty: Option<f32>,
4444

4545
/// If true, sampling will be forced to be greedy.
4646
/// The backend is responsible for selecting the correct backend-specific options to
@@ -118,7 +118,7 @@ fn validate_nv_ext(_nv_ext: &NvExt) -> Result<(), ValidationError> {
118118
Ok(())
119119
}
120120

121-
fn validate_top_k(top_k: i64) -> Result<(), ValidationError> {
121+
pub fn validate_top_k(top_k: i32) -> Result<(), ValidationError> {
122122
if top_k == -1 || (top_k >= 1) {
123123
return Ok(());
124124
}
@@ -200,7 +200,7 @@ mod tests {
200200
// Test invalid `top_k` validation using proptest
201201
proptest! {
202202
#[test]
203-
fn test_invalid_top_k_value(top_k in any::<i64>().prop_filter("Invalid top_k", |&k| k < -1 || (k > 0 && k < 1))) {
203+
fn test_invalid_top_k_value(top_k in any::<i32>().prop_filter("Invalid top_k", |&k| k < -1 || (k > 0 && k < 1))) {
204204
let nv_ext = NvExt::builder()
205205
.top_k(top_k)
206206
.build()
@@ -227,7 +227,7 @@ mod tests {
227227
// Test valid repetition_penalty values
228228
proptest! {
229229
#[test]
230-
fn test_valid_repetition_penalty_values(repetition_penalty in 0.01f64..=2.0f64) {
230+
fn test_valid_repetition_penalty_values(repetition_penalty in 0.01f32..=2.0f32) {
231231
let nv_ext = NvExt::builder()
232232
.repetition_penalty(repetition_penalty)
233233
.build()
@@ -241,7 +241,7 @@ mod tests {
241241
// Test invalid repetition_penalty values
242242
proptest! {
243243
#[test]
244-
fn test_invalid_repetition_penalty_values(repetition_penalty in -10.0f64..0.0f64) {
244+
fn test_invalid_repetition_penalty_values(repetition_penalty in -10.0f32..0.0f32) {
245245
let nv_ext = NvExt::builder()
246246
.repetition_penalty(repetition_penalty)
247247
.build()

lib/llm/tests/test_common_ext.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,26 @@ fn test_min_tokens_only_at_root_level() {
280280
let stop_conditions = request.extract_stop_conditions().unwrap();
281281
assert_eq!(stop_conditions.min_tokens, Some(150));
282282
}
283+
284+
#[test]
285+
fn test_sampling_parameters_extraction() {
286+
use dynamo_llm::protocols::common::SamplingOptionsProvider;
287+
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
288+
use dynamo_llm::protocols::openai::common_ext::CommonExt;
289+
290+
// Test that top_k and repetition_penalty are extracted in sampling options when passed a top level
291+
let request = NvCreateChatCompletionRequest {
292+
inner: Default::default(),
293+
common: CommonExt::builder()
294+
.top_k(42)
295+
.repetition_penalty(1.3)
296+
.build()
297+
.unwrap(),
298+
nvext: None,
299+
};
300+
301+
let sampling_options = request.extract_sampling_options().unwrap();
302+
303+
assert_eq!(sampling_options.top_k, Some(42));
304+
assert_eq!(sampling_options.repetition_penalty, Some(1.3));
305+
}

0 commit comments

Comments
 (0)