-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate INT8 mixed-precision from torchao 0.7 #1552
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1552
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its very simple and it looks great! thanks for the PR!
My two cents:
- We need some tests to make sure it works with compile, AC, offloading (not landed yet), optimizer in backward (i guess?), etc
- The configs can be updated in bulk using something like this dummy script: https://gist.github.com/felipemello1/5f2002433c6da3a21f33d6cdf82e702a
Let me know if you want me to help with any of these
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: would it be nice to add some sort of import guard to tell the user they need torchao > N for this? Torchao is not a requirement anymore
cc: @ebsmothers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work with older GPUs? Does it work on CPU? Maybe we need something like this:
_SUPPORTS_INT8_MIXED_PRECISION = (
torch_version_ge("2.5.0")
and torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (7, 5)
)
and torch.cuda.get_device_capability() >= (7, 5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! I use Triton for this so it will probably run on any GPUs Triton supports (same as torch.compile). Though I think only Ampere (sm80) and above has INT8 tensor cores. To be safe, I think we just guard for sm80 and above.
CPU is not supported. Technically it is possible, but I didn't add it in torchao since I can't reliably test it / see it useful.
This works with PyTorch 2.4 (though FlexAttention requires PyTorch 2.5 🤣).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw doesn't QAT also need some kind of guards like this? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I completely missed this convo previously. Actually now that we've asked users to install torchao manually I am kinda taking a similar stance there to what we have with PyTorch: people should always be running on the latest stable version. So actually I claim that the first two lines of the _SUPPORTS_INT8_MIXED_PRECISION_TRAINING
check are not strictly necessary (fine to keep them in though, I don't have a strong preference). But anyways that should hopefully answer the question about why we don't have similar such guards for QAT
@ebsmothers @weifengpy I'm not too keen on adding the hack (i.e. manually shard original weight and construct tensor subclass from it) inside If we add more training recipes that are implemented with tensor subclass, such as FP8 with FP8 all-gather, torchtune would need to again manually add more logic in Also, just to clarify, this PR will use Actually this can be implemented as module-swap too. Since it doesn't implement quantized all-gather yet (only modify forward and backward pass), module-swap will save troubles for torchtune -> state dict is still pure tensor. (though iirc, module-swap was slower than tensor subclass implementation due to some torch.compile quirks). And to clarify, today, INT8 mixed-precision training in torchao is only implemented as tensor subclass, though I did briefly try module-swap implementation in the past. |
totally agree that float8 and INT8 can be done in a cleaner way. NF4 is unique in 3 ways : 1) NF4 has double-quantization that requires computation over global tensors. quantize-then-chunk is much easier than chunk-then-quantize, 2) NF4 is static quantizaiton and thus lost high precision tensor and stay quantized in the training loop, 3) NF4 dispatch had a hack on
great insights! |
Thanks @gau-nernst and @weifengpy for the helpful discussion here. I guess we don't need to block this PR on figuring out the plan for enabling in distributed recipes. But for the sake of my understanding, let me summarize the options that have been discussed:
It seems like (1)+(a) is unclear, and (3) is not easy to generalize. Also I imagine (4) is a nontrivial effort. So then it is either the vanilla version of (1) or (2). It seems to me like (2) should be similar to the memory and speed of the current |
@ebsmothers (4) is actually somewhat trivial, but I think it's best to be done in torchao instead of torchtune. Maybe I will add module-swap UX to torchao first before continuing with this PR (more about torchao problem, I'm not sure if supporting 2 UX flows is the wise thing to do. But perhaps since it's experimental, having more UX options is good to evaluate which flow is better for downstream users like torchtune). |
@felipemello1 @ebsmothers With pytorch/ao#1179, we can revisit this PR again. Do you mind resolve previous conversations if they are indeed resolved, so it's less cluttered? And i can address the unresolved ones. I have updated the PR description. You can see it for new changes + benchmarks. In summary:
|
great work, @gau-nernst! @mori360 this module swap UX and torchtune example should help with torchtitan recipe as well. |
torchtune/training/quantization.py
Outdated
|
||
_SUPPORTS_INT8_MIXED_PRECISION_TRAINING = ( | ||
torch_version_ge("2.4.0") | ||
and Version(torchao.__version__) >= Version("0.7") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparently I need to add .dev
https://github.com/pytorch/ao/blob/2ba1a61fe1244560325b5051b5d3c10044553be0/torchao/utils.py#L529-L532
I was installing torchao from source, which has version number as 0.7.0+gitxxxxx
, and the comparison works as expected. Will push the change.
grad_weight=grad_weight, | ||
) | ||
|
||
def prepare(self, model: nn.Module) -> nn.Module: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neither 3b or 8b full single device are not working for me using A100
torch 2.6.0.dev20241107+cu124
torchao 0.7.0.dev20241111+cu124
torchtune 0.0.0 /data/users/felipemello/torchtune
torchvision 0.20.0.dev20241107+cu124
I also tried with torch 2.5.1.
I think that the issue is with recompiling. It takes several minutes on this step (10+min). I wonder its recompiling non-stop, until it runs out of memory and errors with RuntimeError: std::bad_alloc
8b
File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 649, in forward
output = self.chunked_output(h)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 501, in chunked_output
return [
^
File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 502, in <listcomp>
self.output(chunk)
File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: std::bad_alloc
3b
File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 649, in forward
output = self.chunked_output(h)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 501, in chunked_output
return [
^
File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 502, in <listcomp>
self.output(chunk)
File "/data/users/felipemello/torchtune/torchtune/modules/tied_linear.py", line 66, in __call__
return self.linear(x, self.tied_module.weight)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/felipemello/torchtune/torchtune/modules/tied_linear.py", line 27, in forward
return F.linear(x, weight)
^^^^^^^^^^^^^^^^^^^
RuntimeError: std::bad_alloc
PS: for 3b we use TiedLinear implementation:
class TiedLinear: |
I printed the shape and dtype, not that it matters
torch.Size([5, 256, 3072]) torch.bfloat16
torch.Size([128256, 3072]) torch.bfloat16
I also checked for 3b, and filter_fn returns false for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding std::bad_alloc
, I think it's an issue with recent pytorch nightlies. ao CI has been seeing it too pytorch/ao#1229. Currently CI pins to 20241101 https://github.com/pytorch/ao/blob/b2642fb33e360ffe478fe19665b1c4efd80537c6/.github/workflows/regression_test.yml#L64. Can you try with this version?
I don't exactly recall, but I think usually I had to kill the first run (it kinda hangs?), then the subsequent run is fast. Didn't think much of it. I will try to reproduce it again in a fresh environment to see if my memory serves me right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I additionally set quantize_(set_inductor_config=False)
to avoid exhaustive torch.compile tuning. Can you try again? It seemed fast to me (<2min). Worked with both Llama3.1-8B and Llama3.2-3B. torch==2.5.1+cu121
and torchao==0.7.0.dev20241112+cu121
The command
tune run full_finetune_single_device --config llama3_2/3B_full_single_device dataset.packed=True tokenizer.max_seq_len=8192 optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True mixed_precision._component_=torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer mixed_precision.enabled=True
@@ -144,6 +166,84 @@ def quantize(self, model): | |||
] = enable_8da4w_fake_quant_module_swap | |||
|
|||
|
|||
class Int8MixedPrecisionTrainingQuantizer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is my understanding:
-
torchao has a nice api that does _quantize(model, config)
-
This pr creates a wrapper specific for int8. The wrapper is needed for two reasons:
i) ignora lora_a/lora_b
ii) we dont do nested configs -
this means that if we want to support fp8. bitnet, or any other torchao technique, we have to create a custom wrapper, instead of doing _quantize(model, torchao_config)
IMO, this makes a lot of sense if:
a) every torchao technique interacts differently with torchtune, e.g. one doesnt work with offloading, another needs some extra work for ckpt, etc, and we cant solve it with a config parser
b) if there are realistically only a couple or two quantization methods we will use from torchao (e.g. int8 and fp8)
But if thats not the case, then we should probably avoid having a custom torchtune wrapper per torchao technique.
I guess we had a similar situation with OptimizerCPUOffload. We ended up giving up on a custom torchtune wrapper and just instantiated directly from the config.
In summary, instead of "Int8MixedPrecisionTrainingQuantizer", should we have a generalist torchtune "PrecisionTrainingQuantizer", just to handle LoRA, etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think PrecisionTrainingQuantizer
is not needed atm. We can revisit it in the future once there are more things like this in the future (like you said, realistically I can only think of int8 and fp8 for now, but it may change)
# mixed precision (disabled) | ||
mixed_precision: | ||
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer | ||
enabled: false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for adding "enabled"!
@gau-nernst , can you help me understand why memory is not changed? Do we do a bf16->int8->bf16 projection inplace? |
Memory is not expected to change. In this scheme, weights and activations are quantized dynamically to INT8 (with scaling, so that we can do INT8 matmul), then the result is scaled back to BF16. Activations and weights are still in BF16 throughout the model. This is the same strategy as the current To reduce memory, either weight or activations must stay in low-precision. Using only low-bit weight for training is a bit challenging (not impossible, but there will be convergence/accuracy issues). There are some research works on using low-bit activations (FP8/INT8), but we don't have that in torchao yet. e.g.
|
Context
What is the purpose of this PR? Is it to
Recent INT8 mixed-precision work in torchao shows very promising results.
Known major limitations
torch.compile()
to enjoy speedup (to codegen efficient dynamic quantization code)seq_len
is static.Does not work with-> solved by using module-swap UX instead. Pending Add module-swap UX for INT8 mixed-precision training ao#1179training.load_from_full_model_state_dict()
-> cannot integrate with distributed recipes atm.See https://github.com/pytorch/ao/tree/v0.5.0/torchao/prototype/quantized_training#int8-mixed-precision for more details.
For now, I only added the code to show the necessary changes. I'm open to suggestions on how to expose this in torchtune. One idea from mine:
int8_mixed_precision
(similar tocompile
flag). This will be a boolean_setup_model()
-> repeated code for each recipe-> UPDATE: from previous feedback, add a new flag
mixed_precision
Some concerns:
Int8MixedPrecisionTrainingConfig
(see doc). Should we expose it to torchtune's users? From my testing, the default config works well. There might be more knobs to customize in the future too.Int8MixedPrecisionTrainingQuantizer
quantize_()
API, though they can be re-implemented to do so).LoRALinear
will always callF.linear()
on the NF4 weight. If we make the base weight inLoRALinear
a separatenn.Linear
module (instead of plainnn.Parameter()
, then we can swap the linear module to change the outer op.These concerns can be addressed in the future I think, when torchao's training subclasses become more mature/stable.
Note: I can't test with 4090 since FlexAttention errors out on 4090
It's pretty strange since it works fine for another repo of mine 🤔.
Changelog
What are the changes made in this PR?
Integrate INT8 mixed-precision from torchao
0.50.7Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models
Llama3.1-8B single device A100 40% speedup. torch=2.5.0.dev20240911, torchao=0.5.0
Llama3.1-8B FSDP2 2x A100 24% speedup. torch=2.5.1, pytorch/ao#1179
Llama3.1-8B single device A100 LoRA 50% speedup. torch==2.6.0.dev20240914, torchao=0.5.0
LLama3.2-1B single device 4070Ti SUPER QLoRA 60% speedup. torch==2.6.0.dev20241102+cu124, pytorch/ao#1179. Proof-of-concept only since it requires quite significant changes to
LoRALinear
class. See main...gau-nernst:qlora