Skip to content

Commit

Permalink
Update low-bit Adam benchmark (#481)
Browse files Browse the repository at this point in the history
* update benchmark

* add rank1 option to lpmm

* add comma

* update readme

* remove unwanted file

* update
  • Loading branch information
gau-nernst authored Jul 6, 2024
1 parent 34fedff commit 56d46a2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
11 changes: 8 additions & 3 deletions benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Adam8bitAo=Adam8bit,
Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True),
Adam4bitAo=Adam4bit,
Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")),
)


Expand Down Expand Up @@ -92,6 +93,7 @@ def get_parser():
parser.add_argument("--project")
parser.add_argument("--run_name", default="debug")
parser.add_argument("--profile", action="store_true")
parser.add_argument("--seed", type=int)
return parser


Expand Down Expand Up @@ -155,6 +157,8 @@ def evaluate_model(model, args):

if args.profile:
args.n_epochs = 1
if args.seed is not None:
torch.manual_seed(args.seed)

for k, v in vars(args).items():
print(f"{k}: {v}")
Expand All @@ -176,11 +180,11 @@ def evaluate_model(model, args):

grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")

start_time = datetime.datetime.now()
step = 0
for epoch_idx in range(args.n_epochs):
model.train()
prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext()
start_time = datetime.datetime.now()

with prof:
for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"):
Expand Down Expand Up @@ -212,9 +216,10 @@ def evaluate_model(model, args):
prof.export_chrome_trace("trace.json")

else:
print(f"Time taken for epoch {epoch_idx + 1}: {(datetime.datetime.now() - start_time)}")

val_acc = evaluate_model(model, args)
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
logger.log(dict(val_acc=val_acc), step=step)

print(f"Time taken: {(datetime.datetime.now() - start_time)}")
print(f"Max used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
print(f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
23 changes: 12 additions & 11 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,18 @@ NOTE:

Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).

Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER:

Adam impl | max memory (GB) | time taken | accuracy
-----------|-----------------|------------|----------
PyTorch | 12.98 | 10m 08s | 87.70
bnb 8-bit | 8.31 | 8m 38s | 86.22
ao 8-bit | 8.32 | 10m 54s | 86.67
lpmm 4-bit | 7.72 | 7m 48s | 84.70
ao 4-bit | 7.72 | 9m 17s | 85.60

NOTE: time taken includes validation time, and compile time for torchao optimizers.
Results for fine-tuning ViT-H (630M params) with BF16 AMP for 2 epochs, batch size 8, on 4070Ti SUPER, with fixed random seed:

Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy
---------------|-----------------|--------------------------|----------
PyTorch | 12.94 | 8m 18s | 91.14
bnb 8-bit | 8.31 | 6m 50s | 90.67
ao 8-bit | 8.32 | 9m 04s | 90.71
lpmm 4-bit | 7.72 | 5m 59s | 89.97
ao 4-bit | 7.72 | 7m 00s | 89.94
lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71

(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details.

## Credits

Expand Down

0 comments on commit 56d46a2

Please sign in to comment.