From af8de677ee5624924292c52e460b0f303972ce40 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 23 Aug 2024 10:10:55 +0800 Subject: [PATCH 1/6] update benchmark script --- benchmarks/benchmark_low_bit_adam.py | 40 ++++++++++++---------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index 33efb4e52..a7b02e8c9 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -23,13 +23,13 @@ import math from contextlib import nullcontext from functools import partial -from pathlib import Path import bitsandbytes as bnb import datasets import timm import torch import torch.nn.functional as F +import wandb from torch.utils.data import DataLoader from torchvision.transforms import v2 from tqdm import tqdm @@ -72,22 +72,6 @@ def get_lr(self, step: int) -> float: return self.final_lr -class WandbLogger: - def __init__(self, args): - if args.project is not None and not args.profile: - import wandb - - Path("wandb_logs").mkdir(exist_ok=True) - self.run = wandb.init(project=args.project, name=args.run_name, config=args, dir="wandb_logs") - - else: - self.run = None - - def log(self, *args, **kwargs): - if self.run is not None: - self.run.log(*args, **kwargs) - - def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--model", required=True) @@ -190,7 +174,13 @@ def evaluate_model(model, args): print(f"{k}: {v}") # wandb is only enabled when args.project is set and args.profile is False - logger = WandbLogger(args) + logger = wandb.init( + project=args.project, + name=args.run_name, + config=args, + dir="/tmp", + mode="disabled" if args.project is not None else None, + ) dloader = get_dloader(args, True) print(f"Train dataset: {len(dloader.dataset):,} images") @@ -239,6 +229,7 @@ def evaluate_model(model, args): lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs) grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") + log_interval = 10 step = 0 for epoch_idx in range(args.n_epochs): @@ -246,6 +237,7 @@ def evaluate_model(model, args): pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}") start_time = datetime.datetime.now() + t0 = start_time with torch.profiler.profile() if args.profile else nullcontext() as prof: for batch in pbar: @@ -267,11 +259,13 @@ def evaluate_model(model, args): for param_group in optim.param_groups: param_group["lr"] = lr - if step % 100 == 0: - logger.log( - dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]), - step=step, - ) + if step % log_interval == 0: + log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]) + if step > 0: + t1 = datetime.datetime.now() + log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0).total_seconds() + t0 = t1 + logger.log(log_dict, step=step) if args.optim_cpu_offload == "deepspeed": model.step() From aecc1c40425c0a1820ecfffaf3aff37fbc758572 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 23 Aug 2024 10:31:35 +0800 Subject: [PATCH 2/6] try CPU Tensor lr --- benchmarks/benchmark_low_bit_adam.py | 12 +++++++++--- torchao/prototype/low_bit_optim/adam.py | 6 +++++- torchao/prototype/low_bit_optim/adamw.py | 6 +++++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index a7b02e8c9..9ff7127a8 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -179,7 +179,7 @@ def evaluate_model(model, args): name=args.run_name, config=args, dir="/tmp", - mode="disabled" if args.project is not None else None, + mode="disabled" if args.project is None else None, ) dloader = get_dloader(args, True) print(f"Train dataset: {len(dloader.dataset):,} images") @@ -257,7 +257,11 @@ def evaluate_model(model, args): if args.cosine_lr_scheduler: lr = lr_schedule.get_lr(step) for param_group in optim.param_groups: - param_group["lr"] = lr + if isinstance(param_group["lr"], torch.Tensor): + param_group["lr"].fill_(lr) + else: + assert False + param_group["lr"] = lr if step % log_interval == 0: log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]) @@ -289,4 +293,6 @@ def evaluate_model(model, args): print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}") logger.log(dict(val_acc=val_acc), step=step) - print(f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + peak_mem = torch.cuda.max_memory_allocated() / 1e9 + print(f"Max memory used: {peak_mem:.02f} GB") + logger.log(dict(max_memory_allocated=peak_mem)) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index a5425e984..948b922b7 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -20,7 +20,7 @@ def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size) raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) + defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) super().__init__(params, defaults) self.block_size = block_size @@ -81,6 +81,10 @@ def _prepare_param_groups(self): # practically, only lr is changed during training. # NOTE: if lr is changed at every step, moving lr to CUDA can slow down training 3-4%. if not isinstance(group["lr"], Tensor): + raise ValueError( + "lr was changed to a non-Tensor object. If you want to update lr, please use " + "optim.param_groups[0]['lr'].fill_(new_lr)" + ) group["lr"] = torch.tensor(group["lr"], device=p.device) p_grad_state = ( diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index 9d1df8e6c..77b27ea33 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -20,7 +20,7 @@ def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size) raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) + defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) super().__init__(params, defaults) self.block_size = block_size @@ -81,6 +81,10 @@ def _prepare_param_groups(self): # practically, only lr is changed during training. # NOTE: if lr is changed at every step, moving lr to CUDA can slow down training 3-4%. if not isinstance(group["lr"], Tensor): + raise ValueError( + "lr was changed to a non-Tensor object. If you want to update lr, please use " + "optim.param_groups[0]['lr'].fill_(new_lr)" + ) group["lr"] = torch.tensor(group["lr"], device=p.device) p_grad_state = ( From 55077cf908a7c69f7e427d61a24b182a949077e0 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 23 Aug 2024 11:37:21 +0800 Subject: [PATCH 3/6] update --- benchmarks/benchmark_low_bit_adam.py | 1 - torchao/prototype/low_bit_optim/adam.py | 1 - torchao/prototype/low_bit_optim/adamw.py | 1 - 3 files changed, 3 deletions(-) diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index 9ff7127a8..712af76ab 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -260,7 +260,6 @@ def evaluate_model(model, args): if isinstance(param_group["lr"], torch.Tensor): param_group["lr"].fill_(lr) else: - assert False param_group["lr"] = lr if step % log_interval == 0: diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 948b922b7..01d8aa00f 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -85,7 +85,6 @@ def _prepare_param_groups(self): "lr was changed to a non-Tensor object. If you want to update lr, please use " "optim.param_groups[0]['lr'].fill_(new_lr)" ) - group["lr"] = torch.tensor(group["lr"], device=p.device) p_grad_state = ( p, diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index 77b27ea33..3e3e4b706 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -85,7 +85,6 @@ def _prepare_param_groups(self): "lr was changed to a non-Tensor object. If you want to update lr, please use " "optim.param_groups[0]['lr'].fill_(new_lr)" ) - group["lr"] = torch.tensor(group["lr"], device=p.device) p_grad_state = ( p, From dc0e2f9b9ea062532c563483a933548f98773bd9 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 23 Aug 2024 13:02:11 +0800 Subject: [PATCH 4/6] consolidate adam --- torchao/prototype/low_bit_optim/__init__.py | 3 +- torchao/prototype/low_bit_optim/adam.py | 151 +++++++++-- torchao/prototype/low_bit_optim/adamw.py | 268 -------------------- 3 files changed, 136 insertions(+), 286 deletions(-) delete mode 100644 torchao/prototype/low_bit_optim/adamw.py diff --git a/torchao/prototype/low_bit_optim/__init__.py b/torchao/prototype/low_bit_optim/__init__.py index 5e9cc50c6..4ad75d4ab 100644 --- a/torchao/prototype/low_bit_optim/__init__.py +++ b/torchao/prototype/low_bit_optim/__init__.py @@ -1,3 +1,2 @@ -from .adam import Adam8bit, Adam4bit, AdamFp8 -from .adamw import _AdamW, AdamW8bit, AdamW4bit, AdamWFp8 +from .adam import Adam4bit, Adam8bit, AdamFp8, AdamW4bit, AdamW8bit, AdamWFp8, _AdamW from .cpu_offload import CPUOffloadOptimizer diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 01d8aa00f..4b0b29534 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -11,7 +11,7 @@ class _AdamBase(Optimizer): - def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size) -> None: + def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size, is_adamw) -> None: if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -23,6 +23,7 @@ def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size) defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) super().__init__(params, defaults) self.block_size = block_size + self.is_adamw = is_adamw def __setstate__(self, state): super().__setstate__(state) @@ -69,7 +70,7 @@ def _prepare_param_groups(self): # State initialization if len(state) == 0: - state["step"] = torch.tensor(0.0, device=p.device) + state["step"] = torch.tensor(0.0) state["exp_avg"] = self._new_buffer(p, True) state["exp_avg_sq"] = self._new_buffer(p, False) if group["amsgrad"]: @@ -77,11 +78,8 @@ def _prepare_param_groups(self): state["step"] += 1 - # must explicitly convert lr to Tensor since torch.compile() will treat Python float as constant. - # practically, only lr is changed during training. - # NOTE: if lr is changed at every step, moving lr to CUDA can slow down training 3-4%. if not isinstance(group["lr"], Tensor): - raise ValueError( + raise RuntimeError( "lr was changed to a non-Tensor object. If you want to update lr, please use " "optim.param_groups[0]['lr'].fill_(new_lr)" ) @@ -110,14 +108,16 @@ def step(self, closure=None): param_groups = self._prepare_param_groups() # static compile optim step for all params in a single graph - torch.compile(param_groups_adam, fullgraph=True)(param_groups) + torch.compile(param_groups_adam, fullgraph=True)(param_groups, self.is_adamw) return loss -def param_groups_adam(param_groups): +def param_groups_adam(param_groups, is_adamw): for group, lr, (beta1, beta2), weight_decay, eps in param_groups: for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group: - single_param_adam(p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, weight_decay, eps) + single_param_adam( + p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, weight_decay, eps, is_adamw + ) # this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default @@ -133,12 +133,13 @@ def single_param_adam( beta2: float, weight_decay: float, eps: float, + is_adamw: bool, ): - if weight_decay != 0: + if not is_adamw: grad = grad.add(p, alpha=weight_decay) - bias_correction1 = 1 - beta1 ** step - bias_correction2 = 1 - beta2 ** step + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step # keep high precision copy for param update new_exp_avg = exp_avg.lerp(grad, 1 - beta1) @@ -155,7 +156,11 @@ def single_param_adam( denom = (new_exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps) step_size = lr / bias_correction1 - p.addcdiv_(new_exp_avg, denom, value=-step_size) + if is_adamw: + # merge weight decay and param update in a single .add_() to make this work with quantized param + p.add_(-lr * weight_decay * p - step_size * new_exp_avg / denom) + else: + p.addcdiv_(new_exp_avg, denom, value=-step_size) class Adam8bit(_AdamBase): @@ -170,7 +175,7 @@ def __init__( *, block_size=2048, ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False) @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -189,7 +194,7 @@ def __init__( *, block_size=128, ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False) @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): @@ -230,6 +235,7 @@ def step(self, closure=None): beta2, weight_decay, eps, + self.is_adamw, ) p.requires_grad_(True) @@ -248,8 +254,121 @@ def __init__( *, block_size=2048, ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False) @staticmethod def _subclass_zeros(p: Tensor, signed: bool, block_size: int): return OptimStateFp8.zeros(p.shape, block_size, p.device) + + +class AdamW8bit(_AdamBase): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + *, + block_size=2048, + ) -> None: + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True) + + @staticmethod + def _subclass_zeros(p: Tensor, signed: bool, block_size: int): + return OptimState8bit.zeros(p.shape, signed, block_size, p.device) + + +class AdamW4bit(_AdamBase): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + *, + block_size=128, + ) -> None: + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True) + + @staticmethod + def _subclass_zeros(p: Tensor, signed: bool, block_size: int): + return OptimState4bit.zeros(p.shape, signed, block_size, p.device) + + @staticmethod + def _unwrap_dtensor(p: Tensor): + return p._local_tensor if isinstance(p, DTensor) else p + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + param_groups = self._prepare_param_groups() + + # NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim. + # thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param. + + # unwrap DTensor since DTensor does not work well with dynamic compile + # flatten p, grad, and optim state to avoid recompilation + for group, lr, (beta1, beta2), weight_decay, eps in param_groups: + for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group: + # DTensor._local_tensor has .requires_grad = False + # to avoid recompilation, set p.requires_grad = False and restore it after optim step + p.requires_grad_(False) + torch.compile(single_param_adam, fullgraph=True, dynamic=True)( + self._unwrap_dtensor(p).view(-1), + self._unwrap_dtensor(grad).view(-1), + step, + self._unwrap_dtensor(exp_avg).view(-1), + self._unwrap_dtensor(exp_avg_sq).view(-1), + self._unwrap_dtensor(max_exp_avg_sq).view(-1) if max_exp_avg_sq is not None else None, + lr, + beta1, + beta2, + weight_decay, + eps, + self.is_adamw, + ) + p.requires_grad_(True) + + return loss + + +class AdamWFp8(_AdamBase): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + *, + block_size=2048, + ) -> None: + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True) + + @staticmethod + def _subclass_zeros(p: Tensor, signed: bool, block_size: int): + return OptimStateFp8.zeros(p.shape, block_size, p.device) + + +class _AdamW(_AdamBase): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + ) -> None: + """AdamW optimizer that supports quantized training (parameter is quantized). This optimizer should + only be used with torchao's quantized training.""" + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=float("inf"), is_adamw=True) diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py deleted file mode 100644 index 3e3e4b706..000000000 --- a/torchao/prototype/low_bit_optim/adamw.py +++ /dev/null @@ -1,268 +0,0 @@ -from typing import Optional - -import torch -from torch import Tensor -from torch.optim import Optimizer -from torch.distributed._tensor import DTensor - -from .subclass_8bit import OptimState8bit -from .subclass_4bit import OptimState4bit -from .subclass_fp8 import OptimStateFp8 - - -class _AdamWBase(Optimizer): - def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size) -> None: - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) - super().__init__(params, defaults) - self.block_size = block_size - - def __setstate__(self, state): - super().__setstate__(state) - for group in self.param_groups: - group.setdefault("amsgrad", False) - - # bring your own function to create zero-filled subclass - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - raise NotImplementedError - - # follow bitsandbytes, only quantize tensors >= 4096 values - # also wrap subclass in DTensor when needed - def _new_buffer(self, p: Tensor, signed: bool): - if p.numel() >= 4096 and p.numel() % self.block_size == 0: - if isinstance(p, DTensor): - out = DTensor.from_local( - local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size), - device_mesh=p.device_mesh, - placements=p.placements, - run_check=False, - ) - else: - out = self._subclass_zeros(p, signed, self.block_size) - else: - out = torch.zeros_like(p) - return out - - def _prepare_param_groups(self): - param_groups = [] - - for group in self.param_groups: - _group = [] - - for p in group["params"]: - if p.grad is None: - continue - - grad = p.grad - if grad.is_sparse: - raise RuntimeError("Sparse gradient is not supported") - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = torch.tensor(0.0, device=p.device) - state["exp_avg"] = self._new_buffer(p, True) - state["exp_avg_sq"] = self._new_buffer(p, False) - if group["amsgrad"]: - state["max_exp_avg_sq"] = self._new_buffer(p, False) - - state["step"] += 1 - - # must explicitly convert lr to Tensor since torch.compile() will treat Python float as constant. - # practically, only lr is changed during training. - # NOTE: if lr is changed at every step, moving lr to CUDA can slow down training 3-4%. - if not isinstance(group["lr"], Tensor): - raise ValueError( - "lr was changed to a non-Tensor object. If you want to update lr, please use " - "optim.param_groups[0]['lr'].fill_(new_lr)" - ) - - p_grad_state = ( - p, - grad, - state["step"], - state["exp_avg"], - state["exp_avg_sq"], - state.get("max_exp_avg_sq", None), - ) - _group.append(p_grad_state) - - param_groups.append((_group, group["lr"], group["betas"], group["weight_decay"], group["eps"])) - - return param_groups - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - param_groups = self._prepare_param_groups() - - # static compile optim step for all params in a single graph - torch.compile(param_groups_adamw, fullgraph=True)(param_groups) - return loss - - -def param_groups_adamw(param_groups): - for group, lr, (beta1, beta2), weight_decay, eps in param_groups: - for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group: - single_param_adamw(p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq, lr, beta1, beta2, weight_decay, eps) - - -# this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default -def single_param_adamw( - p: Tensor, - grad: Tensor, - step: Tensor, - exp_avg: Tensor, - exp_avg_sq: Tensor, - max_exp_avg_sq: Optional[Tensor], - lr: Tensor, - beta1: float, - beta2: float, - weight_decay: float, - eps: float, -): - bias_correction1 = 1 - beta1 ** step - bias_correction2 = 1 - beta2 ** step - - # keep high precision copy for param update - new_exp_avg = exp_avg.lerp(grad, 1 - beta1) - new_exp_avg_sq = exp_avg_sq.lerp(grad.square(), 1 - beta2) - - exp_avg.copy_(new_exp_avg) - exp_avg_sq.copy_(new_exp_avg_sq) - - if max_exp_avg_sq is not None: - new_max_exp_avg_sq = torch.maximum(max_exp_avg_sq, new_exp_avg_sq) - max_exp_avg_sq.copy_(new_max_exp_avg_sq) - denom = (new_max_exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps) - else: - denom = (new_exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps) - - # merge weight decay and param update in a single .add_() to make this work with quantized param - step_size = lr / bias_correction1 - p.add_(-lr * weight_decay * p - step_size * new_exp_avg / denom) - - -class _AdamW(_AdamWBase): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - ) -> None: - """AdamW optimizer that supports quantized training (parameter is quantized). This optimizer should - only be used with torchao's quantized training.""" - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=float("inf")) - - -class AdamW8bit(_AdamWBase): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - *, - block_size=2048, - ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) - - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState8bit.zeros(p.shape, signed, block_size, p.device) - - -class AdamW4bit(_AdamWBase): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - *, - block_size=128, - ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) - - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimState4bit.zeros(p.shape, signed, block_size, p.device) - - @staticmethod - def _unwrap_dtensor(p: Tensor): - return p._local_tensor if isinstance(p, DTensor) else p - - @torch.no_grad() - def step(self, closure=None): - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - param_groups = self._prepare_param_groups() - - # NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim. - # thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param. - - # unwrap DTensor since DTensor does not work well with dynamic compile - # flatten p, grad, and optim state to avoid recompilation - for group, lr, (beta1, beta2), weight_decay, eps in param_groups: - for p, grad, step, exp_avg, exp_avg_sq, max_exp_avg_sq in group: - # DTensor._local_tensor has .requires_grad = False - # to avoid recompilation, set p.requires_grad = False and restore it after optim step - p.requires_grad_(False) - torch.compile(single_param_adamw, fullgraph=True, dynamic=True)( - self._unwrap_dtensor(p).view(-1), - self._unwrap_dtensor(grad).view(-1), - step, - self._unwrap_dtensor(exp_avg).view(-1), - self._unwrap_dtensor(exp_avg_sq).view(-1), - self._unwrap_dtensor(max_exp_avg_sq).view(-1) if max_exp_avg_sq is not None else None, - lr, - beta1, - beta2, - weight_decay, - eps, - ) - p.requires_grad_(True) - - return loss - - -class AdamWFp8(_AdamWBase): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - *, - block_size=2048, - ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) - - @staticmethod - def _subclass_zeros(p: Tensor, signed: bool, block_size: int): - return OptimStateFp8.zeros(p.shape, block_size, p.device) From 87d3be77cce9e29fd56db91907983d73fda2b55b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 23 Aug 2024 14:22:54 +0800 Subject: [PATCH 5/6] remove unnecessary requires_grad in subclass --- torchao/prototype/low_bit_optim/subclass_4bit.py | 7 +------ torchao/prototype/low_bit_optim/subclass_8bit.py | 7 +------ torchao/prototype/low_bit_optim/subclass_fp8.py | 7 +------ 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 5e02a5e04..5c83d8377 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -24,12 +24,7 @@ class OptimState4bit(Tensor): @staticmethod def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape): - return Tensor._make_wrapper_subclass( - cls, - shape, - device=codes.device, - requires_grad=False, - ) + return Tensor._make_wrapper_subclass(cls, shape, device=codes.device) def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape): """Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507 diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 77459a2a3..9a4f54f71 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -19,12 +19,7 @@ class OptimState8bit(Tensor): @staticmethod def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): - return Tensor._make_wrapper_subclass( - cls, - codes.shape, - device=codes.device, - requires_grad=False, - ) + return Tensor._make_wrapper_subclass(cls, codes.shape, device=codes.device) def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): """Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861 diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index ee97fffc7..883d33118 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -27,12 +27,7 @@ class OptimStateFp8(Tensor): @staticmethod def __new__(cls, codes: Tensor, scale: Tensor): - return Tensor._make_wrapper_subclass( - cls, - codes.shape, - device=codes.device, - requires_grad=False, - ) + return Tensor._make_wrapper_subclass(cls, codes.shape, device=codes.device) def __init__(self, codes: Tensor, scale: Tensor): """Create quantized FP8 optimizer state. From cdf283cec3f4e3120c4902c05b3edf72af8304dd Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 23 Aug 2024 14:56:24 +0800 Subject: [PATCH 6/6] update README --- benchmarks/benchmark_low_bit_adam.py | 12 ++++-------- torchao/prototype/low_bit_optim/README.md | 24 +++++++++++------------ 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index 712af76ab..d9f03a88b 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -18,9 +18,9 @@ # To enable cosine learning rate scheduler, set --cosine_lr_scheduler import argparse -import datetime import json import math +import time from contextlib import nullcontext from functools import partial @@ -230,15 +230,13 @@ def evaluate_model(model, args): lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs) grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") log_interval = 10 + t0 = time.perf_counter() step = 0 for epoch_idx in range(args.n_epochs): model.train() pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}") - start_time = datetime.datetime.now() - t0 = start_time - with torch.profiler.profile() if args.profile else nullcontext() as prof: for batch in pbar: if args.full_bf16: @@ -265,8 +263,8 @@ def evaluate_model(model, args): if step % log_interval == 0: log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]) if step > 0: - t1 = datetime.datetime.now() - log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0).total_seconds() + t1 = time.perf_counter() + log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0) t0 = t1 logger.log(log_dict, step=step) @@ -286,8 +284,6 @@ def evaluate_model(model, args): prof.export_chrome_trace("trace.json") else: - print(f"Time taken for epoch {epoch_idx + 1}: {(datetime.datetime.now() - start_time)}") - val_acc = evaluate_model(model, args) print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}") logger.log(dict(val_acc=val_acc), step=step) diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index f307be092..5968b2a79 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -26,23 +26,23 @@ To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. Y NOTE: - The low-bit optimizers require PyTorch >= 2.3. FP8 optimizers require CUDA compute capability >= 8.9. - For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. -- **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed. +- The first training step is expected to be slow since the optimizer needs to be compiled. ## Benchmarks Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py). -Results for fine-tuning ViT-H (630M params) with BF16 AMP for 2 epochs, batch size 8, on 4070Ti SUPER, with fixed random seed: - -Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy ----------------|-----------------|--------------------------|---------- -PyTorch | 12.94 | 8m 18s | 91.14 -bnb 8-bit | 8.31 | 6m 50s | 90.67 -ao 8-bit | 8.31 | 6m 44s | 90.63 -ao FP8 E4M3 | 8.32 | 6m 35s | 90.98 -lpmm 4-bit | 7.72 | 5m 59s | 89.97 -ao 4-bit | 7.72 | 7m 13s | 90.05 -lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71 +Results for fine-tuning ViT-H (630M params) with BF16 AMP for 1 epoch, batch size 8, cosine LR scheduler, 4070Ti SUPER, fixed random seed: + +Adam impl | max memory (GB) | imgs/s | accuracy +----------------|-----------------|--------|---------- +PyTorch (fused) | 12.23 | 41.8 | 94.38 +bnb 8-bit | 8.32 | 43.6 | 94.18 +ao 8-bit | 8.33 | 42.6 | 94.25 +ao FP8 E4M3 | 9.27 | 44.1 | 94.40 +lpmm 4-bit | 7.72 | 46.0 | 94.29 +ao 4-bit | 7.72 | 40.0 | 94.03 +lpmm 4-bit (*) | 7.74 | 26.6 | 94.25 (*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details.