Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchscriptable T5 generation #2146

Merged
merged 3 commits into from
Apr 11, 2023

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented Apr 6, 2023

Makes GenerationUtils TorchScript-compatible w/ T5 model.

This PR makes the following changes:

  • Makes GenerationUtils an nn.Module and TorchScript-compatible
  • Modifes the output type of T5 (and variants) and has a simpler update function from encoder and decoder
  • Adds parameter with_generation_utils to T5Bundle.get_model() that returns a T5 model wrapped in a TorchScript-compatible GenerationUtils class

Testing:

  • Passing all old tests for T5 and GenerationUtils
  • Adds a new test for calling get_model with with_generation_utils=True and confirming that it is Torchscriptable and provides the correct results.

@joecummings
Copy link
Contributor Author

@Nayef211 Thoughts on the with_generation_utils on the model bundle versus having this as a separate class to import?

Copy link
Contributor

@Nayef211 Nayef211 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Left a couple of nits and suggestions. Thanks for doing the additional work to make T5 torchscriptable!

torchtext/prototype/generate.py Outdated Show resolved Hide resolved
Comment on lines 123 to 146
model_inputs = (
self.model.prepare_inputs_for_generation(input_ids, model_kwargs=model_kwargs)
if torch.jit.is_scripting()
else self._call_to_prepare_inputs_for_generation_with_kwargs(input_ids, model_kwargs)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave a comment here explaining what's going on? These helper method names are quite long and it's not immediately clear what the differences are in the model inputs based on if we're in script mode or eager mode.

torchtext/prototype/generate.py Outdated Show resolved Hide resolved
torchtext/prototype/generate.py Outdated Show resolved Hide resolved
torchtext/models/t5/modules.py Show resolved Hide resolved
torchtext/prototype/generate.py Outdated Show resolved Hide resolved
@@ -78,10 +80,12 @@ class T5Bundle:
def get_model(
self,
*,
with_generation_utils: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add an explanation of this param to the docstring.

@Nayef211
Copy link
Contributor

@Nayef211 Thoughts on the with_generation_utils on the model bundle versus having this as a separate class to import?

I think the with_generation_utils is probably a cleaner way to do it and reduces code duplication. I would just make sure to clearly document this parameter so users understand what it does and how it's used.

Copy link
Contributor

@yohann-benchetrit yohann-benchetrit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments.
+1 on @Nayef211 comment on naming non-script and script versions of preparation-for-decoding methods so that one can see that they do the same (up to scripting) at a glance.

torchtext/models/t5/bundler.py Show resolved Hide resolved
model_kwargs = {
"encoder_outputs": encoder_outputs,
"encoder_padding_mask": encoder_padding_mask,
"past_key_values": past,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor: rename past into past_key_values to follow other key-value naming ?

@joecummings joecummings force-pushed the torchscripting-gen-utils branch from 2f1d38b to 09d56fd Compare April 11, 2023 19:02
@joecummings
Copy link
Contributor Author

@atalman @osalpekar Why is torchdata no longer being picked up in smoke_tests.py?

@joecummings joecummings force-pushed the torchscripting-gen-utils branch from 09d56fd to e020b62 Compare April 11, 2023 19:26
@joecummings joecummings force-pushed the torchscripting-gen-utils branch from e020b62 to 5ab9025 Compare April 11, 2023 19:56
@joecummings joecummings merged commit 79100a6 into pytorch:main Apr 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants