You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.
To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true
fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp
tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan
```
The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True` (without tensor parallelism),
which led to the largest speedups in our experiments.
**Experimentation:**
All experiments were run on 4x H100 GPUs with 94GB memory each.
We finetune the model on the alpaca dataset for 1 epoch,
using a batch size of 16 with torch.compile. We use the following
commits from all 3 repos:
```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```
For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:
```
experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved
---------------------- ------------------- ----------------- ---------------- -------------------
full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%)
full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%)
fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%)
fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%)
fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%)
fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%)
fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%)
fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%)
experiment_name hellaswag_acc wikitext_word_perplexity
---------------------- --------------- --------------------------
full 0.584 (+0.000) 9.419 (+0.000)
full_tp 0.584 (+0.000) 9.415 (-0.004)
fp8_noname 0.585 (+0.000) 9.431 (+0.012)
fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006)
fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002)
fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005)
fp8_rowwise 0.583 (-0.002) 9.421 (+0.002)
fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014)
```
A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline
- Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname`
For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:
```
experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved
---------------------- ------------------- ----------------- ---------------- -------------------
full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%)
full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%)
fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%)
fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%)
fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%)
fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%)
fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%)
fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%)
experiment_name hellaswag_acc wikitext_word_perplexity
---------------------- --------------- --------------------------
full 0.594 (+0.000) 9.087 (+0.000)
full_tp 0.594 (+0.001) 9.089 (+0.002)
fp8_noname 0.593 (-0.001) 9.070 (-0.017)
fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009)
fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026)
fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026)
fp8_rowwise 0.593 (-0.000) 9.086 (-0.001)
fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000)
```
Based on #2404 by @nathan-az
0 commit comments