-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow bos_token_id is None
during the generation with inputs_embeds
#29772
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change looks good to me - thanks for adding!
Just need to add a test to make sure:
- Error is triggered when
input_embeds
is not passed andbos_token_id
is None - Not triggered when
input_embeds
is passed andbos_token_id
is None
cc @gante to confirm the desired behaviour
@amyeroberts @gante |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense -- BOS is not required in all cases.
Thank you for fixing 🙏
@amyeroberts should be ready for a final check FYI @zucchini-nlp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this and adding a test!
Just a small comment on the test logic which I think has to be addressed before merge
|
||
model.generate(inputs_embeds=inputs_embeds, max_length=20, bos_token_id=None) | ||
with self.assertRaises(ValueError): | ||
model.generate(max_length=20, bos_token_id=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think? Otherwise the error is being raised from the lack of inputs
model.generate(max_length=20, bos_token_id=None) | |
model.generate(input_ids, max_length=20, bos_token_id=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should not pass in input_ids
. I tested the following code on the main
branch and found that the generation works well (see case 3) If we pass in input_ids
with bos_token_id=None
.
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
article = "Today a dragon flew over Paris."
input_ids = tokenizer(article, return_tensors="pt").input_ids
# Case 1
ids = model.generate(max_length=20)[0]
print(tokenizer.decode(ids))
# �vedvedvedvedvedvedvedved Wh Wh Wh Wh Wh Wh Wh Wh Wh Wh Wh
# Case 2
ids = model.generate(input_ids, max_length=20)[0]
print(tokenizer.decode(ids))
# Today a dragon flew over Paris. fe fe fe fe
# Case 3
ids = model.generate(input_ids, max_length=20, bos_token_id=None)[0]
print(tokenizer.decode(ids))
# Today a dragon flew over Paris. fe fe fe fe
# Case 4, error
# Below code will raise a ValueError
# ids = model.generate(max_length=20, bos_token_id=None)[0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, you're right, sorry. Thanks for showing the cases so clearly!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
…s` (huggingface#29772) * update * add ut * update
…s` (#29772) * update * add ut * update
What does this PR do?
Allow
bos_token_id is None
during the generation withinputs_embeds
.This is important for multi-modal inputs / generation for LLMs whose
bos_token_id
isNone
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.