-
Notifications
You must be signed in to change notification settings - Fork 678
(WIP/RFC) FP8 full finetune distributed #2404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(WIP/RFC) FP8 full finetune distributed #2404
Conversation
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Co-authored-by: Joe Cummings <jrcummings27@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2404
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ebsmothers
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I'm excited to see this feature, especially with code changes that are pretty minimally invasive. Leaving a handful of comments here, many around testing:
- We should definitely add some unit tests for this. At least something to set up a toy model on one or more devices, call
convert_to_float8_training, and validate that linears get swapped out correctly and we see the numerical results we'd expect. - Especially since torchtitan's implementation is really just tested for Llama models it'd be interesting to see whether we run into any issues with e.g. a Qwen model where the embedding weights are tied.
- We can also test with Llama 3.2 Vision just to make sure nothing breaks (would also be interested to see if there's a bigger accuracy dropoff there).
- Does it compose with tensor parallel?
- To assess model quality, we can train for a couple epochs on Alpaca, SlimOrca, or some other canonical dataset, then compare eval results using our Eleuther integration to the equivalent run with bf16. (This may take a bit more effort, so lmk if you need help on this one.)
- (random nit q): why is LR in the plots you shared 0? Is it just an issue with the y-axis scale?
Also cc a couple experts @vkuzo and @gau-nernst in case they have thoughts or suggestions here.
| ) | ||
|
|
||
| def precompute_float8_dynamic_scale_for_fsdp( | ||
| self, model: Union[nn.Module, List[nn.Module]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the actual case that we would be applying to a list of modules?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this specific case it's due to the way that torchtitan supports pipeline parallelism with looped parts, using a model_parts object rather than a single model.
The initial state of this PR is very close to a lift-and-shift of their implementation (or, an implementation from about a week ago - they've now made Converters a more generic class).
If we don't anticipate PP support any time soon (I'm not sure where this sits in the priority list - I imagine even something like DP shard + replica patterns are simpler as a next level), I can remove the list support.
…ce for llama Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
|
@ebsmothers to answer what I can for now:
Checking torchtitan, it looks like if we want float8 + TP, it uses alternate rowwise_parallel and colwise_parallel classes conveniently provided by
I'm selfishly working just with LLaMA for now, for testing purposes. However I'm having trouble getting standard TP working (without fp8) at the moment (at least in multi-node), so this is all... Very WIP until I can test properly.
Yes, just a bad y-axis scale in MLFlow 😅 |
+1 to tests, can we also have a README.md somewhere on how to use this from torchtune?
we are happy to help fix any issues that are uncovered
yes. If you are ok with bfloat16 all-gather, then TP is orthogonal to float8 training. If you want to use float8 all-gather, then there is an extra step of using |
| return BASE_LLAMA_TP_PLAN | ||
| if enable_float8: | ||
| rowwise_parallel, colwise_parallel = ( | ||
| Float8RowwiseParallel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these modules are specific to tensorwise scaling, so if in the future torchtune wants to enable more recipes (such as rowwise scaling) there will have to be additional gating on this
…use original TP classes Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Tested this - I had no issues with TP in the standard case, but with activation checkpointing I get an error: It's not obvious why this occurs yet. If anybody has any insights, please share. |
Any chance you can share how to reproduce this so we can take a look? Also, do you see a different (less cryptic) error message if you try without torch.compile? |
Details beneath contain the full config I used for testing which failed. Unfortunately (but perhaps interestingly) it fully succeeds without The main options I tweaked during testing were: batch_size: 1
gradient_accumulation_steps: 4
epochs: 1
optimizer:
_component_: torchao.prototype.low_bit_optim.AdamW8bit
lr: 2.0e-05
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
tensor_parallel_dim: 8
tensor_parallel_plan:
_component_: torchtune.models.llama3.base_llama_tp_plan
max_steps_per_epoch: 20
clip_grad_norm: null
output_dir: outputs
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
max_seq_len: 8192
path: outputs/base_model/Llama-3.1-8B-Instruct/original/tokenizer.model
dataset:
_component_: torchtune.datasets.slimorca_dataset
packed: true
train_on_input: true
seed: null
shuffle: true
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: '00004'
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
checkpoint_dir: outputs/base_model/Llama-3.1-8B-Instruct
resume_from_checkpoint: false
fsdp_cpu_offload: false
enable_activation_checkpointing: true
enable_activation_offloading: false
custom_sharded_layers: []
compile: true
optimizer_in_bwd: false
dtype: bf16
enable_fp8_training: true
device: cuda
metric_logger:
_component_: torchtune.training.metric_logging.MLFlowLogger
log_every_n_steps: 1
log_peak_memory_stats: false
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: false
output_dir: ${output_dir}/monitoring
cpu: true
cuda: true
profile_memory: true
with_stack: true
record_shapes: true
with_flops: true
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
model:
_component_: torchtune.models.llama3_1.llama3_1_8bI noticed that TPS was outrageously higher with EDIT: Note that I am using I am entirely unable to use the 2.6 version due to other torch.compile issues (reported here). If you are not able to repro this issue I wonder if it could be sensitive to the torch version? |
|
@nathan-az , I patched your PR and float8 + compile + TP work for as expected for me.
Based on ^, how about:
cc @ebsmothers , wdyt? |
|
Hi @nathan-az, do you plan to work on this in the near future? Mind if I take it over so we can land this sooner? |
|
Hey @andrewor14 - I do have time but I haven't started yet, so more than happy for you to take over! Feel free to use this (and the comments) as a reference, I'll focus on merging HSDP #2415. |
TODO: write this Based on meta-pytorch#2404 by @nathan-az
TODO: write this Based on meta-pytorch#2404 by @nathan-az
TODO: write this Based on meta-pytorch#2404 by @nathan-az
TODO: write this Based on meta-pytorch#2404 by @nathan-az
TODO: write this Based on meta-pytorch#2404 by @nathan-az
TODO: write this Based on meta-pytorch#2404 by @nathan-az
TODO: write this Based on meta-pytorch#2404 by @nathan-az
TODO: write this Based on meta-pytorch#2404 by @nathan-az
TODO: write this Based on meta-pytorch#2404 by @nathan-az
This commit adds FP8 finetuning to the full_finetune_distributed recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True` (without tensor parallelism), which led to the largest speedups in our experiments. **Experimentation:** For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) full_tp 0.584 (+0.000) 9.415 (-0.004) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline - Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname` For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) full_tp 0.594 (+0.001) 9.089 (+0.002) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` Based on meta-pytorch#2404 by @nathan-az
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True` (without tensor parallelism), which led to the largest speedups in our experiments. **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) full_tp 0.584 (+0.000) 9.415 (-0.004) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline - Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname` For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) full_tp 0.594 (+0.001) 9.089 (+0.002) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` Based on meta-pytorch#2404 by @nathan-az
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True` (without tensor parallelism), which led to the largest speedups in our experiments. **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) full_tp 0.584 (+0.000) 9.415 (-0.004) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline - Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname` For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) full_tp 0.594 (+0.001) 9.089 (+0.002) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` Based on meta-pytorch#2404 by @nathan-az
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True` (without tensor parallelism), which led to the largest speedups in our experiments. **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) full_tp 0.584 (+0.000) 9.415 (-0.004) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline - Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname` For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) full_tp 0.594 (+0.001) 9.089 (+0.002) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` Based on meta-pytorch#2404 by @nathan-az
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True` (without tensor parallelism), which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) full_tp 0.584 (+0.000) 9.415 (-0.004) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline - Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname` For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) full_tp 0.594 (+0.001) 9.089 (+0.002) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_validate_float8_tp_plan pytest tests -k test_is_fp8_tensorwise_scaling ```
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True` (without tensor parallelism), which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) full_tp 0.584 (+0.000) 9.415 (-0.004) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline - Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname` For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) full_tp 0.594 (+0.001) 9.089 (+0.002) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_validate_float8_tp_plan pytest tests -k test_is_fp8_tensorwise_scaling ```
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True` (without tensor parallelism), which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) full_tp 0.584 (+0.000) 9.415 (-0.004) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline - Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname` For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) full_tp 0.594 (+0.001) 9.089 (+0.002) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_validate_float8_tp_plan pytest tests -k test_is_fp8_tensorwise_scaling ```
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True` (without tensor parallelism), which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) full_tp 0.584 (+0.000) 9.415 (-0.004) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline - Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname` For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) full_tp 0.594 (+0.001) 9.089 (+0.002) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_validate_float8_tp_plan pytest tests -k test_is_fp8_tensorwise_scaling ```
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True` (without tensor parallelism), which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) full_tp 0.584 (+0.000) 9.415 (-0.004) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline - Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname` For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) full_tp 0.594 (+0.001) 9.089 (+0.002) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_validate_float8_tp_plan pytest tests -k test_is_fp8_tensorwise_scaling ```
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True`, which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_is_fp8_tensorwise_scaling ```
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True`, which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_is_fp8_tensorwise_scaling ```
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True`, which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_is_fp8_tensorwise_scaling ```
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True`, which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_is_fp8_tensorwise_scaling ```
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True`, which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_is_fp8_tensorwise_scaling ```
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True`, which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_is_fp8_tensorwise_scaling ```
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True`, which led to the largest speedups in our experiments. Based on meta-pytorch#2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` Llama3.2-3B saw up to 16.5% faster finetuning for rowwise with high precision `grad_weight`, which is a bigger improvement than just tensorwise. Similarly, there are no degradations in memory usage or quantized accuracy. ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 6502.143 (+0.000%) 15.917 (+0.000%) 15.917 (+0.000%) 30.090 (+0.000%) fp8_noname 7205.386 (+10.816%) 15.917 (+0.003%) 15.917 (+0.003%) 30.010 (-0.266%) fp8_tensorwise 7222.198 (+11.074%) 15.917 (+0.003%) 15.917 (+0.003%) 30.010 (-0.266%) fp8_rowwise 6387.968 (-1.756%) 15.916 (-0.002%) 15.916 (-0.002%) 29.158 (-3.096%) fp8_rowwise_with_gw_hp 7573.698 (+16.480%) 15.917 (+0.001%) 15.917 (+0.001%) 29.516 (-1.908%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.533 (+0.000) 12.407 (+0.000) fp8_noname 0.533 (+0.000) 12.414 (+0.007) fp8_tensorwise 0.533 (+0.000) 12.412 (+0.005) fp8_rowwise 0.533 (-0.000) 12.420 (+0.013) fp8_rowwise_with_gw_hp 0.534 (+0.001) 12.416 (+0.009) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_is_fp8_tensorwise_scaling ```
Context
What is the purpose of this PR? Is it to
This would solve #2201.
I'm far from an expert with
torchaobut tested this in an effort to further reduce memory usage, combining it withtorchao.prototype.low_bit_optim.AdamW8bitto get LLaMA 3.3 70B training comfortably on 8x H100s with sequence length over 8k and only activation offloading.I saw a significant (50%) improvement in tokens per second using this. A sense check shows near identical loss curves as well.
This PR should probably be treated equal parts as a PR and RFC, or a reference for a future one. Key points for discussion are:
torchtune.memorymodule, and are we happy with using the class as I haveChangelog
What are the changes made in this PR?
This implements a WIP version of fp8 training largely based on torchtitan's Float8Converter.
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install)pytest testspytest tests -m integration_testUX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example