Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.distributed._tensor import DTensor
from torch.distributed.tensor.parallel import parallelize_module
from torch.optim import Optimizer
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
from torchtune import config, modules, training, utils
Expand All @@ -33,6 +34,10 @@
TrainingProgress,
)
from torchtune.training.lr_schedulers import get_lr
from torchtune.training.quantization import (
convert_to_float8_training,
is_fp8_tensorwise_scaling,
)

from tqdm import tqdm

Expand Down Expand Up @@ -184,6 +189,8 @@ def __init__(self, cfg: DictConfig) -> None:
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
self._checkpoint_client = CheckpointClient(cfg)
self._enable_fp8_training = cfg.get("enable_fp8_training", False)
self._fp8_recipe_name = cfg.get("fp8_recipe_name", None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo does this level of configuration look good to you? Any other fields you think I should expose?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember you mentioned a performance regression in previous experiments, and we then root caused it to the fact that Float8Linear was applied on parts of the model where torch.compile was not applied to. I'm not familiar with how torch.compile is applied in torchtune, but it would be good to ensure Float8Linear is not applied to any regions which are not using compile.

Copy link
Collaborator

@nathan-az nathan-az Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo @andrewor14 From torchao source it appears rowwise scaling options do not support force_recompute_fp8_weight_in_bwd by default.

Given we are defaulting enable_fsdp_float8_all_gather=True in the case where a recipe name isn't provided, should we also set force_recompute_fp8_weight_in_bwd or should this be an option here in the config?

I noticed a TODO in torchao to set this to True by default in future - I wonder if we should do that in tune from now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @vkuzo mentioned that force_recompute_fp8_weight_in_bwd is no longer needed after a certain commit. This PR gates on the nightlies so we don't need this flag anymore

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with how torch.compile is applied in torchtune, but it would be good to ensure Float8Linear is not applied to any regions which are not using compile.

Regarding @vkuzo's comment, torchtune does per-layer compile with separate compilation of the loss. You can see the utility we use here. This means that there may be two cases that are not covered:

  1. output projection (i.e. LM head)
  2. any linear layers in e.g. a vision encoder for a multimodal model

Given the numbers in the test plan it seems (1) is not a hard blocker. We could try out with e.g. Llama 3.2 Vision just to make sure there are no regressions there. Otherwise I don't have any huge concerns here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Yes we already do (1) in this PR, will try Llama 3.2 Vision


self._run_val_every_n_steps = cfg.get("run_val_every_n_steps", None)
if self._run_val_every_n_steps is not None:
Expand Down Expand Up @@ -567,6 +574,19 @@ def _setup_model(
if self._compile:
training.compile_model(model, verbose=self._is_rank_zero)

if self._enable_fp8_training:
# Requires https://github.com/pytorch/pytorch/pull/148922
if torch.__version__ < "2.8.0.dev20250318":
raise RuntimeError(
"Float8 fine-tuning requires PyTorch 2.8.0.dev20250318 or later."
)
if self.tp_plan is not None:
raise ValueError(
"FP8 training does not support tensor parallelism yet. "
"This will be enabled in the near future."
)
model = convert_to_float8_training(model, self._fp8_recipe_name)

# Apply tensor parallelism to the model
if self.parallel_dims.tp_enabled:
if not self.parallel_dims.dp_enabled and self.fsdp_cpu_offload:
Expand Down Expand Up @@ -922,6 +942,16 @@ def train(self) -> None:
if self._lr_scheduler is not None:
self._lr_scheduler.step()

# If float8 training is enabled, perform a single all-reduce to compute the
# scale for all float8 parameters efficiently instead of doing many small
# all-reduces for each parameter
if (
self._enable_fp8_training
and is_fp8_tensorwise_scaling(self._fp8_recipe_name)
and self.dp_degree > 1
):
precompute_float8_dynamic_scale_for_fsdp(self._model)

loss_to_log = running_loss.detach().item() / num_tokens
pbar.update(1)
pbar.set_description(
Expand Down
71 changes: 71 additions & 0 deletions tests/torchtune/training/test_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch

from torchao.float8.float8_linear import Float8Linear

from torchtune.models.llama3 import base_llama_tp_plan
from torchtune.models.llama3._parallelism import _fp8_llama_tp_plan
from torchtune.training.quantization import (
_validate_float8_tp_plan,
convert_to_float8_training,
is_fp8_tensorwise_scaling,
)


class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(512, 256, bias=False).to(torch.float)
self.output = torch.nn.Linear(256, 512, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 512).to(torch.float),)

def forward(self, x):
x = self.linear(x)
x = self.output(x)
return x


class TestFloat8:
def test_convert_to_float8_training(self):
"""
Test that target linear layers are converted to Float8Linear.
"""
m = M()
example_inputs = torch.randn(1, 512).to(torch.float)
m = convert_to_float8_training(m)
assert isinstance(m.linear, Float8Linear)
assert not isinstance(m.output, Float8Linear)
with pytest.raises(Exception):
m = convert_to_float8_training(m, "unrecognized_recipe_name")

# TODO: enable when FP8 + TP is supported
def _test_validate_float8_tp_plan(self):
"""
Test that only float8 TP plan is only valid for "tensorwise" float8 recipes.
"""
_validate_float8_tp_plan(base_llama_tp_plan())
_validate_float8_tp_plan(base_llama_tp_plan(), "anything")
_validate_float8_tp_plan(_fp8_llama_tp_plan())
_validate_float8_tp_plan(_fp8_llama_tp_plan(), "tensorwise")
with pytest.raises(ValueError):
_validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise")
with pytest.raises(ValueError):
_validate_float8_tp_plan(_fp8_llama_tp_plan(), "rowwise_with_gw_hp")

def test_is_fp8_tensorwise_scaling(self):
"""
Test that `is_fp8_tensorwise_scaling` returns True only for tensorwise scaling.
"""
assert is_fp8_tensorwise_scaling(None)
assert is_fp8_tensorwise_scaling("tensorwise")
assert not is_fp8_tensorwise_scaling("rowwise")
assert not is_fp8_tensorwise_scaling("rowwise_with_gw_hp")
90 changes: 64 additions & 26 deletions torchtune/models/llama3/_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict
from typing import Dict, Type

from torch import nn

Expand All @@ -17,31 +17,56 @@
)
from torch.distributed.tensor.parallel.style import ParallelStyle

# Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models
BASE_LLAMA_TP_TRAINING_PLAN = {
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(), output_layouts=Shard(1)
),
"norm": SequenceParallel(),
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
"layers.*.attn": PrepareModuleInput(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"layers.*.mlp": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"layers.*.sa_norm": SequenceParallel(),
"layers.*.mlp_norm": SequenceParallel(),
"layers.*.attn.q_proj": ColwiseParallel(),
"layers.*.attn.k_proj": ColwiseParallel(),
"layers.*.attn.v_proj": ColwiseParallel(),
"layers.*.attn.output_proj": RowwiseParallel(output_layouts=Shard(1)),
"layers.*.mlp.w1": ColwiseParallel(),
"layers.*.mlp.w2": RowwiseParallel(output_layouts=Shard(1)),
"layers.*.mlp.w3": ColwiseParallel(),
}
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)


def _get_base_llama_tp_training_plan(
layerwise_colwise_parallel_cls: Type[ParallelStyle] = ColwiseParallel,
layerwise_rowwise_parallel_cls: Type[ParallelStyle] = RowwiseParallel,
layerwise_prepare_module_input_cls: Type[ParallelStyle] = PrepareModuleInput,
) -> Dict[str, ParallelStyle]:
"""
Define the Tensor Parallel plan for Llama3 model, which will also be shared with 3.1, 3.2, and 3.3 models.
"""
return {
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(), output_layouts=Shard(1)
),
"norm": SequenceParallel(),
"output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
"layers.*.attn": layerwise_prepare_module_input_cls(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"layers.*.mlp": layerwise_prepare_module_input_cls(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"layers.*.sa_norm": SequenceParallel(),
"layers.*.mlp_norm": SequenceParallel(),
"layers.*.attn.q_proj": layerwise_colwise_parallel_cls(),
"layers.*.attn.k_proj": layerwise_colwise_parallel_cls(),
"layers.*.attn.v_proj": layerwise_colwise_parallel_cls(),
"layers.*.attn.output_proj": layerwise_rowwise_parallel_cls(
output_layouts=Shard(1)
),
"layers.*.mlp.w1": layerwise_colwise_parallel_cls(),
"layers.*.mlp.w2": layerwise_rowwise_parallel_cls(output_layouts=Shard(1)),
"layers.*.mlp.w3": layerwise_colwise_parallel_cls(),
}


BASE_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan()

FP8_LLAMA_TP_TRAINING_PLAN = _get_base_llama_tp_training_plan(
layerwise_colwise_parallel_cls=Float8ColwiseParallel,
layerwise_rowwise_parallel_cls=Float8RowwiseParallel,
layerwise_prepare_module_input_cls=PrepareFloat8ModuleInput,
)

BASE_LLAMA_TP_INFERENCE_PLAN = {
"tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
Expand Down Expand Up @@ -70,3 +95,16 @@ def base_llama_tp_plan(
Dict[str, Any]: The tensor parallel plan for Llama3 model.
"""
return BASE_LLAMA_TP_INFERENCE_PLAN if inference else BASE_LLAMA_TP_TRAINING_PLAN


# TODO: expose this once tested
def _fp8_llama_tp_plan() -> Dict[str, ParallelStyle]:
"""
Return the tensor parallel plan for Llama3 model that uses float8 for all-gather for both
rowwise and colwise computation, currently only compatible with float8 fine-tuning with
"tensorwise" scaling. This tensor parallel plan is shared between 3.1, 3.2, and 3.3 models.

Returns:
Dict[str, Any]: The float8-enabled tensor parallel plan for Llama3 model.
"""
return FP8_LLAMA_TP_TRAINING_PLAN
72 changes: 69 additions & 3 deletions torchtune/training/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Optional
from typing import Callable, Dict, Optional

from torch import nn
from torch.distributed.tensor.parallel.style import ParallelStyle

from torchao.dtypes import TensorCoreTiledLayout

from torchao.float8 import (
convert_to_float8_training as _convert_to_float8_training_torchao,
Float8LinearConfig,
)
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
)
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int4_weight,
quantize_,
)

from torchao.quantization.qat import (
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
Expand All @@ -26,6 +33,7 @@
enable_4w_fake_quant,
enable_8da4w_fake_quant,
)

from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear


Expand Down Expand Up @@ -219,3 +227,61 @@ def swap_lora_linear_with_qat(
activation_qat_config,
weight_qat_config,
)


def convert_to_float8_training(
model: nn.Module,
fp8_recipe_name: Optional[str] = None,
) -> nn.Module:
"""
Prepare the model for float8 training by swapping all `nn.Linear` with `Float8Linear`.

Args:
model (nn.Module): The model to swap linear layers on
fp8_recipe_name (Optional[str]): name to identify one of the pre-made recipes,
one of "tensorwise", "rowwise", and "rowwise_with_gw_hp". If not specified,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob float8 q: seems like tensorwise gives the best perf in your runs, is that typically the case? I assume the rowwise scaling will have less degradation, though doesn't seem to show up in the eval numbers at all. (Mainly want to gain some understanding of what a sensible default is for folks to use.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is rowwise is typically a bit slower but more accurate. I think we don't see the "more accurate" part in the experiments because for this workload tensorwise is already preserving accuracies very well, but it's possible we'll need rowwise for other workloads still

defaults to "tensorwise" with "enable_fsdp_float8_all_gather=True". See
https://github.com/pytorch/ao/blob/v0.9.0/torchao/float8/config.py#L150
for more details.

Returns:
(nn.Module) The new model with `Float8Linear`.
"""
if fp8_recipe_name is not None:
fp8_config = Float8LinearConfig.from_recipe_name(fp8_recipe_name)
else:
fp8_config = Float8LinearConfig(enable_fsdp_float8_all_gather=True)
return _convert_to_float8_training_torchao(
model,
config=fp8_config,
module_filter_fn=lambda mod, fqn: fqn != "output",
)


# TODO: validate this in full_finetune_distributed recipe once FP8 + TP is enabled
def _validate_float8_tp_plan(
tp_plan: Optional[Dict[str, ParallelStyle]],
fp8_recipe_name: Optional[str] = None,
) -> None:
"""
Validate that the provided tensor parallel plan is compatible with the
float8 settings. Specifically, float8 tensor parallel plans are only
supported when using 'tensorwise' float8 recipes.
"""
if tp_plan is None or is_fp8_tensorwise_scaling(fp8_recipe_name):
return
for parallel_style in tp_plan.values():
if isinstance(parallel_style, Float8ColwiseParallel) or isinstance(
parallel_style, Float8RowwiseParallel
):
raise ValueError(
"%s and %s are only compatible with 'tensorwise' float8 recipes"
% (Float8ColwiseParallel.__name__, Float8RowwiseParallel.__name__)
)


def is_fp8_tensorwise_scaling(fp8_recipe_name: Optional[str]):
"""
Return True if the fp8 recipe name refers to 'tensorwwise' scaling.
"""
return fp8_recipe_name is None or fp8_recipe_name == "tensorwise"