Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow bos_token_id is None during the generation with inputs_embeds #29772

Merged
merged 3 commits into from
Mar 26, 2024
Merged

Allow bos_token_id is None during the generation with inputs_embeds #29772

merged 3 commits into from
Mar 26, 2024

Conversation

LZHgrla
Copy link
Contributor

@LZHgrla LZHgrla commented Mar 21, 2024

What does this PR do?

Allow bos_token_id is None during the generation with inputs_embeds.

This is important for multi-modal inputs / generation for LLMs whose bos_token_id is None

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change looks good to me - thanks for adding!

Just need to add a test to make sure:

  • Error is triggered when input_embeds is not passed and bos_token_id is None
  • Not triggered when input_embeds is passed and bos_token_id is None

cc @gante to confirm the desired behaviour

@LZHgrla
Copy link
Contributor Author

LZHgrla commented Mar 25, 2024

@amyeroberts @gante
Hi! I have added the related tests for this PR.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense -- BOS is not required in all cases.

Thank you for fixing 🙏

@gante
Copy link
Member

gante commented Mar 25, 2024

@amyeroberts should be ready for a final check

FYI @zucchini-nlp

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this and adding a test!

Just a small comment on the test logic which I think has to be addressed before merge


model.generate(inputs_embeds=inputs_embeds, max_length=20, bos_token_id=None)
with self.assertRaises(ValueError):
model.generate(max_length=20, bos_token_id=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think? Otherwise the error is being raised from the lack of inputs

Suggested change
model.generate(max_length=20, bos_token_id=None)
model.generate(input_ids, max_length=20, bos_token_id=None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts

I think we should not pass in input_ids. I tested the following code on the main branch and found that the generation works well (see case 3) If we pass in input_ids with bos_token_id=None.

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
article = "Today a dragon flew over Paris."
input_ids = tokenizer(article, return_tensors="pt").input_ids

# Case 1
ids = model.generate(max_length=20)[0]
print(tokenizer.decode(ids))
# �vedvedvedvedvedvedvedved Wh Wh Wh Wh Wh Wh Wh Wh Wh Wh Wh

# Case 2
ids = model.generate(input_ids, max_length=20)[0]
print(tokenizer.decode(ids))
# Today a dragon flew over Paris. fe fe fe fe

# Case 3
ids = model.generate(input_ids, max_length=20, bos_token_id=None)[0]
print(tokenizer.decode(ids))
# Today a dragon flew over Paris. fe fe fe fe

# Case 4, error
# Below code will raise a ValueError
# ids = model.generate(max_length=20, bos_token_id=None)[0]  

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right, sorry. Thanks for showing the cases so clearly!

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@amyeroberts amyeroberts merged commit 998b5bb into huggingface:main Mar 26, 2024
20 checks passed
hovnatan pushed a commit to hovnatan/transformers that referenced this pull request Mar 27, 2024
itazap pushed a commit that referenced this pull request May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants