Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate INT8 mixed-precision from torchao 0.7 #1552

Open
wants to merge 41 commits into
base: main
Choose a base branch
from

Conversation

gau-nernst
Copy link
Contributor

@gau-nernst gau-nernst commented Sep 12, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Recent INT8 mixed-precision work in torchao shows very promising results.

  • Single device A100 -> ~40% speedup
  • Single device 4090 -> ~70% speedup (consumer 4000 series GPUs have unusual speedup, which is nice)
  • Works with FSDP2

Known major limitations

  • Requires torch.compile() to enjoy speedup (to codegen efficient dynamic quantization code)
  • Input sizes should not vary too much, since it will trigger autotune for triton INT8 matmul kernel -> this only works well for PackedDataset w/ FlexAttention, since seq_len is static.
  • Does not work with training.load_from_full_model_state_dict() -> cannot integrate with distributed recipes atm. -> solved by using module-swap UX instead. Pending Add module-swap UX for INT8 mixed-precision training ao#1179

See https://github.com/pytorch/ao/tree/v0.5.0/torchao/prototype/quantized_training#int8-mixed-precision for more details.

For now, I only added the code to show the necessary changes. I'm open to suggestions on how to expose this in torchtune. One idea from mine:

  • Add a global config flag int8_mixed_precision (similar to compile flag). This will be a boolean
  • Handle it inside _setup_model() -> repeated code for each recipe
    -> UPDATE: from previous feedback, add a new flag mixed_precision

