-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: validate arguments #18218
Generate: validate arguments #18218
Conversation
The documentation is not available anymore as the PR was closed or merged. |
7c1763d
to
2cdf84b
Compare
@@ -572,14 +572,14 @@ def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_to | |||
@staticmethod | |||
def _expand_inputs_for_generation( | |||
input_ids: torch.LongTensor, | |||
expand_size: int = 1, | |||
num_return_sequences: int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed this arg as it is a private method, it is more readable, and it is useful for name matching (as you'll see below)
# Excludes arguments that are handled before calling the any model function | ||
if self.config.is_encoder_decoder: | ||
for key in ["decoder_input_ids"]: | ||
model_kwargs.pop(key, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If decoder_input_ids
is present, it will be converted to input_ids
# Transfo_XL does not use have "attention_mask" as an argument, and it is harmless (it is being passed in the | ||
# tests, through) | ||
if "transfoxl" in str(self).lower(): | ||
model_kwargs.pop("attention_mask", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ugly ad hoc exception, but the alternative would be to rewrite most GenerationMixin tests for this particular model (most pass attention_mask=attention_mask
when calling generate)
if unused_model_args: | ||
raise ValueError( | ||
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" | ||
" generate arguments will also show up in this list)" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is an example of the output for
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
prompt = tokenizer(["hello world"], return_tensors="pt")
model.generate(**prompt, do_samples=True, foo="bar")
generation_method_name: str, | ||
supporting_objects: List[Callable], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each generation submethod has a set of preparation functions/classes (hence objects
), and some input arguments are consumed there -- we will need their signature to do the correct detection.
if "do_early_stopping" in generation_method_args: | ||
generation_method_args.remove("do_early_stopping") | ||
generation_method_args.add("early_stopping") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The early_stopping
is consumed by a class that has do_early_stopping
as argument. Since the class is public, I can't touch it :( At most I can add a new argument doing the same, but probably not worth it.
raise ValueError( | ||
f"From the generation arguments, `{generation_method_name}` was triggered. The following arguments are" | ||
f" not used by `{generation_method_name}`: {unused_args}. Please remove them from the generation" | ||
" arguments or check the documentation for more information." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is an example of the output for
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
prompt = tokenizer(["hello world"], return_tensors="pt")
model.generate(**prompt, num_return_sequences=2, temperature=2.0)
if is_group_beam_gen_mode and do_sample is True: | ||
raise ValueError( | ||
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved to the if is_group_beam_gen_mode:
block
if num_return_sequences > 1: | ||
raise ValueError( | ||
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant with the new checks
@@ -2052,7 +2052,7 @@ def test_max_new_tokens_decoder_only(self): | |||
|
|||
# max_new_tokens and max_length serve the same purpose and should not be used together. | |||
with self.assertWarns(UserWarning): | |||
gpt2_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20) | |||
gpt2_model.generate(input_ids=input_ids, max_new_tokens=10, max_length=20) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gpt2 is decoder-only
@@ -360,7 +360,7 @@ def test_generate_fp16(self): | |||
if torch_device == "cuda": | |||
model.half() | |||
model.generate(**input_dict) | |||
model.generate(**input_dict, do_sample=True, early_stopping=False, num_return_sequences=3) | |||
model.generate(**input_dict, do_sample=True, num_return_sequences=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
greedy_search doesn't accept early_stopping
if num_beam_groups is not None and num_beam_groups > 1: | ||
raise ValueError("`num_beam_groups` not supported yet for constrained generation.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First, I think it's a tremendous idea to have better argument validation in generate. This PR tackles that problem in two ways:
- validating model kwargs (and thus catching typos)
- validating the arguments passed are actually used by the generation algorithm picked.
1 is easier and I have no problem with that part of the PR, maybe it should be done in its first own PR before diving deeper into 2 :-)
For 2, the way you chose feels very very magical with lots of ad-hoc code that is going to be hard to maintain. I wonder if it wouldn't be better to just centralize all input passed like you did in generation_inputs
then have each of the private methods of generate returns its regular result as well as the unused inputs, then at the end of the function you can inspect the unused inputs to check they are empty.
I also think this case warrants a warning more than an error by the way.
Yeah, I agree, that was the number 1 reason why I left so many comments and caveats. It works but would be annoying to maintain. (@sgugger) If I got it right, the suggestion was to pop used arguments from Meanwhile, I'm going to do as suggested, and move the model kwargs validation to its own PR :) |
No, something more like |
Closing in place of two PRs:
|
What does this PR do?
NOTE: this PR is very experimental, feel free to trash it in the review process :)
A common cause for issues in
generate
is around it not behaving as expected, as arguments can be silently ignored as part of the selected generation submethod (greedy_search, sample, ...). Typos also often fly under the radar, as the method accepts**model_kwargs
, which in turn are passed to models that also accept**kwargs
.This PR adds argument validation to
generate
in two separate steps:model_kwargs
are verified as soon as the method is called. Only arguments that the model actually uses inprepare_inputs_for_generation
or in its forward pass are accepted. This means that typos are caught immediately. The exception enumerates all arguments that triggered this failed check, so the user can correct them.Although I think the checks are super useful, the code around it is not the prettiest. The first check has some logic for edge cases, and the second case requires passing the list of methods that will be called before the submethod in question. The PR is heavily commented in GH, feel free to cast your judgment!
P.S.: (seemingly) unrelated accelerate tests are failing in
run_examples_torch
Related issues