Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: validate model_kwargs on TF (and catch typos in generate arguments) #18651

Merged
merged 5 commits into from
Sep 2, 2022

Conversation

gante
Copy link
Member

@gante gante commented Aug 16, 2022

What does this PR do?

TF version of #18261

Adds model_kwargs validation to TF generate, which also catches typos in the arguments. See the PR above for more details and an example of the error message users will see.

Since TF had no dedicated file for generate tests, I took the liberty to create it and move some existing tests there (>70% of the diff is due to moving things around :) ). The test for this new check was also added there.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 16, 2022

The documentation is not available anymore as the PR was closed or merged.

class UtilsFunctionsTest(unittest.TestCase):

# tests whether the top_k_top_p_filtering function behaves as expected
def test_top_k_top_p_filtering(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved from test_modeling_tf_common, no changes

class TFGenerationIntegrationTests(unittest.TestCase):

@slow
def test_generate_tf_function_export(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved from test_modeling_tf_common, added the @slow (takes >30s)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool that makes sense (similar to PyTorch).
Also just FYI in PyTorch we're testing currently much more than in TF mainly because we've allowed to return hidden_states and attentios. We could do the same for TF at some point

@@ -1288,6 +1290,29 @@ def adjust_logits_during_generation(
else:
return logits

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as for PyTorch (here), with self.forward replaced with self.call

@@ -2702,8 +2702,8 @@ def test_constrained_beam_search_mixin_type_checks(self):
model.generate(input_ids, force_words_ids=[[[-1]]])

def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model is not relevant for the test, but not using a model from hf-internal-testing was an oversight in the previous PR :D

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes def a good idea!

@gante gante marked this pull request as ready for review August 16, 2022 14:17
@gante gante changed the title Generate: validate model_kwargs on TF (and catch typos in generate arguments) Generate: validate model_kwargs on TF (and catch typos in generate arguments) Aug 16, 2022
@@ -1483,6 +1508,9 @@ def _generate(
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
```"""
# 0. Validate model kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, fine with me! Can also increase all numbers otherwise

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean-up

@gante gante force-pushed the tf_generate_kwarg_valid branch from 293120d to 3eb5072 Compare September 2, 2022 13:12
@gante gante merged commit 9196f48 into huggingface:main Sep 2, 2022
@gante gante deleted the tf_generate_kwarg_valid branch September 2, 2022 15:25
oneraghavan pushed a commit to oneraghavan/transformers that referenced this pull request Sep 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants