-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bnb
] Fix bnb skip modules
#24043
[bnb
] Fix bnb skip modules
#24043
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
) | ||
self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear)) | ||
self.assertTrue(isinstance(seq_classification_model.classifier.out_proj, nn.Linear)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also check at least one other layer not in llm_int8_skip_modules
is loaded in 8bit. Ideally one which will effectively check the recursion logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome yes agreed! Will add that now
seq_classification_model = AutoModelForSequenceClassification.from_pretrained( | ||
"roberta-large-mnli", quantization_config=quantization_config | ||
) | ||
self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my own understanding (not a comment to address), here we're checking the layers of the classifier are nn.Linear
. In test_linear_are_8bit
, we check that the layers are nn.Linear
too and that their dtype is torch.int8
(I didn't know this was possible!). Are we certain that this means these layers are loaded in correctly? Do we need a dtype check on the weights?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, we also need a dtype check on the weights! Linear8bitLt
has nn.Linear
as a super class. Adding new tests!
* fix skip modules test * oops * address comments
What does this PR do?
Fixes #24037
#23479 removed by mistake the logic introduced in #21579 to deal with modules that are not needed to be converted
The PR also adds a nice test to make sure this will never happen again