Skip to content

Commit bd94d84

Browse files
committed
Enable FP8 full finetune distributed
**Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp tensor_parallel_plan._component_: torchtune.models.llama3.fp8_llama_tp_plan ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True` (without tensor parallelism), which led to the largest speedups in our experiments. **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) full_tp 2773.598 (+0.005%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_noname_tp 3159.515 (+13.919%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise_tp 3160.202 (+13.944%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) fp8_rowwise_with_gw_hp 3171.742 (+14.360%) 18.492 (+0.060%) 18.492 (+0.060%) 34.405 (+0.330%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) full_tp 0.584 (+0.000) 9.415 (-0.004) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_noname_tp 0.584 (-0.000) 9.425 (+0.006) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_tensorwise_tp 0.584 (-0.000) 9.425 (+0.005) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) fp8_rowwise_with_gw_hp 0.585 (+0.001) 9.405 (-0.014) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline - Float8 tensor parallel did not seem to have helped, and even degraded tok/s for `fp8_noname` For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) full_tp 2764.611 (-0.133%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_noname_tp 3144.787 (+13.600%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise_tp 3163.867 (+14.289%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) full_tp 0.594 (+0.001) 9.089 (+0.002) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_noname_tp 0.593 (-0.000) 9.078 (-0.009) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_tensorwise_tp 0.593 (-0.001) 9.060 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` Based on #2404 by @nathan-az
1 parent f1ecdd6 commit bd94d84

File tree

4 files changed

+157
-31
lines changed

4 files changed

+157
-31
lines changed

recipes/full_finetune_distributed.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.distributed._tensor import DTensor
2020
from torch.distributed.tensor.parallel import parallelize_module
2121
from torch.optim import Optimizer
22+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
2223
from torchdata.stateful_dataloader import StatefulDataLoader
2324
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
2425
from torchtune import config, modules, training, utils
@@ -33,6 +34,11 @@
3334
TrainingProgress,
3435
)
3536
from torchtune.training.lr_schedulers import get_lr
37+
from torchtune.training.quantization import (
38+
convert_to_float8_training,
39+
is_fp8_tensorwise_scaling,
40+
validate_float8_tp_plan,
41+
)
3642

3743
from tqdm import tqdm
3844

@@ -184,6 +190,8 @@ def __init__(self, cfg: DictConfig) -> None:
184190
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
185191
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
186192
self._checkpoint_client = CheckpointClient(cfg)
193+
self._enable_fp8_training = cfg.get("enable_fp8_training", False)
194+
self._fp8_recipe_name = cfg.get("fp8_recipe_name", None)
187195

188196
# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
189197
if self._optimizer_in_bwd:
@@ -545,6 +553,15 @@ def _setup_model(
545553
if self._compile:
546554
training.compile_model(model, verbose=self._is_rank_zero)
547555

556+
if self._enable_fp8_training:
557+
# Requires https://github.com/pytorch/pytorch/pull/148922
558+
if torch.__version__ < "2.8.0.dev20250318":
559+
raise RuntimeError(
560+
"Float8 fine-tuning requires PyTorch 2.8.0.dev20250318 or later."
561+
)
562+
validate_float8_tp_plan(self.tp_plan, self._fp8_recipe_name)
563+
model = convert_to_float8_training(model, self._fp8_recipe_name)
564+
548565
# Apply tensor parallelism to the model
549566
if self.parallel_dims.tp_enabled:
550567
if not self.parallel_dims.dp_enabled and self.fsdp_cpu_offload:
@@ -846,6 +863,16 @@ def train(self) -> None:
846863
if self._lr_scheduler is not None:
847864
self._lr_scheduler.step()
848865

866+
# If float8 training is enabled, perform a single all-reduce to compute the
867+
# scale for all float8 parameters efficiently instead of doing many small
868+
# all-reduces for each parameter
869+
if (
870+
self._enable_fp8_training
871+
and is_fp8_tensorwise_scaling(self._fp8_recipe_name)
872+
and self.dp_degree > 1
873+
):
874+
precompute_float8_dynamic_scale_for_fsdp(self._model)
875+
849876
loss_to_log = running_loss.item() / num_tokens
850877
pbar.update(1)
851878
pbar.set_description(

torchtune/models/llama3/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
qlora_llama3_70b,
1616
qlora_llama3_8b,
1717
)
18-
from ._parallelism import base_llama_tp_plan
18+
from ._parallelism import base_llama_tp_plan, fp8_llama_tp_plan
1919
from ._tokenizer import Llama3Tokenizer
2020

2121
__all__ = [
@@ -30,4 +30,5 @@
3030
"qlora_llama3_8b",
3131
"qlora_llama3_70b",
3232
"base_llama_tp_plan",
33+
"fp8_llama_tp_plan",
3334
]

torchtune/models/llama3/_parallelism.py

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Dict
7+
from typing import Dict, Type
88

99
from torch.distributed.tensor import Replicate, Shard
1010
from torch.distributed.tensor.parallel import (
@@ -15,32 +15,53 @@
1515
)
1616
from torch.distributed.tensor.parallel.style import ParallelStyle
1717

18+
from torchao.float8.float8_tensor_parallel import (
19+
Float8ColwiseParallel,
20+
Float8RowwiseParallel,
21+
)
22+
1823

19-
# Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models
20-
BASE_LLAMA_TP_PLAN = {
21-
"tok_embeddings": RowwiseParallel(
22-
input_layouts=Replicate(), output_layouts=Shard(1)
23-
),
24-
"norm": SequenceParallel(),
25-
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
26-
"layers.*.attn": PrepareModuleInput(
27-
input_layouts=(Shard(1), None),
28-
desired_input_layouts=(Replicate(), None),
29-
),
30-
"layers.*.mlp": PrepareModuleInput(
31-
input_layouts=(Shard(1),),
32-
desired_input_layouts=(Replicate(),),
33-
),
34-
"layers.*.sa_norm": SequenceParallel(),
35-
"layers.*.mlp_norm": SequenceParallel(),
36-
"layers.*.attn.q_proj": ColwiseParallel(),
37-
"layers.*.attn.k_proj": ColwiseParallel(),
38-
"layers.*.attn.v_proj": ColwiseParallel(),
39-
"layers.*.attn.output_proj": RowwiseParallel(output_layouts=Shard(1)),
40-
"layers.*.mlp.w1": ColwiseParallel(),
41-
"layers.*.mlp.w2": RowwiseParallel(output_layouts=Shard(1)),
42-
"layers.*.mlp.w3": ColwiseParallel(),
43-
}
24+
def _get_base_llama_tp_plan(
25+
_sequence_parallel_cls: Type[ParallelStyle] = SequenceParallel,
26+
_colwise_parallel_cls: Type[ParallelStyle] = ColwiseParallel,
27+
_rowwise_parallel_cls: Type[ParallelStyle] = RowwiseParallel,
28+
) -> Dict[str, ParallelStyle]:
29+
"""
30+
Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models.
31+
"""
32+
return {
33+
"tok_embeddings": _rowwise_parallel_cls(
34+
input_layouts=Replicate(), output_layouts=Shard(1)
35+
),
36+
"norm": _sequence_parallel_cls(),
37+
"output": _colwise_parallel_cls(
38+
input_layouts=Shard(1), output_layouts=Replicate()
39+
),
40+
"layers.*.attn": PrepareModuleInput(
41+
input_layouts=(Shard(1), None),
42+
desired_input_layouts=(Replicate(), None),
43+
),
44+
"layers.*.mlp": PrepareModuleInput(
45+
input_layouts=(Shard(1),),
46+
desired_input_layouts=(Replicate(),),
47+
),
48+
"layers.*.sa_norm": _sequence_parallel_cls(),
49+
"layers.*.mlp_norm": _sequence_parallel_cls(),
50+
"layers.*.attn.q_proj": _colwise_parallel_cls(),
51+
"layers.*.attn.k_proj": _colwise_parallel_cls(),
52+
"layers.*.attn.v_proj": _colwise_parallel_cls(),
53+
"layers.*.attn.output_proj": _rowwise_parallel_cls(output_layouts=Shard(1)),
54+
"layers.*.mlp.w1": _colwise_parallel_cls(),
55+
"layers.*.mlp.w2": _rowwise_parallel_cls(output_layouts=Shard(1)),
56+
"layers.*.mlp.w3": _colwise_parallel_cls(),
57+
}
58+
59+
60+
_BASE_LLAMA_TP_PLAN = _get_base_llama_tp_plan()
61+
_FP8_LLAMA_TP_PLAN = _get_base_llama_tp_plan(
62+
_colwise_parallel_cls=Float8ColwiseParallel,
63+
_rowwise_parallel_cls=Float8RowwiseParallel,
64+
)
4465

4566

4667
def base_llama_tp_plan() -> Dict[str, ParallelStyle]:
@@ -50,4 +71,16 @@ def base_llama_tp_plan() -> Dict[str, ParallelStyle]:
5071
Returns:
5172
Dict[str, Any]: The tensor parallel plan for Llama3 model.
5273
"""
53-
return BASE_LLAMA_TP_PLAN
74+
return _BASE_LLAMA_TP_PLAN
75+
76+
77+
def fp8_llama_tp_plan() -> Dict[str, ParallelStyle]:
78+
"""
79+
Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both
80+
rowwise and colwise computation, currently only compatible with float8 fine-tuning with
81+
"tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models.
82+
83+
Returns:
84+
Dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model.
85+
"""
86+
return _FP8_LLAMA_TP_PLAN

torchtune/training/quantization.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,25 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Callable, Optional
7+
from typing import Callable, Dict, Optional
88

99
from torch import nn
10+
from torch.distributed.tensor.parallel.style import ParallelStyle
1011

1112
from torchao.dtypes import TensorCoreTiledLayout
12-
13+
from torchao.float8 import (
14+
convert_to_float8_training as _convert_to_float8_training_torchao,
15+
Float8LinearConfig,
16+
)
17+
from torchao.float8.float8_tensor_parallel import (
18+
Float8ColwiseParallel,
19+
Float8RowwiseParallel,
20+
)
1321
from torchao.quantization import (
1422
int4_weight_only,
1523
int8_dynamic_activation_int4_weight,
1624
quantize_,
1725
)
18-
1926
from torchao.quantization.qat import (
2027
Int4WeightOnlyQATQuantizer,
2128
Int8DynActInt4WeightQATQuantizer,
@@ -26,6 +33,7 @@
2633
enable_4w_fake_quant,
2734
enable_8da4w_fake_quant,
2835
)
36+
2937
from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear
3038

3139

@@ -219,3 +227,60 @@ def swap_lora_linear_with_qat(
219227
activation_qat_config,
220228
weight_qat_config,
221229
)
230+
231+
232+
def convert_to_float8_training(
233+
model: nn.Module,
234+
fp8_recipe_name: Optional[str] = None,
235+
) -> nn.Module:
236+
"""
237+
Prepare the model for float8 training by swapping all `nn.Linear` with `Float8Linear`.
238+
239+
Args:
240+
model (nn.Module): The model to swap linear layers on
241+
fp8_recipe_name (Optional[str]): name to identify one of the pre-made recipes,
242+
one of "tensorwise", "rowwise", and "rowwise_with_gw_hp". If not specified,
243+
defaults to "tensorwise" with "enable_fsdp_float8_all_gather=True". See
244+
https://github.com/pytorch/ao/blob/v0.9.0/torchao/float8/config.py#L150
245+
for more details.
246+
247+
Returns:
248+
(nn.Module) The new model with `Float8Linear`.
249+
"""
250+
if fp8_recipe_name is not None:
251+
fp8_config = Float8LinearConfig.from_recipe_name(fp8_recipe_name)
252+
else:
253+
fp8_config = Float8LinearConfig(enable_fsdp_float8_all_gather=True)
254+
return _convert_to_float8_training_torchao(
255+
model,
256+
config=fp8_config,
257+
module_filter_fn=lambda mod, fqn: fqn != "output",
258+
)
259+
260+
261+
def validate_float8_tp_plan(
262+
tp_plan: Optional[Dict[str, ParallelStyle]],
263+
fp8_recipe_name: Optional[str] = None,
264+
) -> None:
265+
"""
266+
Validate that the provided tensor parallel plan is compatible with the
267+
float8 settings. Specifically, float8 tensor parallel plans are only
268+
supported when using 'tensorwise' float8 recipes.
269+
"""
270+
if tp_plan is None or is_fp8_tensorwise_scaling(fp8_recipe_name):
271+
return
272+
for parallel_style in tp_plan.values():
273+
if isinstance(parallel_style, Float8ColwiseParallel) or isinstance(
274+
parallel_style, Float8RowwiseParallel
275+
):
276+
raise ValueError(
277+
"%s and %s are only compatible with 'tensorwise' float8 recipes"
278+
% (Float8ColwiseParallel.__name__, Float8RowwiseParallel.__name__)
279+
)
280+
281+
282+
def is_fp8_tensorwise_scaling(fp8_recipe_name: Optional[str]):
283+
"""
284+
Return True if the fp8 recipe name refers to 'tensorwwise' scaling.
285+
"""
286+
return fp8_recipe_name is None or fp8_recipe_name == "tensorwise"

0 commit comments

Comments
 (0)