Skip to content

Commit aaaed80

Browse files
committed
Enable FP8 full finetune distributed
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True` (without tensor parallelism), which led to the largest speedups in our experiments. Based on #2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) full_tp 0.584 (+0.000) 9.415 (-0.004) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline - Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname` For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) full_tp 0.594 (+0.001) 9.089 (+0.002) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_validate_float8_tp_plan pytest tests -k test_is_fp8_tensorwise_scaling ```
1 parent 7d92c10 commit aaaed80

File tree

4 files changed

+234
-29
lines changed

4 files changed

+234
-29
lines changed

recipes/full_finetune_distributed.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.distributed._tensor import DTensor
2020
from torch.distributed.tensor.parallel import parallelize_module
2121
from torch.optim import Optimizer
22+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
2223
from torchdata.stateful_dataloader import StatefulDataLoader
2324
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
2425
from torchtune import config, modules, training, utils
@@ -33,6 +34,11 @@
3334
TrainingProgress,
3435
)
3536
from torchtune.training.lr_schedulers import get_lr
37+
from torchtune.training.quantization import (
38+
convert_to_float8_training,
39+
is_fp8_tensorwise_scaling,
40+
validate_float8_tp_plan,
41+
)
3642

3743
from tqdm import tqdm
3844

@@ -184,6 +190,13 @@ def __init__(self, cfg: DictConfig) -> None:
184190
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
185191
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
186192
self._checkpoint_client = CheckpointClient(cfg)
193+
self._enable_fp8_training = cfg.get("enable_fp8_training", False)
194+
self._fp8_recipe_name = cfg.get("fp8_recipe_name", None)
195+
if self._enable_fp8_training and self.tp_plan is not None:
196+
raise ValueError(
197+
"FP8 training does not support tensor parallelism yet. "
198+
"This will be enabled in the near future."
199+
)
187200

188201
self._run_val_every_n_steps = cfg.get("run_val_every_n_steps", None)
189202
if self._run_val_every_n_steps is not None:
@@ -567,6 +580,15 @@ def _setup_model(
567580
if self._compile:
568581
training.compile_model(model, verbose=self._is_rank_zero)
569582

583+
if self._enable_fp8_training:
584+
# Requires https://github.com/pytorch/pytorch/pull/148922
585+
if torch.__version__ < "2.8.0.dev20250318":
586+
raise RuntimeError(
587+
"Float8 fine-tuning requires PyTorch 2.8.0.dev20250318 or later."
588+
)
589+
validate_float8_tp_plan(self.tp_plan, self._fp8_recipe_name)
590+
model = convert_to_float8_training(model, self._fp8_recipe_name)
591+
570592
# Apply tensor parallelism to the model
571593
if self.parallel_dims.tp_enabled:
572594
if not self.parallel_dims.dp_enabled and self.fsdp_cpu_offload:
@@ -922,6 +944,16 @@ def train(self) -> None:
922944
if self._lr_scheduler is not None:
923945
self._lr_scheduler.step()
924946

947+
# If float8 training is enabled, perform a single all-reduce to compute the
948+
# scale for all float8 parameters efficiently instead of doing many small
949+
# all-reduces for each parameter
950+
if (
951+
self._enable_fp8_training
952+
and is_fp8_tensorwise_scaling(self._fp8_recipe_name)
953+
and self.dp_degree > 1
954+
):
955+
precompute_float8_dynamic_scale_for_fsdp(self._model)
956+
925957
loss_to_log = running_loss.detach().item() / num_tokens
926958
pbar.update(1)
927959
pbar.set_description(
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
9+
import torch
10+
11+
from torchao.float8.float8_linear import Float8Linear
12+
13+
from torchtune.models.llama3 import base_llama_tp_plan, fp8_llama_tp_plan
14+
from torchtune.training.quantization import (
15+
convert_to_float8_training,
16+
is_fp8_tensorwise_scaling,
17+
validate_float8_tp_plan,
18+
)
19+
20+
21+
class M(torch.nn.Module):
22+
def __init__(self):
23+
super().__init__()
24+
self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float)
25+
self.output = torch.nn.Linear(256, 512, bias=False).to(torch.float)
26+
27+
def example_inputs(self):
28+
return (torch.randn(1, 512).to(torch.float),)
29+
30+
def forward(self, x):
31+
x = self.linear(x)
32+
x = self.output(x)
33+
return x
34+
35+
36+
class TestFloat8:
37+
def test_convert_to_float8_training(self):
38+
"""
39+
Test that target linear layers are converted to Float8Linear.
40+
"""
41+
m = M()
42+
example_inputs = torch.randn(1, 512).to(torch.float)
43+
m = convert_to_float8_training(m)
44+
assert isinstance(m.linear, Float8Linear)
45+
assert not isinstance(m.output, Float8Linear)
46+
with pytest.raises(Exception):
47+
m = convert_to_float8_training(m, "unrecognized_recipe_name")
48+
49+
def test_validate_float8_tp_plan(self):
50+
"""
51+
Test that only float8 TP plan is only valid for "tensorwise" float8 recipes.
52+
"""
53+
validate_float8_tp_plan(base_llama_tp_plan())
54+
validate_float8_tp_plan(base_llama_tp_plan(), "anything")
55+
validate_float8_tp_plan(fp8_llama_tp_plan())
56+
validate_float8_tp_plan(fp8_llama_tp_plan(), "tensorwise")
57+
with pytest.raises(ValueError):
58+
validate_float8_tp_plan(fp8_llama_tp_plan(), "rowwise")
59+
with pytest.raises(ValueError):
60+
validate_float8_tp_plan(fp8_llama_tp_plan(), "rowwise_with_gw_hp")
61+
62+
def test_is_fp8_tensorwise_scaling(self):
63+
"""
64+
Test that `is_fp8_tensorwise_scaling` returns True only for tensorwise scaling.
65+
"""
66+
assert is_fp8_tensorwise_scaling(None)
67+
assert is_fp8_tensorwise_scaling("tensorwise")
68+
assert not is_fp8_tensorwise_scaling("rowwise")
69+
assert not is_fp8_tensorwise_scaling("rowwise_with_gw_hp")

torchtune/models/llama3/_parallelism.py

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Dict
7+
from typing import Dict, Type
88

99
from torch import nn
1010

@@ -17,31 +17,56 @@
1717
)
1818
from torch.distributed.tensor.parallel.style import ParallelStyle
1919

20-
# Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models
21-
BASE_LLAMA_TP_TRAINING_PLAN = {
22-
"tok_embeddings": RowwiseParallel(
23-
input_layouts=Replicate(), output_layouts=Shard(1)
24-
),
25-
"norm": SequenceParallel(),
26-
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
27-
"layers.*.attn": PrepareModuleInput(
28-
input_layouts=(Shard(1), None),
29-
desired_input_layouts=(Replicate(), None),
30-
),
31-
"layers.*.mlp": PrepareModuleInput(
32-
input_layouts=(Shard(1),),
33-
desired_input_layouts=(Replicate(),),
34-
),
35-
"layers.*.sa_norm": SequenceParallel(),
36-
"layers.*.mlp_norm": SequenceParallel(),
37-
"layers.*.attn.q_proj": ColwiseParallel(),
38-
"layers.*.attn.k_proj": ColwiseParallel(),
39-
"layers.*.attn.v_proj": ColwiseParallel(),
40-
"layers.*.attn.output_proj": RowwiseParallel(output_layouts=Shard(1)),
41-
"layers.*.mlp.w1": ColwiseParallel(),
42-
"layers.*.mlp.w2": RowwiseParallel(output_layouts=Shard(1)),
43-
"layers.*.mlp.w3": ColwiseParallel(),
44-
}
20+
from torchao.float8.float8_tensor_parallel import (
21+
Float8ColwiseParallel,
22+
Float8RowwiseParallel,
23+
PrepareFloat8ModuleInput,
24+
)
25+
26+
27+
def _get_base_llama_tp_training_plan(
28+
layerwise_colwise_parallel_cls: Type[ParallelStyle] = ColwiseParallel,
29+
layerwise_rowwise_parallel_cls: Type[ParallelStyle] = RowwiseParallel,
30+
layerwise_prepare_module_input_cls: Type[ParallelStyle] = PrepareModuleInput,
31+
) -> Dict[str, ParallelStyle]:
32+
"""
33+
Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models.
34+
"""
35+
return {
36+
"tok_embeddings": RowwiseParallel(
37+
input_layouts=Replicate(), output_layouts=Shard(1)
38+
),
39+
"norm": SequenceParallel(),
40+
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
41+
"layers.*.attn": layerwise_prepare_module_input_cls(
42+
input_layouts=(Shard(1), None),
43+
desired_input_layouts=(Replicate(), None),
44+
),
45+
"layers.*.mlp": layerwise_prepare_module_input_cls(
46+
input_layouts=(Shard(1),),
47+
desired_input_layouts=(Replicate(),),
48+
),
49+
"layers.*.sa_norm": SequenceParallel(),
50+
"layers.*.mlp_norm": SequenceParallel(),
51+
"layers.*.attn.q_proj": layerwise_colwise_parallel_cls(),
52+
"layers.*.attn.k_proj": layerwise_colwise_parallel_cls(),
53+
"layers.*.attn.v_proj": layerwise_colwise_parallel_cls(),
54+
"layers.*.attn.output_proj": layerwise_rowwise_parallel_cls(
55+
output_layouts=Shard(1)
56+
),
57+
"layers.*.mlp.w1": layerwise_colwise_parallel_cls(),
58+
"layers.*.mlp.w2": layerwise_rowwise_parallel_cls(output_layouts=Shard(1)),
59+
"layers.*.mlp.w3": layerwise_colwise_parallel_cls(),
60+
}
61+
62+
63+
BASE_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan()
64+
65+
FP8_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan(
66+
layerwise_colwise_parallel_cls=Float8ColwiseParallel,
67+
layerwise_rowwise_parallel_cls=Float8RowwiseParallel,
68+
layerwise_prepare_module_input_cls=PrepareFloat8ModuleInput,
69+
)
4570

4671
BASE_LLAMA_TP_INFERENCE_PLAN = {
4772
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
@@ -70,3 +95,16 @@ def base_llama_tp_plan(
7095
Dict[str, Any]: The tensor parallel plan for Llama3 model.
7196
"""
7297
return BASE_LLAMA_TP_INFERENCE_PLAN if inference else BASE_LLAMA_TP_TRAINING_PLAN
98+
99+
100+
# TODO: expose this once tested
101+
def _fp8_llama_tp_plan() -> Dict[str, ParallelStyle]:
102+
"""
103+
Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both
104+
rowwise and colwise computation, currently only compatible with float8 fine-tuning with
105+
"tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models.
106+
107+
Returns:
108+
Dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model.
109+
"""
110+
return FP8_LLAMA_TP_TRAINING_PLAN

torchtune/training/quantization.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,25 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Callable, Optional
7+
from typing import Callable, Dict, Optional
88

99
from torch import nn
10+
from torch.distributed.tensor.parallel.style import ParallelStyle
1011

1112
from torchao.dtypes import TensorCoreTiledLayout
12-
13+
from torchao.float8 import (
14+
convert_to_float8_training as _convert_to_float8_training_torchao,
15+
Float8LinearConfig,
16+
)
17+
from torchao.float8.float8_tensor_parallel import (
18+
Float8ColwiseParallel,
19+
Float8RowwiseParallel,
20+
)
1321
from torchao.quantization import (
1422
int4_weight_only,
1523
int8_dynamic_activation_int4_weight,
1624
quantize_,
1725
)
18-
1926
from torchao.quantization.qat import (
2027
Int4WeightOnlyQATQuantizer,
2128
Int8DynActInt4WeightQATQuantizer,
@@ -26,6 +33,7 @@
2633
enable_4w_fake_quant,
2734
enable_8da4w_fake_quant,
2835
)
36+
2937
from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear
3038

3139

@@ -219,3 +227,61 @@ def swap_lora_linear_with_qat(
219227
activation_qat_config,
220228
weight_qat_config,
221229
)
230+
231+
232+
def convert_to_float8_training(
233+
model: nn.Module,
234+
fp8_recipe_name: Optional[str] = None,
235+
) -> nn.Module:
236+
"""
237+
Prepare the model for float8 training by swapping all `nn.Linear` with `Float8Linear`.
238+
239+
Args:
240+
model (nn.Module): The model to swap linear layers on
241+
fp8_recipe_name (Optional[str]): name to identify one of the pre-made recipes,
242+
one of "tensorwise", "rowwise", and "rowwise_with_gw_hp". If not specified,
243+
defaults to "tensorwise" with "enable_fsdp_float8_all_gather=True". See
244+
https://github.com/pytorch/ao/blob/v0.9.0/torchao/float8/config.py#L150
245+
for more details.
246+
247+
Returns:
248+
(nn.Module) The new model with `Float8Linear`.
249+
"""
250+
if fp8_recipe_name is not None:
251+
fp8_config = Float8LinearConfig.from_recipe_name(fp8_recipe_name)
252+
else:
253+
fp8_config = Float8LinearConfig.from_recipe_name("tensorwise")
254+
fp8_config.enable_fsdp_float8_all_gather = True
255+
return _convert_to_float8_training_torchao(
256+
model,
257+
config=fp8_config,
258+
module_filter_fn=lambda mod, fqn: fqn != "output",
259+
)
260+
261+
262+
def validate_float8_tp_plan(
263+
tp_plan: Optional[Dict[str, ParallelStyle]],
264+
fp8_recipe_name: Optional[str] = None,
265+
) -> None:
266+
"""
267+
Validate that the provided tensor parallel plan is compatible with the
268+
float8 settings. Specifically, float8 tensor parallel plans are only
269+
supported when using 'tensorwise' float8 recipes.
270+
"""
271+
if tp_plan is None or is_fp8_tensorwise_scaling(fp8_recipe_name):
272+
return
273+
for parallel_style in tp_plan.values():
274+
if isinstance(parallel_style, Float8ColwiseParallel) or isinstance(
275+
parallel_style, Float8RowwiseParallel
276+
):
277+
raise ValueError(
278+
"%s and %s are only compatible with 'tensorwise' float8 recipes"
279+
% (Float8ColwiseParallel.__name__, Float8RowwiseParallel.__name__)
280+
)
281+
282+
283+
def is_fp8_tensorwise_scaling(fp8_recipe_name: Optional[str]):
284+
"""
285+
Return True if the fp8 recipe name refers to 'tensorwwise' scaling.
286+
"""
287+
return fp8_recipe_name is None or fp8_recipe_name == "tensorwise"

0 commit comments

Comments
 (0)