-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error in _prepare_generated_length #32911
Comments
hey! Could you share a snippet with an error for reproducing? 🤗 |
Of course! |
Indeed there's an issue when we generate from embeds with static cache, since we rely on already-adjusted I can open a PR to fix this @ArthurZucker |
🫠 Let's first make sure this does not affect Gemma2 / Gemma on the v4.44-release, then let's fix for main |
I am not sure we test static cache + input embeds, let's add this |
I have also tried the v4.44-release version. The issue also occurred. I switch to the 4.45.dev because of the PR #32493 to fix the getting shape error in 4.44-release:
|
Yes, it is also in v4.44 as static cache never worked with input embeddings, I'll add a test oke |
Thank you for your updates. This PR fixed this problem with a small typo in Line 1824 generation.utils.py
|
System Info
transformers
version: 4.45.0.dev0Who can help?
@ArthurZucker @zucchini-nlp
There might be a logic error in the generation utils. When passing the input_embeds instead of the input_ids, the shape[-1] is 0, and the max_length is set to the max_new_tokens instead of the generation_config.max_new_tokens + input_ids_length. It further cause the size mismatch error in the _prepare_4d_causal_attention_mask_with_cache_position due to passing the target_length (when the target_length is less than the mask_length).
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
...
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Simply replace the input_ids to the input_embeds in gemma2.
Expected behavior
Support the passing of input_embeds normally like input_ids
The text was updated successfully, but these errors were encountered: