Skip to content

Commit

Permalink
Optimizer CPU offload for single GPU training (#584)
Browse files Browse the repository at this point in the history
* initial commit

* use fused=True by default for PyTorch adam

* detach param

* try overlap D2H grad copy with backward

* add customizable profile num steps

* add v2

* fix various bugs

* fix v1 impl

* add full BF16 option

* change n_profile_steps to 5

* add v3

* fix gradient accumulation

* add note

* add deepspeed offload

* update deepspeed config

* add some notes

* update instructions. make some packages optional. change to AdamW

* add last updated ordered dict

* update deepspeed version

* remove old versions

* update docs

* say deepspeed is untuned

* add test

* add test for offload_gradients. update benchmark script

* update benchmark results. fix test. fix benchmark script

* fix language

* add save and load

* pre-allocate CPU params. add note about gradient clipping

* update README and remove unused imports
  • Loading branch information
gau-nernst committed Aug 6, 2024
1 parent de4a1fb commit 1b1e94c
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 28 deletions.
126 changes: 99 additions & 27 deletions benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,60 @@
# pip install timm wandb tqdm datasets yacs bitsandbytes git+https://github.com/thu-ml/low-bit-optimizers.git
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core
# pip install timm wandb tqdm datasets bitsandbytes
#
# optional:
# - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git
# - DeepSpeed (ZeRO-Offload):
# sudo apt install libopenmpi-dev
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p
# DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir
#
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core
# python benchmark_low_bit_adam.py \
# --model "timm/vit_base_patch16_224.augreg_in21k" \
# --amp bf16 \
# --optim Adam
# --optim AdamW
#
# See OPTIM_MAP for the available optimizer options
# To profile and export chrome trace, set --profile
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler

import argparse
import datetime
import json
import math
from contextlib import nullcontext
from functools import partial
from pathlib import Path

import bitsandbytes as bnb
import datasets
import lpmm
import timm
import torch
import torch.nn.functional as F
from torch.profiler import ProfilerActivity, profile
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from tqdm import tqdm

from torchao.prototype import low_bit_optim

# lpmm doesn't have Adam, only AdamW
OPTIM_MAP = dict(
Adam=torch.optim.Adam,
Adam8bitBnb=bnb.optim.Adam8bit,
Adam8bitAo=low_bit_optim.Adam8bit,
AdamFp8Ao=low_bit_optim.AdamFp8,
Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True),
Adam4bitAo=low_bit_optim.Adam4bit,
Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")),
AdamW=partial(torch.optim.AdamW, fused=True),
AdamW8bitBnb=bnb.optim.AdamW8bit,
AdamW8bitAo=low_bit_optim.AdamW8bit,
AdamWFp8Ao=low_bit_optim.AdamWFp8,
AdamW4bitAo=low_bit_optim.AdamW4bit,
)

try:
import lpmm

OPTIM_MAP.update(
AdamW4bitLpmm=partial(lpmm.optim.AdamW, fused=True),
AdamW4bitRank1Lpmm=partial(lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")),
)

except ImportError:
pass


class CosineSchedule:
def __init__(self, lr: float, total_steps: int, warmup: float = 0.05) -> None:
Expand Down Expand Up @@ -77,19 +91,23 @@ def log(self, *args, **kwargs):
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--model_kwargs", type=json.loads, default=dict())
parser.add_argument("--checkpoint_activations", action="store_true")

parser.add_argument("--amp", default="none")
parser.add_argument("--full_bf16", action="store_true")
parser.add_argument("--channels_last", action="store_true")
parser.add_argument("--compile", action="store_true")

parser.add_argument("--n_epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--n_workers", type=int, default=4)

parser.add_argument("--optim", default="Adam", choices=OPTIM_MAP.keys())
parser.add_argument("--optim", default="AdamW", choices=OPTIM_MAP.keys())
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--cosine_lr_scheduler", action="store_true")
parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"])

parser.add_argument("--project")
parser.add_argument("--run_name", default="debug")
Expand Down Expand Up @@ -140,6 +158,8 @@ def evaluate_model(model, args):

for batch in tqdm(val_dloader, dynamic_ncols=True, desc=f"Evaluating"):
all_labels.append(batch["label"].clone())
if args.full_bf16:
batch["image"] = batch["image"].bfloat16()
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)

