Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma2ForCausalLM: ValueError in prepare_inputs_for_generation when using custom input embeddings #32479

Closed
2 of 4 tasks
serteal opened this issue Aug 6, 2024 · 1 comment · Fixed by #32493
Closed
2 of 4 tasks
Labels

Comments

@serteal
Copy link

serteal commented Aug 6, 2024

System Info

  • transformers version: 4.44.0
  • Platform: Linux-5.4.0-189-generic-x86_64-with-glibc2.31
  • Python version: 3.10.14
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: 0.32.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: NO
  • Using GPU in script?: YES
  • GPU type: NVIDIA RTX A6000

Who can help?

@ArthurZucker
@gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_name="google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

messages = [
    {"role": "user", "content": "Write me a poem about Machine Learning."},
]
template = tokenizer.apply_chat_template(messages, tokenize=False)
input_ids = tokenizer(template, return_tensors="pt").to("cuda")

embedding_layer = model.get_input_embeddings()
inputs_embeds = embedding_layer(input_ids["input_ids"])

outputs = model.generate(
    inputs_embeds=inputs_embeds,
    max_new_tokens=32,
)

Expected behavior

I'm encountering an error when attempting to use custom input embeddings with the Gemma2ForCausalLM model's .generate() method. Specifically, I'm modifying the input embeddings for the model and then trying to generate output from these custom embeddings.

Expected behavior:
When using the inputs_embeds argument to supply custom input embeddings, the .generate() function should be able to process these embeddings and produce a generated output. This is what happens in other models such as "meta-llama/Meta-Llama-3-8B-Instruct".

Actual behavior:
The .generate() method raises a ValueError in the prepare_inputs_for_generation function. The error suggests that the inputs_embeds tensor doesn't have the expected shape, resulting in a "too many values to unpack (expected 2)" error.

The error trace looks like the following:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[5], line 21
     18 embedding_layer = model.get_input_embeddings()
     19 inputs_embeds = embedding_layer(input_ids["input_ids"])
---> 21 outputs = model.generate(
     22     inputs_embeds=inputs_embeds,
     23     max_new_tokens=32,
     24 )

File <full_path>/miniforge3/envs/myenv/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File <full_path>/miniforge3/envs/myenv/lib/python3.10/site-packages/transformers/generation/utils.py:2024, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2016     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2017         input_ids=input_ids,
   2018         expand_size=generation_config.num_return_sequences,
   2019         is_encoder_decoder=self.config.is_encoder_decoder,
   2020         **model_kwargs,
   2021     )
   2023     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2024     result = self._sample(
   2025         input_ids,
   2026         logits_processor=prepared_logits_processor,
   2027         logits_warper=prepared_logits_warper,
   2028         stopping_criteria=prepared_stopping_criteria,
   2029         generation_config=generation_config,
   2030         synced_gpus=synced_gpus,
   2031         streamer=streamer,
   2032         **model_kwargs,
   2033     )
   2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2036     # 11. prepare logits warper
   2037     prepared_logits_warper = (
   2038         self._get_logits_warper(generation_config, device=input_ids.device)
   2039         if generation_config.do_sample
   2040         else None
   2041     )

File <full_path>/miniforge3/envs/myenv/lib/python3.10/site-packages/transformers/generation/utils.py:2975, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2969 model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
   2971 while self._has_unfinished_sequences(
   2972     this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
   2973 ):
   2974     # prepare model inputs
-> 2975     model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2977     # prepare variable output controls (note: some models won't accept all output controls)
   2978     model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})

File <full_path>/miniforge3/envs/myenv/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:1223, in Gemma2ForCausalLM.prepare_inputs_for_generation(self, input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, position_ids, use_cache, **kwargs)
   1221 if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2:
   1222     if inputs_embeds is not None:
-> 1223         batch_size, sequence_length = inputs_embeds.shape
   1224         device = inputs_embeds.device
   1225     else:

ValueError: too many values to unpack (expected 2)
@serteal serteal added the bug label Aug 6, 2024
@molbap
Copy link
Contributor

molbap commented Aug 7, 2024

Thanks @serteal , indeed, looks like a recent addition broke this - inputs_embeds is always of shape [batch_size, sequence_length, embedding_dimension] so I'll open a PR for that!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants