@@ -456,41 +456,6 @@ def test_retain_grad_hidden_states_attentions(self):
456456 def test_model_get_set_embeddings (self ):
457457 pass
458458
459- # override as the `logit_scale` parameter initialization is different for Blip
460- def test_initialization (self ):
461- config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
462-
463- configs_no_init = _config_zero_init (config )
464- for model_class in self .all_model_classes :
465- model = model_class (config = configs_no_init )
466- for name , param in model .named_parameters ():
467- if param .requires_grad :
468- # check if `logit_scale` is initialized as per the original implementation
469- if name == "logit_scale" :
470- self .assertAlmostEqual (
471- param .data .item (),
472- np .log (1 / 0.07 ),
473- delta = 1e-3 ,
474- msg = f"Parameter { name } of model { model_class } seems not properly initialized" ,
475- )
476- else :
477- # See PR #38607 (to avoid flakiness)
478- data = torch .flatten (param .data )
479- n_elements = torch .numel (data )
480- # skip 2.5% of elements on each side to avoid issues caused by `nn.init.trunc_normal_` described in
481- # https://github.com/huggingface/transformers/pull/27906#issuecomment-1846951332
482- n_elements_to_skip_on_each_side = int (n_elements * 0.025 )
483- data_to_check = torch .sort (data ).values
484- if n_elements_to_skip_on_each_side > 0 :
485- data_to_check = data_to_check [
486- n_elements_to_skip_on_each_side :- n_elements_to_skip_on_each_side
487- ]
488- self .assertIn (
489- ((data_to_check .mean () * 1e9 ).round () / 1e9 ).item (),
490- [0.0 , 1.0 ],
491- msg = f"Parameter { name } of model { model_class } seems not properly initialized" ,
492- )
493-
494459 def _create_and_check_torchscript (self , config , inputs_dict ):
495460 if not self .test_torchscript :
496461 self .skipTest (reason = "test_torchscript is set to False" )
@@ -981,30 +946,6 @@ def test_training_gradient_checkpointing_use_reentrant(self):
981946 def test_training_gradient_checkpointing_use_reentrant_false (self ):
982947 pass
983948
984- # override as the `logit_scale` parameter initialization is different for Blip
985- def test_initialization (self ):
986- config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
987-
988- configs_no_init = _config_zero_init (config )
989- for model_class in self .all_model_classes :
990- model = model_class (config = configs_no_init )
991- for name , param in model .named_parameters ():
992- if param .requires_grad :
993- # check if `logit_scale` is initialized as per the original implementation
994- if name == "logit_scale" :
995- self .assertAlmostEqual (
996- param .data .item (),
997- np .log (1 / 0.07 ),
998- delta = 1e-3 ,
999- msg = f"Parameter { name } of model { model_class } seems not properly initialized" ,
1000- )
1001- else :
1002- self .assertIn (
1003- ((param .data .mean () * 1e9 ).round () / 1e9 ).item (),
1004- [0.0 , 1.0 ],
1005- msg = f"Parameter { name } of model { model_class } seems not properly initialized" ,
1006- )
1007-
1008949 def _create_and_check_torchscript (self , config , inputs_dict ):
1009950 if not self .test_torchscript :
1010951 self .skipTest (reason = "test_torchscript is set to False" )
@@ -1194,30 +1135,6 @@ def test_training_gradient_checkpointing(self):
11941135 loss = model (** inputs ).loss
11951136 loss .backward ()
11961137
1197- # override as the `logit_scale` parameter initialization is different for Blip
1198- def test_initialization (self ):
1199- config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
1200-
1201- configs_no_init = _config_zero_init (config )
1202- for model_class in self .all_model_classes :
1203- model = model_class (config = configs_no_init )
1204- for name , param in model .named_parameters ():
1205- if param .requires_grad :
1206- # check if `logit_scale` is initialized as per the original implementation
1207- if name == "logit_scale" :
1208- self .assertAlmostEqual (
1209- param .data .item (),
1210- np .log (1 / 0.07 ),
1211- delta = 1e-3 ,
1212- msg = f"Parameter { name } of model { model_class } seems not properly initialized" ,
1213- )
1214- else :
1215- self .assertIn (
1216- ((param .data .mean () * 1e9 ).round () / 1e9 ).item (),
1217- [0.0 , 1.0 ],
1218- msg = f"Parameter { name } of model { model_class } seems not properly initialized" ,
1219- )
1220-
12211138 def _create_and_check_torchscript (self , config , inputs_dict ):
12221139 if not self .test_torchscript :
12231140 self .skipTest (reason = "test_torchscript is set to False" )
0 commit comments