Skip to content

Add FSDP2 support for low-bit optimizers #484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
9331495
add test for FSDP2
gau-nernst Jul 8, 2024
e159e93
fix optim
gau-nernst Jul 8, 2024
2fc8830
add some fsdp2 ops
gau-nernst Jul 8, 2024
e30f142
add DTensor
gau-nernst Jul 8, 2024
a88750c
try DTensor
gau-nernst Jul 8, 2024
9593c07
undo changes
gau-nernst Jul 8, 2024
e3408d8
add DTensor support for adamw
gau-nernst Jul 8, 2024
a71c9bc
update imports
gau-nernst Jul 8, 2024
1aa4600
remove whitespace
gau-nernst Jul 8, 2024
17e4cb8
fix view issue with compiler
gau-nernst Jul 8, 2024
a0aef48
Merge branch 'pytorch:main' into low_bit_optim_fsdp2
gau-nernst Jul 11, 2024
8bfff6d
small refactoring
gau-nernst Jul 11, 2024
7d28556
static horizontal fusion
gau-nernst Jul 11, 2024
96c4286
add test for CPU offload (may timeout)
gau-nernst Jul 11, 2024
f9817fb
update note
gau-nernst Jul 11, 2024
4549aa1
update benchmarks
gau-nernst Jul 11, 2024
bb34b00
remove CPUOffloadPolicy test
gau-nernst Jul 11, 2024
ddbdd9c
Merge branch 'pytorch:main' into low_bit_optim_fsdp2
gau-nernst Jul 12, 2024
5ffab6d
fix version test
gau-nernst Jul 12, 2024
ccc5904
fix typo
gau-nernst Jul 12, 2024
99ac35d
add custom 4-bit
gau-nernst Jul 12, 2024
d39d750
revert version check
gau-nernst Jul 12, 2024
8b19e23
replace 4-bit adam
gau-nernst Jul 12, 2024
e859fbd
switch 4-bit impl
gau-nernst Jul 12, 2024
0b05caa
refactor
gau-nernst Jul 12, 2024
a0390af
update test
gau-nernst Jul 12, 2024
4b0d22c
update test. some fixes
gau-nernst Jul 12, 2024
f9e40e5
bring back 4bit subclass
gau-nernst Jul 13, 2024
b7ccb60
fixes
gau-nernst Jul 13, 2024
9a5a722
add DTensor to custom optim_4bit
gau-nernst Jul 13, 2024
108df09
fix default block_size
gau-nernst Jul 13, 2024
7bbf46f
separate scale and quantize
gau-nernst Jul 13, 2024
f75c2ee
add 4-bit optim fix
gau-nernst Jul 13, 2024
c448e86
fix 4-bit subclass. replace default 4-bit optim impl
gau-nernst Jul 14, 2024
db207e4
update table
gau-nernst Jul 14, 2024
0c39755
add BF16 smoke test
gau-nernst Jul 14, 2024
320b0e5
remove unused 4-bit impl
gau-nernst Jul 15, 2024
ac72ea3
Merge branch 'pytorch:main' into low_bit_optim_fsdp2
gau-nernst Jul 16, 2024
6b1d7d2
debug FSDP test
gau-nernst Jul 16, 2024
d936a28
print cache size limit
gau-nernst Jul 16, 2024
fe5864c
increase cache size limit
gau-nernst Jul 16, 2024
fb12060
add compute capability check for FP8 in FSDP test
gau-nernst Jul 16, 2024
87ae147
revert debug tests
gau-nernst Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 114 additions & 43 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
from functools import partial

import pytest
import torch
Expand All @@ -10,9 +9,11 @@
parametrize,
run_tests,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torchao.prototype import low_bit_optim
from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit
from torchao.utils import TORCH_VERSION_AFTER_2_3
from torchao.prototype.low_bit_optim.quant_utils import quantize_8bit_with_qmap, quantize_4bit_with_qmap
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

try:
import bitsandbytes as bnb
Expand All @@ -31,52 +32,69 @@
class TestQuantize(TestCase):
@parametrize("device", _DEVICES)
def test_quantize_8bit_with_qmap_correctness(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device)
x = torch.rand(32, 1024, device=device)
qmap = torch.rand(256, device=device).sort().values

actual_codes, actual_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=1)
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=0)
actual = (x.unsqueeze(-1) - qmap).abs().argmin(-1).to(torch.uint8)
expected = quantize_8bit_with_qmap(x, qmap)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
def test_quantize_8bit_with_qmap_compile(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device)
x = torch.rand(32, 1024, device=device)
qmap = torch.rand(256, device=device).sort().values

compiled_f = torch.compile(subclass_8bit.quantize_8bit_with_qmap, fullgraph=True)
actual_codes, actual_scale = compiled_f(x, qmap, 256)
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256)
compiled_f = torch.compile(quantize_8bit_with_qmap, fullgraph=True)
actual = compiled_f(x, qmap)
expected = quantize_8bit_with_qmap(x, qmap)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
def test_quantize_4bit_with_qmap_correctness(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device)
x = torch.rand(32, 1024, device=device)
qmap = torch.rand(16, device=device).sort().values

actual_codes, actual_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=1)
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=0)
actual = (x.unsqueeze(-1) - qmap).abs().argmin(-1).to(torch.uint8)
expected = quantize_4bit_with_qmap(x, qmap)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
def test_quantize_4bit_with_qmap_compile(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device)
x = torch.rand(32, 1024, device=device)
qmap = torch.rand(16, device=device).sort().values

compiled_f = torch.compile(subclass_4bit.quantize_4bit_with_qmap, fullgraph=True)
actual_codes, actual_scale = compiled_f(x, qmap, 256)
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256)
compiled_f = torch.compile(quantize_4bit_with_qmap, fullgraph=True)
actual = compiled_f(x, qmap)
expected = quantize_4bit_with_qmap(x, qmap)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)
torch.testing.assert_close(actual, expected)


class TestOptim(TestCase):
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"])
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("device", _DEVICES)
def test_optim_smoke(self, optim_name, dtype, device):
if optim_name.endswith("Fp8") and device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 requires compute capability >= 8.9")

# reset cache to avoid hitting cache_size_limit, since the function will re-compile for each test
torch._dynamo.reset_code_caches()

model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32))
model.to(device=device, dtype=dtype)
optim = getattr(low_bit_optim, optim_name)(model.parameters())

x = torch.randn(4, 32, device=device, dtype=dtype)
loss = model(x).sum()
loss.backward()
optim.step()
optim.zero_grad()

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
Expand Down Expand Up @@ -139,21 +157,74 @@ def test_optim_4bit_correctness(self, optim_name):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@parametrize("optim_name", ["AdamFp8", "AdamWFp8"])
@parametrize("device", _DEVICES)
def test_optim_fp8_smoke(self, optim_name, device):
if device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 requires compute capability >= 8.9")

model = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
optim = getattr(low_bit_optim, optim_name)(model.parameters())

x = torch.randn(4, 32, device=device)
loss = model(x).sum()
loss.backward()
optim.step()
optim.zero_grad()
class TestFSDP2(FSDPTest):
@property
def world_size(self) -> int:
return 2

@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="torch >= 2.4 required")
@skip_if_lt_x_gpu(2)
def test_fsdp2(self):
optim_classes = [low_bit_optim.Adam8bit, low_bit_optim.Adam4bit]
if torch.cuda.get_device_capability() >= (8, 9):
optim_classes.append(low_bit_optim.AdamFp8)

self.run_subtests(
{"optim_cls": optim_classes},
self._test_fsdp2,
)

def _test_fsdp2(self, optim_cls):
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
TransformerBlock,
)

# seems like cache_size_limit is shared between FSDP processes?
torch._dynamo.config.cache_size_limit = 8 * self.world_size

batch_size = 3
vocab_size = 1024
seq_len = 64
model_args = ModelArgs(
n_layers=3,
n_heads=4,
dim=1024,
vocab_size=vocab_size,
max_seq_len=seq_len,
dropout_p=0,
)
torch.manual_seed(42)
with torch.device("cuda"):
base_model = Transformer(model_args)
base_optim = optim_cls(base_model.parameters(), lr=1e-2)

fsdp_model = copy.deepcopy(base_model)
for m in fsdp_model.modules():
if isinstance(m, TransformerBlock):
fully_shard(m)
fully_shard(fsdp_model)
fsdp_optim = optim_cls(fsdp_model.parameters(), lr=1e-2)

torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(5):
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = fsdp_model(inp).mean()
fsdp_loss.backward()
fsdp_optim.step()

base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
base_loss = base_model(inp).mean()
base_loss.backward()
for param in base_model.parameters():
if param.grad is not None:
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG)
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)


instantiate_parametrized_tests(TestQuantize)
Expand Down
6 changes: 3 additions & 3 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy
---------------|-----------------|--------------------------|----------
PyTorch | 12.94 | 8m 18s | 91.14
bnb 8-bit | 8.31 | 6m 50s | 90.67
ao 8-bit | 8.32 | 9m 04s | 90.71
ao FP8 E4M3 | 8.32 | 6m 38s | 91.08
ao 8-bit | 8.31 | 6m 44s | 90.63
ao FP8 E4M3 | 8.32 | 6m 35s | 90.98
lpmm 4-bit | 7.72 | 5m 59s | 89.97
ao 4-bit | 7.72 | 7m 00s | 89.94
ao 4-bit | 7.72 | 7m 13s | 90.05
lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71

(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details.
Expand Down
Loading
Loading