Skip to content

Commit

Permalink
[Low-bit optim] Add Llama2-7B finetune benchmarks (pytorch#746)
Browse files Browse the repository at this point in the history
* add Llama3.1-8B finetune bench

* update doc

* Update README.md

---------

Co-authored-by: Mark Saroufim <marksaroufim@gmail.com>
  • Loading branch information
2 people authored and jerryzh168 committed Sep 4, 2024
1 parent fa78b5b commit 97274ad
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ NOTE:

## Benchmarks

Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).
Fine-tune [timm](https://github.com/huggingface/pytorch-image-models)'s ViT-H (630M params) on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset. BF16 AMP, 1 epoch, batch size 8, cosine LR scheduler, 4070Ti SUPER, fixed random seed. Benchmark script is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).

Results for fine-tuning ViT-H (630M params) with BF16 AMP for 1 epoch, batch size 8, cosine LR scheduler, 4070Ti SUPER, fixed random seed:

Adam impl | max memory (GB) | imgs/s | accuracy
AdamW impl | Max memory (GB) | imgs/s | accuracy
----------------|-----------------|--------|----------
PyTorch (fused) | 12.23 | 41.8 | 94.38
bnb 8-bit | 8.32 | 43.6 | 94.18
Expand All @@ -46,6 +44,27 @@ lpmm 4-bit (*) | 7.74 | 26.6 | 94.25

(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details.

Fine-tune [Llama2-7B](https://huggingface.co/meta-llama/Llama-2-7b) on [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset. Full BF16, 1 epoch, A100, fixed random seed. Benchmark is done with [torchtune](https://github.com/pytorch/torchtune). See [#746](https://github.com/pytorch/ao/pull/746) for more details.

AdamW impl | Max memory (GB) | toks/s | `truthfulqa_mc2` acc | Compile time
-----------------|-----------------|--------|----------------------|-------------
Not fine-tuned | - | - | 38.95 | -
PyTorch (fused) | 52 | ~4500 | 42.12 | ~4 min
bnb 8-bit | 39 | ~4000 | 41.98 | ~4 min
ao 8-bit | 39 | ~4000 | 42.41 | ~12 min
ao 4-bit | 33 | ~3600 | 42.34 | ~4 min

NOTE: lpmm's 4-bit AdamW does not support BF16 weights.

### Note on compile times

There are 2 approaches to compile optimizer step in low-bit optim:

1. Compile optim step for single param i.e. `torch.compile(single_param_adam)`
2. Compile optim step for all params i.e. `torch.compile(param_groups_adam)`

Currently Adam8bit and AdamFp8 use approach (2) (with static shape) since it is faster (but compile much slower), while Adam4bit uses approach (1) (with dynamic shape) since there are excessive memory usage for "Adam4bit + approach (2)". Approach (1) requires dynamic shape to avoid hitting recompiles limit.

## Optimizer CPU offload

This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. For multi-GPU training, you can use FSDP's built-in CPU offload.
Expand Down

0 comments on commit 97274ad

Please sign in to comment.