From dcdeea4b39952a77759c27257f84fba301cc4ea3 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 1 Apr 2025 15:33:58 -0700 Subject: [PATCH] Enable FP8 full finetune distributed **Summary:** This commit adds FP8 finetuning to the `full_finetune_distributed` recipe as an optional feature. For Llama3-8B, we saw up to 14.7% improvement in finetuning throughput with no degradation in memory usage or accuracy. This feature is currently gated on PyTorch nightlies since it depends on recent features added there. However, it will be available in the next torchtune release. To use this feature, add the following to your config.yaml: ``` enable_fp8_training: true fp8_recipe_name: tensorwise # or rowwise, or rowwise_with_gw_hp ``` The default setting uses tensorwise scaling + `enable_fsdp_float8_all_gather=True`, which led to the largest speedups in our experiments. Based on https://github.com/pytorch/torchtune/pull/2404 by @nathan-az **Experimentation:** All experiments were run on 4x H100 GPUs with 94GB memory each. We finetune the model on the cleaned alpaca dataset for 1 epoch, using a batch size of 16 with torch.compile. We use the following commits from all 3 repos: ``` torchtune: b818e03 (https://github.com/andrewor14/torchtune/blob/fp8-finetuning) torchao: 5a78b70 torch: 1017927 ``` For Llama3-8B, fp8 finetuning saw 14.7% faster finetuning with no change in memory usage or quantized accuracy compared to the bf16 baseline: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2773.473 (+0.000%) 18.481 (+0.000%) 18.481 (+0.000%) 34.291 (+0.000%) fp8_noname 3182.220 (+14.738%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_tensorwise 3159.676 (+13.925%) 18.484 (+0.014%) 18.484 (+0.014%) 34.325 (+0.097%) fp8_rowwise 2790.424 (+0.611%) 18.496 (+0.078%) 18.496 (+0.078%) 34.327 (+0.103%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.584 (+0.000) 9.419 (+0.000) fp8_noname 0.585 (+0.000) 9.431 (+0.012) fp8_tensorwise 0.584 (+0.000) 9.421 (+0.002) fp8_rowwise 0.583 (-0.002) 9.421 (+0.002) ``` A few more observations here: - The best tok/s improvement was from the default setting (`fp8_noname`) - `fp8_rowwise` was the worst fp8 configuration, though still marginally better than the baseline For Llama3.1-8B, we observed similar observations, with up to 14.3% faster finetuning and no change in quantized accuracy. However, memory usage did increase minimally (+2%) for most fp8 settings: ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 2768.292 (+0.000%) 18.541 (+0.000%) 18.541 (+0.000%) 34.270 (+0.000%) fp8_noname 3164.370 (+14.308%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_tensorwise 3136.952 (+13.317%) 18.542 (+0.008%) 18.542 (+0.008%) 34.963 (+2.021%) fp8_rowwise 2790.672 (+0.808%) 18.554 (+0.073%) 18.554 (+0.073%) 34.389 (+0.348%) fp8_rowwise_with_gw_hp 3144.678 (+13.596%) 18.551 (+0.056%) 18.551 (+0.056%) 34.966 (+2.032%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.594 (+0.000) 9.087 (+0.000) fp8_noname 0.593 (-0.001) 9.070 (-0.017) fp8_tensorwise 0.593 (-0.001) 9.061 (-0.026) fp8_rowwise 0.593 (-0.000) 9.086 (-0.001) fp8_rowwise_with_gw_hp 0.595 (+0.001) 9.087 (+0.000) ``` Llama3.2-3B saw up to 16.5% faster finetuning for rowwise with high precision `grad_weight`, which is a bigger improvement than just tensorwise. Similarly, there are no degradations in memory usage or quantized accuracy. ``` experiment_name tok/s peak_mem_active peak_mem_alloc peak_mem_reserved ---------------------- ------------------- ----------------- ---------------- ------------------- full 6502.143 (+0.000%) 15.917 (+0.000%) 15.917 (+0.000%) 30.090 (+0.000%) fp8_noname 7205.386 (+10.816%) 15.917 (+0.003%) 15.917 (+0.003%) 30.010 (-0.266%) fp8_tensorwise 7222.198 (+11.074%) 15.917 (+0.003%) 15.917 (+0.003%) 30.010 (-0.266%) fp8_rowwise 6387.968 (-1.756%) 15.916 (-0.002%) 15.916 (-0.002%) 29.158 (-3.096%) fp8_rowwise_with_gw_hp 7573.698 (+16.480%) 15.917 (+0.001%) 15.917 (+0.001%) 29.516 (-1.908%) experiment_name hellaswag_acc wikitext_word_perplexity ---------------------- --------------- -------------------------- full 0.533 (+0.000) 12.407 (+0.000) fp8_noname 0.533 (+0.000) 12.414 (+0.007) fp8_tensorwise 0.533 (+0.000) 12.412 (+0.005) fp8_rowwise 0.533 (-0.000) 12.420 (+0.013) fp8_rowwise_with_gw_hp 0.534 (+0.001) 12.416 (+0.009) ``` **Test Plan:** Experiment command: ``` tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full \ enable_fp8_training=true \ fp8_recipe_name=tensorwise \ epochs=1 \ batch_size=16 \ compile=true \ dataset._component_=torchtune.datasets.alpaca_cleaned_dataset \ checkpointer.output_dir="$LOG_DIR" \ output_dir="${LOG_DIR}/metrics" \ metric_logger.log_dir="${LOG_DIR}/metrics" ``` (full script: https://github.com/andrewor14/torchtune/blob/fp8-finetuning-debug/run_it.sh) Unit tests: ``` pytest tests -k test_convert_to_float8_training pytest tests -k test_is_fp8_tensorwise_scaling ``` --- recipes/full_finetune_distributed.py | 30 +++++++ tests/torchtune/training/test_quantization.py | 71 +++++++++++++++ torchtune/models/llama3/_parallelism.py | 90 +++++++++++++------ torchtune/training/quantization.py | 72 ++++++++++++++- 4 files changed, 234 insertions(+), 29 deletions(-) create mode 100644 tests/torchtune/training/test_quantization.py diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 9dbed787db..255f826b33 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -19,6 +19,7 @@ from torch.distributed._tensor import DTensor from torch.distributed.tensor.parallel import parallelize_module from torch.optim import Optimizer +from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchtune import config, modules, training, utils @@ -33,6 +34,10 @@ TrainingProgress, ) from torchtune.training.lr_schedulers import get_lr +from torchtune.training.quantization import ( + convert_to_float8_training, + is_fp8_tensorwise_scaling, +) from tqdm import tqdm @@ -184,6 +189,8 @@ def __init__(self, cfg: DictConfig) -> None: self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) self._clip_grad_norm = cfg.get("clip_grad_norm", None) self._checkpoint_client = CheckpointClient(cfg) + self._enable_fp8_training = cfg.get("enable_fp8_training", False) + self._fp8_recipe_name = cfg.get("fp8_recipe_name", None) self._run_val_every_n_steps = cfg.get("run_val_every_n_steps", None) if self._run_val_every_n_steps is not None: @@ -567,6 +574,19 @@ def _setup_model( if self._compile: training.compile_model(model, verbose=self._is_rank_zero) + if self._enable_fp8_training: + # Requires https://github.com/pytorch/pytorch/pull/148922 + if torch.__version__ < "2.8.0.dev20250318": + raise RuntimeError( + "Float8 fine-tuning requires PyTorch 2.8.0.dev20250318 or later." + ) + if self.tp_plan is not None: + raise ValueError( + "FP8 training does not support tensor parallelism yet. " + "This will be enabled in the near future." + ) + model = convert_to_float8_training(model, self._fp8_recipe_name) + # Apply tensor parallelism to the model if self.parallel_dims.tp_enabled: if not self.parallel_dims.dp_enabled and self.fsdp_cpu_offload: @@ -922,6 +942,16 @@ def train(self) -> None: if self._lr_scheduler is not None: self._lr_scheduler.step() + # If float8 training is enabled, perform a single all-reduce to compute the + # scale for all float8 parameters efficiently instead of doing many small + # all-reduces for each parameter + if ( + self._enable_fp8_training + and is_fp8_tensorwise_scaling(self._fp8_recipe_name) + and self.dp_degree > 1 + ): + precompute_float8_dynamic_scale_for_fsdp(self._model) + loss_to_log = running_loss.detach().item() / num_tokens pbar.update(1) pbar.set_description( diff --git a/tests/torchtune/training/test_quantization.py b/tests/torchtune/training/test_quantization.py new file mode 100644 index 0000000000..6581dca99c --- /dev/null +++ b/tests/torchtune/training/test_quantization.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch + +from torchao.float8.float8_linear import Float8Linear + +from torchtune.models.llama3 import base_llama_tp_plan +from torchtune.models.llama3._parallelism import _fp8_llama_tp_plan +from torchtune.training.quantization import ( + _validate_float8_tp_plan, + convert_to_float8_training, + is_fp8_tensorwise_scaling, +) + + +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float) + self.output = torch.nn.Linear(256, 512, bias=False).to(torch.float) + + def example_inputs(self): + return (torch.randn(1, 512).to(torch.float),) + + def forward(self, x): + x = self.linear(x) + x = self.output(x) + return x + + +class TestFloat8: + def test_convert_to_float8_training(self): + """ + Test that target linear layers are converted to Float8Linear. + """ + m = M() + example_inputs = torch.randn(1, 512).to(torch.float) + m = convert_to_float8_training(m) + assert isinstance(m.linear, Float8Linear) + assert not isinstance(m.output, Float8Linear) + with pytest.raises(Exception): + m = convert_to_float8_training(m, "unrecognized_recipe_name") + + # TODO: enable when FP8 + TP is supported + def _test_validate_float8_tp_plan(self): + """ + Test that only float8 TP plan is only valid for "tensorwise" float8 recipes. + """ + _validate_float8_tp_plan(base_llama_tp_plan()) + _validate_float8_tp_plan(base_llama_tp_plan(), "anything") + _validate_float8_tp_plan(_fp8_llama_tp_plan()) + _validate_float8_tp_plan(_fp8_llama_tp_plan(), "tensorwise") + with pytest.raises(ValueError): + _validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise") + with pytest.raises(ValueError): + _validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise_with_gw_hp") + + def test_is_fp8_tensorwise_scaling(self): + """ + Test that `is_fp8_tensorwise_scaling` returns True only for tensorwise scaling. + """ + assert is_fp8_tensorwise_scaling(None) + assert is_fp8_tensorwise_scaling("tensorwise") + assert not is_fp8_tensorwise_scaling("rowwise") + assert not is_fp8_tensorwise_scaling("rowwise_with_gw_hp") diff --git a/torchtune/models/llama3/_parallelism.py b/torchtune/models/llama3/_parallelism.py index 3d636c653b..67261018ed 100644 --- a/torchtune/models/llama3/_parallelism.py +++ b/torchtune/models/llama3/_parallelism.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict +from typing import Dict, Type from torch import nn @@ -17,31 +17,56 @@ ) from torch.distributed.tensor.parallel.style import ParallelStyle -# Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models -BASE_LLAMA_TP_TRAINING_PLAN = { - "tok_embeddings": RowwiseParallel( - input_layouts=Replicate(), output_layouts=Shard(1) - ), - "norm": SequenceParallel(), - "output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()), - "layers.*.attn": PrepareModuleInput( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), - ), - "layers.*.mlp": PrepareModuleInput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - ), - "layers.*.sa_norm": SequenceParallel(), - "layers.*.mlp_norm": SequenceParallel(), - "layers.*.attn.q_proj": ColwiseParallel(), - "layers.*.attn.k_proj": ColwiseParallel(), - "layers.*.attn.v_proj": ColwiseParallel(), - "layers.*.attn.output_proj": RowwiseParallel(output_layouts=Shard(1)), - "layers.*.mlp.w1": ColwiseParallel(), - "layers.*.mlp.w2": RowwiseParallel(output_layouts=Shard(1)), - "layers.*.mlp.w3": ColwiseParallel(), -} +from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, +) + + +def _get_base_llama_tp_training_plan( + layerwise_colwise_parallel_cls: Type[ParallelStyle] = ColwiseParallel, + layerwise_rowwise_parallel_cls: Type[ParallelStyle] = RowwiseParallel, + layerwise_prepare_module_input_cls: Type[ParallelStyle] = PrepareModuleInput, +) -> Dict[str, ParallelStyle]: + """ + Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models. + """ + return { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), output_layouts=Shard(1) + ), + "norm": SequenceParallel(), + "output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()), + "layers.*.attn": layerwise_prepare_module_input_cls( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "layers.*.mlp": layerwise_prepare_module_input_cls( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "layers.*.sa_norm": SequenceParallel(), + "layers.*.mlp_norm": SequenceParallel(), + "layers.*.attn.q_proj": layerwise_colwise_parallel_cls(), + "layers.*.attn.k_proj": layerwise_colwise_parallel_cls(), + "layers.*.attn.v_proj": layerwise_colwise_parallel_cls(), + "layers.*.attn.output_proj": layerwise_rowwise_parallel_cls( + output_layouts=Shard(1) + ), + "layers.*.mlp.w1": layerwise_colwise_parallel_cls(), + "layers.*.mlp.w2": layerwise_rowwise_parallel_cls(output_layouts=Shard(1)), + "layers.*.mlp.w3": layerwise_colwise_parallel_cls(), + } + + +BASE_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan() + +FP8_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan( + layerwise_colwise_parallel_cls=Float8ColwiseParallel, + layerwise_rowwise_parallel_cls=Float8RowwiseParallel, + layerwise_prepare_module_input_cls=PrepareFloat8ModuleInput, +) BASE_LLAMA_TP_INFERENCE_PLAN = { "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), @@ -70,3 +95,16 @@ def base_llama_tp_plan( Dict[str, Any]: The tensor parallel plan for Llama3 model. """ return BASE_LLAMA_TP_INFERENCE_PLAN if inference else BASE_LLAMA_TP_TRAINING_PLAN + + +# TODO: expose this once tested +def _fp8_llama_tp_plan() -> Dict[str, ParallelStyle]: + """ + Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both + rowwise and colwise computation, currently only compatible with float8 fine-tuning with + "tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models. + + Returns: + Dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model. + """ + return FP8_LLAMA_TP_TRAINING_PLAN diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 32432f2853..17995b7c47 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -4,18 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable, Optional +from typing import Callable, Dict, Optional from torch import nn +from torch.distributed.tensor.parallel.style import ParallelStyle from torchao.dtypes import TensorCoreTiledLayout - +from torchao.float8 import ( + convert_to_float8_training as _convert_to_float8_training_torchao, + Float8LinearConfig, +) +from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, +) from torchao.quantization import ( int4_weight_only, int8_dynamic_activation_int4_weight, quantize_, ) - from torchao.quantization.qat import ( Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, @@ -26,6 +33,7 @@ enable_4w_fake_quant, enable_8da4w_fake_quant, ) + from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear @@ -219,3 +227,61 @@ def swap_lora_linear_with_qat( activation_qat_config, weight_qat_config, ) + + +def convert_to_float8_training( + model: nn.Module, + fp8_recipe_name: Optional[str] = None, +) -> nn.Module: + """ + Prepare the model for float8 training by swapping all `nn.Linear` with `Float8Linear`. + + Args: + model (nn.Module): The model to swap linear layers on + fp8_recipe_name (Optional[str]): name to identify one of the pre-made recipes, + one of "tensorwise", "rowwise", and "rowwise_with_gw_hp". If not specified, + defaults to "tensorwise" with "enable_fsdp_float8_all_gather=True". See + https://github.com/pytorch/ao/blob/v0.9.0/torchao/float8/config.py#L150 + for more details. + + Returns: + (nn.Module) The new model with `Float8Linear`. + """ + if fp8_recipe_name is not None: + fp8_config = Float8LinearConfig.from_recipe_name(fp8_recipe_name) + else: + fp8_config = Float8LinearConfig(enable_fsdp_float8_all_gather=True) + return _convert_to_float8_training_torchao( + model, + config=fp8_config, + module_filter_fn=lambda mod, fqn: fqn != "output", + ) + + +# TODO: validate this in full_finetune_distributed recipe once FP8 + TP is enabled +def _validate_float8_tp_plan( + tp_plan: Optional[Dict[str, ParallelStyle]], + fp8_recipe_name: Optional[str] = None, +) -> None: + """ + Validate that the provided tensor parallel plan is compatible with the + float8 settings. Specifically, float8 tensor parallel plans are only + supported when using 'tensorwise' float8 recipes. + """ + if tp_plan is None or is_fp8_tensorwise_scaling(fp8_recipe_name): + return + for parallel_style in tp_plan.values(): + if isinstance(parallel_style, Float8ColwiseParallel) or isinstance( + parallel_style, Float8RowwiseParallel + ): + raise ValueError( + "%s and %s are only compatible with 'tensorwise' float8 recipes" + % (Float8ColwiseParallel.__name__, Float8RowwiseParallel.__name__) + ) + + +def is_fp8_tensorwise_scaling(fp8_recipe_name: Optional[str]): + """ + Return True if the fp8 recipe name refers to 'tensorwwise' scaling. + """ + return fp8_recipe_name is None or fp8_recipe_name == "tensorwise"