@@ -271,7 +271,6 @@ class GenerationConfig(PushToHubMixin):
271271
272272 def __init__ (self , ** kwargs ):
273273 # Parameters that control the length of the output
274- # if the default `max_length` is updated here, make sure to update the `generate` tests following https://github.com/huggingface/transformers/pull/25030
275274 self .max_length = kwargs .pop ("max_length" , 20 )
276275 self .max_new_tokens = kwargs .pop ("max_new_tokens" , None )
277276 self .min_length = kwargs .pop ("min_length" , 0 )
@@ -407,32 +406,34 @@ def validate(self, is_init=False):
407406 "used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
408407 + fix_location
409408 )
410- if self .temperature != 1.0 :
409+ if self .temperature is not None and self . temperature != 1.0 :
411410 warnings .warn (
412411 greedy_wrong_parameter_msg .format (flag_name = "temperature" , flag_value = self .temperature ),
413412 UserWarning ,
414413 )
415- if self .top_p != 1.0 :
414+ if self .top_p is not None and self . top_p != 1.0 :
416415 warnings .warn (
417416 greedy_wrong_parameter_msg .format (flag_name = "top_p" , flag_value = self .top_p ),
418417 UserWarning ,
419418 )
420- if self .typical_p != 1.0 :
419+ if self .typical_p is not None and self . typical_p != 1.0 :
421420 warnings .warn (
422421 greedy_wrong_parameter_msg .format (flag_name = "typical_p" , flag_value = self .typical_p ),
423422 UserWarning ,
424423 )
425- if self .top_k != 50 and self .penalty_alpha is None : # contrastive search uses top_k
424+ if (
425+ self .top_k is not None and self .top_k != 50 and self .penalty_alpha is None
426+ ): # contrastive search uses top_k
426427 warnings .warn (
427428 greedy_wrong_parameter_msg .format (flag_name = "top_k" , flag_value = self .top_k ),
428429 UserWarning ,
429430 )
430- if self .epsilon_cutoff != 0.0 :
431+ if self .epsilon_cutoff is not None and self . epsilon_cutoff != 0.0 :
431432 warnings .warn (
432433 greedy_wrong_parameter_msg .format (flag_name = "epsilon_cutoff" , flag_value = self .epsilon_cutoff ),
433434 UserWarning ,
434435 )
435- if self .eta_cutoff != 0.0 :
436+ if self .eta_cutoff is not None and self . eta_cutoff != 0.0 :
436437 warnings .warn (
437438 greedy_wrong_parameter_msg .format (flag_name = "eta_cutoff" , flag_value = self .eta_cutoff ),
438439 UserWarning ,
@@ -453,21 +454,21 @@ def validate(self, is_init=False):
453454 single_beam_wrong_parameter_msg .format (flag_name = "early_stopping" , flag_value = self .early_stopping ),
454455 UserWarning ,
455456 )
456- if self .num_beam_groups != 1 :
457+ if self .num_beam_groups is not None and self . num_beam_groups != 1 :
457458 warnings .warn (
458459 single_beam_wrong_parameter_msg .format (
459460 flag_name = "num_beam_groups" , flag_value = self .num_beam_groups
460461 ),
461462 UserWarning ,
462463 )
463- if self .diversity_penalty != 0.0 :
464+ if self .diversity_penalty is not None and self . diversity_penalty != 0.0 :
464465 warnings .warn (
465466 single_beam_wrong_parameter_msg .format (
466467 flag_name = "diversity_penalty" , flag_value = self .diversity_penalty
467468 ),
468469 UserWarning ,
469470 )
470- if self .length_penalty != 1.0 :
471+ if self .length_penalty is not None and self . length_penalty != 1.0 :
471472 warnings .warn (
472473 single_beam_wrong_parameter_msg .format (flag_name = "length_penalty" , flag_value = self .length_penalty ),
473474 UserWarning ,
@@ -491,7 +492,7 @@ def validate(self, is_init=False):
491492 raise ValueError (
492493 constrained_wrong_parameter_msg .format (flag_name = "do_sample" , flag_value = self .do_sample )
493494 )
494- if self .num_beam_groups != 1 :
495+ if self .num_beam_groups is not None and self . num_beam_groups != 1 :
495496 raise ValueError (
496497 constrained_wrong_parameter_msg .format (
497498 flag_name = "num_beam_groups" , flag_value = self .num_beam_groups
@@ -1000,6 +1001,9 @@ def update(self, **kwargs):
10001001 setattr (self , key , value )
10011002 to_remove .append (key )
10021003
1003- # remove all the attributes that were updated, without modifying the input dict
1004+ # Confirm that the updated instance is still valid
1005+ self .validate ()
1006+
1007+ # Remove all the attributes that were updated, without modifying the input dict
10041008 unused_kwargs = {key : value for key , value in kwargs .items () if key not in to_remove }
10051009 return unused_kwargs
0 commit comments