Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LR schedule handling for low-bit optimizers #736

Merged
merged 7 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 25 additions & 26 deletions benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
import math
from contextlib import nullcontext
from functools import partial
from pathlib import Path

import bitsandbytes as bnb
import datasets
import timm
import torch
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from tqdm import tqdm
Expand Down Expand Up @@ -72,22 +72,6 @@ def get_lr(self, step: int) -> float:
return self.final_lr


class WandbLogger:
def __init__(self, args):
if args.project is not None and not args.profile:
import wandb

Path("wandb_logs").mkdir(exist_ok=True)
self.run = wandb.init(project=args.project, name=args.run_name, config=args, dir="wandb_logs")

else:
self.run = None

def log(self, *args, **kwargs):
if self.run is not None:
self.run.log(*args, **kwargs)


def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
Expand Down Expand Up @@ -190,7 +174,13 @@ def evaluate_model(model, args):
print(f"{k}: {v}")

# wandb is only enabled when args.project is set and args.profile is False
logger = WandbLogger(args)
logger = wandb.init(
project=args.project,
name=args.run_name,
config=args,
dir="/tmp",
mode="disabled" if args.project is None else None,
)
dloader = get_dloader(args, True)
print(f"Train dataset: {len(dloader.dataset):,} images")

Expand Down Expand Up @@ -239,13 +229,15 @@ def evaluate_model(model, args):

lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
log_interval = 10

step = 0
for epoch_idx in range(args.n_epochs):
model.train()
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")

start_time = datetime.datetime.now()
t0 = start_time

with torch.profiler.profile() if args.profile else nullcontext() as prof:
for batch in pbar:
Expand All @@ -265,13 +257,18 @@ def evaluate_model(model, args):
if args.cosine_lr_scheduler:
lr = lr_schedule.get_lr(step)
for param_group in optim.param_groups:
param_group["lr"] = lr

if step % 100 == 0:
logger.log(
dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]),
step=step,
)
if isinstance(param_group["lr"], torch.Tensor):
param_group["lr"].fill_(lr)
else:
param_group["lr"] = lr

if step % log_interval == 0:
log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"])
if step > 0:
t1 = datetime.datetime.now()
log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0).total_seconds()
t0 = t1
logger.log(log_dict, step=step)

if args.optim_cpu_offload == "deepspeed":
model.step()
Expand All @@ -295,4 +292,6 @@ def evaluate_model(model, args):
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
logger.log(dict(val_acc=val_acc), step=step)

print(f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
peak_mem = torch.cuda.max_memory_allocated() / 1e9
print(f"Max memory used: {peak_mem:.02f} GB")
logger.log(dict(max_memory_allocated=peak_mem))
7 changes: 5 additions & 2 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size)
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the right place to set it? i would advocate for making this very obvious with users what is going on, and i would modify the AdamW constructor below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless you’re adamant that this optimizer will only support Tensor lrs. In pytorch/pytorch, we support python float lrs in eager because it is faster to compute python math than launch kernels, though that may be less relevant here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative I’d suggest for max visibility from user is changing line 165 below to be torch.tensor(1e-3)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think to keep it simple, we just enforce lr to be a scalar tensor here. If lr is not changed during training, whether it is a tensor or not would not matter. But if it is changed during training, we need lr to be a tensor anyway since torch.compile will recompile when python float lr changes value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, but I still think it’d be most visible/clear to users if the constructors of Adam and AdamW clearly set the default as Tensors and if this base would just error if lr was not a Tensor.

Copy link
Collaborator Author

@gau-nernst gau-nernst Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, do you mean that forcing users to pass in LR as Tensor instead of Python float? i.e. error is raised if LR is not a Tensor.
And is it your view that converting Python float to Tensor in the constructor might seem unexpected to the users?

Personally when using optimizers, I never pass in Tensor LR before, so it feels strange to me 😅 (doesn't mean I'm correct, just a feeling from my limited experience). I think that converting LR from float to Tensor inside the constructor is an implementation detail that the users shouldn't need to care about.
Also, most (if not all?) other optimizers will work if I pass in a Python float LR? So feel kinda strange (again 🤣) if users are forced to pass in Tensor LR to this particular optimizer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it'd be good to have users pass in Tensor lrs, so they're most aware of what is going on. I think it is not great to switch it up under the hood and have the user be confused if there's ever an error regarding the Tensorness of the lr.

super().__init__(params, defaults)
self.block_size = block_size

Expand Down Expand Up @@ -81,7 +81,10 @@ def _prepare_param_groups(self):
# practically, only lr is changed during training.
# NOTE: if lr is changed at every step, moving lr to CUDA can slow down training 3-4%.
if not isinstance(group["lr"], Tensor):
group["lr"] = torch.tensor(group["lr"], device=p.device)
raise ValueError(
"lr was changed to a non-Tensor object. If you want to update lr, please use "
"optim.param_groups[0]['lr'].fill_(new_lr)"
)

p_grad_state = (
p,
Expand Down
7 changes: 5 additions & 2 deletions torchao/prototype/low_bit_optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size)
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
super().__init__(params, defaults)
self.block_size = block_size

Expand Down Expand Up @@ -81,7 +81,10 @@ def _prepare_param_groups(self):
# practically, only lr is changed during training.
# NOTE: if lr is changed at every step, moving lr to CUDA can slow down training 3-4%.
if not isinstance(group["lr"], Tensor):
group["lr"] = torch.tensor(group["lr"], device=p.device)
raise ValueError(
"lr was changed to a non-Tensor object. If you want to update lr, please use "
"optim.param_groups[0]['lr'].fill_(new_lr)"
)

p_grad_state = (
p,
Expand Down
Loading