-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: validate model_kwargs
on TF (and catch typos in generate arguments)
#18651
Conversation
The documentation is not available anymore as the PR was closed or merged. |
class UtilsFunctionsTest(unittest.TestCase): | ||
|
||
# tests whether the top_k_top_p_filtering function behaves as expected | ||
def test_top_k_top_p_filtering(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved from test_modeling_tf_common
, no changes
class TFGenerationIntegrationTests(unittest.TestCase): | ||
|
||
@slow | ||
def test_generate_tf_function_export(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved from test_modeling_tf_common
, added the @slow
(takes >30s)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool that makes sense (similar to PyTorch).
Also just FYI in PyTorch we're testing currently much more than in TF mainly because we've allowed to return hidden_states
and attentios
. We could do the same for TF at some point
@@ -1288,6 +1290,29 @@ def adjust_logits_during_generation( | |||
else: | |||
return logits | |||
|
|||
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as for PyTorch (here), with self.forward
replaced with self.call
@@ -2702,8 +2702,8 @@ def test_constrained_beam_search_mixin_type_checks(self): | |||
model.generate(input_ids, force_words_ids=[[[-1]]]) | |||
|
|||
def test_validate_generation_inputs(self): | |||
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") | |||
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random") | |||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model is not relevant for the test, but not using a model from hf-internal-testing
was an oversight in the previous PR :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes def a good idea!
model_kwargs
on TF (and catch typos in generate arguments)
@@ -1483,6 +1508,9 @@ def _generate( | |||
# generate sequences without allowing bad_words to be generated | |||
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) | |||
```""" | |||
# 0. Validate model kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, fine with me! Can also increase all numbers otherwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clean-up
293120d
to
3eb5072
Compare
What does this PR do?
TF version of #18261
Adds
model_kwargs
validation to TFgenerate
, which also catches typos in the arguments. See the PR above for more details and an example of the error message users will see.Since TF had no dedicated file for
generate
tests, I took the liberty to create it and move some existing tests there (>70% of the diff is due to moving things around :) ). The test for this new check was also added there.