@@ -491,11 +491,6 @@ def test_retain_grad_hidden_states_attentions(self):
491491 def test_generate_without_input_ids (self ):
492492 pass
493493
494- @pytest .mark .generate
495- @unittest .skip (reason = """IDEFICS cannot generate with no images provided!""" )
496- def test_generate_continue_from_inputs_embeds (self ):
497- pass
498-
499494 @pytest .mark .generate
500495 @unittest .skip (reason = """IDEFICS cannot do contrastive generation yet and it is not worth fixing""" )
501496 def test_contrastive_generate (self ):
@@ -776,65 +771,6 @@ def test_generate_without_input_ids(self):
776771 )
777772 self .assertIsNotNone (output_ids_generate )
778773
779- @pytest .mark .generate
780- def test_generate_continue_from_inputs_embeds (self ):
781- """Overwrite for IDEFICS: Ensure image attention mask is processed while continuing from `inputs_embeds`."""
782-
783- for model_class in self .all_generative_model_classes :
784- config , inputs = self .model_tester .prepare_config_and_inputs_for_common ()
785- print (inputs )
786-
787- model = model_class (config ).to (torch_device ).eval ()
788-
789- model .generation_config .pad_token_id = model .generation_config .eos_token_id = - 1
790- model .generation_config .forced_eos_token_id = None
791- model .generation_config .use_cache = True
792-
793- input_ids = inputs .pop ("input_ids" )
794- input_embeds = model .get_input_embeddings ()(input_ids )
795-
796- generation_kwargs = {
797- "return_dict_in_generate" : True ,
798- "do_sample" : False ,
799- }
800-
801- inputs ["inputs_embeds" ] = input_embeds
802-
803- # Traditional way of generating text, with `return_dict_in_generate` to return the past key values
804- outputs = model .generate (** inputs , max_new_tokens = 4 , ** generation_kwargs )
805- # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
806- # inputs may need to be tweaked across `generate` calls (like the attention mask).
807- initial_output = model .generate (** inputs , max_new_tokens = 3 , ** generation_kwargs )
808- inputs ["past_key_values" ] = initial_output .past_key_values
809-
810- new_attention_len = input_ids .shape [1 ] + initial_output .sequences .shape [- 1 ]
811- continued_embeds = torch .cat ([input_embeds , model .get_input_embeddings ()(initial_output .sequences )], dim = 1 )
812- inputs ["inputs_embeds" ] = continued_embeds
813-
814- if "attention_mask" in inputs :
815- inputs ["attention_mask" ] = torch .nn .functional .pad (
816- inputs ["attention_mask" ],
817- (0 , new_attention_len - inputs ["attention_mask" ].shape [1 ]),
818- mode = "constant" ,
819- value = 1 ,
820- )
821- if "image_attention_mask" in inputs :
822- inputs ["image_attention_mask" ] = inputs ["image_attention_mask" ][..., - 1 :, :]
823-
824- cached_output = model .generate (** inputs , max_new_tokens = 1 , ** generation_kwargs )
825-
826- # Verify that the combined outputs match the full generation.
827- combined_output_sequences = torch .concat ([initial_output .sequences , cached_output .sequences ], axis = 1 )
828- self .assertListEqual (outputs .sequences .tolist (), combined_output_sequences .tolist ())
829- for layer_idx in range (len (cached_output .past_key_values )):
830- for kv_idx in range (len (cached_output .past_key_values [layer_idx ])):
831- self .assertTrue (
832- torch .allclose (
833- outputs .past_key_values [layer_idx ][kv_idx ],
834- cached_output .past_key_values [layer_idx ][kv_idx ],
835- )
836- )
837-
838774 def _check_attentions_for_generate (
839775 self , batch_size , attentions , prompt_length , output_length , config , decoder_past_key_values
840776 ):
0 commit comments