Skip to content

Conversation

@nathan-az
Copy link
Collaborator

@nathan-az nathan-az commented Feb 17, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

This would solve #2201.

I'm far from an expert with torchao but tested this in an effort to further reduce memory usage, combining it with torchao.prototype.low_bit_optim.AdamW8bit to get LLaMA 3.3 70B training comfortably on 8x H100s with sequence length over 8k and only activation offloading.

I saw a significant (50%) improvement in tokens per second using this. A sense check shows near identical loss curves as well.

image

This PR should probably be treated equal parts as a PR and RFC, or a reference for a future one. Key points for discussion are:

  • does this code make sense or conflict with any current features? (e.g. is it correct to do the conversion on the meta device before compilation)
  • does the new functionality sit in the correct place in the torchtune.memory module, and are we happy with using the class as I have
  • what is the best way to test this and ensure correctness and (near) equality versus the standard (bf16) training

Changelog

What are the changes made in this PR?
This implements a WIP version of fp8 training largely based on torchtitan's Float8Converter.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

nathan-az and others added 15 commits February 8, 2025 15:23
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Co-authored-by: Joe Cummings <jrcummings27@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2404

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 17, 2025
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I'm excited to see this feature, especially with code changes that are pretty minimally invasive. Leaving a handful of comments here, many around testing:

  • We should definitely add some unit tests for this. At least something to set up a toy model on one or more devices, call convert_to_float8_training, and validate that linears get swapped out correctly and we see the numerical results we'd expect.
  • Especially since torchtitan's implementation is really just tested for Llama models it'd be interesting to see whether we run into any issues with e.g. a Qwen model where the embedding weights are tied.
  • We can also test with Llama 3.2 Vision just to make sure nothing breaks (would also be interested to see if there's a bigger accuracy dropoff there).
  • Does it compose with tensor parallel?
  • To assess model quality, we can train for a couple epochs on Alpaca, SlimOrca, or some other canonical dataset, then compare eval results using our Eleuther integration to the equivalent run with bf16. (This may take a bit more effort, so lmk if you need help on this one.)
  • (random nit q): why is LR in the plots you shared 0? Is it just an issue with the y-axis scale?

Also cc a couple experts @vkuzo and @gau-nernst in case they have thoughts or suggestions here.

)

def precompute_float8_dynamic_scale_for_fsdp(
self, model: Union[nn.Module, List[nn.Module]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the actual case that we would be applying to a list of modules?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this specific case it's due to the way that torchtitan supports pipeline parallelism with looped parts, using a model_parts object rather than a single model.

The initial state of this PR is very close to a lift-and-shift of their implementation (or, an implementation from about a week ago - they've now made Converters a more generic class).

If we don't anticipate PP support any time soon (I'm not sure where this sits in the priority list - I imagine even something like DP shard + replica patterns are simpler as a next level), I can remove the list support.

…ce for llama

Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
@nathan-az
Copy link
Collaborator Author

@ebsmothers to answer what I can for now:

Does it compose with tensor parallel?

Checking torchtitan, it looks like if we want float8 + TP, it uses alternate rowwise_parallel and colwise_parallel classes conveniently provided by torchhao 😌. Just made a commit to do this swapping, although I'm not certain that it's sufficient yet.

Especially since torchtitan's implementation is really just tested for Llama models it'd be interesting to see whether we run into any issues with e.g. a Qwen model where the embedding weights are tied.

I'm selfishly working just with LLaMA for now, for testing purposes. However I'm having trouble getting standard TP working (without fp8) at the moment (at least in multi-node), so this is all... Very WIP until I can test properly.

(random nit q): why is LR in the plots you shared 0? Is it just an issue with the y-axis scale?

Yes, just a bad y-axis scale in MLFlow 😅

@vkuzo
Copy link

vkuzo commented Feb 18, 2025

We should definitely add some unit tests for this.

+1 to tests, can we also have a README.md somewhere on how to use this from torchtune?

it'd be interesting to see whether we run into any issues with e.g. a Qwen model where the embedding weights are tied.

we are happy to help fix any issues that are uncovered

Does it compose with tensor parallel?

yes. If you are ok with bfloat16 all-gather, then TP is orthogonal to float8 training. If you want to use float8 all-gather, then there is an extra step of using Float8ColwiseParallel and Float8RowwiseParallel. Note that there is TP + tensorwise scaling is well tested, and there is an ongoing issue for tracking enablement of TP + rowwise scaling here: pytorch/torchtitan#845

return BASE_LLAMA_TP_PLAN
if enable_float8:
rowwise_parallel, colwise_parallel = (
Float8RowwiseParallel,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these modules are specific to tensorwise scaling, so if in the future torchtune wants to enable more recipes (such as rowwise scaling) there will have to be additional gating on this

…use original TP classes

Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
@nathan-az
Copy link
Collaborator Author

nathan-az commented Feb 19, 2025

Does it compose with tensor parallel?

yes. If you are ok with bfloat16 all-gather, then TP is orthogonal to float8 training. If you want to use float8 all-gather, then there is an extra step of using Float8ColwiseParallel and Float8RowwiseParallel. Note that there is TP + tensorwise scaling is well tested, and there is an ongoing issue for tracking enablement of TP + rowwise scaling here: pytorch/torchtitan#845

Tested this - I had no issues with TP in the standard case, but with activation checkpointing I get an error: [rank0]: torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'AsyncCollectiveTensor' object has no attribute 'elem'.

It's not obvious why this occurs yet. If anybody has any insights, please share.

[rank0]: Traceback (most recent call last):
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/recipes/full_finetune_distributed.py", line 981, in <module>
[rank0]:     sys.exit(recipe_main())
[rank0]:              ^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank0]:     sys.exit(recipe_main(conf))
[rank0]:              ^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/recipes/full_finetune_distributed.py", line 976, in recipe_main
[rank0]:     recipe.train()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/recipes/full_finetune_distributed.py", line 847, in train
[rank0]:     current_loss.backward()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_tensor.py", line 581, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 307, in apply
[rank0]:     return user_fn(self, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1740, in backward
[rank0]:     ctx_saved_tensors = ctx.saved_tensors
[rank0]:                         ^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 1125, in unpack_hook
[rank0]:     frame.recompute_fn(*args)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/utils/checkpoint.py", line 1519, in recompute_fn
[rank0]:     fn(*args, **kwargs)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1734, in _wrapped_call_impl
[rank0]:     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
[rank0]:     return self._torchdynamo_orig_callable(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
[rank0]:     result = self._inner_convert(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
[rank0]:     return _compile(
[rank0]:            ^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 952, in _compile
[rank0]:     raise InternalTorchDynamoError(
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
[rank0]:     return _compile_inner(code, one_graph, hooks, transform)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
[rank0]:     return function(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
[rank0]:     out_code = transform_code_object(code, transform)
[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
[rank0]:     transformations(instructions, code_options)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
[rank0]:     tracer.run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
[rank0]:     super().run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
[rank0]:     return getattr(self.realize(), name)(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function
[rank0]:     return variables.UserFunctionVariable(fn, source=source).call_function(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1710, in LOAD_METHOD
[rank0]:     self._load_attr(inst)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1734, in _load_attr
[rank0]:     result = BuiltinVariable(getattr).call_function(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 967, in call_function
[rank0]:     return handler(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 712, in <lambda>
[rank0]:     tx, [v.realize() for v in args], kwargs
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 712, in <listcomp>
[rank0]:     tx, [v.realize() for v in args], kwargs
[rank0]:          ^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 63, in realize
[rank0]:     self._cache.realize()
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 29, in realize
[rank0]:     self.vt = VariableBuilder(tx, self.source)(self.value)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 377, in __call__
[rank0]:     vt = self._wrap(value)
[rank0]:          ^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 559, in _wrap
[rank0]:     return self.wrap_tensor(value)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1542, in wrap_tensor
[rank0]:     self.assert_not_wrapped_by_this_graph(value)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1452, in assert_not_wrapped_by_this_graph
[rank0]:     if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode:
[rank0]:        ^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 177, in is_fake
[rank0]:     flattened_tensors = [getattr(x, attr) for attr in attrs]
[rank0]:                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 177, in <listcomp>
[rank0]:     flattened_tensors = [getattr(x, attr) for attr in attrs]
[rank0]:                          ^^^^^^^^^^^^^^^^
[rank0]: torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'AsyncCollectiveTensor' object has no attribute 'elem'

[rank0]: from user code:
[rank0]:    File "/opt/conda/lib/python3.11/site-packages/torchtune/modules/transformer.py", line 121, in forward
[rank0]:     h = self.sa_norm(x)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torchtune/modules/rms_norm.py", line 39, in forward
[rank0]:     x.float(),

@vkuzo
Copy link

vkuzo commented Feb 19, 2025

Tested this - I had no issues with TP in the standard case, but with activation checkpointing I get an error: [rank0]: torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'AsyncCollectiveTensor' object has no attribute 'elem'.

Any chance you can share how to reproduce this so we can take a look? Also, do you see a different (less cryptic) error message if you try without torch.compile?

@nathan-az
Copy link
Collaborator Author

nathan-az commented Feb 19, 2025

Any chance you can share how to reproduce this so we can take a look? Also, do you see a different (less cryptic) error message if you try without torch.compile?

Details beneath contain the full config I used for testing which failed. Unfortunately (but perhaps interestingly) it fully succeeds without torch.compile.

The main options I tweaked during testing were: compile, enable_activation_checkpointing, and enable_fp8_training. I found that all configurations of the above succeed except the below with all 3 true.

batch_size: 1
gradient_accumulation_steps: 4
epochs: 1
optimizer:
  _component_: torchao.prototype.low_bit_optim.AdamW8bit
  lr: 2.0e-05
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
tensor_parallel_dim: 8
tensor_parallel_plan:
  _component_: torchtune.models.llama3.base_llama_tp_plan
max_steps_per_epoch: 20
clip_grad_norm: null
output_dir: outputs
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: 8192
  path: outputs/base_model/Llama-3.1-8B-Instruct/original/tokenizer.model
dataset:
  _component_: torchtune.datasets.slimorca_dataset
  packed: true
  train_on_input: true
seed: null
shuffle: true
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00004'
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
  checkpoint_dir: outputs/base_model/Llama-3.1-8B-Instruct
resume_from_checkpoint: false
fsdp_cpu_offload: false
enable_activation_checkpointing: true
enable_activation_offloading: false
custom_sharded_layers: []
compile: true
optimizer_in_bwd: false
dtype: bf16
enable_fp8_training: true
device: cuda
metric_logger:
  _component_: torchtune.training.metric_logging.MLFlowLogger
log_every_n_steps: 1
log_peak_memory_stats: false
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: false
  output_dir: ${output_dir}/monitoring
  cpu: true
  cuda: true
  profile_memory: true
  with_stack: true
  record_shapes: true
  with_flops: true
  wait_steps: 5
  warmup_steps: 3
  active_steps: 2
  num_cycles: 1
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b

I noticed that TPS was outrageously higher with compile: true in all configurations of those settings (~30k with compile, under 10k without). Certainly makes me want to get compile working if possible!

EDIT: Note that I am using docker pull pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel.

I am entirely unable to use the 2.6 version due to other torch.compile issues (reported here). If you are not able to repro this issue I wonder if it could be sensitive to the torch version?

@nathan-az nathan-az marked this pull request as draft February 20, 2025 11:27
@vkuzo
Copy link

vkuzo commented Feb 20, 2025

@nathan-az , I patched your PR and float8 + compile + TP work for as expected for me.

  • I'm working off the LLaMa 3 8B config, with these changes https://gist.github.com/vkuzo/670adda972f65e1e02061019c7770a34 (to mimic your custom config).
  • Command I used: with-proxy tune run --nnodes 1 --nproc-per-node 8 full_finetune_distributed --config llama3/8B_full epochs=1
  • PyTorch version: 2.7.0a0+git382fbcc (main branch, built from source)

Based on ^, how about:

  • we can land the float8 support with TP disabled, throw an exception if float8 + TP is on
  • we file an issue to enable float8 + TP, and myself or someone else from PT can look into this further and enable with the right gating.

cc @ebsmothers , wdyt?

andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Feb 20, 2025
@andrewor14
Copy link
Contributor

Hi @nathan-az, do you plan to work on this in the near future? Mind if I take it over so we can land this sooner?

@nathan-az
Copy link
Collaborator Author

Hey @andrewor14 - I do have time but I haven't started yet, so more than happy for you to take over! Feel free to use this (and the comments) as a reference, I'll focus on merging HSDP #2415.

@nathan-az nathan-az closed this Mar 11, 2025
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 1, 2025
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 2, 2025
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 2, 2025
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 2, 2025
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 2, 2025
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 4, 2025
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 4, 2025
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 4, 2025
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 8, 2025
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 8, 2025
This commit adds FP8 finetuning to the full_finetune_distributed
recipe as an optional feature. For Llama3-8B, we saw up to 14.7%
improvement in finetuning throughput with no degradation in memory
usage or accuracy. This feature is currently gated on PyTorch
nightlies since it depends on recent features added there.
However, it will be available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true

fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True` (without tensor parallelism),
which led to the largest speedups in our experiments.

**Experimentation:**

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
full_tp                 2773.598 (+0.005%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_noname_tp           3159.515 (+13.919%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise_tp       3160.202 (+13.944%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)
fp8_rowwise_with_gw_hp  3171.742 (+14.360%)  18.492 (+0.060%)   18.492 (+0.060%)  34.405 (+0.330%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
full_tp                 0.584 (+0.000)   9.415 (-0.004)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_noname_tp           0.584 (-0.000)   9.425 (+0.006)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_tensorwise_tp       0.584 (-0.000)   9.425 (+0.005)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
fp8_rowwise_with_gw_hp  0.585 (+0.001)   9.405 (-0.014)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline
- Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname`

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
full_tp                 2764.611 (-0.133%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_noname_tp           3144.787 (+13.600%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise_tp       3163.867 (+14.289%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
full_tp                 0.594 (+0.001)   9.089 (+0.002)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_noname_tp           0.593 (-0.000)   9.078 (-0.009)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_tensorwise_tp       0.593 (-0.001)   9.060 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

Based on meta-pytorch#2404 by @nathan-az
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 8, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true

fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True` (without tensor parallelism),
which led to the largest speedups in our experiments.

**Experimentation:**

All experiments were run on 4x H100 GPUs with 94GB memory each.
We finetune the model on the alpaca dataset for 1 epoch,
using a batch size of 16 with torch.compile. We use the following
commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
full_tp                 2773.598 (+0.005%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_noname_tp           3159.515 (+13.919%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise_tp       3160.202 (+13.944%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)
fp8_rowwise_with_gw_hp  3171.742 (+14.360%)  18.492 (+0.060%)   18.492 (+0.060%)  34.405 (+0.330%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
full_tp                 0.584 (+0.000)   9.415 (-0.004)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_noname_tp           0.584 (-0.000)   9.425 (+0.006)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_tensorwise_tp       0.584 (-0.000)   9.425 (+0.005)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
fp8_rowwise_with_gw_hp  0.585 (+0.001)   9.405 (-0.014)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline
- Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname`

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
full_tp                 2764.611 (-0.133%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_noname_tp           3144.787 (+13.600%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise_tp       3163.867 (+14.289%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
full_tp                 0.594 (+0.001)   9.089 (+0.002)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_noname_tp           0.593 (-0.000)   9.078 (-0.009)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_tensorwise_tp       0.593 (-0.001)   9.060 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

Based on meta-pytorch#2404 by @nathan-az
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 8, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true

fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True` (without tensor parallelism),
which led to the largest speedups in our experiments.

**Experimentation:**

All experiments were run on 4x H100 GPUs with 94GB memory each.
We finetune the model on the cleaned alpaca dataset for 1 epoch,
using a batch size of 16 with torch.compile. We use the following
commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
full_tp                 2773.598 (+0.005%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_noname_tp           3159.515 (+13.919%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise_tp       3160.202 (+13.944%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)
fp8_rowwise_with_gw_hp  3171.742 (+14.360%)  18.492 (+0.060%)   18.492 (+0.060%)  34.405 (+0.330%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
full_tp                 0.584 (+0.000)   9.415 (-0.004)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_noname_tp           0.584 (-0.000)   9.425 (+0.006)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_tensorwise_tp       0.584 (-0.000)   9.425 (+0.005)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
fp8_rowwise_with_gw_hp  0.585 (+0.001)   9.405 (-0.014)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline
- Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname`

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
full_tp                 2764.611 (-0.133%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_noname_tp           3144.787 (+13.600%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise_tp       3163.867 (+14.289%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
full_tp                 0.594 (+0.001)   9.089 (+0.002)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_noname_tp           0.593 (-0.000)   9.078 (-0.009)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_tensorwise_tp       0.593 (-0.001)   9.060 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

Based on meta-pytorch#2404 by @nathan-az
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 8, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true

fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True` (without tensor parallelism),
which led to the largest speedups in our experiments.

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
full_tp                 2773.598 (+0.005%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_noname_tp           3159.515 (+13.919%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise_tp       3160.202 (+13.944%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)
fp8_rowwise_with_gw_hp  3171.742 (+14.360%)  18.492 (+0.060%)   18.492 (+0.060%)  34.405 (+0.330%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
full_tp                 0.584 (+0.000)   9.415 (-0.004)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_noname_tp           0.584 (-0.000)   9.425 (+0.006)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_tensorwise_tp       0.584 (-0.000)   9.425 (+0.005)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
fp8_rowwise_with_gw_hp  0.585 (+0.001)   9.405 (-0.014)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline
- Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname`

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
full_tp                 2764.611 (-0.133%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_noname_tp           3144.787 (+13.600%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise_tp       3163.867 (+14.289%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
full_tp                 0.594 (+0.001)   9.089 (+0.002)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_noname_tp           0.593 (-0.000)   9.078 (-0.009)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_tensorwise_tp       0.593 (-0.001)   9.060 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

Based on meta-pytorch#2404 by @nathan-az
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 8, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true

fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True` (without tensor parallelism),
which led to the largest speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
full_tp                 2773.598 (+0.005%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_noname_tp           3159.515 (+13.919%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise_tp       3160.202 (+13.944%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)
fp8_rowwise_with_gw_hp  3171.742 (+14.360%)  18.492 (+0.060%)   18.492 (+0.060%)  34.405 (+0.330%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
full_tp                 0.584 (+0.000)   9.415 (-0.004)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_noname_tp           0.584 (-0.000)   9.425 (+0.006)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_tensorwise_tp       0.584 (-0.000)   9.425 (+0.005)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
fp8_rowwise_with_gw_hp  0.585 (+0.001)   9.405 (-0.014)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline
- Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname`

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
full_tp                 2764.611 (-0.133%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_noname_tp           3144.787 (+13.600%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise_tp       3163.867 (+14.289%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
full_tp                 0.594 (+0.001)   9.089 (+0.002)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_noname_tp           0.593 (-0.000)   9.078 (-0.009)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_tensorwise_tp       0.593 (-0.001)   9.060 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_validate_float8_tp_plan
pytest tests -k test_is_fp8_tensorwise_scaling
```
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 11, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true

fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True` (without tensor parallelism),
which led to the largest speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
full_tp                 2773.598 (+0.005%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_noname_tp           3159.515 (+13.919%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise_tp       3160.202 (+13.944%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)
fp8_rowwise_with_gw_hp  3171.742 (+14.360%)  18.492 (+0.060%)   18.492 (+0.060%)  34.405 (+0.330%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
full_tp                 0.584 (+0.000)   9.415 (-0.004)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_noname_tp           0.584 (-0.000)   9.425 (+0.006)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_tensorwise_tp       0.584 (-0.000)   9.425 (+0.005)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
fp8_rowwise_with_gw_hp  0.585 (+0.001)   9.405 (-0.014)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline
- Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname`

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
full_tp                 2764.611 (-0.133%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_noname_tp           3144.787 (+13.600%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise_tp       3163.867 (+14.289%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
full_tp                 0.594 (+0.001)   9.089 (+0.002)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_noname_tp           0.593 (-0.000)   9.078 (-0.009)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_tensorwise_tp       0.593 (-0.001)   9.060 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_validate_float8_tp_plan
pytest tests -k test_is_fp8_tensorwise_scaling
```
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 11, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true

fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True` (without tensor parallelism),
which led to the largest speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
full_tp                 2773.598 (+0.005%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_noname_tp           3159.515 (+13.919%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise_tp       3160.202 (+13.944%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)
fp8_rowwise_with_gw_hp  3171.742 (+14.360%)  18.492 (+0.060%)   18.492 (+0.060%)  34.405 (+0.330%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
full_tp                 0.584 (+0.000)   9.415 (-0.004)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_noname_tp           0.584 (-0.000)   9.425 (+0.006)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_tensorwise_tp       0.584 (-0.000)   9.425 (+0.005)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
fp8_rowwise_with_gw_hp  0.585 (+0.001)   9.405 (-0.014)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline
- Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname`

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
full_tp                 2764.611 (-0.133%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_noname_tp           3144.787 (+13.600%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise_tp       3163.867 (+14.289%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
full_tp                 0.594 (+0.001)   9.089 (+0.002)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_noname_tp           0.593 (-0.000)   9.078 (-0.009)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_tensorwise_tp       0.593 (-0.001)   9.060 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_validate_float8_tp_plan
pytest tests -k test_is_fp8_tensorwise_scaling
```
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 11, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true

fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True` (without tensor parallelism),
which led to the largest speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
full_tp                 2773.598 (+0.005%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_noname_tp           3159.515 (+13.919%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise_tp       3160.202 (+13.944%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)
fp8_rowwise_with_gw_hp  3171.742 (+14.360%)  18.492 (+0.060%)   18.492 (+0.060%)  34.405 (+0.330%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
full_tp                 0.584 (+0.000)   9.415 (-0.004)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_noname_tp           0.584 (-0.000)   9.425 (+0.006)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_tensorwise_tp       0.584 (-0.000)   9.425 (+0.005)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
fp8_rowwise_with_gw_hp  0.585 (+0.001)   9.405 (-0.014)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline
- Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname`

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
full_tp                 2764.611 (-0.133%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_noname_tp           3144.787 (+13.600%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise_tp       3163.867 (+14.289%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
full_tp                 0.594 (+0.001)   9.089 (+0.002)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_noname_tp           0.593 (-0.000)   9.078 (-0.009)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_tensorwise_tp       0.593 (-0.001)   9.060 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_validate_float8_tp_plan
pytest tests -k test_is_fp8_tensorwise_scaling
```
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 16, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true

fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True` (without tensor parallelism),
which led to the largest speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
full_tp                 2773.598 (+0.005%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_noname_tp           3159.515 (+13.919%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise_tp       3160.202 (+13.944%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)
fp8_rowwise_with_gw_hp  3171.742 (+14.360%)  18.492 (+0.060%)   18.492 (+0.060%)  34.405 (+0.330%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
full_tp                 0.584 (+0.000)   9.415 (-0.004)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_noname_tp           0.584 (-0.000)   9.425 (+0.006)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_tensorwise_tp       0.584 (-0.000)   9.425 (+0.005)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
fp8_rowwise_with_gw_hp  0.585 (+0.001)   9.405 (-0.014)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline
- Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname`

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
full_tp                 2764.611 (-0.133%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_noname_tp           3144.787 (+13.600%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise_tp       3163.867 (+14.289%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
full_tp                 0.594 (+0.001)   9.089 (+0.002)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_noname_tp           0.593 (-0.000)   9.078 (-0.009)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_tensorwise_tp       0.593 (-0.001)   9.060 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_validate_float8_tp_plan
pytest tests -k test_is_fp8_tensorwise_scaling
```
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 16, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true
fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True`, which led to the largest
speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_is_fp8_tensorwise_scaling
```
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 16, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true
fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True`, which led to the largest
speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_is_fp8_tensorwise_scaling
```
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 16, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true
fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True`, which led to the largest
speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_is_fp8_tensorwise_scaling
```
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 16, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true
fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True`, which led to the largest
speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_is_fp8_tensorwise_scaling
```
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 17, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true
fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True`, which led to the largest
speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_is_fp8_tensorwise_scaling
```
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 17, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true
fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True`, which led to the largest
speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_is_fp8_tensorwise_scaling
```
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request Apr 17, 2025
**Summary:** This commit adds FP8 finetuning to the
`full_finetune_distributed` recipe as an optional feature.
For Llama3-8B, we saw up to 14.7% improvement in finetuning
throughput with no degradation in memory usage or accuracy.
This feature is currently gated on PyTorch nightlies since it
depends on recent features added there. However, it will be
available in the next torchtune release.

To use this feature, add the following to your config.yaml:
```
enable_fp8_training: true
fp8_recipe_name: tensorwise  # or rowwise, or rowwise_with_gw_hp
```

The default setting uses tensorwise scaling +
`enable_fsdp_float8_all_gather=True`, which led to the largest
speedups in our experiments.

Based on meta-pytorch#2404 by @nathan-az

**Experimentation:** All experiments were run on 4x H100 GPUs
with 94GB memory each. We finetune the model on the cleaned alpaca
dataset for 1 epoch, using a batch size of 16 with torch.compile.
We use the following commits from all 3 repos:

```
torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning)
torchao: 5a78b70
torch: 1017927
```

For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no
change in memory usage or quantized accuracy compared to the bf16
baseline:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2773.473 (+0.000%)   18.481 (+0.000%)   18.481 (+0.000%)  34.291 (+0.000%)
fp8_noname              3182.220 (+14.738%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_tensorwise          3159.676 (+13.925%)  18.484 (+0.014%)   18.484 (+0.014%)  34.325 (+0.097%)
fp8_rowwise             2790.424 (+0.611%)   18.496 (+0.078%)   18.496 (+0.078%)  34.327 (+0.103%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.584 (+0.000)   9.419 (+0.000)
fp8_noname              0.585 (+0.000)   9.431 (+0.012)
fp8_tensorwise          0.584 (+0.000)   9.421 (+0.002)
fp8_rowwise             0.583 (-0.002)   9.421 (+0.002)
```

A few more observations here:
- The best tok/s improvement was from the default setting (`fp8_noname`)
- `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline

For Llama3.1-8B, we observed similar observations, with up to
14.3% faster finetuning and no change in quantized accuracy.
However, memory usage did increase minimally (+2%) for most fp8
settings:

```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    2768.292 (+0.000%)   18.541 (+0.000%)   18.541 (+0.000%)  34.270 (+0.000%)
fp8_noname              3164.370 (+14.308%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_tensorwise          3136.952 (+13.317%)  18.542 (+0.008%)   18.542 (+0.008%)  34.963 (+2.021%)
fp8_rowwise             2790.672 (+0.808%)   18.554 (+0.073%)   18.554 (+0.073%)  34.389 (+0.348%)
fp8_rowwise_with_gw_hp  3144.678 (+13.596%)  18.551 (+0.056%)   18.551 (+0.056%)  34.966 (+2.032%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.594 (+0.000)   9.087 (+0.000)
fp8_noname              0.593 (-0.001)   9.070 (-0.017)
fp8_tensorwise          0.593 (-0.001)   9.061 (-0.026)
fp8_rowwise             0.593 (-0.000)   9.086 (-0.001)
fp8_rowwise_with_gw_hp  0.595 (+0.001)   9.087 (+0.000)
```

Llama3.2-3B saw up to 16.5% faster finetuning for rowwise with high
precision `grad_weight`, which is a bigger improvement than just
tensorwise. Similarly, there are no degradations in memory usage or
quantized accuracy.
```
experiment_name         tok/s                peak_mem_active    peak_mem_alloc    peak_mem_reserved
----------------------  -------------------  -----------------  ----------------  -------------------
full                    6502.143 (+0.000%)   15.917 (+0.000%)   15.917 (+0.000%)  30.090 (+0.000%)
fp8_noname              7205.386 (+10.816%)  15.917 (+0.003%)   15.917 (+0.003%)  30.010 (-0.266%)
fp8_tensorwise          7222.198 (+11.074%)  15.917 (+0.003%)   15.917 (+0.003%)  30.010 (-0.266%)
fp8_rowwise             6387.968 (-1.756%)   15.916 (-0.002%)   15.916 (-0.002%)  29.158 (-3.096%)
fp8_rowwise_with_gw_hp  7573.698 (+16.480%)  15.917 (+0.001%)   15.917 (+0.001%)  29.516 (-1.908%)

experiment_name         hellaswag_acc    wikitext_word_perplexity
----------------------  ---------------  --------------------------
full                    0.533 (+0.000)   12.407 (+0.000)
fp8_noname              0.533 (+0.000)   12.414 (+0.007)
fp8_tensorwise          0.533 (+0.000)   12.412 (+0.005)
fp8_rowwise             0.533 (-0.000)   12.420 (+0.013)
fp8_rowwise_with_gw_hp  0.534 (+0.001)   12.416 (+0.009)
```

**Test Plan:**

Experiment command:
```
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \
    enable_fp8_training=true \
    fp8_recipe_name=tensorwise \
    epochs=1 \
    batch_size=16 \
    compile=true \
    dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \
    checkpointer.output_dir="$LOG_DIR" \
    output_dir="${LOG_DIR}/metrics" \
    metric_logger.log_dir="${LOG_DIR}/metrics"
```
(full script:
https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh)

Unit tests:
```
pytest tests -k test_convert_to_float8_training
pytest tests -k test_is_fp8_tensorwise_scaling
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants