-
Notifications
You must be signed in to change notification settings - Fork 26.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generating text with Llama 2 doesn't work when num_beams > 1 and only inputs_embeds is provided #29968
Comments
That looks like a regression and should not be failing. cc @gante I'll see what I can do in the mean time. The issue stems from |
Thanks for looking into this! Since my task is speech-related + LLMs, usually beam search coding improves the performance over greedy decoding. Also, I'm aware of other projects that make use of a similar approach and everything works fine (e.g., https://github.com/Sally-SH/VSP-LLM/blob/main/src/vsp_llm.py#L396) |
Okay, this was introduced by |
>>> from transformers import LlamaForCausalLM, AutoTokenizer
>>> import torch
>>> llm_name = "meta-llama/Llama-2-7b-hf"
>>> llm = LlamaForCausalLM.from_pretrained(llm_name)
>>> tokenizer = AutoTokenizer.from_pretrained(llm_name)
>>> inputs = tokenizer.encode("my favorite condiment is", return_tensors="pt")
>>> text_embeddings = llm.get_input_embeddings()(inputs)
>>> decoded_ids = llm.generate(inputs_embeds=text_embeddings, max_new_tokens=10, num_beams=2)
>>> print(tokenizer.batch_decode(decoded_ids))
["ketchup.\nI'm not a"] this is before the failing commit. So it is not a feature request, and a regression. I can confirm that the beams were properly taken into account. |
Thank you Arthur for fixing the issue! Waiting for the finale merge then. |
System Info
transformers
version: 4.39.0Who can help?
@gante
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Hello guys!
For my project, I’m using Llama2 as LLM which accepts multimodal tokens (audio/video + prompts + text). When I want to generate text given the audio and prompt token embeddings, everything works fine if I use greedy decoding. For example, if our text_embeddings has shape [1,10,4096], where 10 is the number of tokens and 4096 is the hidden size of Llama 2, I generate the output like:
However, if I want to use beam search with N_beams=5, so I also include the num_beams= 5 parameter to the generate model, I get this error:
What causes this error? Do I need to modify something when we want to use beam search? Maybe it depends on the fact that I’m using inputs_embeds rather than input_ids in the generate method and something must be adapted? Based on https://huggingface.co/docs/transformers/v4.38.2/en/generation_strategies#beam-search-decoding it seems like adding num_beams=5 should be sufficient.
Thank you for your help!
Expected behavior
By setting num_beams = N >1, I should swap from greedy decoding to beam search, but while with num_beams = 1 everything works fine, with num_beams >1 I get the above error. I've noticed some changes to the cache_position and similar attributes in the modeling_llama quite recently, maybe those pull requests fixed my error as well.
The text was updated successfully, but these errors were encountered: