Skip to content

[float8 moe training] FSDP support #2413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions test/prototype/moe_training/test_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import copy
import os

import pytest
import torch
from torch import distributed as dist
from torch import nn
from torch.distributed._composable.fsdp import fully_shard
from torch.nn import functional as F

# this feature requires CUDA and SM89+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
pytest.skip(
"CUDA not available or compute capability < 8.9", allow_module_level=True
)

from torchao.float8.float8_utils import compute_error
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
from torchao.quantization.quant_api import quantize_

# this test requires torchtitan
try:
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
from torchtitan.experiments.llama4.model.moe import MoE
except ImportError:
import warnings

warnings.warn("torchtitan not installed, skipping MoE tests.")
pytest.skip(allow_module_level=True)


def test_moe_float8_training_fsdp():
assert torch.cuda.is_available()

# setup distributed for fsdp
setup_distributed()

# define model args
target_fqns = ["experts"]
model_args = TransformerModelArgs(
moe_enabled=True,
num_experts=8,
dim=256,
)
init_std = 0.02
device = torch.device("cuda")

# reference bf16 MoE
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
torch.manual_seed(42)
ref_model.init_weights(init_std, device)

# target MoE for testing conversion
model = copy.deepcopy(ref_model)

# assert starting params are identical for both models
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
assert torch.equal(param1, param2)

# convert MoE to float8 training
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
for target_fqn in target_fqns:
if target_fqn in cur_fqn:
return True
return False

# quantize test model
config = MoETrainingConfig()
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# validate that only the experts were converted
_validate_model_conversion(
model,
target_fqns=target_fqns,
)

# FSDP2
fully_shard(model)
fully_shard(ref_model)

# inputs
batch, seq, dim = 8, 2048, 256
ref_x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)
x = ref_x.detach().clone().requires_grad_(True)

# forward pass
ref_out = ref_model(ref_x)
out = model(x)

# validate output
out_sqnr = compute_error(out, ref_out)
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."

# compute loss
labels = torch.ones_like(ref_out)
ref_loss = F.mse_loss(ref_out, labels)
out_loss = F.mse_loss(out, labels)

# backward pass
ref_loss.backward()
out_loss.backward()

# validate input gradient
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
assert input_grad_sqnr.item() >= 30.0, (
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
)

# validate param gradients
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
param_grad_sqnr = compute_error(param1.grad, param2.grad)
assert param_grad_sqnr.item() >= 25.0, (
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
)

dist.destroy_process_group()


def _validate_model_conversion(
root_module: nn.Module,
target_fqns: list[str],
):
def _recursive_validate(
module: nn.Module,
cur_fqn: str,
):
is_allowed_module = cur_fqn in target_fqns

# check current module params
for param_name, param in module.named_parameters(recurse=False):
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
if is_converted_type:
assert is_allowed_module, (
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
)
if not is_allowed_module:
assert not is_converted_type, (
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
)

# recursively check child modules
for child_name, child_module in module.named_children():
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
_recursive_validate(child_module, child_fqn)

_recursive_validate(root_module, "")


def setup_distributed():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ class ExperimentConfig:

@dataclass(frozen=True)
class ExperimentResult:
torch_time_us: float
triton_time_us: bool
triton_speedup: float
time_us: float


@dataclass(frozen=True)
Expand Down Expand Up @@ -98,46 +96,34 @@ def warmup(func, *args, **kwargs):
for _ in range(10):
func(*args, **kwargs)

def forward_backward(A, B_t, offs, use_triton=True):
def forward_backward(A, B_t, offs):
out = _scaled_grouped_mm(
A,
B_t,
offs=offs,
out_dtype=torch.bfloat16,
use_triton_for_per_group_scales=use_triton,
)
out.sum().backward()
torch.cuda.synchronize()

# benchmark torch
torch_func = torch.compile(forward_backward) if args.compile else forward_backward
warmup(torch_func, A, B_t, offs, use_triton=False)
warmup(torch_func, A, B_t, offs)
start_time_ns = time.perf_counter_ns()
torch_func(A, B_t, offs, use_triton=False)
torch_func(A, B_t, offs)
torch_time_ns = time.perf_counter_ns() - start_time_ns
torch_time_us = torch_time_ns / 1e3

