Skip to content

Commit

Permalink
Add FSDP2 support for low-bit optimizers (pytorch#484)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst authored Jul 16, 2024
1 parent efda619 commit ade8feb
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 287 deletions.
157 changes: 114 additions & 43 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
from functools import partial

import pytest
import torch
Expand All @@ -10,9 +9,11 @@
parametrize,
run_tests,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torchao.prototype import low_bit_optim
from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit
from torchao.utils import TORCH_VERSION_AFTER_2_3
from torchao.prototype.low_bit_optim.quant_utils import quantize_8bit_with_qmap, quantize_4bit_with_qmap
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

try:
import bitsandbytes as bnb
Expand All @@ -31,52 +32,69 @@
class TestQuantize(TestCase):
@parametrize("device", _DEVICES)
def test_quantize_8bit_with_qmap_correctness(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device)
x = torch.rand(32, 1024, device=device)
qmap = torch.rand(256, device=device).sort().values

actual_codes, actual_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=1)
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=0)
actual = (x.unsqueeze(-1) - qmap).abs().argmin(-1).to(torch.uint8)
expected = quantize_8bit_with_qmap(x, qmap)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
def test_quantize_8bit_with_qmap_compile(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device)
x = torch.rand(32, 1024, device=device)
qmap = torch.rand(256, device=device).sort().values

compiled_f = torch.compile(subclass_8bit.quantize_8bit_with_qmap, fullgraph=True)
actual_codes, actual_scale = compiled_f(x, qmap, 256)
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256)
compiled_f = torch.compile(quantize_8bit_with_qmap, fullgraph=True)
actual = compiled_f(x, qmap)
expected = quantize_8bit_with_qmap(x, qmap)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
def test_quantize_4bit_with_qmap_correctness(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device)
x = torch.rand(32, 1024, device=device)
qmap = torch.rand(16, device=device).sort().values

actual_codes, actual_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=1)
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=0)
actual = (x.unsqueeze(-1) - qmap).abs().argmin(-1).to(torch.uint8)
expected = quantize_4bit_with_qmap(x, qmap)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
def test_quantize_4bit_with_qmap_compile(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device)
x = torch.rand(32, 1024, device=device)
qmap = torch.rand(16, device=device).sort().values

compiled_f = torch.compile(subclass_4bit.quantize_4bit_with_qmap, fullgraph=True)
actual_codes, actual_scale = compiled_f(x, qmap, 256)
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256)
compiled_f = torch.compile(quantize_4bit_with_qmap, fullgraph=True)
actual = compiled_f(x, qmap)
expected = quantize_4bit_with_qmap(x, qmap)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)
torch.testing.assert_close(actual, expected)


class TestOptim(TestCase):
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"])
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("device", _DEVICES)
def test_optim_smoke(self, optim_name, dtype, device):
if optim_name.endswith("Fp8") and device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 requires compute capability >= 8.9")

# reset cache to avoid hitting cache_size_limit, since the function will re-compile for each test
torch._dynamo.reset_code_caches()

model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32))
model.to(device=device, dtype=dtype)
optim = getattr(low_bit_optim, optim_name)(model.parameters())

x = torch.randn(4, 32, device=device, dtype=dtype)
loss = model(x).sum()
loss.backward()
optim.step()
optim.zero_grad()

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
Expand Down Expand Up @@ -139,21 +157,74 @@ def test_optim_4bit_correctness(self, optim_name):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@parametrize("optim_name", ["AdamFp8", "AdamWFp8"])
@parametrize("device", _DEVICES)
def test_optim_fp8_smoke(self, optim_name, device):
if device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 requires compute capability >= 8.9")

model = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
optim = getattr(low_bit_optim, optim_name)(model.parameters())

x = torch.randn(4, 32, device=device)
loss = model(x).sum()
loss.backward()
optim.step()
optim.zero_grad()
class TestFSDP2(FSDPTest):
@property
def world_size(self) -> int:
return 2

@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="torch >= 2.4 required")
@skip_if_lt_x_gpu(2)
def test_fsdp2(self):
optim_classes = [low_bit_optim.Adam8bit, low_bit_optim.Adam4bit]
if torch.cuda.get_device_capability() >= (8, 9):
optim_classes.append(low_bit_optim.AdamFp8)

self.run_subtests(
{"optim_cls": optim_classes},
self._test_fsdp2,
)

def _test_fsdp2(self, optim_cls):
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
TransformerBlock,
)

# seems like cache_size_limit is shared between FSDP processes?
torch._dynamo.config.cache_size_limit = 8 * self.world_size

batch_size = 3
vocab_size = 1024
seq_len = 64
model_args = ModelArgs(
n_layers=3,
n_heads=4,
dim=1024,
vocab_size=vocab_size,
max_seq_len=seq_len,
dropout_p=0,
)
torch.manual_seed(42)
with torch.device("cuda"):
base_model = Transformer(model_args)
base_optim = optim_cls(base_model.parameters(), lr=1e-2)

fsdp_model = copy.deepcopy(base_model)
for m in fsdp_model.modules():
if isinstance(m, TransformerBlock):
fully_shard(m)
fully_shard(fsdp_model)
fsdp_optim = optim_cls(fsdp_model.parameters(), lr=1e-2)

torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(5):
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = fsdp_model(inp).mean()
fsdp_loss.backward()
fsdp_optim.step()

base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
base_loss = base_model(inp).mean()
base_loss.backward()
for param in base_model.parameters():
if param.grad is not None:
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG)
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)


instantiate_parametrized_tests(TestQuantize)
Expand Down
6 changes: 3 additions & 3 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy
---------------|-----------------|--------------------------|----------
PyTorch | 12.94 | 8m 18s | 91.14
bnb 8-bit | 8.31 | 6m 50s | 90.67
ao 8-bit | 8.32 | 9m 04s | 90.71
ao FP8 E4M3 | 8.32 | 6m 38s | 91.08
ao 8-bit | 8.31 | 6m 44s | 90.63
ao FP8 E4M3 | 8.32 | 6m 35s | 90.98
lpmm 4-bit | 7.72 | 5m 59s | 89.97
ao 4-bit | 7.72 | 7m 00s | 89.94
ao 4-bit | 7.72 | 7m 13s | 90.05
lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71

(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details.
Expand Down
Loading

0 comments on commit ade8feb

Please sign in to comment.