From 1b1e94c5b537a2bccdf705e11009555646cf6ae6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 6 Aug 2024 23:19:13 +0800 Subject: [PATCH] Optimizer CPU offload for single GPU training (#584) * initial commit * use fused=True by default for PyTorch adam * detach param * try overlap D2H grad copy with backward * add customizable profile num steps * add v2 * fix various bugs * fix v1 impl * add full BF16 option * change n_profile_steps to 5 * add v3 * fix gradient accumulation * add note * add deepspeed offload * update deepspeed config * add some notes * update instructions. make some packages optional. change to AdamW * add last updated ordered dict * update deepspeed version * remove old versions * update docs * say deepspeed is untuned * add test * add test for offload_gradients. update benchmark script * update benchmark results. fix test. fix benchmark script * fix language * add save and load * pre-allocate CPU params. add note about gradient clipping * update README and remove unused imports --- benchmarks/benchmark_low_bit_adam.py | 126 ++++++++++++++---- test/prototype/test_low_bit_optim.py | 64 +++++++++ torchao/prototype/low_bit_optim/README.md | 63 ++++++++- torchao/prototype/low_bit_optim/__init__.py | 1 + .../prototype/low_bit_optim/cpu_offload.py | 104 +++++++++++++++ 5 files changed, 330 insertions(+), 28 deletions(-) create mode 100644 torchao/prototype/low_bit_optim/cpu_offload.py diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index 6517eac8f..33efb4e52 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -1,10 +1,17 @@ -# pip install timm wandb tqdm datasets yacs bitsandbytes git+https://github.com/thu-ml/low-bit-optimizers.git -# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core +# pip install timm wandb tqdm datasets bitsandbytes # +# optional: +# - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git +# - DeepSpeed (ZeRO-Offload): +# sudo apt install libopenmpi-dev +# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p +# DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir +# +# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core # python benchmark_low_bit_adam.py \ # --model "timm/vit_base_patch16_224.augreg_in21k" \ # --amp bf16 \ -# --optim Adam +# --optim AdamW # # See OPTIM_MAP for the available optimizer options # To profile and export chrome trace, set --profile @@ -12,6 +19,7 @@ import argparse import datetime +import json import math from contextlib import nullcontext from functools import partial @@ -19,28 +27,34 @@ import bitsandbytes as bnb import datasets -import lpmm import timm import torch import torch.nn.functional as F -from torch.profiler import ProfilerActivity, profile from torch.utils.data import DataLoader from torchvision.transforms import v2 from tqdm import tqdm from torchao.prototype import low_bit_optim -# lpmm doesn't have Adam, only AdamW OPTIM_MAP = dict( - Adam=torch.optim.Adam, - Adam8bitBnb=bnb.optim.Adam8bit, - Adam8bitAo=low_bit_optim.Adam8bit, - AdamFp8Ao=low_bit_optim.AdamFp8, - Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True), - Adam4bitAo=low_bit_optim.Adam4bit, - Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")), + AdamW=partial(torch.optim.AdamW, fused=True), + AdamW8bitBnb=bnb.optim.AdamW8bit, + AdamW8bitAo=low_bit_optim.AdamW8bit, + AdamWFp8Ao=low_bit_optim.AdamWFp8, + AdamW4bitAo=low_bit_optim.AdamW4bit, ) +try: + import lpmm + + OPTIM_MAP.update( + AdamW4bitLpmm=partial(lpmm.optim.AdamW, fused=True), + AdamW4bitRank1Lpmm=partial(lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")), + ) + +except ImportError: + pass + class CosineSchedule: def __init__(self, lr: float, total_steps: int, warmup: float = 0.05) -> None: @@ -77,8 +91,11 @@ def log(self, *args, **kwargs): def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--model", required=True) + parser.add_argument("--model_kwargs", type=json.loads, default=dict()) + parser.add_argument("--checkpoint_activations", action="store_true") parser.add_argument("--amp", default="none") + parser.add_argument("--full_bf16", action="store_true") parser.add_argument("--channels_last", action="store_true") parser.add_argument("--compile", action="store_true") @@ -86,10 +103,11 @@ def get_parser(): parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--n_workers", type=int, default=4) - parser.add_argument("--optim", default="Adam", choices=OPTIM_MAP.keys()) + parser.add_argument("--optim", default="AdamW", choices=OPTIM_MAP.keys()) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--weight_decay", type=float, default=0) parser.add_argument("--cosine_lr_scheduler", action="store_true") + parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"]) parser.add_argument("--project") parser.add_argument("--run_name", default="debug") @@ -140,6 +158,8 @@ def evaluate_model(model, args): for batch in tqdm(val_dloader, dynamic_ncols=True, desc=f"Evaluating"): all_labels.append(batch["label"].clone()) + if args.full_bf16: + batch["image"] = batch["image"].bfloat16() if args.channels_last: batch["image"] = batch["image"].to(memory_format=torch.channels_last) @@ -156,6 +176,11 @@ def evaluate_model(model, args): if __name__ == "__main__": args = get_parser().parse_args() + if args.full_bf16: + assert args.amp == "none", "When --full_bf16 is set, --amp must be none" + if args.optim_cpu_offload == "deepspeed": + assert args.amp == "none", "When using DeepSpeed ZeRO-Offload, --amp must be none" + assert args.optim == "AdamW", "When using DeepSpeed ZeRO-Offload, --optim must be AdamW" if args.profile: args.n_epochs = 1 if args.seed is not None: @@ -169,32 +194,73 @@ def evaluate_model(model, args): dloader = get_dloader(args, True) print(f"Train dataset: {len(dloader.dataset):,} images") - model = timm.create_model(args.model, pretrained=True, num_classes=45).cuda() + model = timm.create_model(args.model, pretrained=True, num_classes=45, **args.model_kwargs) + if args.checkpoint_activations: + model.set_grad_checkpointing() + if args.full_bf16: + model.bfloat16() if args.channels_last: model.to(memory_format=torch.channels_last) + model.cuda() # move model to CUDA after optionally convert it to BF16 if args.compile: model.compile(fullgraph=True) print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") - optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay) - lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs) + if args.optim_cpu_offload == "deepspeed": + import deepspeed + + model, optim, _, _ = deepspeed.initialize( + model=model, + model_parameters=model.parameters(), + config=dict( + train_batch_size=args.batch_size, + optimizer=dict( + type="Adam", + params=dict(lr=args.lr, weight_decay=args.weight_decay, fp32_optimizer_states=False), + ), + bf16=dict(enabled=args.full_bf16), + zero_optimization=dict( + stage=2, # requires ZeRO-2 to enable overlap_comm + overlap_comm=True, # interleave grad D2H with backward + offload_optimizer=dict(device="cpu", pin_memory=True), + ), + ), + ) + + else: + optim_cls = OPTIM_MAP[args.optim] + + if args.optim_cpu_offload == "ao": + optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls) + elif args.optim_cpu_offload == "ao_offload_grads": + optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True) + optim = optim_cls(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs) grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") step = 0 for epoch_idx in range(args.n_epochs): model.train() - prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext() + pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}") + start_time = datetime.datetime.now() - with prof: - for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"): + with torch.profiler.profile() if args.profile else nullcontext() as prof: + for batch in pbar: + if args.full_bf16: + batch["image"] = batch["image"].bfloat16() if args.channels_last: batch["image"] = batch["image"].to(memory_format=torch.channels_last) with get_amp_ctx(args.amp): loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda()) - grad_scaler.scale(loss).backward() + + if args.optim_cpu_offload == "deepspeed": + model.backward(loss) + else: + grad_scaler.scale(loss).backward() if args.cosine_lr_scheduler: lr = lr_schedule.get_lr(step) @@ -202,15 +268,21 @@ def evaluate_model(model, args): param_group["lr"] = lr if step % 100 == 0: - logger.log(dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]), step=step) - - grad_scaler.step(optim) - grad_scaler.update() - optim.zero_grad() + logger.log( + dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]), + step=step, + ) + + if args.optim_cpu_offload == "deepspeed": + model.step() + else: + grad_scaler.step(optim) + grad_scaler.update() + optim.zero_grad() step += 1 - if args.profile and step == 20: + if args.profile and step == 5: break if args.profile: diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 5eb0a54b6..97d6ea9da 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -1,4 +1,5 @@ import copy +import tempfile import pytest import torch @@ -157,6 +158,69 @@ def test_optim_4bit_correctness(self, optim_name): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA") + @parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)]) + def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): + device = "cuda" + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model2 = copy.deepcopy(model1) + + optim1 = torch.optim.AdamW(model1.parameters()) + optim2 = low_bit_optim.CPUOffloadOptimizer( + model2.parameters(), torch.optim.AdamW, offload_gradients=offload_grad, + ) + + for _ in range(2): + for _ in range(grad_accum): + x = torch.randn(4, 32, device=device) + model1(x).sum().backward() + model2(x).sum().backward() + + optim1.step() + optim1.zero_grad() + + optim2.step() + optim2.zero_grad() + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + torch.testing.assert_close(p2, p1) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA") + def test_optim_cpu_offload_save_load(self): + device = "cuda" + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW) + + for _ in range(2): + x = torch.randn(4, 32, device=device) + model1(x).sum().backward() + optim1.step() + optim1.zero_grad() + + # save checkpoint. make sure it can be serialized by torch.save() + with tempfile.NamedTemporaryFile() as file: + torch.save(optim1.state_dict(), file.name) + state_dict = torch.load(file.name) + + # resume training + model2 = copy.deepcopy(model1) + optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW) + optim2.load_state_dict(state_dict) + + for _ in range(2): + x = torch.randn(4, 32, device=device) + + model1(x).sum().backward() + optim1.step() + optim1.zero_grad() + + model2(x).sum().backward() + optim2.step() + optim2.zero_grad() + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + torch.testing.assert_close(p2, p1) + class TestFSDP2(FSDPTest): @property diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index 641207414..f307be092 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -46,6 +46,67 @@ lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71 (*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details. +## Optimizer CPU offload + +This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. For multi-GPU training, you can use FSDP's built-in CPU offload. + +```python +import torch +from torchao.prototype.low_bit_optim import CPUOffloadOptimizer + +model = ... + +# only offload optimizer state +optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True) + +# offload optimizer state AND gradients +optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, offload_gradients=True, fused=True) +``` + +This will reduce GPU memory usage by optimizer state size, and additionally gradient size if `offload_gradients=True`. `CPUOffloadOptimizer` can wrap any base optimizer. + +For saving and loading `CPUOffloadOptimizer`, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside `CPUOffloadOptimizer.__init__()`. (TODO: we might want to have a method to synchronize CUDA and CPU params in either direction CPU->CUDA and CUDA->CPU, in case they are out of sync.) + +```python +ckpt = torch.load("checkpoint.pth") + +model = ... +model.load_state_dict(ckpt["model"]) + +optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True) +optim.load_state_dict(ckpt["optim"]) +``` + +NOTE: +- Since the optimizer step is done on CPU, it is highly recommended to use a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` (requires PyTorch 2.4). For other optimizers, you can try `torch.compile()` their optimizer step. +- To minimize the amount of CPU<->GPU data transfer, we keep a copy of parameters and pre-allocate gradients memory on CPU. Therefore, expect your RAM usage to increase by 2x model size + optimizer state (which is 2x model size for Adam). +- It is recommended NOT to `torch.compile()` your whole model when `CPUOffloadOptimizer` is used, as it prevents us from interleaving gradient device-to-host transfer with backward pass. To minimize such impact, you can compile parts of your model separately. See [#584](https://github.com/pytorch/ao/pull/584) for more information. +- CPU optimizer step is often the bottleneck when optimizer CPU offload is used. To minimize the slowdown, it is recommended to (1) do full BF16 training (instead of AMP), so that parameters, gradients, and optimizer states are in BF16; and (2) give GPU more work per optimizer step (e.g. larger batch size with activation checkpointing, gradient accumulation). +- `offload_gradients=True` is not compatible with gradient accumulation, since we clear gradients on GPU every backward pass. +- Gradient clipping is currently not supported. + +Benchmark done for `timm/vit_giant_patch14_dinov2.lvd142m` (1.1B params), eager mode, full BF16 training, activations checkpointing, batch size 32, on 4070Ti SUPER (16GB VRAM), Ryzen 5600, DDR4 RAM. DeepSpeed is untuned. + +Adam offload | Time per step | Max memory +-----------------------|---------------|------------ +None | 1.27s/it | 9.82 GB +DeepSpeed ZeRO-Offload | 3.13s/it | 6.85 GB +ao | 1.52s/it | 5.24 GB +ao (offload gradients) | 1.53s/it | 4.01 GB + +Ablations on AMP and `torch.compile()` + +Training config | Adam offload | Time per step | Max memory +--------------------|--------------|---------------|------------ +Full BF16, compiled | None | 1.18s/it | 9.90 GB +Full BF16, compiled | ao | 1.75s/it | 5.33 GB +BF16 AMP, eager | None | OOM | OOM +BF16 AMP, eager | ao | 2.18s/it | 9.90 GB + ## Credits -Credits to Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library, and [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers. +Credits to + +- Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library. +- [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers. +- [DeepSpeed](https://github.com/microsoft/DeepSpeed) team for [ZeRO-Offload](https://arxiv.org/abs/2101.06840). diff --git a/torchao/prototype/low_bit_optim/__init__.py b/torchao/prototype/low_bit_optim/__init__.py index 962726b96..01729bc6a 100644 --- a/torchao/prototype/low_bit_optim/__init__.py +++ b/torchao/prototype/low_bit_optim/__init__.py @@ -1,2 +1,3 @@ from .adam import Adam8bit, Adam4bit, AdamFp8 from .adamw import AdamW8bit, AdamW4bit, AdamWFp8 +from .cpu_offload import CPUOffloadOptimizer diff --git a/torchao/prototype/low_bit_optim/cpu_offload.py b/torchao/prototype/low_bit_optim/cpu_offload.py new file mode 100644 index 000000000..69ee4c240 --- /dev/null +++ b/torchao/prototype/low_bit_optim/cpu_offload.py @@ -0,0 +1,104 @@ +from typing import Type + +import torch +from torch.optim.optimizer import Optimizer + + +class CPUOffloadOptimizer: + def __init__(self, params, optimizer_class: Type[Optimizer], *, offload_gradients: bool = False, **kwargs) -> None: + """Offload optimizer to CPU for single-GPU training. This will reduce GPU memory by the size of optimizer state. + Optimizer step will be done on CPU. + + Args + params: a list of parameters or parameter groups. + optimizer_class: constructor of the base optimizer. + offload_gradients: free GPU gradients once they are moved to CPU. Not compatible with gradient accumulation. + kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`. + """ + param_groups = list(params) + if len(param_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + if not isinstance(param_groups[0], dict): + param_groups = [{"params": param_groups}] + + self.param_cuda2cpu_map = dict() + self.optim_dict = dict() + self.stream = torch.cuda.Stream() + + # the queue maintains the order which param we should do optim step on first. + self.queue = dict() + + def backward_hook(p_cuda): + if p_cuda.grad is not None: + p_cpu = self.param_cuda2cpu_map[p_cuda] + + # make sure backward for this param finishes + self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + p_cpu.grad.copy_(p_cuda.grad, non_blocking=True) + + # we rely on CPython implementation of dictionary, which preserves insertion order. + # if a param is added again (e.g. due to gradient accumulation), it is moved to the + # end of the queue by removing and inserting it again. + if p_cuda in self.queue: + del self.queue[p_cuda] + self.queue[p_cuda] = self.stream.record_event() + + # deallocate CUDA gradients once D2H transfer finishes. + if offload_gradients: + p_cuda.grad.record_stream(self.stream) + p_cuda.grad = None + + for param_group in param_groups: + params = param_group.pop("params") + + for p_cuda in params: + # pre-allocate CPU params and grads + p_cpu = torch.empty_like(p_cuda, device="cpu", pin_memory=True) + p_cpu.grad = torch.empty_like(p_cpu, pin_memory=True) + + p_cpu.copy_(p_cuda.detach(), non_blocking=True) + self.param_cuda2cpu_map[p_cuda] = p_cpu + + p_cuda.register_post_accumulate_grad_hook(backward_hook) + self.optim_dict[p_cuda] = optimizer_class([{"params": p_cpu, **param_group}], **kwargs) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for p_cuda, grad_d2h_event in self.queue.items(): + grad_d2h_event.synchronize() + self.optim_dict[p_cuda].step() + + # submit more job to self.stream. it guarantees that we only start + # moving param H2D once all backwards finish, since self.stream + # will wait for current_stream when moving grad D2H. + p_cpu = self.param_cuda2cpu_map[p_cuda] + with torch.cuda.stream(self.stream): + p_cuda.copy_(p_cpu, non_blocking=True) + + self.queue.clear() + return loss + + def zero_grad(self, set_to_none=True): + assert set_to_none + + # only clear CUDA grad. CPU grad will always be overwritten by CUDA grad. + for p_cuda in self.param_cuda2cpu_map.keys(): + p_cuda.grad = None + + @property + def param_groups(self): + # each param group will only has 1 parameter + # TODO: we might want to return the original param_groups instead. + return sum((optim.param_groups for optim in self.optim_dict.values()), start=[]) + + def state_dict(self): + return [optim.state_dict() for optim in self.optim_dict.values()] + + def load_state_dict(self, state_dict): + for optim, optim_state_dict in zip(self.optim_dict.values(), state_dict): + optim.load_state_dict(optim_state_dict)