Some concerns:

  • It's possible to customize INT8 mixed-precision via Int8MixedPrecisionTrainingConfig (see doc). Should we expose it to torchtune's users? From my testing, the default config works well. There might be more knobs to customize in the future too.
    • UPDATE: expose all options via Int8MixedPrecisionTrainingQuantizer
  • Ability to extend to other torchao's subclasses? e.g. Float8 and NF4 (right now they don't use quantize_() API, though they can be re-implemented to do so).
    • UPDATE: the better question is how to compose this with QLoRA (i.e. NF4). LoRALinear will always call F.linear() on the NF4 weight. If we make the base weight in LoRALinear a separate nn.Linear module (instead of plain nn.Parameter(), then we can swap the linear module to change the outer op.

These concerns can be addressed in the future I think, when torchao's training subclasses become more mature/stable.

Note: I can't test with 4090 since FlexAttention errors out on 4090

triton.runtime.errors.OutOfResources: out of resource: shared memory

It's pretty strange since it works fine for another repo of mine 🤔.

Changelog

What are the changes made in this PR?

Integrate INT8 mixed-precision from torchao 0.5 0.7

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

mixed_precision._component_=torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer mixed_precision.enabled=True
  • I did not change any public API;
  • I have added an example to docs or docstrings;

Llama3.1-8B single device A100 40% speedup. torch=2.5.0.dev20240911, torchao=0.5.0

tune run full_finetune_single_device --config llama3_1/8B_full_single_device dataset.packed=True tokenizer.max_seq_len=8192 optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True

image

Llama3.1-8B FSDP2 2x A100 24% speedup. torch=2.5.1, pytorch/ao#1179

tune run --nproc_per_node 2 full_finetune_distributed --config llama3_1/8B_full tokenizer.max_seq_len=8192 dataset.packed=True optimizer.fused=True compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True

image

Llama3.1-8B single device A100 LoRA 50% speedup. torch==2.6.0.dev20240914, torchao=0.5.0

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device tokenizer.max_seq_len=8192 dataset.packed=True compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True gradient_accumulation_steps=1

image

LLama3.2-1B single device 4070Ti SUPER QLoRA 60% speedup. torch==2.6.0.dev20241102+cu124, pytorch/ao#1179. Proof-of-concept only since it requires quite significant changes to LoRALinear class. See main...gau-nernst:qlora

tune run lora_finetune_single_device --config llama3_2/1B_qlora_single_device dataset.packed=True tokenizer.max_seq_len=8192 optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True batch_size=1 enable_activation_checkpointing=True

image

Copy link

pytorch-bot bot commented Sep 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1552

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2024
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its very simple and it looks great! thanks for the PR!

My two cents:

  1. We need some tests to make sure it works with compile, AC, offloading (not landed yet), optimizer in backward (i guess?), etc
  2. The configs can be updated in bulk using something like this dummy script: https://gist.github.com/felipemello1/5f2002433c6da3a21f33d6cdf82e702a

Let me know if you want me to help with any of these

torchtune/training/quantization.py Outdated Show resolved Hide resolved
torchtune/training/quantization.py Outdated Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: would it be nice to add some sort of import guard to tell the user they need torchao > N for this? Torchao is not a requirement anymore

cc: @ebsmothers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work with older GPUs? Does it work on CPU? Maybe we need something like this:

_SUPPORTS_INT8_MIXED_PRECISION = (
    torch_version_ge("2.5.0")
    and torch.cuda.is_available()
    and torch.cuda.get_device_capability() >= (7, 5)
)

and torch.cuda.get_device_capability() >= (7, 5)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! I use Triton for this so it will probably run on any GPUs Triton supports (same as torch.compile). Though I think only Ampere (sm80) and above has INT8 tensor cores. To be safe, I think we just guard for sm80 and above.

CPU is not supported. Technically it is possible, but I didn't add it in torchao since I can't reliably test it / see it useful.

This works with PyTorch 2.4 (though FlexAttention requires PyTorch 2.5 🤣).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw doesn't QAT also need some kind of guards like this? 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I completely missed this convo previously. Actually now that we've asked users to install torchao manually I am kinda taking a similar stance there to what we have with PyTorch: people should always be running on the latest stable version. So actually I claim that the first two lines of the _SUPPORTS_INT8_MIXED_PRECISION_TRAINING check are not strictly necessary (fine to keep them in though, I don't have a strong preference). But anyways that should hopefully answer the question about why we don't have similar such guards for QAT

recipes/full_finetune_single_device.py Outdated Show resolved Hide resolved
torchtune/training/quantization.py Outdated Show resolved Hide resolved
recipes/full_finetune_single_device.py Outdated Show resolved Hide resolved
@gau-nernst
Copy link
Contributor Author

@ebsmothers @weifengpy I'm not too keen on adding the hack (i.e. manually shard original weight and construct tensor subclass from it) inside load_from_full_model_state_dict() like NF4. I think this is an internal implementation that downstream users like torchtune should not use/rely on it, as it can change in the future, either because of new features or some kind of refactoring (I think NF4 might have this problem too: It is hard to refactor/change NF4 dtype in torchao now).

If we add more training recipes that are implemented with tensor subclass, such as FP8 with FP8 all-gather, torchtune would need to again manually add more logic in load_from_full_model_state_dict() too.

Also, just to clarify, this PR will use Int8MixedPrecisionTrainingLinearWeight from here. Weight is kept at original precision, and it is quantized dynamically to use INT8 tensor cores, just like FP8 training. So when saving state dict, it's just the matter of unwrapping the tensor subclass. The weight itself is not quantized.

Actually this can be implemented as module-swap too. Since it doesn't implement quantized all-gather yet (only modify forward and backward pass), module-swap will save troubles for torchtune -> state dict is still pure tensor. (though iirc, module-swap was slower than tensor subclass implementation due to some torch.compile quirks). And to clarify, today, INT8 mixed-precision training in torchao is only implemented as tensor subclass, though I did briefly try module-swap implementation in the past.

@weifengpy
Copy link
Contributor

weifengpy commented Oct 5, 2024

@ebsmothers @weifengpy I'm not too keen on adding the hack (i.e. manually shard original weight and construct tensor subclass from it) inside load_from_full_model_state_dict() like NF4

totally agree that float8 and INT8 can be done in a cleaner way. NF4 is unique in 3 ways : 1) NF4 has double-quantization that requires computation over global tensors. quantize-then-chunk is much easier than chunk-then-quantize, 2) NF4 is static quantizaiton and thus lost high precision tensor and stay quantized in the training loop, 3) NF4 dispatch had a hack on copy_ in the old days to quickly unblock single device QLoRA

when saving state dict, it's just the matter of unwrapping the tensor subclass. The weight itself is not quantized.
module-swap was slower than tensor subclass implementation due to some torch.compile quirks

great insights!

@ebsmothers
Copy link
Contributor

Thanks @gau-nernst and @weifengpy for the helpful discussion here. I guess we don't need to block this PR on figuring out the plan for enabling in distributed recipes. But for the sake of my understanding, let me summarize the options that have been discussed:

  1. Perform load_state_dict without assign=True
    a) Optionally shard with CPU process group
  2. Call load_state_dict(strict=False) one tensor at a time
  3. Manually shard original weight and build tensor subclass inside load_from_full_model_state_dict
  4. Enable module-swap int8 mixed precision training

