-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable finetuning with torchao quantized model #33361
Conversation
@@ -166,7 +166,8 @@ def is_serializable(self): | |||
|
|||
@property | |||
def is_trainable(self): | |||
# torchao does not have official support for QAT (Quantization Aware Training) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we support both QAT and have experimental support for quantized training as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably create another propriety like is_qat_trainable
and change this to is_peft_trainable
to distinguish these two types of training.
Do you have a script for QAT training ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for delay but here you go https://github.com/pytorch/ao/tree/main/torchao/quantization/prototype/qat also cc @andrewor14
Should the method also check if PEFT is installed? I guess it's not strictly necessary, so this is also fine, just wondering. Also, it would be nice to get confirmation if others also run into problems with int4 or if it's just me. |
If one uses Trainer to train the quantized model, it will trigger an error in the trainer class if the user doesn't use peft. So I guess we don't need to add a warning there. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to merge @SunMarc
supported_quant_types_for_training = [ | ||
"int8_weight_only", | ||
"int8_dynamic_activation_int8_weight", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry actually I think these configs do not support training currently, since the underlying tensor subclass does not pass gradients correctly. There is actually an ongoing effort from our side to support this. Can you confirm this @jerryzh168?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYM by not passing the gradient correctly ? cc @BenjaminBossan
Just to clarify, here we are performing peft fine-tuning, meaning that we are only training adapters (added linear layers) and freezing the other modules (quantized linear layers). Is the issue during the gradient calculation or on the update stage (needed for QAT but not for PEFT)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, thanks for the clarification. So we do not actually train these quantized linear layers, but we still need gradients to flow through them, is that correct? I think a potential issue from the torchao side is that the tensor subclass AffineQuantizedTensor
currently explicitly does not require gradients. However, if we're just freezing these layers then it might be fine. Were you able to verify that the end-to-end PEFT accuracies are as expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fact that requires_grad
is False
should not be an issue, it's the same for other quantization methods. But I can investigate performance compared to, say, bnb. I'll check next week.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran a small experiment using LoRA for text classification with google/gemma-2-2b
as the base model. Memory was simply measured by observing nvidia-smi
.
Results for 8bit bnb:
epoch 1 | train loss 1.0293 | {'accuracy': 0.6397058823529411, 'f1': 0.7282809611829945}
epoch 2 | train loss 0.5860 | {'accuracy': 0.7622549019607843, 'f1': 0.8369747899159664}
epoch 3 | train loss 0.4509 | {'accuracy': 0.7941176470588235, 'f1': 0.8546712802768166}
epoch 4 | train loss 0.3845 | {'accuracy': 0.8112745098039216, 'f1': 0.8683760683760684}
epoch 5 | train loss 0.3431 | {'accuracy': 0.8186274509803921, 'f1': 0.8745762711864407}
Wall time: 4min 29s
memory: 21520MiB
Results for int8_weight_only
torchao (notebook):
epoch 1 | train loss 1.0672 | {'accuracy': 0.6715686274509803, 'f1': 0.7751677852348994}
epoch 2 | train loss 0.6261 | {'accuracy': 0.7377450980392157, 'f1': 0.8201680672268907}
epoch 3 | train loss 0.4743 | {'accuracy': 0.7867647058823529, 'f1': 0.8502581755593803}
epoch 4 | train loss 0.4006 | {'accuracy': 0.803921568627451, 'f1': 0.8586572438162544}
epoch 5 | train loss 0.3585 | {'accuracy': 0.8235294117647058, 'f1': 0.8791946308724832}
Wall time: 2min 46s
memory: 18098MiB
Results for int8_dynamic_activation_int8_weight
torchao (notebook):
epoch 1 | train loss 1.7618 | {'accuracy': 0.46568627450980393, 'f1': 0.5458333333333333}
epoch 2 | train loss 1.1905 | {'accuracy': 0.5245098039215687, 'f1': 0.6325757575757576}
epoch 3 | train loss 1.1478 | {'accuracy': 0.5318627450980392, 'f1': 0.6456400742115028}
epoch 4 | train loss 1.1384 | {'accuracy': 0.5367647058823529, 'f1': 0.6506469500924215}
epoch 5 | train loss 1.1365 | {'accuracy': 0.5367647058823529, 'f1': 0.6506469500924215}
Wall time: 4min 2s
memory: 4122MiB
So int8_weight_only
compares quite favorably to bnb 8bit, as the scores are very close but torchao is faster and requires a little bit less memory.
int8_dynamic_activation_int8_weight
is absolutely great when it comes to memory (~3.2 GB for the model itself, i.e. only 1 GB for hidden states etc.) while still being reasonably fast. However, the scores are considerably worse. Not sure if that's expected or if I should use different settings/params or what.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried increasing the learning rate for int8_dynamic_activation_int8_weight
. With 10x the learning rate, I could get a final score of:
epoch 5 | train loss 0.6309 | {'accuracy': 0.6985294117647058, 'f1': 0.7890222984562607}
Still not as good as the other runs but a significant improvement. Is it expected that int8_dynamic_activation_int8_weight
requires different hyper-parameters compared to int8_weight_only
?
What does this PR do ?
This PR enable training with torchao quantized model. @BenjaminBossan conducted a few experiements with torchao + peft and it works out of the box for int8 quantizer. More details here.