@@ -1209,7 +1209,7 @@ def test_generate_from_inputs_embeds(self, _, num_beams):
12091209
12101210 # This test is for decoder-only models (encoder-decoder models have native input embeddings support in the
12111211 # decoder)
1212- if config .get_text_config ( decoder = True ). is_encoder_decoder :
1212+ if config .is_encoder_decoder :
12131213 continue
12141214 config .is_decoder = True
12151215
@@ -1288,7 +1288,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
12881288
12891289 config , inputs_dict = self .prepare_config_and_inputs_for_generate ()
12901290
1291- if config .get_text_config ( decoder = True ). is_encoder_decoder :
1291+ if config .is_encoder_decoder :
12921292 self .skipTest (reason = "This model is encoder-decoder and has Encoder-Decoder Cache" )
12931293
12941294 model = model_class (config ).to (torch_device ).eval ()
@@ -1439,7 +1439,7 @@ def test_generate_continue_from_inputs_embeds(self):
14391439 if "token_type_ids" in inputs_dict :
14401440 del inputs_dict ["token_type_ids" ]
14411441
1442- if config .get_text_config ( decoder = True ). is_encoder_decoder :
1442+ if config .is_encoder_decoder :
14431443 self .skipTest (reason = "This model is encoder-decoder" )
14441444 # TODO (joao, raushan): the correct line below is `if not hasattr(config.get_text_config(), "use_cache")`,
14451445 # but it breaks a few models. Fix and then apply `has_similar_generate_outputs` pattern
@@ -1512,7 +1512,7 @@ def test_generate_with_static_cache(self):
15121512 set_config_for_less_flaky_test (config )
15131513 main_input = inputs_dict [model_class .main_input_name ]
15141514
1515- if config .get_text_config ( decoder = True ). is_encoder_decoder :
1515+ if config .is_encoder_decoder :
15161516 self .skipTest (reason = "This model is encoder-decoder and has Encoder-Decoder Cache" )
15171517
15181518 config .is_decoder = True
@@ -1567,10 +1567,7 @@ def test_generate_with_quant_cache(self):
15671567 for model_class in self .all_generative_model_classes :
15681568 config , inputs_dict = self .prepare_config_and_inputs_for_generate ()
15691569
1570- if (
1571- config .get_text_config (decoder = True ).is_encoder_decoder
1572- or not model_class ._supports_default_dynamic_cache ()
1573- ):
1570+ if config .is_encoder_decoder or not model_class ._supports_default_dynamic_cache ():
15741571 self .skipTest (reason = "This model does not support the quantized cache format" )
15751572
15761573 config .is_decoder = True
@@ -1670,7 +1667,7 @@ def test_generate_compile_model_forward_fullgraph(self):
16701667 if not has_defined_cache_implementation :
16711668 decoder_cache = (
16721669 gen_out .past_key_values .self_attention_cache
1673- if config .get_text_config ( decoder = True ). is_encoder_decoder
1670+ if config .is_encoder_decoder
16741671 else gen_out .past_key_values
16751672 )
16761673 self .assertTrue (isinstance (decoder_cache , DynamicCache ))
@@ -1696,7 +1693,7 @@ def test_generate_compile_model_forward_fullgraph(self):
16961693 # sanity checks
16971694 decoder_cache = (
16981695 gen_out .past_key_values .self_attention_cache
1699- if config .get_text_config ( decoder = True ). is_encoder_decoder
1696+ if config .is_encoder_decoder
17001697 else gen_out .past_key_values
17011698 )
17021699 self .assertFalse (isinstance (decoder_cache , DynamicCache ))
0 commit comments