It seems like (1)+(a) is unclear, and (3) is not easy to generalize. Also I imagine (4) is a nontrivial effort. So then it is either the vanilla version of (1) or (2).

It seems to me like (2) should be similar to the memory and speed of the current assign=True usage, but it is a bit awkward. Then the question for me is whether the doubled memory of (1) would actually be the peak memory for the entire run -- if so I think we shouldn't do it. My very basic math is that this would be the case whenever memory(model weights) > max(memory(grads + optimizer states + activations)), is that right? If so I guess it could hold for LoRA + low-precision optimizer at smaller sequence lengths (but as mentioned in a comment elsewhere, maybe LoRA is not the primary use case here).

@gau-nernst
Copy link
Contributor Author

@ebsmothers (4) is actually somewhat trivial, but I think it's best to be done in torchao instead of torchtune. Maybe I will add module-swap UX to torchao first before continuing with this PR (more about torchao problem, I'm not sure if supporting 2 UX flows is the wise thing to do. But perhaps since it's experimental, having more UX options is good to evaluate which flow is better for downstream users like torchtune).

@gau-nernst gau-nernst changed the title Integrate INT8 mixed-precision from torchao 0.5 Integrate INT8 mixed-precision from torchao 0.7 Nov 8, 2024
@gau-nernst gau-nernst marked this pull request as ready for review November 8, 2024 01:36
@gau-nernst
Copy link
Contributor Author

gau-nernst commented Nov 8, 2024

@felipemello1 @ebsmothers With pytorch/ao#1179, we can revisit this PR again. Do you mind resolve previous conversations if they are indeed resolved, so it's less cluttered? And i can address the unresolved ones.

I have updated the PR description. You can see it for new changes + benchmarks.

In summary:

  1. Add module-swap UX for INT8 mixed-precision training ao#1179 adds module-swap UX -> torchtune will use the module swap UX -> resolve the issues with (1) state dict (everything is still plain tensors) and (2) distributed training (shard and load pre-trained weights)
    • This will only be available for the next torchao release (0.7) or from nightly/source.
  2. Expose this as mixed_precision flag in the config file, with enabled option. Once we are ok with this, I will add it to the rest of the configs (if you want)
    • Some alternative names: low_bit_linear/matmul -> a bit more precise than mixed_precision, since it refers to the fact that it uses low-bit for matmul (and the output is still BF16). I think nowadays "mixed precision" can mean many things, thus might be a bit vague. (Just raising some ideas/options to think about. No strong preference about the naming) (cc https://github.com/pytorch/torchtitan/blob/3247841423429faf37bdf6918204350db293e482/train_configs/llama3_8b.toml#L58)
    • torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer might not be the right name/place to put this? Perhaps torchtune.training.mixed_precision (or whatever we gonna call the flag from the previous pointer). To give a better picture, in the future we can do float8 in a similar manner. Again, no strong preference, I can follow whatever you all decide.
  3. This is no longer compatible with LoRA/QLoRA (previously tensor-subclass is compatible with LoRA, but not QLoRA). This is because LoRALinear hard-code F.linear(x, self.weight) for the base weight, making it less flexible to modify this op. See Add support for QAT + LoRA #1931 (comment) for more detailed explanation and my proposed solution (this would be for future PR)

@weifengpy
Copy link
Contributor

great work, @gau-nernst!

@mori360 this module swap UX and torchtune example should help with torchtitan recipe as well.


_SUPPORTS_INT8_MIXED_PRECISION_TRAINING = (
torch_version_ge("2.4.0")
and Version(torchao.__version__) >= Version("0.7")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this didnt work for me

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently I need to add .dev https://github.com/pytorch/ao/blob/2ba1a61fe1244560325b5051b5d3c10044553be0/torchao/utils.py#L529-L532

I was installing torchao from source, which has version number as 0.7.0+gitxxxxx, and the comparison works as expected. Will push the change.

grad_weight=grad_weight,
)

def prepare(self, model: nn.Module) -> nn.Module:
Copy link
Contributor

@felipemello1 felipemello1 Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neither 3b or 8b full single device are not working for me using A100

torch 2.6.0.dev20241107+cu124
torchao 0.7.0.dev20241111+cu124
torchtune 0.0.0 /data/users/felipemello/torchtune
torchvision 0.20.0.dev20241107+cu124

I also tried with torch 2.5.1.

I think that the issue is with recompiling. It takes several minutes on this step (10+min). I wonder its recompiling non-stop, until it runs out of memory and errors with RuntimeError: std::bad_alloc

8b

  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 649, in forward
    output = self.chunked_output(h)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 501, in chunked_output
    return [
           ^
  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 502, in <listcomp>
    self.output(chunk)
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: std::bad_alloc

3b

  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 649, in forward
    output = self.chunked_output(h)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 501, in chunked_output
    return [
           ^
  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 502, in <listcomp>
    self.output(chunk)
  File "/data/users/felipemello/torchtune/torchtune/modules/tied_linear.py", line 66, in __call__
    return self.linear(x, self.tied_module.weight)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/felipemello/torchtune/torchtune/modules/tied_linear.py", line 27, in forward
    return F.linear(x, weight)
           ^^^^^^^^^^^^^^^^^^^
RuntimeError: std::bad_alloc

PS: for 3b we use TiedLinear implementation:

I printed the shape and dtype, not that it matters

torch.Size([5, 256, 3072]) torch.bfloat16
torch.Size([128256, 3072]) torch.bfloat16

I also checked for 3b, and filter_fn returns false for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding std::bad_alloc, I think it's an issue with recent pytorch nightlies. ao CI has been seeing it too pytorch/ao#1229. Currently CI pins to 20241101 https://github.com/pytorch/ao/blob/b2642fb33e360ffe478fe19665b1c4efd80537c6/.github/workflows/regression_test.yml#L64. Can you try with this version?

I don't exactly recall, but I think usually I had to kill the first run (it kinda hangs?), then the subsequent run is fast. Didn't think much of it. I will try to reproduce it again in a fresh environment to see if my memory serves me right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I additionally set quantize_(set_inductor_config=False) to avoid exhaustive torch.compile tuning. Can you try again? It seemed fast to me (<2min). Worked with both Llama3.1-8B and Llama3.2-3B. torch==2.5.1+cu121 and torchao==0.7.0.dev20241112+cu121

The command

tune run full_finetune_single_device --config llama3_2/3B_full_single_device dataset.packed=True tokenizer.max_seq_len=8192 optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True mixed_precision._component_=torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer mixed_precision.enabled=True

@@ -144,6 +166,84 @@ def quantize(self, model):
] = enable_8da4w_fake_quant_module_swap


class Int8MixedPrecisionTrainingQuantizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is my understanding:

  1. torchao has a nice api that does _quantize(model, config)

  2. This pr creates a wrapper specific for int8. The wrapper is needed for two reasons:
    i) ignora lora_a/lora_b
    ii) we dont do nested configs

  3. this means that if we want to support fp8. bitnet, or any other torchao technique, we have to create a custom wrapper, instead of doing _quantize(model, torchao_config)


