-
Notifications
You must be signed in to change notification settings - Fork 811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torchscriptable T5 generation #2146
Torchscriptable T5 generation #2146
Conversation
682b21a
to
2f1d38b
Compare
@Nayef211 Thoughts on the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Left a couple of nits and suggestions. Thanks for doing the additional work to make T5 torchscriptable!
torchtext/prototype/generate.py
Outdated
model_inputs = ( | ||
self.model.prepare_inputs_for_generation(input_ids, model_kwargs=model_kwargs) | ||
if torch.jit.is_scripting() | ||
else self._call_to_prepare_inputs_for_generation_with_kwargs(input_ids, model_kwargs) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we leave a comment here explaining what's going on? These helper method names are quite long and it's not immediately clear what the differences are in the model inputs based on if we're in script mode or eager mode.
@@ -78,10 +80,12 @@ class T5Bundle: | |||
def get_model( | |||
self, | |||
*, | |||
with_generation_utils: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add an explanation of this param to the docstring.
I think the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments.
+1 on @Nayef211 comment on naming non-script and script versions of preparation-for-decoding methods so that one can see that they do the same (up to scripting) at a glance.
model_kwargs = { | ||
"encoder_outputs": encoder_outputs, | ||
"encoder_padding_mask": encoder_padding_mask, | ||
"past_key_values": past, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very minor: rename past
into past_key_values
to follow other key-value naming ?
2f1d38b
to
09d56fd
Compare
@atalman @osalpekar Why is |
09d56fd
to
e020b62
Compare
e020b62
to
5ab9025
Compare
Makes
GenerationUtils
TorchScript-compatible w/ T5 model.This PR makes the following changes:
GenerationUtils
annn.Module
and TorchScript-compatiblewith_generation_utils
toT5Bundle.get_model()
that returns a T5 model wrapped in a TorchScript-compatibleGenerationUtils
classTesting:
GenerationUtils
get_model
withwith_generation_utils=True
and confirming that it is Torchscriptable and provides the correct results.