Expand All @@ -156,6 +176,11 @@ def evaluate_model(model, args):
if __name__ == "__main__":
args = get_parser().parse_args()

if args.full_bf16:
assert args.amp == "none", "When --full_bf16 is set, --amp must be none"
if args.optim_cpu_offload == "deepspeed":
assert args.amp == "none", "When using DeepSpeed ZeRO-Offload, --amp must be none"
assert args.optim == "AdamW", "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
if args.profile:
args.n_epochs = 1
if args.seed is not None:
Expand All @@ -169,48 +194,95 @@ def evaluate_model(model, args):
dloader = get_dloader(args, True)
print(f"Train dataset: {len(dloader.dataset):,} images")

model = timm.create_model(args.model, pretrained=True, num_classes=45).cuda()
model = timm.create_model(args.model, pretrained=True, num_classes=45, **args.model_kwargs)
if args.checkpoint_activations:
model.set_grad_checkpointing()
if args.full_bf16:
model.bfloat16()
if args.channels_last:
model.to(memory_format=torch.channels_last)
model.cuda() # move model to CUDA after optionally convert it to BF16
if args.compile:
model.compile(fullgraph=True)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay)
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
if args.optim_cpu_offload == "deepspeed":
import deepspeed

model, optim, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config=dict(
train_batch_size=args.batch_size,
optimizer=dict(
type="Adam",
params=dict(lr=args.lr, weight_decay=args.weight_decay, fp32_optimizer_states=False),
),
bf16=dict(enabled=args.full_bf16),
zero_optimization=dict(
stage=2, # requires ZeRO-2 to enable overlap_comm
overlap_comm=True, # interleave grad D2H with backward
offload_optimizer=dict(device="cpu", pin_memory=True),
),
),
)

else:
optim_cls = OPTIM_MAP[args.optim]

if args.optim_cpu_offload == "ao":
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
elif args.optim_cpu_offload == "ao_offload_grads":
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True)

optim = optim_cls(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")

step = 0
for epoch_idx in range(args.n_epochs):
model.train()
prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext()
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")

start_time = datetime.datetime.now()

with prof:
for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"):
with torch.profiler.profile() if args.profile else nullcontext() as prof:
for batch in pbar:
if args.full_bf16:
batch["image"] = batch["image"].bfloat16()
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)

with get_amp_ctx(args.amp):
loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda())
grad_scaler.scale(loss).backward()

if args.optim_cpu_offload == "deepspeed":
model.backward(loss)
else:
grad_scaler.scale(loss).backward()

if args.cosine_lr_scheduler:
lr = lr_schedule.get_lr(step)
for param_group in optim.param_groups:
param_group["lr"] = lr

if step % 100 == 0:
logger.log(dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]), step=step)

grad_scaler.step(optim)
grad_scaler.update()
optim.zero_grad()
logger.log(
dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]),
step=step,
)

if args.optim_cpu_offload == "deepspeed":
model.step()
else:
grad_scaler.step(optim)
grad_scaler.update()
optim.zero_grad()

step += 1

if args.profile and step == 20:
if args.profile and step == 5:
break

if args.profile:
Expand Down
64 changes: 64 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import tempfile

import pytest
import torch
Expand Down Expand Up @@ -157,6 +158,69 @@ def test_optim_4bit_correctness(self, optim_name):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
@parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)])
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model2 = copy.deepcopy(model1)

optim1 = torch.optim.AdamW(model1.parameters())
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW, offload_gradients=offload_grad,
)

for _ in range(2):
for _ in range(grad_accum):
x = torch.randn(4, 32, device=device)
model1(x).sum().backward()
model2(x).sum().backward()

optim1.step()
optim1.zero_grad()

optim2.step()
optim2.zero_grad()

for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA")
def test_optim_cpu_offload_save_load(self):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW)

for _ in range(2):
x = torch.randn(4, 32, device=device)
model1(x).sum().backward()
optim1.step()
optim1.zero_grad()