# benchmark triton
warmup(forward_backward, A, B_t, offs, use_triton=True)
start_time_ns = time.perf_counter_ns()
forward_backward(A, B_t, offs, use_triton=True)
triton_time_ns = time.perf_counter_ns() - start_time_ns
triton_time_us = triton_time_ns / 1e3
time_us = torch_time_ns / 1e3

return ExperimentResult(
torch_time_us=round(torch_time_us, 3),
triton_time_us=round(triton_time_us, 3),
triton_speedup=round(torch_time_us / triton_time_us, 3),
time_us=round(time_us, 3),
)


def print_results(experiments: List[Experiment]):
headers = [
"A_shape",
"B_shape",
"torch_time_us",
"triton_time_us",
"triton_speedup",
"time_us",
]
rows = []
for experiment in experiments:
Expand All @@ -147,9 +133,7 @@ def print_results(experiments: List[Experiment]):
[
A_shape,
B_shape,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
experiment.result.triton_speedup,
experiment.result.time_us,
]
)
print(tabulate(rows, headers=headers))
Expand Down
8 changes: 1 addition & 7 deletions torchao/prototype/moe_training/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ class MoETrainingConfig(AOBaseConfig):
For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor.
"""

# temporary config flag for testing/benchmarking, will remove before graduating out of prototype
use_triton_for_per_group_scales: bool = True


@register_quantize_module_handler(MoETrainingConfig)
def _moe_training_transform(
Expand Down Expand Up @@ -71,7 +68,6 @@ def _swap_params(
Returns:
nn.Module: The modified module with swapped linear layers.
"""
use_triton = config.use_triton_for_per_group_scales if config is not None else False
if isinstance(module, nn.Parameter) and (
module_filter_fn is None or module_filter_fn(module, "")
):
Expand All @@ -80,9 +76,7 @@ def _swap_params(
f"Does not support a root nn.Parameter with children: {module}"
)
if not isinstance(module.data, ScaledGroupedMMTensor):
new_data = ScaledGroupedMMTensor(
module.data, use_triton_for_per_group_scales=use_triton
)
new_data = ScaledGroupedMMTensor(module.data)
return nn.Parameter(new_data, requires_grad=module.requires_grad)
return module

Expand Down
20 changes: 2 additions & 18 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
)
from torchao.prototype.moe_training.utils import (
_is_column_major,
_to_2d_jagged_float8_tensor_colwise,
_to_2d_jagged_float8_tensor_rowwise,
)


Expand All @@ -26,7 +24,6 @@ def _scaled_grouped_mm(
B_t: torch.Tensor,
offs: torch.Tensor,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
use_triton_for_per_group_scales: bool = True,
) -> torch.Tensor:
"""
This function performs dynamic float8 quantization with row-wise scaling
Expand Down Expand Up @@ -143,7 +140,6 @@ def forward(
# Store what we need for backward.
ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)
ctx.out_dtype = out_dtype
ctx.use_triton_for_per_group_scales = use_triton_for_per_group_scales

# Perform scaled grouped GEMM and return result.
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
Expand All @@ -167,7 +163,6 @@ def forward(
def backward(ctx, grad_output: torch.Tensor):
A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors
out_dtype = ctx.out_dtype
use_triton_for_per_group_scales = ctx.use_triton_for_per_group_scales

# Convert grad_output to float8, row-major for left operand of grouped GEMM
# needed for grad_A: grad_output @ B
Expand Down Expand Up @@ -216,27 +211,16 @@ def backward(ctx, grad_output: torch.Tensor):

# grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups."
# Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups.
per_group_rowwise_scale_func = (
triton_fp8_row_major_jagged_rowwise_scales
if use_triton_for_per_group_scales
else _to_2d_jagged_float8_tensor_rowwise
)
per_group_colwise_scale_func = (
triton_fp8_col_major_jagged_colwise_scales
if use_triton_for_per_group_scales
else _to_2d_jagged_float8_tensor_colwise
)

grad_output_t_fp8_row_major, grad_output_t_scales = (
per_group_rowwise_scale_func(
triton_fp8_row_major_jagged_rowwise_scales(
grad_output_t_row_major,
offs,
torch.float8_e4m3fn,
round_scales_to_power_of_2=True,
)
)

A_fp8_col_major, A_scales = per_group_colwise_scale_func(
A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales(
A_col_major,
offs,
torch.float8_e4m3fn,
Expand Down
Loading
Loading