-
Notifications
You must be signed in to change notification settings - Fork 811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize T5 for sequence generation #2054
Conversation
daa19d7
to
73465d8
Compare
e807b7e
to
d0e866e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp
@@ -79,7 +76,7 @@ def _t5_get_encoder(self, model, model_input, encoder_output): | |||
encoder = model.get_encoder() | |||
# Need to set the tgt_key_padding_mask to ensure the same results | |||
encoder_padding_mask = model_input.eq(model.padding_idx) | |||
output_from_get_encoder = encoder(tgt=model_input, tgt_key_padding_mask=encoder_padding_mask)["encoder_output"] | |||
output_from_get_encoder = encoder(model_input, src_key_padding_mask=encoder_padding_mask)["encoder_output"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change, is it using different set of existing arguments or changing the name of the arguments?
If changing the name of the arguments, that's BC-breaking unless it's prototype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changing the name of arguments, but yes this is prototype until tomorrow :)
@@ -56,13 +55,13 @@ def __post_init__(self): | |||
self.activation = "gelu_new" | |||
|
|||
|
|||
# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L1269 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this context no-longer applicable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We include in the header that several functions are based on HF and I call it out in the docstring of those functions, as well. No need to say that there is a comparable HF implementation for ones that are just the normal Enc/Dec forward functions.
@torch.jit.export | ||
def _reorder_cache( | ||
self, past: List[Tuple[Tensor, Tensor, Tensor, Tensor]], beam_idx: Tensor | ||
) -> List[Tuple[Tensor, Tensor, Tensor, Tensor]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice if there is a comment/docstring of why and what, for the future developer.
for layer_past_states in past: | ||
# get the correct batch idx from layer past batch dim | ||
# batch dim of `past` is at 2nd position | ||
reordered_layer_past_states = () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
List would be semantically better, but is it for TorchScript compaibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup :(
) -> Dict[ | ||
str, | ||
Union[ | ||
Tensor, | ||
Dict[str, Union[Optional[Tensor], List[Tensor], List[Optional[Tensor]]]], | ||
Optional[List[Tuple[Tensor, Tensor, Tensor, Tensor]]], | ||
bool, | ||
], | ||
]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This annotation is complex and it seems to be repeated. Can we define variable to store the annotation?
This PR makes the following changes to T5 to improve generation capabilities.
prepare_inputs_for_generation
function to be compliant w/GenerationWrapper
APIget_encoder
andget_decoder
helper functions.past_key_values
to implement incremental decoding. This involves also a custom reorder cache function that can be used for beam search.T5Wrapper
Testing:
Todo: