Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable finetuning with torchao quantized model #33361

Merged
merged 2 commits into from
Sep 13, 2024
Merged

Enable finetuning with torchao quantized model #33361

merged 2 commits into from
Sep 13, 2024

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Sep 6, 2024

What does this PR do ?

This PR enable training with torchao quantized model. @BenjaminBossan conducted a few experiements with torchao + peft and it works out of the box for int8 quantizer. More details here.

@@ -166,7 +166,8 @@ def is_serializable(self):

@property
def is_trainable(self):
# torchao does not have official support for QAT (Quantization Aware Training)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we support both QAT and have experimental support for quantized training as well

Copy link
Member Author

@SunMarc SunMarc Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably create another propriety like is_qat_trainable and change this to is_peft_trainable to distinguish these two types of training.
Do you have a script for QAT training ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan
Copy link
Member

Should the method also check if PEFT is installed? I guess it's not strictly necessary, so this is also fine, just wondering. Also, it would be nice to get confirmation if others also run into problems with int4 or if it's just me.

@SunMarc
Copy link
Member Author

SunMarc commented Sep 6, 2024

Should the method also check if PEFT is installed? I guess it's not strictly necessary, so this is also fine, just wondering. Also, it would be nice to get confirmation if others also run into problems with int4 or if it's just me.

If one uses Trainer to train the quantized model, it will trigger an error in the trainer class if the user doesn't use peft. So I guess we don't need to add a warning there.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc SunMarc changed the title Eenable finetuning with torchao quantized model Enable finetuning with torchao quantized model Sep 6, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to merge @SunMarc

@SunMarc SunMarc merged commit 0963229 into main Sep 13, 2024
22 checks passed
@SunMarc SunMarc deleted the peft-torchao branch September 13, 2024 13:07
supported_quant_types_for_training = [
"int8_weight_only",
"int8_dynamic_activation_int8_weight",
]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry actually I think these configs do not support training currently, since the underlying tensor subclass does not pass gradients correctly. There is actually an ongoing effort from our side to support this. Can you confirm this @jerryzh168?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYM by not passing the gradient correctly ? cc @BenjaminBossan
Just to clarify, here we are performing peft fine-tuning, meaning that we are only training adapters (added linear layers) and freezing the other modules (quantized linear layers). Is the issue during the gradient calculation or on the update stage (needed for QAT but not for PEFT)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, thanks for the clarification. So we do not actually train these quantized linear layers, but we still need gradients to flow through them, is that correct? I think a potential issue from the torchao side is that the tensor subclass AffineQuantizedTensor currently explicitly does not require gradients. However, if we're just freezing these layers then it might be fine. Were you able to verify that the end-to-end PEFT accuracies are as expected?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that requires_grad is False should not be an issue, it's the same for other quantization methods. But I can investigate performance compared to, say, bnb. I'll check next week.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran a small experiment using LoRA for text classification with google/gemma-2-2b as the base model. Memory was simply measured by observing nvidia-smi.

Results for 8bit bnb:

epoch 1 | train loss 1.0293 | {'accuracy': 0.6397058823529411, 'f1': 0.7282809611829945}
epoch 2 | train loss 0.5860 | {'accuracy': 0.7622549019607843, 'f1': 0.8369747899159664}
epoch 3 | train loss 0.4509 | {'accuracy': 0.7941176470588235, 'f1': 0.8546712802768166}
epoch 4 | train loss 0.3845 | {'accuracy': 0.8112745098039216, 'f1': 0.8683760683760684}
epoch 5 | train loss 0.3431 | {'accuracy': 0.8186274509803921, 'f1': 0.8745762711864407}

Wall time: 4min 29s
memory: 21520MiB

Results for int8_weight_only torchao (notebook):

epoch 1 | train loss 1.0672 | {'accuracy': 0.6715686274509803, 'f1': 0.7751677852348994}
epoch 2 | train loss 0.6261 | {'accuracy': 0.7377450980392157, 'f1': 0.8201680672268907}
epoch 3 | train loss 0.4743 | {'accuracy': 0.7867647058823529, 'f1': 0.8502581755593803}
epoch 4 | train loss 0.4006 | {'accuracy': 0.803921568627451, 'f1': 0.8586572438162544}
epoch 5 | train loss 0.3585 | {'accuracy': 0.8235294117647058, 'f1': 0.8791946308724832}

Wall time: 2min 46s
memory: 18098MiB

Results for int8_dynamic_activation_int8_weight torchao (notebook):

epoch 1 | train loss 1.7618 | {'accuracy': 0.46568627450980393, 'f1': 0.5458333333333333}
epoch 2 | train loss 1.1905 | {'accuracy': 0.5245098039215687, 'f1': 0.6325757575757576}
epoch 3 | train loss 1.1478 | {'accuracy': 0.5318627450980392, 'f1': 0.6456400742115028}
epoch 4 | train loss 1.1384 | {'accuracy': 0.5367647058823529, 'f1': 0.6506469500924215}
epoch 5 | train loss 1.1365 | {'accuracy': 0.5367647058823529, 'f1': 0.6506469500924215}

Wall time: 4min 2s
memory: 4122MiB

So int8_weight_only compares quite favorably to bnb 8bit, as the scores are very close but torchao is faster and requires a little bit less memory.

int8_dynamic_activation_int8_weight is absolutely great when it comes to memory (~3.2 GB for the model itself, i.e. only 1 GB for hidden states etc.) while still being reasonably fast. However, the scores are considerably worse. Not sure if that's expected or if I should use different settings/params or what.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried increasing the learning rate for int8_dynamic_activation_int8_weight. With 10x the learning rate, I could get a final score of:

epoch 5 | train loss 0.6309 | {'accuracy': 0.6985294117647058, 'f1': 0.7890222984562607}

Still not as good as the other runs but a significant improvement. Is it expected that int8_dynamic_activation_int8_weight requires different hyper-parameters compared to int8_weight_only?

itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Oct 2, 2024
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants