Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize T5 for sequence generation #2054

Merged
merged 10 commits into from
Feb 17, 2023
Merged

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented Feb 11, 2023

This PR makes the following changes to T5 to improve generation capabilities.

  • Adds prepare_inputs_for_generation function to be compliant w/ GenerationWrapper API
  • Adds get_encoder and get_decoder helper functions.
  • Utilizes past_key_values to implement incremental decoding. This involves also a custom reorder cache function that can be used for beam search.
  • Updates docstrings.
  • Updates model weights to fit new architecture.
  • Remove T5Wrapper
  • Fix TorchScripting w/ new APIs

Testing:

  • Passes existing tests
  • Tests to come w/ generation integration tests (link)

Todo:

  • Add license to files
  • Move to main folder
  • Add model to README

@joecummings joecummings force-pushed the optimize-t5 branch 3 times, most recently from daa19d7 to 73465d8 Compare February 17, 2023 02:10
@joecummings joecummings changed the title Optimize t5 Optimize T5 for sequence generation Feb 17, 2023
@joecummings joecummings marked this pull request as ready for review February 17, 2023 06:00
Copy link
Contributor

@mthrok mthrok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp

@@ -79,7 +76,7 @@ def _t5_get_encoder(self, model, model_input, encoder_output):
encoder = model.get_encoder()
# Need to set the tgt_key_padding_mask to ensure the same results
encoder_padding_mask = model_input.eq(model.padding_idx)
output_from_get_encoder = encoder(tgt=model_input, tgt_key_padding_mask=encoder_padding_mask)["encoder_output"]
output_from_get_encoder = encoder(model_input, src_key_padding_mask=encoder_padding_mask)["encoder_output"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change, is it using different set of existing arguments or changing the name of the arguments?
If changing the name of the arguments, that's BC-breaking unless it's prototype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing the name of arguments, but yes this is prototype until tomorrow :)

@@ -56,13 +55,13 @@ def __post_init__(self):
self.activation = "gelu_new"


# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L1269
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this context no-longer applicable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We include in the header that several functions are based on HF and I call it out in the docstring of those functions, as well. No need to say that there is a comparable HF implementation for ones that are just the normal Enc/Dec forward functions.

@torch.jit.export
def _reorder_cache(
self, past: List[Tuple[Tensor, Tensor, Tensor, Tensor]], beam_idx: Tensor
) -> List[Tuple[Tensor, Tensor, Tensor, Tensor]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice if there is a comment/docstring of why and what, for the future developer.

for layer_past_states in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List would be semantically better, but is it for TorchScript compaibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup :(

Comment on lines 195 to 203
) -> Dict[
str,
Union[
Tensor,
Dict[str, Union[Optional[Tensor], List[Tensor], List[Optional[Tensor]]]],
Optional[List[Tuple[Tensor, Tensor, Tensor, Tensor]]],
bool,
],
]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This annotation is complex and it seems to be repeated. Can we define variable to store the annotation?

@joecummings joecummings merged commit 19f8bc9 into pytorch:main Feb 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants