Skip to content

Commit 1075d9c

Browse files
authored
Enable FP8 full finetune distributed (#2546)
1 parent f478cc3 commit 1075d9c

File tree

4 files changed

+234
-29
lines changed

4 files changed

+234
-29
lines changed

recipes/full_finetune_distributed.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.distributed._tensor import DTensor
2020
from torch.distributed.tensor.parallel import parallelize_module
2121
from torch.optim import Optimizer
22+
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
2223
from torchdata.stateful_dataloader import StatefulDataLoader
2324
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
2425
from torchtune import config, modules, training, utils
@@ -33,6 +34,10 @@
3334
TrainingProgress,
3435
)
3536
from torchtune.training.lr_schedulers import get_lr
37+
from torchtune.training.quantization import (
38+
convert_to_float8_training,
39+
is_fp8_tensorwise_scaling,
40+
)
3641

3742
from tqdm import tqdm
3843

@@ -184,6 +189,8 @@ def __init__(self, cfg: DictConfig) -> None:
184189
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
185190
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
186191
self._checkpoint_client = CheckpointClient(cfg)
192+
self._enable_fp8_training = cfg.get("enable_fp8_training", False)
193+
self._fp8_recipe_name = cfg.get("fp8_recipe_name", None)
187194

188195
self._run_val_every_n_steps = cfg.get("run_val_every_n_steps", None)
189196
if self._run_val_every_n_steps is not None:
@@ -567,6 +574,19 @@ def _setup_model(
567574
if self._compile:
568575
training.compile_model(model, verbose=self._is_rank_zero)
569576

577+
if self._enable_fp8_training:
578+
# Requires https://github.com/pytorch/pytorch/pull/148922
579+
if torch.__version__ < "2.8.0.dev20250318":
580+
raise RuntimeError(
581+
"Float8 fine-tuning requires PyTorch 2.8.0.dev20250318 or later."
582+
)
583+
if self.tp_plan is not None:
584+
raise ValueError(
585+
"FP8 training does not support tensor parallelism yet. "
586+
"This will be enabled in the near future."
587+
)
588+
model = convert_to_float8_training(model, self._fp8_recipe_name)
589+
570590
# Apply tensor parallelism to the model
571591
if self.parallel_dims.tp_enabled:
572592
if not self.parallel_dims.dp_enabled and self.fsdp_cpu_offload:
@@ -922,6 +942,16 @@ def train(self) -> None:
922942
if self._lr_scheduler is not None:
923943
self._lr_scheduler.step()
924944

945+
# If float8 training is enabled, perform a single all-reduce to compute the
946+
# scale for all float8 parameters efficiently instead of doing many small
947+
# all-reduces for each parameter
948+
if (
949+
self._enable_fp8_training
950+
and is_fp8_tensorwise_scaling(self._fp8_recipe_name)
951+
and self.dp_degree > 1
952+
):
953+
precompute_float8_dynamic_scale_for_fsdp(self._model)
954+
925955
loss_to_log = running_loss.detach().item() / num_tokens
926956
pbar.update(1)
927957
pbar.set_description(
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
9+
import torch
10+
11+
from torchao.float8.float8_linear import Float8Linear
12+
13+
from torchtune.models.llama3 import base_llama_tp_plan
14+
from torchtune.models.llama3._parallelism import _fp8_llama_tp_plan
15+
from torchtune.training.quantization import (
16+
_validate_float8_tp_plan,
17+
convert_to_float8_training,
18+
is_fp8_tensorwise_scaling,
19+
)
20+
21+
22+
class M(torch.nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float)
26+
self.output = torch.nn.Linear(256, 512, bias=False).to(torch.float)
27+
28+
def example_inputs(self):
29+
return (torch.randn(1, 512).to(torch.float),)
30+
31+
def forward(self, x):
32+
x = self.linear(x)
33+
x = self.output(x)
34+
return x
35+
36+
37+
class TestFloat8:
38+
def test_convert_to_float8_training(self):
39+
"""
40+
Test that target linear layers are converted to Float8Linear.
41+
"""
42+
m = M()
43+
example_inputs = torch.randn(1, 512).to(torch.float)
44+
m = convert_to_float8_training(m)
45+
assert isinstance(m.linear, Float8Linear)
46+
assert not isinstance(m.output, Float8Linear)
47+
with pytest.raises(Exception):
48+
m = convert_to_float8_training(m, "unrecognized_recipe_name")
49+
50+
# TODO: enable when FP8 + TP is supported
51+
def _test_validate_float8_tp_plan(self):
52+
"""
53+
Test that only float8 TP plan is only valid for "tensorwise" float8 recipes.
54+
"""
55+
_validate_float8_tp_plan(base_llama_tp_plan())
56+
_validate_float8_tp_plan(base_llama_tp_plan(), "anything")
57+
_validate_float8_tp_plan(_fp8_llama_tp_plan())
58+
_validate_float8_tp_plan(_fp8_llama_tp_plan(), "tensorwise")
59+
with pytest.raises(ValueError):
60+
_validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise")
61+
with pytest.raises(ValueError):
62+
_validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise_with_gw_hp")
63+
64+
def test_is_fp8_tensorwise_scaling(self):
65+
"""
66+
Test that `is_fp8_tensorwise_scaling` returns True only for tensorwise scaling.
67+
"""
68+
assert is_fp8_tensorwise_scaling(None)
69+
assert is_fp8_tensorwise_scaling("tensorwise")
70+
assert not is_fp8_tensorwise_scaling("rowwise")
71+
assert not is_fp8_tensorwise_scaling("rowwise_with_gw_hp")

torchtune/models/llama3/_parallelism.py

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Dict
7+
from typing import Dict, Type
88

99
from torch import nn
1010

@@ -17,31 +17,56 @@
1717
)
1818
from torch.distributed.tensor.parallel.style import ParallelStyle
1919

20-
# Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models
21-
BASE_LLAMA_TP_TRAINING_PLAN = {
22-
"tok_embeddings": RowwiseParallel(
23-
input_layouts=Replicate(), output_layouts=Shard(1)
24-
),
25-
"norm": SequenceParallel(),
26-
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
27-
"layers.*.attn": PrepareModuleInput(
28-
input_layouts=(Shard(1), None),
29-
desired_input_layouts=(Replicate(), None),
30-
),
31-
"layers.*.mlp": PrepareModuleInput(
32-
input_layouts=(Shard(1),),
33-
desired_input_layouts=(Replicate(),),
34-
),
35-
"layers.*.sa_norm": SequenceParallel(),
36-
"layers.*.mlp_norm": SequenceParallel(),
37-
"layers.*.attn.q_proj": ColwiseParallel(),
38-
"layers.*.attn.k_proj": ColwiseParallel(),
39-
"layers.*.attn.v_proj": ColwiseParallel(),
40-
"layers.*.attn.output_proj": RowwiseParallel(output_layouts=Shard(1)),
41-
"layers.*.mlp.w1": ColwiseParallel(),
42-
"layers.*.mlp.w2": RowwiseParallel(output_layouts=Shard(1)),
43-
"layers.*.mlp.w3": ColwiseParallel(),
44-
}
20+
from torchao.float8.float8_tensor_parallel import (
21+
Float8ColwiseParallel,
22+
Float8RowwiseParallel,
23+
PrepareFloat8ModuleInput,
24+
)
25+
26+
27+
def _get_base_llama_tp_training_plan(
28+
layerwise_colwise_parallel_cls: Type[ParallelStyle] = ColwiseParallel,
29+
layerwise_rowwise_parallel_cls: Type[ParallelStyle] = RowwiseParallel,
30+
layerwise_prepare_module_input_cls: Type[ParallelStyle] = PrepareModuleInput,
31+
) -> Dict[str, ParallelStyle]:
32+
"""
33+
Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models.
34+
"""
35+
return {
36+
"tok_embeddings": RowwiseParallel(
37+
input_layouts=Replicate(), output_layouts=Shard(1)
38+
),
39+
"norm": SequenceParallel(),
40+
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
41+
"layers.*.attn": layerwise_prepare_module_input_cls(
42+
input_layouts=(Shard(1), None),
43+
desired_input_layouts=(Replicate(), None),
44+
),
45+
"layers.*.mlp": layerwise_prepare_module_input_cls(
46+
input_layouts=(Shard(1),),
47+
desired_input_layouts=(Replicate(),),
48+
),
49+
"layers.*.sa_norm": SequenceParallel(),
50+
"layers.*.mlp_norm": SequenceParallel(),
51+
"layers.*.attn.q_proj": layerwise_colwise_parallel_cls(),
52+
"layers.*.attn.k_proj": layerwise_colwise_parallel_cls(),
53+
"layers.*.attn.v_proj": layerwise_colwise_parallel_cls(),
54+
"layers.*.attn.output_proj": layerwise_rowwise_parallel_cls(
55+
output_layouts=Shard(1)
56+
),
57+
"layers.*.mlp.w1": layerwise_colwise_parallel_cls(),
58+
"layers.*.mlp.w2": layerwise_rowwise_parallel_cls(output_layouts=Shard(1)),
59+
"layers.*.mlp.w3": layerwise_colwise_parallel_cls(),
60+
}
61+
62+
63+
BASE_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan()
64+
65+
FP8_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan(
66+
layerwise_colwise_parallel_cls=Float8ColwiseParallel,
67+
layerwise_rowwise_parallel_cls=Float8RowwiseParallel,
68+
layerwise_prepare_module_input_cls=PrepareFloat8ModuleInput,
69+
)
4570

4671
BASE_LLAMA_TP_INFERENCE_PLAN = {
4772
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
@@ -70,3 +95,16 @@ def base_llama_tp_plan(
7095
Dict[str, Any]: The tensor parallel plan for Llama3 model.
7196
"""
7297
return BASE_LLAMA_TP_INFERENCE_PLAN if inference else BASE_LLAMA_TP_TRAINING_PLAN
98+
99+
100+
# TODO: expose this once tested
101+
def _fp8_llama_tp_plan() -> Dict[str, ParallelStyle]:
102+
"""
103+
Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both
104+
rowwise and colwise computation, currently only compatible with float8 fine-tuning with
105+
"tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models.
106+
107+
Returns:
108+
Dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model.
109+
"""
110+
return FP8_LLAMA_TP_TRAINING_PLAN

torchtune/training/quantization.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,25 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Callable, Optional
7+
from typing import Callable, Dict, Optional
88

99
from torch import nn
10+
from torch.distributed.tensor.parallel.style import ParallelStyle
1011

1112
from torchao.dtypes import TensorCoreTiledLayout
12-
13+
from torchao.float8 import (
14+
convert_to_float8_training as _convert_to_float8_training_torchao,
15+
Float8LinearConfig,
16+
)
17+
from torchao.float8.float8_tensor_parallel import (
18+
Float8ColwiseParallel,
19+
Float8RowwiseParallel,
20+
)
1321
from torchao.quantization import (
1422
int4_weight_only,
1523
int8_dynamic_activation_int4_weight,
1624
quantize_,
1725
)
18-
1926
from torchao.quantization.qat import (
2027
Int4WeightOnlyQATQuantizer,
2128
Int8DynActInt4WeightQATQuantizer,
@@ -26,6 +33,7 @@
2633
enable_4w_fake_quant,
2734
enable_8da4w_fake_quant,
2835
)
36+
2937
from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear
3038

3139

@@ -219,3 +227,61 @@ def swap_lora_linear_with_qat(
219227
activation_qat_config,
220228
weight_qat_config,
221229
)
230+
231+
232+
def convert_to_float8_training(
233+
model: nn.Module,
234+
fp8_recipe_name: Optional[str] = None,
235+
) -> nn.Module:
236+
"""
237+
Prepare the model for float8 training by swapping all `nn.Linear` with `Float8Linear`.
238+
239+
Args:
240+
model (nn.Module): The model to swap linear layers on
241+
fp8_recipe_name (Optional[str]): name to identify one of the pre-made recipes,
242+
one of "tensorwise", "rowwise", and "rowwise_with_gw_hp". If not specified,
243+
defaults to "tensorwise" with "enable_fsdp_float8_all_gather=True". See
244+
https://github.com/pytorch/ao/blob/v0.9.0/torchao/float8/config.py#L150
245+
for more details.
246+
247+
Returns:
248+
(nn.Module) The new model with `Float8Linear`.
249+
"""
250+
if fp8_recipe_name is not None:
251+
fp8_config = Float8LinearConfig.from_recipe_name(fp8_recipe_name)
252+
else:
253+
fp8_config = Float8LinearConfig(enable_fsdp_float8_all_gather=True)
254+
return _convert_to_float8_training_torchao(
255+
model,
256+
config=fp8_config,
257+
module_filter_fn=lambda mod, fqn: fqn != "output",
258+
)
259+
260+
261+
# TODO: validate this in full_finetune_distributed recipe once FP8 + TP is enabled
262+
def _validate_float8_tp_plan(
263+
tp_plan: Optional[Dict[str, ParallelStyle]],
264+
fp8_recipe_name: Optional[str] = None,
265+
) -> None:
266+
"""
267+
Validate that the provided tensor parallel plan is compatible with the
268+
float8 settings. Specifically, float8 tensor parallel plans are only
269+
supported when using 'tensorwise' float8 recipes.
270+
"""
271+
if tp_plan is None or is_fp8_tensorwise_scaling(fp8_recipe_name):
272+
return
273+
for parallel_style in tp_plan.values():
274+
if isinstance(parallel_style, Float8ColwiseParallel) or isinstance(
275+
parallel_style, Float8RowwiseParallel
276+
):
277+
raise ValueError(
278+
"%s and %s are only compatible with 'tensorwise' float8 recipes"
279+
% (Float8ColwiseParallel.__name__, Float8RowwiseParallel.__name__)
280+
)
281+
282+
283+
def is_fp8_tensorwise_scaling(fp8_recipe_name: Optional[str]):
284+
"""
285+
Return True if the fp8 recipe name refers to 'tensorwwise' scaling.
286+
"""
287+
return fp8_recipe_name is None or fp8_recipe_name == "tensorwise"

0 commit comments

Comments
 (0)