-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mamba & RecurrentGemma: enable strict signature #31549
Conversation
@@ -545,7 +545,6 @@ def forward( | |||
use_cache: Optional[bool] = None, | |||
output_hidden_states: Optional[bool] = None, | |||
return_dict: Optional[bool] = None, | |||
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alternatively, we can accept attention_mask
and raise an exception when it is not None
or not all ones
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's googoogogogogo 🚀
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) | ||
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yesssss I think I have a PR open where I dod this! Finally!
@@ -545,7 +545,6 @@ def forward( | |||
use_cache: Optional[bool] = None, | |||
output_hidden_states: Optional[bool] = None, | |||
return_dict: Optional[bool] = None, | |||
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing this will break FDSP :( See #31161
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts I had a look and it should be fine: this PR removes **kwargs
from the model class (e.g. MambaModel
), while the FSDP PR ensures there are **kwargs
in the decoder layers (e.g. FalconDecoderLayer
).
We can see on main
that the model themselves don't have **kwargs
, even after the FSDP fix (e.g. llama) 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK!
What does this PR do?
Fixes #31540
Mamba accepts
**kwargs
, and thusattention_mask
can be passed. Many users thus assume it behaves just like other models and can support left-padding.RecurrentGemma also accept
**kwargs
, but simply not to crashgenerate
.This PR enables a strict signature on Mamba and RecurrentGemma.