-
Notifications
You must be signed in to change notification settings - Fork 169
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add 4bit * rename * simplify 4bit * add rank1 scaling * add lpmm to benchmark * remove rank-1 scaling * update * clean * rename * update test * fix * fix * update adam * add AdamW 4bit * update * remove lpmm from dev cuz CI can't compile * fix test * update README * Update README.md * update readme. small fixes * remove zero padding
- Loading branch information
1 parent
9f85488
commit 34fedff
Showing
12 changed files
with
492 additions
and
198 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import copy | ||
from functools import partial | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
instantiate_parametrized_tests, | ||
parametrize, | ||
run_tests, | ||
) | ||
from torchao.prototype import low_bit_optim | ||
from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit | ||
from torchao.utils import TORCH_VERSION_AFTER_2_3 | ||
|
||
try: | ||
import bitsandbytes as bnb | ||
except ImportError: | ||
bnb = None | ||
|
||
try: | ||
import lpmm | ||
except ImportError: | ||
lpmm = None | ||
|
||
|
||
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
|
||
|
||
class TestQuantize(TestCase): | ||
@parametrize("device", _DEVICES) | ||
def test_quantize_8bit_with_qmap_correctness(self, device): | ||
x = torch.randn(32, 1024, device=device) | ||
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device) | ||
|
||
actual_codes, actual_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=1) | ||
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=0) | ||
|
||
torch.testing.assert_close(actual_codes, expected_codes) | ||
torch.testing.assert_close(actual_scale, expected_scale) | ||
|
||
@parametrize("device", _DEVICES) | ||
def test_quantize_8bit_with_qmap_compile(self, device): | ||
x = torch.randn(32, 1024, device=device) | ||
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device) | ||
|
||
compiled_f = torch.compile(subclass_8bit.quantize_8bit_with_qmap, fullgraph=True) | ||
actual_codes, actual_scale = compiled_f(x, qmap, 256) | ||
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256) | ||
|
||
torch.testing.assert_close(actual_codes, expected_codes) | ||
torch.testing.assert_close(actual_scale, expected_scale) | ||
|
||
@parametrize("device", _DEVICES) | ||
def test_quantize_4bit_with_qmap_correctness(self, device): | ||
x = torch.randn(32, 1024, device=device) | ||
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device) | ||
|
||
actual_codes, actual_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=1) | ||
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=0) | ||
|
||
torch.testing.assert_close(actual_codes, expected_codes) | ||
torch.testing.assert_close(actual_scale, expected_scale) | ||
|
||
@parametrize("device", _DEVICES) | ||
def test_quantize_4bit_with_qmap_compile(self, device): | ||
x = torch.randn(32, 1024, device=device) | ||
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device) | ||
|
||
compiled_f = torch.compile(subclass_4bit.quantize_4bit_with_qmap, fullgraph=True) | ||
actual_codes, actual_scale = compiled_f(x, qmap, 256) | ||
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256) | ||
|
||
torch.testing.assert_close(actual_codes, expected_codes) | ||
torch.testing.assert_close(actual_scale, expected_scale) | ||
|
||
|
||
class TestOptim(TestCase): | ||
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") | ||
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") | ||
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) | ||
def test_optim_8bit_correctness(self, optim_name): | ||
device = "cuda" | ||
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) | ||
model2 = copy.deepcopy(model1) | ||
|
||
optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) | ||
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) | ||
|
||
for _ in range(2): | ||
x = torch.randn(4, 32, device=device) | ||
|
||
loss1 = model1(x).sum() | ||
loss1.backward() | ||
optim1.step() | ||
optim1.zero_grad() | ||
|
||
loss2 = model2(x).sum() | ||
loss2.backward() | ||
optim2.step() | ||
optim2.zero_grad() | ||
|
||
for p1, p2 in zip(model1.parameters(), model2.parameters()): | ||
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) | ||
|
||
@pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle") | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA") | ||
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") | ||
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) | ||
def test_optim_4bit_correctness(self, optim_name): | ||
device = "cuda" | ||
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) | ||
model2 = copy.deepcopy(model1) | ||
|
||
# lpmm doesn't have Adam. use AdamW with no weight decay instead. | ||
if optim_name == "Adam4bit": | ||
optim1 = lpmm.optim.AdamW(model1.parameters(), weight_decay=0) | ||
elif optim_name == "AdamW4bit": | ||
optim1 = lpmm.optim.AdamW(model1.parameters()) | ||
else: | ||
raise ValueError(f"Unsupported {optim_name} optimizer for lpmm") | ||
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) | ||
|
||
for _ in range(2): | ||
x = torch.randn(4, 32, device=device) | ||
|
||
loss1 = model1(x).sum() | ||
loss1.backward() | ||
optim1.step() | ||
optim1.zero_grad() | ||
|
||
loss2 = model2(x).sum() | ||
loss2.backward() | ||
optim2.step() | ||
optim2.zero_grad() | ||
|
||
for p1, p2 in zip(model1.parameters(), model2.parameters()): | ||
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) | ||
|
||
|
||
instantiate_parametrized_tests(TestQuantize) | ||
instantiate_parametrized_tests(TestOptim) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Low-bit optimizers | ||
|
||
This folder implements: | ||
|
||
- 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861 | ||
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507 | ||
|
||
The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel. | ||
|
||
## Usage | ||
|
||
This is a drop-in replacement for `torch.optim.Adam` | ||
|
||
```python | ||
from torchao.prototype.low_bit_optim import Adam8bit | ||
|
||
model = ... | ||
optim = Adam8bit(model.parameters()) | ||
``` | ||
|
||
To use 4-bit Adam, replace the above with `Adam4bit`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit optimizers, and 128 for 4-bit optimizers. | ||
|
||
**Other optimizers**: AdamW is also available as `AdamW8bit` and `AdamW4bit`. Other optimizers can be added based on demand. | ||
|
||
NOTE: | ||
- The low-bit optimizers require PyTorch >= 2.3 | ||
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. | ||
- **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed. | ||
|
||
## Benchmarks | ||
|
||
Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py). | ||
|
||
Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER: | ||
|
||
Adam impl | max memory (GB) | time taken | accuracy | ||
-----------|-----------------|------------|---------- | ||
PyTorch | 12.98 | 10m 08s | 87.70 | ||
bnb 8-bit | 8.31 | 8m 38s | 86.22 | ||
ao 8-bit | 8.32 | 10m 54s | 86.67 | ||
lpmm 4-bit | 7.72 | 7m 48s | 84.70 | ||
ao 4-bit | 7.72 | 9m 17s | 85.60 | ||
|
||
NOTE: time taken includes validation time, and compile time for torchao optimizers. | ||
|
||
## Credits | ||
|
||
Credits to Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library, and [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .adam import Adam8bit, Adam4bit | ||
from .adamw import AdamW8bit, AdamW4bit |
Oops, something went wrong.