Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LR schedule handling for low-bit optimizers #736

Merged
merged 7 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 26 additions & 31 deletions benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler

import argparse
import datetime
import json
import math
import time
from contextlib import nullcontext
from functools import partial
from pathlib import Path

import bitsandbytes as bnb
import datasets
import timm
import torch
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from tqdm import tqdm
Expand Down Expand Up @@ -72,22 +72,6 @@ def get_lr(self, step: int) -> float:
return self.final_lr


class WandbLogger:
def __init__(self, args):
if args.project is not None and not args.profile:
import wandb

Path("wandb_logs").mkdir(exist_ok=True)
self.run = wandb.init(project=args.project, name=args.run_name, config=args, dir="wandb_logs")

else:
self.run = None

def log(self, *args, **kwargs):
if self.run is not None:
self.run.log(*args, **kwargs)


def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
Expand Down Expand Up @@ -190,7 +174,13 @@ def evaluate_model(model, args):
print(f"{k}: {v}")

# wandb is only enabled when args.project is set and args.profile is False
logger = WandbLogger(args)
logger = wandb.init(
project=args.project,
name=args.run_name,
config=args,
dir="/tmp",
mode="disabled" if args.project is None else None,
)
dloader = get_dloader(args, True)
print(f"Train dataset: {len(dloader.dataset):,} images")

Expand Down Expand Up @@ -239,14 +229,14 @@ def evaluate_model(model, args):

lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
log_interval = 10
t0 = time.perf_counter()

step = 0
for epoch_idx in range(args.n_epochs):
model.train()
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")

start_time = datetime.datetime.now()

with torch.profiler.profile() if args.profile else nullcontext() as prof:
for batch in pbar:
if args.full_bf16:
Expand All @@ -265,13 +255,18 @@ def evaluate_model(model, args):
if args.cosine_lr_scheduler:
lr = lr_schedule.get_lr(step)
for param_group in optim.param_groups:
param_group["lr"] = lr

if step % 100 == 0:
logger.log(
dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]),
step=step,
)
if isinstance(param_group["lr"], torch.Tensor):
param_group["lr"].fill_(lr)
else:
param_group["lr"] = lr

if step % log_interval == 0:
log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"])
if step > 0:
t1 = time.perf_counter()
log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0)
t0 = t1
logger.log(log_dict, step=step)

if args.optim_cpu_offload == "deepspeed":
model.step()
Expand All @@ -289,10 +284,10 @@ def evaluate_model(model, args):
prof.export_chrome_trace("trace.json")

else:
print(f"Time taken for epoch {epoch_idx + 1}: {(datetime.datetime.now() - start_time)}")

val_acc = evaluate_model(model, args)
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
logger.log(dict(val_acc=val_acc), step=step)

print(f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
peak_mem = torch.cuda.max_memory_allocated() / 1e9
print(f"Max memory used: {peak_mem:.02f} GB")
logger.log(dict(max_memory_allocated=peak_mem))
24 changes: 12 additions & 12 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@ To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. Y
NOTE:
- The low-bit optimizers require PyTorch >= 2.3. FP8 optimizers require CUDA compute capability >= 8.9.
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
- **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed.
- The first training step is expected to be slow since the optimizer needs to be compiled.

## Benchmarks

Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).

Results for fine-tuning ViT-H (630M params) with BF16 AMP for 2 epochs, batch size 8, on 4070Ti SUPER, with fixed random seed:

Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy
---------------|-----------------|--------------------------|----------
PyTorch | 12.94 | 8m 18s | 91.14
bnb 8-bit | 8.31 | 6m 50s | 90.67
ao 8-bit | 8.31 | 6m 44s | 90.63
ao FP8 E4M3 | 8.32 | 6m 35s | 90.98
lpmm 4-bit | 7.72 | 5m 59s | 89.97
ao 4-bit | 7.72 | 7m 13s | 90.05
lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71
Results for fine-tuning ViT-H (630M params) with BF16 AMP for 1 epoch, batch size 8, cosine LR scheduler, 4070Ti SUPER, fixed random seed:

Adam impl | max memory (GB) | imgs/s | accuracy
----------------|-----------------|--------|----------
PyTorch (fused) | 12.23 | 41.8 | 94.38
bnb 8-bit | 8.32 | 43.6 | 94.18
ao 8-bit | 8.33 | 42.6 | 94.25
ao FP8 E4M3 | 9.27 | 44.1 | 94.40
Copy link
Collaborator Author

@gau-nernst gau-nernst Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP8 max memory is ~1GB higher than expected. I re-ran the benchmark on main branch (without this PR) and the max memory for FP8 is the same. I'm suspecting something funny happening with torch.compile. The benchmark was done with 2.5.0.dev20240820. I don't think it's a big issue, especially since FP8 optimizer is not popular yet (may change in the future though 👀). Re-running with 2.4 now. (probably won't re-run the rest with 2.4 since I'm lazy)

(Accuracy is much better than before across the board thanks to cosine LR scheduler)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP8 optimizer

PyTorch version max mem (GB) imgs/s acc
2.4.0 9.04 42.5 94.18
2.5.0.dev20240820 9.27 44.1 94.40

Definitely something funny with newer torch.compile 🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On 2.3.1, Triton refused to run due to compute capability (though my GPU should support - 2.4.0 and nightly are fine)

Conversion from/to f8e4m3nv is only supported on compute capability >= 90Conversion from/to f8e4m3nv is only supported on compute capability >= 90

UNREACHABLE executed at ../../../lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp:823!

In the original benchmark numbers, FP8 optim was good (max memory is the same as 8-bit optim). But I don't rmb which PyTorch version I used back then 😅.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just compiling the optimizer. There were some larger changes to the min cut partioner that affect how we split forward + bwd graphs ( determine what to recompute vs save). These changes show up more in fp8 since you typically have long chains of ops to dequant or quant. But if there is no fwd/bwd I am not totally sure what might be happening

lpmm 4-bit | 7.72 | 46.0 | 94.29
ao 4-bit | 7.72 | 40.0 | 94.03
lpmm 4-bit (*) | 7.74 | 26.6 | 94.25

(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details.

Expand Down
3 changes: 1 addition & 2 deletions torchao/prototype/low_bit_optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .adam import Adam8bit, Adam4bit, AdamFp8
from .adamw import _AdamW, AdamW8bit, AdamW4bit, AdamWFp8
from .adam import Adam4bit, Adam8bit, AdamFp8, AdamW4bit, AdamW8bit, AdamWFp8, _AdamW
from .cpu_offload import CPUOffloadOptimizer
Loading
Loading