# save checkpoint. make sure it can be serialized by torch.save()
with tempfile.NamedTemporaryFile() as file:
torch.save(optim1.state_dict(), file.name)
state_dict = torch.load(file.name)

# resume training
model2 = copy.deepcopy(model1)
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
optim2.load_state_dict(state_dict)

for _ in range(2):
x = torch.randn(4, 32, device=device)

model1(x).sum().backward()
optim1.step()
optim1.zero_grad()

model2(x).sum().backward()
optim2.step()
optim2.zero_grad()

for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)


class TestFSDP2(FSDPTest):
@property
Expand Down
63 changes: 62 additions & 1 deletion torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,67 @@ lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71

(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details.

## Optimizer CPU offload

This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. For multi-GPU training, you can use FSDP's built-in CPU offload.

```python
import torch
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer

model = ...

# only offload optimizer state
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)

# offload optimizer state AND gradients
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, offload_gradients=True, fused=True)
```

This will reduce GPU memory usage by optimizer state size, and additionally gradient size if `offload_gradients=True`. `CPUOffloadOptimizer` can wrap any base optimizer.

For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize CUDA and CPU params in either direction CPU->CUDA and CUDA->CPU, in case they are out of sync.)

```python
ckpt = torch.load("checkpoint.pth")

model = ...
model.load_state_dict(ckpt["model"])

optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
optim.load_state_dict(ckpt["optim"])
```

NOTE:
- Since the optimizer step is done on CPU, it is highly recommended to use a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` (requires PyTorch 2.4). For other optimizers, you can try `torch.compile()` their optimizer step.
- To minimize the amount of CPU<->GPU data transfer, we keep a copy of parameters and pre-allocate gradients memory on CPU. Therefore, expect your RAM usage to increase by 2x model size + optimizer state (which is 2x model size for Adam).
- It is recommended NOT to `torch.compile()` your whole model when `CPUOffloadOptimizer` is used, as it prevents us from interleaving gradient device-to-host transfer with backward pass. To minimize such impact, you can compile parts of your model separately. See [#584](https://github.com/pytorch/ao/pull/584) for more information.
- CPU optimizer step is often the bottleneck when optimizer CPU offload is used. To minimize the slowdown, it is recommended to (1) do full BF16 training (instead of AMP), so that parameters, gradients, and optimizer states are in BF16; and (2) give GPU more work per optimizer step (e.g. larger batch size with activation checkpointing, gradient accumulation).
- `offload_gradients=True` is not compatible with gradient accumulation, since we clear gradients on GPU every backward pass.
- Gradient clipping is currently not supported.

Benchmark done for `timm/vit_giant_patch14_dinov2.lvd142m` (1.1B params), eager mode, full BF16 training, activations checkpointing, batch size 32, on 4070Ti SUPER (16GB VRAM), Ryzen 5600, DDR4 RAM. DeepSpeed is untuned.

Adam offload | Time per step | Max memory
-----------------------|---------------|------------
None | 1.27s/it | 9.82 GB
DeepSpeed ZeRO-Offload | 3.13s/it | 6.85 GB
ao | 1.52s/it | 5.24 GB
ao (offload gradients) | 1.53s/it | 4.01 GB

Ablations on AMP and `torch.compile()`

Training config | Adam offload | Time per step | Max memory
--------------------|--------------|---------------|------------
Full BF16, compiled | None | 1.18s/it | 9.90 GB
Full BF16, compiled | ao | 1.75s/it | 5.33 GB
BF16 AMP, eager | None | OOM | OOM
BF16 AMP, eager | ao | 2.18s/it | 9.90 GB

## Credits

Credits to Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library, and [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers.
Credits to

- Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library.
- [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers.
- [DeepSpeed](https://github.com/microsoft/DeepSpeed) team for [ZeRO-Offload](https://arxiv.org/abs/2101.06840).
1 change: 1 addition & 0 deletions torchao/prototype/low_bit_optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .adam import Adam8bit, Adam4bit, AdamFp8
from .adamw import AdamW8bit, AdamW4bit, AdamWFp8
from .cpu_offload import CPUOffloadOptimizer
Loading

0 comments on commit 1b1e94c

Please sign in to comment.