diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 9dbed787db..255f826b33 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -19,6 +19,7 @@ from torch.distributed._tensor import DTensor from torch.distributed.tensor.parallel import parallelize_module from torch.optim import Optimizer +from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchtune import config, modules, training, utils @@ -33,6 +34,10 @@ TrainingProgress, ) from torchtune.training.lr_schedulers import get_lr +from torchtune.training.quantization import ( + convert_to_float8_training, + is_fp8_tensorwise_scaling, +) from tqdm import tqdm @@ -184,6 +189,8 @@ def __init__(self, cfg: DictConfig) -> None: self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) self._clip_grad_norm = cfg.get("clip_grad_norm", None) self._checkpoint_client = CheckpointClient(cfg) + self._enable_fp8_training = cfg.get("enable_fp8_training", False) + self._fp8_recipe_name = cfg.get("fp8_recipe_name", None) self._run_val_every_n_steps = cfg.get("run_val_every_n_steps", None) if self._run_val_every_n_steps is not None: @@ -567,6 +574,19 @@ def _setup_model( if self._compile: training.compile_model(model, verbose=self._is_rank_zero) + if self._enable_fp8_training: + # Requires https://github.com/pytorch/pytorch/pull/148922 + if torch.__version__ < "2.8.0.dev20250318": + raise RuntimeError( + "Float8 fine-tuning requires PyTorch 2.8.0.dev20250318 or later." + ) + if self.tp_plan is not None: + raise ValueError( + "FP8 training does not support tensor parallelism yet. " + "This will be enabled in the near future." + ) + model = convert_to_float8_training(model, self._fp8_recipe_name) + # Apply tensor parallelism to the model if self.parallel_dims.tp_enabled: if not self.parallel_dims.dp_enabled and self.fsdp_cpu_offload: @@ -922,6 +942,16 @@ def train(self) -> None: if self._lr_scheduler is not None: self._lr_scheduler.step() + # If float8 training is enabled, perform a single all-reduce to compute the + # scale for all float8 parameters efficiently instead of doing many small + # all-reduces for each parameter + if ( + self._enable_fp8_training + and is_fp8_tensorwise_scaling(self._fp8_recipe_name) + and self.dp_degree > 1 + ): + precompute_float8_dynamic_scale_for_fsdp(self._model) + loss_to_log = running_loss.detach().item() / num_tokens pbar.update(1) pbar.set_description( diff --git a/tests/torchtune/training/test_quantization.py b/tests/torchtune/training/test_quantization.py new file mode 100644 index 0000000000..6581dca99c --- /dev/null +++ b/tests/torchtune/training/test_quantization.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch + +from torchao.float8.float8_linear import Float8Linear + +from torchtune.models.llama3 import base_llama_tp_plan +from torchtune.models.llama3._parallelism import _fp8_llama_tp_plan +from torchtune.training.quantization import ( + _validate_float8_tp_plan, + convert_to_float8_training, + is_fp8_tensorwise_scaling, +) + + +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float) + self.output = torch.nn.Linear(256, 512, bias=False).to(torch.float) + + def example_inputs(self): + return (torch.randn(1, 512).to(torch.float),) + + def forward(self, x): + x = self.linear(x) + x = self.output(x) + return x + + +class TestFloat8: + def test_convert_to_float8_training(self): + """ + Test that target linear layers are converted to Float8Linear. + """ + m = M() + example_inputs = torch.randn(1, 512).to(torch.float) + m = convert_to_float8_training(m) + assert isinstance(m.linear, Float8Linear) + assert not isinstance(m.output, Float8Linear) + with pytest.raises(Exception): + m = convert_to_float8_training(m, "unrecognized_recipe_name") + + # TODO: enable when FP8 + TP is supported + def _test_validate_float8_tp_plan(self): + """ + Test that only float8 TP plan is only valid for "tensorwise" float8 recipes. + """ + _validate_float8_tp_plan(base_llama_tp_plan()) + _validate_float8_tp_plan(base_llama_tp_plan(), "anything") + _validate_float8_tp_plan(_fp8_llama_tp_plan()) + _validate_float8_tp_plan(_fp8_llama_tp_plan(), "tensorwise") + with pytest.raises(ValueError): + _validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise") + with pytest.raises(ValueError): + _validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise_with_gw_hp") + + def test_is_fp8_tensorwise_scaling(self): + """ + Test that `is_fp8_tensorwise_scaling` returns True only for tensorwise scaling. + """ + assert is_fp8_tensorwise_scaling(None) + assert is_fp8_tensorwise_scaling("tensorwise") + assert not is_fp8_tensorwise_scaling("rowwise") + assert not is_fp8_tensorwise_scaling("rowwise_with_gw_hp") diff --git a/torchtune/models/llama3/_parallelism.py b/torchtune/models/llama3/_parallelism.py index 3d636c653b..67261018ed 100644 --- a/torchtune/models/llama3/_parallelism.py +++ b/torchtune/models/llama3/_parallelism.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict +from typing import Dict, Type from torch import nn @@ -17,31 +17,56 @@ ) from torch.distributed.tensor.parallel.style import ParallelStyle -# Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models -BASE_LLAMA_TP_TRAINING_PLAN = { - "tok_embeddings": RowwiseParallel( - input_layouts=Replicate(), output_layouts=Shard(1) - ), - "norm": SequenceParallel(), - "output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()), - "layers.*.attn": PrepareModuleInput( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), - ), - "layers.*.mlp": PrepareModuleInput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - ), - "layers.*.sa_norm": SequenceParallel(), - "layers.*.mlp_norm": SequenceParallel(), - "layers.*.attn.q_proj": ColwiseParallel(), - "layers.*.attn.k_proj": ColwiseParallel(), - "layers.*.attn.v_proj": ColwiseParallel(), - "layers.*.attn.output_proj": RowwiseParallel(output_layouts=Shard(1)), - "layers.*.mlp.w1": ColwiseParallel(), - "layers.*.mlp.w2": RowwiseParallel(output_layouts=Shard(1)), - "layers.*.mlp.w3": ColwiseParallel(), -} +from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, +) + + +def _get_base_llama_tp_training_plan( + layerwise_colwise_parallel_cls: Type[ParallelStyle] = ColwiseParallel, + layerwise_rowwise_parallel_cls: Type[ParallelStyle] = RowwiseParallel, + layerwise_prepare_module_input_cls: Type[ParallelStyle] = PrepareModuleInput, +) -> Dict[str, ParallelStyle]: + """ + Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models. + """ + return { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), output_layouts=Shard(1) + ), + "norm": SequenceParallel(), + "output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()), + "layers.*.attn": layerwise_prepare_module_input_cls( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "layers.*.mlp": layerwise_prepare_module_input_cls( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "layers.*.sa_norm": SequenceParallel(), + "layers.*.mlp_norm": SequenceParallel(), + "layers.*.attn.q_proj": layerwise_colwise_parallel_cls(), + "layers.*.attn.k_proj": layerwise_colwise_parallel_cls(), + "layers.*.attn.v_proj": layerwise_colwise_parallel_cls(), + "layers.*.attn.output_proj": layerwise_rowwise_parallel_cls( + output_layouts=Shard(1) + ), + "layers.*.mlp.w1": layerwise_colwise_parallel_cls(), + "layers.*.mlp.w2": layerwise_rowwise_parallel_cls(output_layouts=Shard(1)), + "layers.*.mlp.w3": layerwise_colwise_parallel_cls(), + } + + +BASE_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan() + +FP8_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan( + layerwise_colwise_parallel_cls=Float8ColwiseParallel, + layerwise_rowwise_parallel_cls=Float8RowwiseParallel, + layerwise_prepare_module_input_cls=PrepareFloat8ModuleInput, +) BASE_LLAMA_TP_INFERENCE_PLAN = { "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), @@ -70,3 +95,16 @@ def base_llama_tp_plan( Dict[str, Any]: The tensor parallel plan for Llama3 model. """ return BASE_LLAMA_TP_INFERENCE_PLAN if inference else BASE_LLAMA_TP_TRAINING_PLAN + + +# TODO: expose this once tested +def _fp8_llama_tp_plan() -> Dict[str, ParallelStyle]: + """ + Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both + rowwise and colwise computation, currently only compatible with float8 fine-tuning with + "tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models. + + Returns: + Dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model. + """ + return FP8_LLAMA_TP_TRAINING_PLAN diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 32432f2853..17995b7c47 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -4,18 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Optional +from typing import Callable, Dict, Optional from torch import nn +from torch.distributed.tensor.parallel.style import ParallelStyle from torchao.dtypes import TensorCoreTiledLayout - +from torchao.float8 import ( + convert_to_float8_training as _convert_to_float8_training_torchao, + Float8LinearConfig, +) +from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, +) from torchao.quantization import ( int4_weight_only, int8_dynamic_activation_int4_weight, quantize_, ) - from torchao.quantization.qat import ( Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, @@ -26,6 +33,7 @@ enable_4w_fake_quant, enable_8da4w_fake_quant, ) + from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear @@ -219,3 +227,61 @@ def swap_lora_linear_with_qat( activation_qat_config, weight_qat_config, ) + + +def convert_to_float8_training( + model: nn.Module, + fp8_recipe_name: Optional[str] = None, +) -> nn.Module: + """ + Prepare the model for float8 training by swapping all `nn.Linear` with `Float8Linear`. + + Args: + model (nn.Module): The model to swap linear layers on + fp8_recipe_name (Optional[str]): name to identify one of the pre-made recipes, + one of "tensorwise", "rowwise", and "rowwise_with_gw_hp". If not specified, + defaults to "tensorwise" with "enable_fsdp_float8_all_gather=True". See + https://github.com/pytorch/ao/blob/v0.9.0/torchao/float8/config.py#L150 + for more details. + + Returns: + (nn.Module) The new model with `Float8Linear`. + """ + if fp8_recipe_name is not None: + fp8_config = Float8LinearConfig.from_recipe_name(fp8_recipe_name) + else: + fp8_config = Float8LinearConfig(enable_fsdp_float8_all_gather=True) + return _convert_to_float8_training_torchao( + model, + config=fp8_config, + module_filter_fn=lambda mod, fqn: fqn != "output", + ) + + +# TODO: validate this in full_finetune_distributed recipe once FP8 + TP is enabled +def _validate_float8_tp_plan( + tp_plan: Optional[Dict[str, ParallelStyle]], + fp8_recipe_name: Optional[str] = None, +) -> None: + """ + Validate that the provided tensor parallel plan is compatible with the + float8 settings. Specifically, float8 tensor parallel plans are only + supported when using 'tensorwise' float8 recipes. + """ + if tp_plan is None or is_fp8_tensorwise_scaling(fp8_recipe_name): + return + for parallel_style in tp_plan.values(): + if isinstance(parallel_style, Float8ColwiseParallel) or isinstance( + parallel_style, Float8RowwiseParallel + ): + raise ValueError( + "%s and %s are only compatible with 'tensorwise' float8 recipes" + % (Float8ColwiseParallel.__name__, Float8RowwiseParallel.__name__) + ) + + +def is_fp8_tensorwise_scaling(fp8_recipe_name: Optional[str]): + """ + Return True if the fp8 recipe name refers to 'tensorwwise' scaling. + """ + return fp8_recipe_name is None or fp8_recipe_name == "tensorwise"