@@ -184,6 +184,9 @@ class SamplingParams(
184184 allowed_token_ids: If provided, the engine will construct a logits
185185 processor which only retains scores for the given token ids.
186186 Defaults to None.
187+ extra_args: Arbitrary additional args, that can be used by custom
188+ sampling implementations. Not used by any in-tree sampling
189+ implementations.
187190 """
188191
189192 n : int = 1
@@ -227,6 +230,7 @@ class SamplingParams(
227230 guided_decoding : Optional [GuidedDecodingParams ] = None
228231 logit_bias : Optional [dict [int , float ]] = None
229232 allowed_token_ids : Optional [list [int ]] = None
233+ extra_args : Optional [dict [str , Any ]] = None
230234
231235 @staticmethod
232236 def from_optional (
@@ -259,6 +263,7 @@ def from_optional(
259263 guided_decoding : Optional [GuidedDecodingParams ] = None ,
260264 logit_bias : Optional [Union [dict [int , float ], dict [str , float ]]] = None ,
261265 allowed_token_ids : Optional [list [int ]] = None ,
266+ extra_args : Optional [dict [str , Any ]] = None ,
262267 ) -> "SamplingParams" :
263268 if logit_bias is not None :
264269 # Convert token_id to integer
@@ -300,6 +305,7 @@ def from_optional(
300305 guided_decoding = guided_decoding ,
301306 logit_bias = logit_bias ,
302307 allowed_token_ids = allowed_token_ids ,
308+ extra_args = extra_args ,
303309 )
304310
305311 def __post_init__ (self ) -> None :
@@ -509,7 +515,8 @@ def __repr__(self) -> str:
509515 "spaces_between_special_tokens="
510516 f"{ self .spaces_between_special_tokens } , "
511517 f"truncate_prompt_tokens={ self .truncate_prompt_tokens } , "
512- f"guided_decoding={ self .guided_decoding } )" )
518+ f"guided_decoding={ self .guided_decoding } , "
519+ f"extra_args={ self .extra_args } )" )
513520
514521
515522class BeamSearchParams (
0 commit comments