Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: generate() with max_new_tokens=0 produces a single token. #28579

Closed
wants to merge 4 commits into from

Conversation

danielkorat
Copy link
Contributor

@danielkorat danielkorat commented Jan 18, 2024

What does this PR do?

Currently, setting max_new_tokens=0 produces 1 token instead of 0, and the warning is unclear.
For example, for the following code:

checkpoint = "bigcode/tiny_starcoder_py"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint)
inputs = tokenizer("def print_hello_world():", return_tensors="pt")
max_new_tokens = 0
outputs = model.generate(**inputs,
                        pad_token_id=tokenizer.eos_token_id,
                        max_new_tokens=max_new_tokens)
input_length = len(inputs['input_ids'][0])
output_length = len(outputs[0])
print(f"\nTest:{input_length - output_length == max_new_tokens}")

The output is:

utils.py:1134: UserWarning: Input length of input_ids is 7, but `max_length` is set to 7. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
  warnings.warn(

Test: False

After the fix, this is the output:

`max_new_tokens`=0, no tokens will be generated.
utils.py:1134: UserWarning: Input length of input_ids is 7, but `max_length` is set to 7. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
  warnings.warn(

Test: True

(Note the new warning).

Currently fixed only for greedy_search(). Once this PR is reviewed, I'll add the fix to all other generation modes.
@gante @amyeroberts

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@gante
Copy link
Member

gante commented Jan 19, 2024

Hi @danielkorat 👋

In lower-level APIs like generate, we should strive to be more strict if some parameterization is causing problems to avoid creating unexpected patterns. In this case, generating with 0 tokens should not be possible, as there is no generation involved.

As such, I'd suggest increasing the severity of the associated warning to an exception instead of the changes proposed here :)

@danielkorat
Copy link
Contributor Author

Hi @danielkorat 👋

In lower-level APIs like generate, we should strive to be more strict if some parameterization is causing problems to avoid creating unexpected patterns. In this case, generating with 0 tokens should not be possible, as there is no generation involved.

As such, I'd suggest increasing the severity of the associated warning to an exception instead of the changes proposed here :)

Hi @gante, I made the requested change in a new PR: 28621.

@danielkorat danielkorat deleted the fix-zero-max-new-tokens branch November 3, 2024 12:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants