Skip to content

Commit 2b7f8dc

Browse files
committed
tmp commit
1 parent 2436b5d commit 2b7f8dc

File tree

6 files changed

+21
-91
lines changed

6 files changed

+21
-91
lines changed

src/transformers/configuration_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1224,7 +1224,8 @@ def get_text_config(self, decoder: Optional[bool] = None, encoder: Optional[bool
12241224
for text_config_name in possible_text_config_names:
12251225
if hasattr(self, text_config_name):
12261226
text_config = getattr(self, text_config_name, None)
1227-
if text_config is not None:
1227+
# Assumption: all text configs have a `vocab_size` attribute
1228+
if text_config is not None and "vocab_size" in text_config:
12281229
valid_text_config_names += [text_config_name]
12291230

12301231
if len(valid_text_config_names) == 1:

tests/generation/test_utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,27 +1591,36 @@ def test_generate_continue_from_past_key_values(self):
15911591

15921592
@pytest.mark.generate
15931593
def test_generate_continue_from_inputs_embeds(self):
1594-
"""Tests that we can continue generation from `inputs_embeds` and past key values returned from a previous `generate` call."""
1594+
"""
1595+
Tests that we can continue generation from `inputs_embeds` and past key values returned from a previous
1596+
`generate` call.
1597+
"""
15951598
for model_class in self.all_generative_model_classes:
1596-
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
1599+
# To be more precise: technically we can run this test on all models that have `inputs_embeds` or
1600+
# `decoder_inputs_embeds` in their signatures, but the main use case of this feature is on LLMs.
1601+
# Let's prevent overwrites and additional test logic by adding this constraint.
1602+
if model_class.main_input_name != "input_ids":
1603+
self.skipTest(reason="This test is only for models that use `input_ids` as their main input")
1604+
if "inputs_embeds" not in inspect.signature(model_class.prepare_inputs_for_generation).parameters:
1605+
self.skipTest(reason="This model does not support `inputs_embeds` in generation")
1606+
# these models have a different cache format/class
1607+
different_cache = ["gpt_bigcode", "zamba2"]
1608+
# these models require special input preparation logic for this test
1609+
non_llm = ["mllama", "idefics", "moshi"]
1610+
if any(model_name in model_class.__name__.lower() for model_name in different_cache + non_llm):
15971611
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
1598-
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
1599-
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
16001612

16011613
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
16021614

1603-
if "token_type_ids" in inputs_dict:
1604-
del inputs_dict["token_type_ids"]
1605-
1615+
# TODO (joao, raushan): this shouldn't be a constraint to this test, `decoder_inputs_embeds` exists
16061616
if config.is_encoder_decoder:
16071617
self.skipTest(reason="This model is encoder-decoder")
16081618
if not hasattr(config.get_text_config(decoder=True), "use_cache"):
16091619
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
16101620

16111621
model = model_class(config).to(torch_device).eval()
1612-
1613-
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters:
1614-
self.skipTest(reason="This model does not support `inputs_embeds` in generation")
1622+
if "token_type_ids" in inputs_dict:
1623+
del inputs_dict["token_type_ids"]
16151624

16161625
# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
16171626
outputs = model(**inputs_dict)

tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -436,10 +436,6 @@ def test_disk_offload(self):
436436
def test_past_key_values_format(self):
437437
pass
438438

439-
@unittest.skip(reason="BigCodeGPT has a non-standard KV cache format and breaks this test.")
440-
def test_generate_continue_from_inputs_embeds(self):
441-
pass
442-
443439
def test_gpt_bigcode_model(self):
444440
config_and_inputs = self.model_tester.prepare_config_and_inputs()
445441
self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs)

tests/models/idefics/test_modeling_idefics.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -491,11 +491,6 @@ def test_retain_grad_hidden_states_attentions(self):
491491
def test_generate_without_input_ids(self):
492492
pass
493493

494-
@pytest.mark.generate
495-
@unittest.skip(reason="""IDEFICS cannot generate with no images provided!""")
496-
def test_generate_continue_from_inputs_embeds(self):
497-
pass
498-
499494
@pytest.mark.generate
500495
@unittest.skip(reason="""IDEFICS cannot do contrastive generation yet and it is not worth fixing""")
501496
def test_contrastive_generate(self):
@@ -776,65 +771,6 @@ def test_generate_without_input_ids(self):
776771
)
777772
self.assertIsNotNone(output_ids_generate)
778773

779-
@pytest.mark.generate
780-
def test_generate_continue_from_inputs_embeds(self):
781-
"""Overwrite for IDEFICS: Ensure image attention mask is processed while continuing from `inputs_embeds`."""
782-
783-
for model_class in self.all_generative_model_classes:
784-
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
785-
print(inputs)
786-
787-
model = model_class(config).to(torch_device).eval()
788-
789-
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
790-
model.generation_config.forced_eos_token_id = None
791-
model.generation_config.use_cache = True
792-
793-
input_ids = inputs.pop("input_ids")
794-
input_embeds = model.get_input_embeddings()(input_ids)
795-
796-
generation_kwargs = {
797-
"return_dict_in_generate": True,
798-
"do_sample": False,
799-
}
800-
801-
inputs["inputs_embeds"] = input_embeds
802-
803-
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
804-
outputs = model.generate(**inputs, max_new_tokens=4, **generation_kwargs)
805-
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
806-
# inputs may need to be tweaked across `generate` calls (like the attention mask).
807-
initial_output = model.generate(**inputs, max_new_tokens=3, **generation_kwargs)
808-
inputs["past_key_values"] = initial_output.past_key_values
809-
810-
new_attention_len = input_ids.shape[1] + initial_output.sequences.shape[-1]
811-
continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1)
812-
inputs["inputs_embeds"] = continued_embeds
813-
814-
if "attention_mask" in inputs:
815-
inputs["attention_mask"] = torch.nn.functional.pad(
816-
inputs["attention_mask"],
817-
(0, new_attention_len - inputs["attention_mask"].shape[1]),
818-
mode="constant",
819-
value=1,
820-
)
821-
if "image_attention_mask" in inputs:
822-
inputs["image_attention_mask"] = inputs["image_attention_mask"][..., -1:, :]
823-
824-
cached_output = model.generate(**inputs, max_new_tokens=1, **generation_kwargs)
825-
826-
# Verify that the combined outputs match the full generation.
827-
combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1)
828-
self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist())
829-
for layer_idx in range(len(cached_output.past_key_values)):
830-
for kv_idx in range(len(cached_output.past_key_values[layer_idx])):
831-
self.assertTrue(
832-
torch.allclose(
833-
outputs.past_key_values[layer_idx][kv_idx],
834-
cached_output.past_key_values[layer_idx][kv_idx],
835-
)
836-
)
837-
838774
def _check_attentions_for_generate(
839775
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
840776
):

tests/models/moshi/test_modeling_moshi.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,6 @@ def test_disk_offload_bin(self):
362362
def test_disk_offload_safetensors(self):
363363
pass
364364

365-
@unittest.skip(reason="Test becomes too complex with Moshi requiring multiple input modalities.")
366-
def test_generate_continue_from_inputs_embeds(self):
367-
pass
368-
369365
@is_flaky(max_attempts=5, description="flaky on some models.")
370366
def test_save_load(self):
371367
super().test_save_load()
@@ -872,10 +868,6 @@ def test_disk_offload_bin(self):
872868
def test_disk_offload_safetensors(self):
873869
pass
874870

875-
@unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities")
876-
def test_generate_continue_from_inputs_embeds(self):
877-
pass
878-
879871
@is_flaky(max_attempts=5, description="flaky on some models.")
880872
def test_save_load(self):
881873
super().test_save_load()

tests/models/zamba2/test_modeling_zamba2.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,10 +360,6 @@ def test_past_key_values_format(self):
360360
all_cache_shapes.append([self_attention_cache_shape, self_attention_cache_shape])
361361
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
362362

363-
@unittest.skip(reason="Zamba2 has hybrid mamba cache.")
364-
def test_generate_continue_from_inputs_embeds(self):
365-
pass
366-
367363
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
368364
def test_multi_gpu_data_parallel_forward(self):
369365
pass

0 commit comments

Comments
 (0)