Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add module-swap UX for INT8 mixed-precision training #1179

Merged
merged 11 commits into from
Nov 7, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Oct 26, 2024

Background

The current INT8 mixed-precision training recipe conceptually is an op-modifier - torch.matmul is replaced with dynamic_int8_mm. It is implemented with tensor subclass, though it doesn't use any tensor subclass-specific features, such as quantized storage and quantized FSDP all-gather. Having an alternative module-swap UX would have the following benefits:

  1. State dict remains plain tensor. This is beneficial for model checkpointing, as well as some complex use cases such as shard-then-load pre-trained weights for FSDP fine-tuning (cc Integrate INT8 mixed-precision from torchao 0.7 torchtune#1552)
  2. It can be easier to hack/compose with other techniques. e.g. use NF4 weight for storage and INT8 matmul for compute -> QLoRA integration

Usage

from torchao import quantize_
from torchao.prototype.quantized_training import int8_mixed_precision_training

model = ...
## nn.Linear -> Int8MixedPrecisionTrainingLinear
quantize_(model, int8_mixed_precision_training(module_swap=True))

Benchmarks

Pre-train Llama2-1B on 4070Ti SUPER. torch==2.6.0.dev20241029. No regression. Module swap has the same perf as tensor subclass

python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize int8_mixed_precision_module_swap --model 1B --activation_checkpointing
image

Pre-train Llama3-8B with torchtitan, 2x A100, torch==2.6.0.dev20241104+cu124. No regression. Module swap has the same perf as tensor subclass

image

Fine-tune Llama3-1B QLoRA with torchtune (using pytorch/torchtune@main...gau-nernst:qlora)

image

Copy link

pytorch-bot bot commented Oct 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1179

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit ca8c85a with merge base f99b667 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 26, 2024
@vkuzo
Copy link
Contributor

vkuzo commented Oct 28, 2024

It is implemented with tensor subclass, though it doesn't use any tensor subclass-specific features, such as quantized storage and quantized FSDP all-gather. Having an alternative module-swap UX would have the following benefits:

You can still use tensor subclass inside of your module swap to interact with FSDP and TP/SP and get low precision all_gather. See Float8Linear for reference. I think this is a good way to go for training in general.

@gau-nernst
Copy link
Collaborator Author

You can still use tensor subclass inside of your module swap

The way I see it is that, whatever can be done with module swap, can be done with tensor subclass. (maybe it's better to hold persistent states with modules? like for delayed scaling. But at least for my use cases, I don't need persistent states). So using both module swap + tensor subclass feels redundant to me.

@gau-nernst gau-nernst marked this pull request as ready for review November 4, 2024 13:49
@gau-nernst gau-nernst requested review from msaroufim and andrewor14 and removed request for msaroufim November 4, 2024 13:50
@msaroufim msaroufim requested a review from vkuzo November 7, 2024 04:17
Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@gau-nernst gau-nernst merged commit e41ca4e into pytorch:main Nov 7, 2024
16 of 17 checks passed
@gau-nernst gau-nernst deleted the int8mp_module branch November 7, 2024 05:36
jainapurva pushed a commit that referenced this pull request Nov 7, 2024
* add module swap UX

* update

* fix typing. add small notes

* try NF4 support

* fix

* fix unpacking

* fix

* update nf4 integration

* update backward pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants