@@ -459,41 +459,6 @@ def test_retain_grad_hidden_states_attentions(self):
459459 def test_model_get_set_embeddings (self ):
460460 pass
461461
462- # override as the `logit_scale` parameter initialization is different for Blip
463- def test_initialization (self ):
464- config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
465-
466- configs_no_init = _config_zero_init (config )
467- for model_class in self .all_model_classes :
468- model = model_class (config = configs_no_init )
469- for name , param in model .named_parameters ():
470- if param .requires_grad :
471- # check if `logit_scale` is initialized as per the original implementation
472- if name == "logit_scale" :
473- self .assertAlmostEqual (
474- param .data .item (),
475- np .log (1 / 0.07 ),
476- delta = 1e-3 ,
477- msg = f"Parameter { name } of model { model_class } seems not properly initialized" ,
478- )
479- else :
480- # See PR #38607 (to avoid flakiness)
481- data = torch .flatten (param .data )
482- n_elements = torch .numel (data )
483- # skip 2.5% of elements on each side to avoid issues caused by `nn.init.trunc_normal_` described in
484- # https://github.com/huggingface/transformers/pull/27906#issuecomment-1846951332
485- n_elements_to_skip_on_each_side = int (n_elements * 0.025 )
486- data_to_check = torch .sort (data ).values
487- if n_elements_to_skip_on_each_side > 0 :
488- data_to_check = data_to_check [
489- n_elements_to_skip_on_each_side :- n_elements_to_skip_on_each_side
490- ]
491- self .assertIn (
492- ((data_to_check .mean () * 1e9 ).round () / 1e9 ).item (),
493- [0.0 , 1.0 ],
494- msg = f"Parameter { name } of model { model_class } seems not properly initialized" ,
495- )
496-
497462 def _create_and_check_torchscript (self , config , inputs_dict ):
498463 if not self .test_torchscript :
499464 self .skipTest (reason = "test_torchscript is set to False" )
@@ -990,30 +955,6 @@ def test_training_gradient_checkpointing_use_reentrant(self):
990955 def test_training_gradient_checkpointing_use_reentrant_false (self ):
991956 pass
992957
993- # override as the `logit_scale` parameter initialization is different for Blip
994- def test_initialization (self ):
995- config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
996-
997- configs_no_init = _config_zero_init (config )
998- for model_class in self .all_model_classes :
999- model = model_class (config = configs_no_init )
1000- for name , param in model .named_parameters ():
1001- if param .requires_grad :
1002- # check if `logit_scale` is initialized as per the original implementation
1003- if name == "logit_scale" :
1004- self .assertAlmostEqual (
1005- param .data .item (),
1006- np .log (1 / 0.07 ),
1007- delta = 1e-3 ,
1008- msg = f"Parameter { name } of model { model_class } seems not properly initialized" ,
1009- )
1010- else :
1011- self .assertIn (
1012- ((param .data .mean () * 1e9 ).round () / 1e9 ).item (),
1013- [0.0 , 1.0 ],
1014- msg = f"Parameter { name } of model { model_class } seems not properly initialized" ,
1015- )
1016-
1017958 def _create_and_check_torchscript (self , config , inputs_dict ):
1018959 if not self .test_torchscript :
1019960 self .skipTest (reason = "test_torchscript is set to False" )
@@ -1208,30 +1149,6 @@ def test_training_gradient_checkpointing(self):
12081149 loss = model (** inputs ).loss
12091150 loss .backward ()
12101151
1211- # override as the `logit_scale` parameter initialization is different for Blip
1212- def test_initialization (self ):
1213- config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
1214-
1215- configs_no_init = _config_zero_init (config )
1216- for model_class in self .all_model_classes :
1217- model = model_class (config = configs_no_init )
1218- for name , param in model .named_parameters ():
1219- if param .requires_grad :
1220- # check if `logit_scale` is initialized as per the original implementation
1221- if name == "logit_scale" :
1222- self .assertAlmostEqual (
1223- param .data .item (),
1224- np .log (1 / 0.07 ),
1225- delta = 1e-3 ,
1226- msg = f"Parameter { name } of model { model_class } seems not properly initialized" ,
1227- )
1228- else :
1229- self .assertIn (
1230- ((param .data .mean () * 1e9 ).round () / 1e9 ).item (),
1231- [0.0 , 1.0 ],
1232- msg = f"Parameter { name } of model { model_class } seems not properly initialized" ,
1233- )
1234-
12351152 def _create_and_check_torchscript (self , config , inputs_dict ):
12361153 if not self .test_torchscript :
12371154 self .skipTest (reason = "test_torchscript is set to False" )
0 commit comments