IMO, this makes a lot of sense if:
a) every torchao technique interacts differently with torchtune, e.g. one doesnt work with offloading, another needs some extra work for ckpt, etc, and we cant solve it with a config parser
b) if there are realistically only a couple or two quantization methods we will use from torchao (e.g. int8 and fp8)

But if thats not the case, then we should probably avoid having a custom torchtune wrapper per torchao technique.

I guess we had a similar situation with OptimizerCPUOffload. We ended up giving up on a custom torchtune wrapper and just instantiated directly from the config.


In summary, instead of "Int8MixedPrecisionTrainingQuantizer", should we have a generalist torchtune "PrecisionTrainingQuantizer", just to handle LoRA, etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think PrecisionTrainingQuantizer is not needed atm. We can revisit it in the future once there are more things like this in the future (like you said, realistically I can only think of int8 and fp8 for now, but it may change)

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding "enabled"!

@felipemello1
Copy link
Contributor

@gau-nernst , can you help me understand why memory is not changed? Do we do a bf16->int8->bf16 projection inplace?

@gau-nernst
Copy link
Contributor Author

@felipemello1

can you help me understand why memory is not changed?

Memory is not expected to change. In this scheme, weights and activations are quantized dynamically to INT8 (with scaling, so that we can do INT8 matmul), then the result is scaled back to BF16. Activations and weights are still in BF16 throughout the model. This is the same strategy as the current torchao.float8 (in fact, in some bad cases, torchao.float8 can consume more memory in FSDP2 due to some autograd shenanigans).

To reduce memory, either weight or activations must stay in low-precision. Using only low-bit weight for training is a bit challenging (not impossible, but there will be convergence/accuracy issues). There are some research works on using low-bit activations (FP8/INT8), but we don't have that in torchao yet. e.g.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants