Skip to content

Commit

Permalink
Fix LR schedule handling for low-bit optimizers (#736)
Browse files Browse the repository at this point in the history
* update benchmark script

* try CPU Tensor lr

* update

* consolidate adam

* remove unnecessary requires_grad in subclass

* update README
  • Loading branch information
gau-nernst committed Aug 23, 2024
1 parent d26bcca commit aacaf9b
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 345 deletions.
57 changes: 26 additions & 31 deletions benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler

import argparse
import datetime
import json
import math
import time
from contextlib import nullcontext
from functools import partial
from pathlib import Path

import bitsandbytes as bnb
import datasets
import timm
import torch
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from tqdm import tqdm
Expand Down Expand Up @@ -72,22 +72,6 @@ def get_lr(self, step: int) -> float:
return self.final_lr


class WandbLogger:
def __init__(self, args):
if args.project is not None and not args.profile:
import wandb

Path("wandb_logs").mkdir(exist_ok=True)
self.run = wandb.init(project=args.project, name=args.run_name, config=args, dir="wandb_logs")

else:
self.run = None

def log(self, *args, **kwargs):
if self.run is not None:
self.run.log(*args, **kwargs)


def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
Expand Down Expand Up @@ -190,7 +174,13 @@ def evaluate_model(model, args):
print(f"{k}: {v}")

# wandb is only enabled when args.project is set and args.profile is False
logger = WandbLogger(args)
logger = wandb.init(
project=args.project,
name=args.run_name,
config=args,
dir="/tmp",
mode="disabled" if args.project is None else None,
)
dloader = get_dloader(args, True)
print(f"Train dataset: {len(dloader.dataset):,} images")

Expand Down Expand Up @@ -239,14 +229,14 @@ def evaluate_model(model, args):

lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
log_interval = 10
t0 = time.perf_counter()

step = 0
for epoch_idx in range(args.n_epochs):
model.train()
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")

start_time = datetime.datetime.now()

with torch.profiler.profile() if args.profile else nullcontext() as prof:
for batch in pbar:
if args.full_bf16:
Expand All @@ -265,13 +255,18 @@ def evaluate_model(model, args):
if args.cosine_lr_scheduler:
lr = lr_schedule.get_lr(step)
for param_group in optim.param_groups:
param_group["lr"] = lr

if step % 100 == 0:
logger.log(
dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]),
step=step,
)
if isinstance(param_group["lr"], torch.Tensor):
param_group["lr"].fill_(lr)
else:
param_group["lr"] = lr

if step % log_interval == 0:
log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"])
if step > 0:
t1 = time.perf_counter()
log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0)
t0 = t1
logger.log(log_dict, step=step)

if args.optim_cpu_offload == "deepspeed":
model.step()
Expand All @@ -289,10 +284,10 @@ def evaluate_model(model, args):
prof.export_chrome_trace("trace.json")

else:
print(f"Time taken for epoch {epoch_idx + 1}: {(datetime.datetime.now() - start_time)}")

val_acc = evaluate_model(model, args)
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
logger.log(dict(val_acc=val_acc), step=step)

print(f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
peak_mem = torch.cuda.max_memory_allocated() / 1e9
print(f"Max memory used: {peak_mem:.02f} GB")
logger.log(dict(max_memory_allocated=peak_mem))
24 changes: 12 additions & 12 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@ To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. Y
NOTE:
- The low-bit optimizers require PyTorch >= 2.3. FP8 optimizers require CUDA compute capability >= 8.9.
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
- **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed.
- The first training step is expected to be slow since the optimizer needs to be compiled.

## Benchmarks

Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).

Results for fine-tuning ViT-H (630M params) with BF16 AMP for 2 epochs, batch size 8, on 4070Ti SUPER, with fixed random seed:

Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy
---------------|-----------------|--------------------------|----------
PyTorch | 12.94 | 8m 18s | 91.14
bnb 8-bit | 8.31 | 6m 50s | 90.67
ao 8-bit | 8.31 | 6m 44s | 90.63
ao FP8 E4M3 | 8.32 | 6m 35s | 90.98
lpmm 4-bit | 7.72 | 5m 59s | 89.97
ao 4-bit | 7.72 | 7m 13s | 90.05
lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71
Results for fine-tuning ViT-H (630M params) with BF16 AMP for 1 epoch, batch size 8, cosine LR scheduler, 4070Ti SUPER, fixed random seed:

Adam impl | max memory (GB) | imgs/s | accuracy
----------------|-----------------|--------|----------
PyTorch (fused) | 12.23 | 41.8 | 94.38
bnb 8-bit | 8.32 | 43.6 | 94.18
ao 8-bit | 8.33 | 42.6 | 94.25
ao FP8 E4M3 | 9.27 | 44.1 | 94.40
lpmm 4-bit | 7.72 | 46.0 | 94.29
ao 4-bit | 7.72 | 40.0 | 94.03
lpmm 4-bit (*) | 7.74 | 26.6 | 94.25

(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details.

Expand Down
3 changes: 1 addition & 2 deletions torchao/prototype/low_bit_optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .adam import Adam8bit, Adam4bit, AdamFp8
from .adamw import _AdamW, AdamW8bit, AdamW4bit, AdamWFp8
from .adam import Adam4bit, Adam8bit, AdamFp8, AdamW4bit, AdamW8bit, AdamWFp8, _AdamW
from .cpu_offload import CPUOffloadOptimizer
Loading

0 comments on commit aacaf9b

Please sign in to comment.