Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generating text with Llama 2 doesn't work when num_beams > 1 and only inputs_embeds is provided #29968

Closed
2 of 4 tasks
umbertocappellazzo opened this issue Mar 30, 2024 · 5 comments · Fixed by #29976
Closed
2 of 4 tasks

Comments

@umbertocappellazzo
Copy link

umbertocappellazzo commented Mar 30, 2024

System Info

  • transformers version: 4.39.0
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.2
  • Accelerate version: 0.28.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: True
  • Using distributed or parallel set-up in script?: No

Who can help?

@gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hello guys!

For my project, I’m using Llama2 as LLM which accepts multimodal tokens (audio/video + prompts + text). When I want to generate text given the audio and prompt token embeddings, everything works fine if I use greedy decoding. For example, if our text_embeddings has shape [1,10,4096], where 10 is the number of tokens and 4096 is the hidden size of Llama 2, I generate the output like:

from transformers import LlamaForCausalLM
import torch

llm_name = "meta-llama/Llama-2-7b-hf"
llm = LlamaForCausalLM.from_pretrained(llm_name)
text_embeddings = torch.randn(1,10,4096)

decoded_ids = llm.generate(inputs_embeds = text_embeddings, max_new_tokens = 10 
                            )

However, if I want to use beam search with N_beams=5, so I also include the num_beams= 5 parameter to the generate model, I get this error:

>>> decoded_ids = llm.generate(inputs_embeds = text_embeddings, max_new_tokens=10,num_beams=5)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1648, in generate
    result = self._beam_sample(
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 3402, in _beam_sample
    outputs = self(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1196, in forward
    outputs = self.model(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 990, in forward
    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1077, in _update_causal_mask
    causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
RuntimeError: The size of tensor a (10) must match the size of tensor b (0) at non-singleton dimension 0

What causes this error? Do I need to modify something when we want to use beam search? Maybe it depends on the fact that I’m using inputs_embeds rather than input_ids in the generate method and something must be adapted? Based on https://huggingface.co/docs/transformers/v4.38.2/en/generation_strategies#beam-search-decoding it seems like adding num_beams=5 should be sufficient.

Thank you for your help!

Expected behavior

By setting num_beams = N >1, I should swap from greedy decoding to beam search, but while with num_beams = 1 everything works fine, with num_beams >1 I get the above error. I've noticed some changes to the cache_position and similar attributes in the modeling_llama quite recently, maybe those pull requests fixed my error as well.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Mar 30, 2024

That looks like a regression and should not be failing. cc @gante I'll see what I can do in the mean time. The issue stems from _beam_sample that seems to slice the cache positions to torch.tensor([]). I think this needs a hard error if beam search is not possible with input_embeds (form the doc?) or a fix.

@umbertocappellazzo
Copy link
Author

umbertocappellazzo commented Mar 30, 2024

Thanks for looking into this! Since my task is speech-related + LLMs, usually beam search coding improves the performance over greedy decoding. Also, I'm aware of other projects that make use of a similar approach and everything works fine (e.g., https://github.com/Sally-SH/VSP-LLM/blob/main/src/vsp_llm.py#L396)

@ArthurZucker
Copy link
Collaborator

Okay, this was introduced by 7b87ecb04712eed50793e65a2b39376f4570fcf2 (#29467) and it's not expected

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Apr 1, 2024

>>> from transformers import LlamaForCausalLM, AutoTokenizer
>>> import torch

>>> llm_name = "meta-llama/Llama-2-7b-hf"
>>> llm = LlamaForCausalLM.from_pretrained(llm_name)
>>> tokenizer = AutoTokenizer.from_pretrained(llm_name)
>>> inputs = tokenizer.encode("my favorite condiment is", return_tensors="pt")

>>> text_embeddings = llm.get_input_embeddings()(inputs)
>>> decoded_ids = llm.generate(inputs_embeds=text_embeddings, max_new_tokens=10, num_beams=2)
>>> print(tokenizer.batch_decode(decoded_ids))
["ketchup.\nI'm not a"]

this is before the failing commit. So it is not a feature request, and a regression. I can confirm that the beams were properly taken into account.

@umbertocappellazzo
Copy link
Author

Thank you Arthur for fixing the issue! Waiting for the finale merge then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants