-
Notifications
You must be signed in to change notification settings - Fork 811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Beginning of generation utils and necessary refactors of T5 Model #2011
Conversation
@@ -0,0 +1,83 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we verify completeness often w/ internal notebooks - I thought for those that show parity with HuggingFace or external libraries, we could put those notebooks in the actual repo. Seems like a better way to keep track rather than some Bento notebooks w/ scattered ownership.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you upload the notebook to Github gist and provide a link in the PR so it's easier to review the contents?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but for a quick fix you can right click on expand dots on the top right of this file and select "View file" and it'll give you a notebook view.
|
||
|
||
@dataclass | ||
@dataclass(frozen=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Freezing this as we probably don't want people to be able to overwrite configs and still try to use the model - much more likely to run into bugs that way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if someone wants to experiment with a smaller model or modified architecture? Are there distilled or smaller T5 models out there? We don't freeze other configs so I am not sure I agree with this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Freezing the config won't make it impossible to try a smaller model or modified architecture. It just means that once they instantiate the config and pass the config to the model, they won't be able to modify it.
Example:
config = T5Config(encoder_only=True)
t5_model = T5Model(config=config)
t5_model.config.encoder_only = False # Currently allowed; with freezing config, this would throw an error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to follow up here, in the example you just showed, would it affect the model behavior if users did end up changing the config after instantiating the model? IIUC the config is only used during model instantiation anyways. That being said I don't see any issues with freezing the config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It wouldn't affect the model behavior, but it would throw an error saying "Config cannot be modified", which I think is what we want. It would be considered undefined behavior if someone e.g. instantiated a model without a decoder and then went back and changed the config to say that it did have a decoder.
def forward( | ||
self, | ||
encoder_tokens: Tensor, | ||
encoder_tokens: Optional[Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't have to include encoder_tokens
if encoder_outputs
are already provided.
|
||
def forward( | ||
self, | ||
tgt: Tensor, | ||
tgt: Optional[Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the forward of the entire model, if the target is already embedded, no need to inlcude the raw tokenized tgt.
|
||
if self.is_encoder_decoder: | ||
encoder = self.model.get_encoder() | ||
model_kwargs["encoder_outputs"] = encoder(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be the necessary args for the forward
method of whatever model is being used in decoding.
from torch import nn | ||
|
||
|
||
class GenerationUtil: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this whole class have to be torchscriptable, as well??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would make it extremely difficult to incorporate other models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this whole class have to be torchscriptable, as well??
If we expect that this util will be used in Predictor during inference time then yes it does. Can you explain what makes it difficult to make this torchscriptable.
As a first step, we can always implement this without torchscriptability support for customers to experiment with. And if there's enough demand to make it torchscriptable then we can come back and add this support.
|
||
|
||
@dataclass | ||
@dataclass(frozen=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if someone wants to experiment with a smaller model or modified architecture? Are there distilled or smaller T5 models out there? We don't freeze other configs so I am not sure I agree with this
self, batch_size: int, device: Optional[torch.device] = None, **model_kwargs | ||
): | ||
if model_kwargs is not None and "decoder_input_ids" in model_kwargs: | ||
return model_kwargs.pop("decoder_input_ids") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why pass around model_kwargs dict instead of just having an optional decoder_input_ids param?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
@@ -0,0 +1,83 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you upload the notebook to Github gist and provide a link in the PR so it's easier to review the contents?
device: Optional[torch.device] = None, | ||
dtype=None, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.token_embeddings = token_embeddings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a description of this input argument to the docstring above?
from torch import nn | ||
|
||
|
||
class GenerationUtil: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this whole class have to be torchscriptable, as well??
If we expect that this util will be used in Predictor during inference time then yes it does. Can you explain what makes it difficult to make this torchscriptable.
As a first step, we can always implement this without torchscriptability support for customers to experiment with. And if there's enough demand to make it torchscriptable then we can come back and add this support.
self, batch_size: int, device: Optional[torch.device] = None, **model_kwargs | ||
): | ||
if model_kwargs is not None and "decoder_input_ids" in model_kwargs: | ||
return model_kwargs.pop("decoder_input_ids") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_idx).long()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add an explanation for this line? Having a hard time following the logic. Alternatively let's add a couple of lines to the docstring of this method explaining the approach.
|
||
|
||
@dataclass | ||
@dataclass(frozen=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to follow up here, in the example you just showed, would it affect the model behavior if users did end up changing the config after instantiating the model? IIUC the config is only used during model instantiation anyways. That being said I don't see any issues with freezing the config.
self.norm = T5LayerNorm(d_model) | ||
self.dropout1 = nn.Dropout(dropout) | ||
self.dropout2 = nn.Dropout(dropout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did we decide to move this from the model to the encoder/decoder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeps the entire encoder forward method self-contained.
b9c54a0
to
4172133
Compare
4172133
to
b699de2
Compare
183b80c
to
b699de2
Compare
f9d8274
to
6860a30
Compare
@Nayef211 Can I get some 👀 on this again when you have a chance? |
@atalman @osalpekar Is this failing integration test related to Nova migration? The process seems to be killed with no helpful error and the integration tests pass on my local machine. |
@joecummings looks like integration tests are running out of memory, code 137: 3796110184 |
Silly follow-up question, but how would I go about allocating more memory for these integration tests? |
13e9ce4
to
25f18ef
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great to have sampling as part of the library!
|
||
return input_ids | ||
|
||
def beam_search(self, input_ids: torch.Tensor, num_beams: int, max_len: Optional[int]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we put all the sampling methods in this class (beam_search, greedy), is that very extensible for the user? Or should these be separate classes that inherit from GenerationUtil or a general Sampler class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idk about inheriting from GenerationUtil
, but as a standalone class, this makes sense.
@@ -176,7 +176,8 @@ def build_model_from_huggingface_ckpt( | |||
|
|||
t5_model_state_dict = { | |||
"token_embeddings.weight": hf_weights["shared.weight"], | |||
"norm1.weight": hf_weights["encoder.final_layer_norm.weight"], | |||
"encoder.token_embeddings.weight": hf_weights["shared.weight"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anyway to bundle Generation parameters with a model so the user doesn't have to know the correct sampling defaults for a given model?
@@ -47,6 +48,8 @@ def __init__( | |||
strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: `False`) | |||
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: `None`) | |||
""" | |||
warnings.warn("`T5Wrapper` is being deprecated. Please use new `GenerationUtils`.", category=DeprecationWarning) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GenerationUtils, if made in an nn.Module, could be treated as a generic wrapper for any LLM. This might be easier for the user but would break from the Huggingface design. It would allow for generation parameters to be saved with the model.
Context
We aim to add generation utils that support a number of encoder/decoder and decoder-based models. To do so, we also have to rework our current encoder/decoder model, T5.
Changes
Separated logic for encoder and decoder into self-contained
nn.Modules
.1a. Move dropout layers and norms to
T5Encoder
andT5Decoder
1b. Pass
token_embeddings
to the encoder if constructed through theT5Model
. Now the encoder can take in tokenized text or embedded text.1c. Add
get_encoder
andget_decoder
getter functions (not torchscriptable ATM)1d. Update type annotations to allow for padding_masks and encoder_outputs
1e. Change
T5Encoder
andT5Decoder
return types to dictionariesAdded
GenerationUtils
class andgreedy_search
generation technique2a. Added deprecation warning to
T5Wrapper
untilbeam_search
is added.Froze configs to avoid mutating the model unnecessarily
Testing
GenerationUtil
Notes
T5Wrapper
is no longer Torchscriptable