11// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22// SPDX-License-Identifier: Apache-2.0
33
4+ use super :: nvext:: validate_top_k;
45use derive_builder:: Builder ;
56use serde:: { Deserialize , Serialize } ;
67use validator:: Validate ;
@@ -21,6 +22,19 @@ pub struct CommonExt {
2122 #[ builder( default , setter( strip_option) ) ]
2223 pub min_tokens : Option < u32 > ,
2324
25+ /// Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens.
26+ #[ serde( default , skip_serializing_if = "Option::is_none" ) ]
27+ #[ builder( default , setter( strip_option) ) ]
28+ #[ validate( custom( function = "validate_top_k" ) ) ]
29+ pub top_k : Option < i32 > ,
30+
31+ /// How much to penalize tokens based on how frequently they occur in the text.
32+ /// A value of 1 means no penalty, while values larger than 1 discourage and values smaller encourage.
33+ #[ serde( default , skip_serializing_if = "Option::is_none" ) ]
34+ #[ builder( default , setter( strip_option) ) ]
35+ #[ validate( range( exclusive_min = 0.0 , max = 2.0 ) ) ]
36+ pub repetition_penalty : Option < f32 > ,
37+
2438 /// Guided Decoding Options
2539 /// If specified, the output will be a JSON object. Can be a string, an object, or null.
2640 #[ serde( default , skip_serializing_if = "Option::is_none" ) ]
@@ -65,6 +79,10 @@ pub trait CommonExtProvider {
6579 fn get_guided_grammar ( & self ) -> Option < String > ;
6680 fn get_guided_choice ( & self ) -> Option < Vec < String > > ;
6781 fn get_guided_decoding_backend ( & self ) -> Option < String > ;
82+
83+ /// Other sampling Options
84+ fn get_top_k ( & self ) -> Option < i32 > ;
85+ fn get_repetition_penalty ( & self ) -> Option < f32 > ;
6886}
6987
7088/// Helper function to emit deprecation warnings for nvext parameters
@@ -107,6 +125,8 @@ mod tests {
107125 let common_ext = CommonExt :: builder ( ) . build ( ) . unwrap ( ) ;
108126 assert_eq ! ( common_ext. ignore_eos, None ) ;
109127 assert_eq ! ( common_ext. min_tokens, None ) ;
128+ assert_eq ! ( common_ext. top_k, None ) ;
129+ assert_eq ! ( common_ext. repetition_penalty, None ) ;
110130 assert_eq ! ( common_ext. guided_json, None ) ;
111131 assert_eq ! ( common_ext. guided_regex, None ) ;
112132 assert_eq ! ( common_ext. guided_grammar, None ) ;
@@ -119,6 +139,8 @@ mod tests {
119139 let common_ext = CommonExt :: builder ( )
120140 . ignore_eos ( true )
121141 . min_tokens ( 10 )
142+ . top_k ( 50 )
143+ . repetition_penalty ( 1.2 )
122144 . guided_json ( serde_json:: json!( { "key" : "value" } ) )
123145 . guided_regex ( "regex" . to_string ( ) )
124146 . guided_grammar ( "grammar" . to_string ( ) )
@@ -129,6 +151,8 @@ mod tests {
129151
130152 assert_eq ! ( common_ext. ignore_eos, Some ( true ) ) ;
131153 assert_eq ! ( common_ext. min_tokens, Some ( 10 ) ) ;
154+ assert_eq ! ( common_ext. top_k, Some ( 50 ) ) ;
155+ assert_eq ! ( common_ext. repetition_penalty, Some ( 1.2 ) ) ;
132156 assert_eq ! (
133157 common_ext. guided_json. as_ref( ) ,
134158 Some ( & serde_json:: json!( { "key" : "value" } ) )
@@ -164,6 +188,8 @@ mod tests {
164188 let common_ext = CommonExt {
165189 ignore_eos : None ,
166190 min_tokens : Some ( 0 ) , // Should be valid (min = 0)
191+ top_k : None ,
192+ repetition_penalty : None ,
167193 guided_json : None ,
168194 guided_regex : None ,
169195 guided_grammar : None ,
@@ -180,6 +206,8 @@ mod tests {
180206
181207 assert_eq ! ( common_ext. ignore_eos, None ) ;
182208 assert_eq ! ( common_ext. min_tokens, None ) ;
209+ assert_eq ! ( common_ext. top_k, None ) ;
210+ assert_eq ! ( common_ext. repetition_penalty, None ) ;
183211 assert ! ( common_ext. validate( ) . is_ok( ) ) ;
184212 }
185213
@@ -190,6 +218,8 @@ mod tests {
190218
191219 assert_eq ! ( common_ext. ignore_eos, None ) ;
192220 assert_eq ! ( common_ext. min_tokens, None ) ;
221+ assert_eq ! ( common_ext. top_k, None ) ;
222+ assert_eq ! ( common_ext. repetition_penalty, None ) ;
193223 assert ! ( common_ext. validate( ) . is_ok( ) ) ;
194224 }
195225
0 commit comments