Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Low-bit optim] Add Llama2-7B finetune benchmarks #746

Merged
merged 5 commits into from
Sep 2, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Aug 25, 2024

Update: change Llama3.1-8B-instruct to Llama2-7B

Fine-tune Llama2-7B on Alpaca dataset. Full BF16, 1 epoch, A100, fixed random seed. Benchmark is done with torchtune.

Summary

AdamW impl Max memory (GB) toks/s truthfulqa_mc2 acc Compile time
Not fine-tuned - - 38.95 -
PyTorch (fused) 52 ~4500 42.12 ~4 min
bnb 8-bit 39 ~4000 41.98 ~4 min
ao 8-bit 39 ~4000 42.41 ~12 min
ao 4-bit 33 ~3600 42.34 ~4 min

NOTE:

  • lpmm's 4-bit AdamW does not support BF16 weights -> not include in benchmark
  • A100 does not support FP8 -> not include FP8 AdamW

Observations

  • The reduction in peak memory looks correct: going from 16-bit to 8-bit -> 52 - 39 = 13GB reduction. Going from 8-bit to 4-bit -> 39 - 33 = 6GB reduction.
  • Our 8-bit AdamW is only slightly slower than bnb, which is nice.
  • The compile time for 8-bit is HUGE (12 min). Might need to find ways to mitigate this.
  • Our 4-bit AdamW is quite slow, but compile fast. This is expected because we compile w/ dynamic shape for each param, while in 8-bit AdamW, we compile w/ static shape for all params. We do this way because 4-bit AdamW will have a memory bug when compiling for all params.

Command used (change optimizer and checkpointer.output_dir across runs)

tune run full_finetune_single_device --config llama2/7B_full optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False compile=True metric_logger=torchtune.utils.metric_logging.WandBLogger log_peak_memory_stats=True batch_size=16 log_every_n_steps=10 epochs=1 seed=2024 checkpointer.output_dir=experiments/adamw_baseline

Fancy graphs!

Compare across different n-bit optimizers

image

Compare 8-bit AdamW between ao and bnb. The fact that the two graphs overlap show that our implementation is correct and competitive in speed (except compile time 😭)!

image

Copy link

pytorch-bot bot commented Aug 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/746

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 5 Unrelated Failures

As of commit 2de6df0 with merge base ba2d3b1 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 25, 2024
@msaroufim msaroufim mentioned this pull request Aug 27, 2024
3 tasks
@gau-nernst gau-nernst changed the title [Low-bit optim] Add Llama3.1-8B finetune benchmarks [Low-bit optim] Add Llama2-7B finetune benchmarks Aug 27, 2024
@gau-nernst gau-nernst marked this pull request as ready for review August 27, 2024 15:36
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @mlazos who was looking at large compile times

@gau-nernst
Copy link
Collaborator Author

@msaroufim Any blockers to merge this? The failing CPU test is unrelated, though I'm probably in charge of it since it's FP6-LLM 🌚. Seems like something change with CPU inductor.

Some thoughts on reducing compile time. There are 2 approaches to compile optimizer step in low-bit optim:

  1. Compile optim step for single param i.e. torch.compile(single_param_adam)
  2. Compile optim step for all params i.e. torch.compile(param_groups_adam)

Currently Adam8bit and AdamFp8 use approach (2) (with static shape) since it is faster (but compile much slower), while Adam4bit uses approach (1) (with dynamic shape) since there are excessive memory usage for "Adam4bit + approach (2)". Approach (1) requires dynamic shape to avoid hitting recompiles limit.

Now looking back, perhaps we can do approach (1) with static shape + temporarily remove recompile limit? I have seen FlexAttention doing this

https://github.com/pytorch/pytorch/blob/76710d4f95d1f920bdf56e4db4d6d71ef6c9aea2/torch/nn/attention/flex_attention.py#L989

It's probably safe to do so, since for a given model, the number of recompiles for single_param_adam() is fixed, though some models may have more recompiles than others (e.g. ViT vs LLM).

@msaroufim
Copy link
Member

I'm gonna add some of your comments here to the README since they're helpful

@msaroufim msaroufim merged commit e5246fc into pytorch:main Sep 2, 2024
8 of 14 checks passed
@gau-nernst gau-nernst deleted the update_optim_bench branch September 2, 2024 19:25
jerryzh168 pushed a commit to jerryzh168/ao that referenced this pull request Sep 4, 2024
* add Llama3.1-8B finetune bench

* update doc

* Update README.md

---------

Co-authored-by: Mark Saroufim <marksaroufim@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants