Skip to content

Add an option to use fp8-all-gather only without fp8 computation. #1093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def _test_transformer_parity(
module,
optim,
local_inp,
precompute,
config=float8_linear_config2,
precompute=precompute,
compile_transformer_block=compile_transformer_block,
)

Expand Down
188 changes: 188 additions & 0 deletions test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import copy
import pytest
from typing import Optional

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import torch
import torch._dynamo.testing
import torch.distributed as dist
import torch.nn as nn
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import DTensor, init_device_mesh
from torchao.float8.float8_tensor import GemmInputRole
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
)
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
swap_linear_layers,
)
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.float8_tensor import GemmInputRole
from torchao.testing.float8.fsdp2_utils import check_parity_fp8_comm_only

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
if not is_cuda_8_9:
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)


class Float8CommTestLinear(torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
fp8_param = hp_tensor_to_float8_dynamic(
self.weight,
torch.float8_e4m3fn,
None, # mm_linear_config,
reduce_amax=False,
gemm_input_role=GemmInputRole.WEIGHT,
)
weight_orig = fp8_param.to_original_precision()
output = torch.matmul(input, weight_orig.t())
if self.bias is not None:
output = output + self.bias.to(output.dtype)
return output

@classmethod
def from_float(
cls,
mod,
):
with torch.device("meta"):
new_mod = cls(
mod.in_features,
mod.out_features,
bias=(mod.bias is not None),
)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
return new_mod


def convert_to_float8_comm_test_layers(
module: nn.Module,
) -> nn.Module:
from_float = lambda m: Float8CommTestLinear.from_float(
m,
)
return swap_linear_layers(
module,
from_float,
)


class TestFloat8Common:
def broadcast_module(self, module: nn.Module) -> None:
# Broadcast for multi-threaded process group tests since seed is per
# process, not per thread
for param in module.parameters():
dist.broadcast(param, src=0)

def init_transformer(self, weight_tying: bool, dtype: Optional[torch.dtype] = None) -> nn.Module:
torch.manual_seed(42)
args = ModelArgs(
n_layers=3,
dim=768,
n_heads=12,
dropout_p=0.0,
weight_tying=weight_tying,
vocab_size=32,
)
module = Transformer(args).cuda()
if dtype is not None:
module = module.to(dtype=dtype)
self.broadcast_module(module)
return module


class TestFloat8MultiProcess(FSDPTest, TestFloat8Common):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 2)


@skip_if_lt_x_gpu(2)
def test_transformer_parity(self):
self.run_subtests(
{
"compile_transformer_block": [False, True],
"precompute": [False, True],
"scaling_type_weight": [ScalingType.DYNAMIC],
"dtype": [torch.float32, torch.bfloat16],
},
self._test_transformer_parity,
)

def _test_transformer_parity(
self,
precompute: bool,
scaling_type_weight: ScalingType,
compile_transformer_block: bool,
dtype: Optional[torch.dtype] = None,
):
if scaling_type_weight is ScalingType.DELAYED and precompute:
return

module = self.init_transformer(weight_tying=False, dtype=dtype)

local_inp = torch.randint(
0, module.tok_embeddings.weight.size(0), (16, 16), device="cuda"
)

# reference modules
ref_module = copy.deepcopy(module)
convert_to_float8_comm_test_layers(
ref_module,
)

# fp8 comm-only modules
float8_linear_config2 = Float8LinearConfig(
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
enable_fsdp_float8_all_gather=True,
use_fp8_all_gather_only=True,
)
convert_to_float8_training(
module,
config=float8_linear_config2,
)

for layer_id, transformer_block in module.layers.named_children():
if compile_transformer_block:
transformer_block = torch.compile(transformer_block, dynamic=False)
fully_shard(transformer_block)
module.layers.register_module(layer_id, transformer_block)
fully_shard(module)

ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)

check_parity_fp8_comm_only(
self,
ref_module,
ref_optim,
module,
optim,
local_inp,
config=float8_linear_config2,
precompute=precompute,
compile=compile_transformer_block,
)



if __name__ == "__main__":
run_tests()
22 changes: 21 additions & 1 deletion torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
# LICENSE file in the root directory of this source tree.

import enum
import logging
from dataclasses import dataclass
from typing import Optional

import torch

logger: logging.Logger = logging.getLogger()

class ScalingType(enum.Enum):
DELAYED = "delayed"
Expand Down Expand Up @@ -220,9 +222,14 @@ class Float8LinearConfig:
# For now, we use the checkpointing api to force the recomputation of fp8 weight in backward.
# TODO(future PR): either enable by default or have a warning and set up the
# tests so that the warning does not spam the CI stdout.

force_recompute_fp8_weight_in_bwd: bool = False

# If True, we only use fp8-all-gather to reduce the communication cost.
# The gemm computation is still done in the original precision.
# `cast_config_weight` is used to decide how to cast the weight to fp8,
# other casting configs will be ignored.
use_fp8_all_gather_only: bool = False

def __post_init__(self):
# Populate the additional cast overrides, if the user did not specify them
# Note: this hacks around the frozen-ness of this dataclass
Expand Down Expand Up @@ -261,6 +268,19 @@ def __post_init__(self):
is_disabled_2 = cc1.scaling_type is ScalingType.DISABLED
assert is_disabled_1 == is_disabled_2, \
f"incompatible operand precision for {gemm_name}"

if self.use_fp8_all_gather_only:
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"

# See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning.
if (
self.enable_fsdp_float8_all_gather
and not self.force_recompute_fp8_weight_in_bwd
):
logger.warning(
"When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd."
)



# Pre-made recipes for common configurations
Expand Down
30 changes: 26 additions & 4 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,12 @@ def cast_weight_to_float8_t(
)
return weight_fp8.t()

def cast_weight_to_original_t(self, weight: torch.Tensor):
if isinstance(weight, Float8Tensor):
return weight.to_original_precision().t()
else:
return weight.t()

def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
Expand Down Expand Up @@ -550,10 +556,7 @@ def float8_post_forward(self):
self.is_amax_initialized = True
self.amax_and_scale_synced = False

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.has_any_delayed_scaling:
self.float8_pre_forward(input)

def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
has_any_axiswise_scaling = (
self.config.cast_config_input.scaling_granularity is ScalingGranularity.AXISWISE or
self.config.cast_config_weight.scaling_granularity is ScalingGranularity.AXISWISE or
Expand Down Expand Up @@ -595,6 +598,25 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.linear_mm_config,
self.config,
)
return output

def forward_original_precision_matmul(self, input: torch.Tensor) -> torch.Tensor:
if self.config.force_recompute_fp8_weight_in_bwd:
orig_weight_t = checkpoint.checkpoint(self.cast_weight_to_original_t, self.weight)
else:
orig_weight_t = self.cast_weight_to_original_t(self.weight)

output = torch.matmul(input, orig_weight_t)
return output

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.has_any_delayed_scaling:
self.float8_pre_forward(input)

if self.config.use_fp8_all_gather_only:
output = self.forward_original_precision_matmul(input)
else:
output = self.forward_fp8_matmul(input)

if self.bias is not None:
output = output + self.bias.to(output.dtype)
Expand Down
53 changes: 49 additions & 4 deletions torchao/testing/float8/fsdp2_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import contextlib
from typing import List, Optional

import torchao.float8.config as config

import torch
import torch.distributed as dist
import torch.nn as nn
from torchao.float8.config import Float8LinearConfig, ScalingType

import torchao.float8.config as config
from torchao.float8.config import (
Float8LinearConfig,
ScalingType,
)

from torchao.float8.float8_linear_utils import (
linear_requires_sync,
sync_float8_amax_and_scale_history,
Expand All @@ -21,8 +25,8 @@ def check_parity_no_mp(
fsdp_model: nn.Module,
fsdp_optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
config: Float8LinearConfig,
precompute: bool = False,
config: Optional[Float8LinearConfig] = None,
compile_transformer_block: bool = False,
):
# TODO(before land): reorder args and make config not optional
Expand Down Expand Up @@ -84,3 +88,44 @@ def check_parity_bf16_mp(
):
param_bf16.detach().copy_(param_fp32)
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")


def check_parity_fp8_comm_only(
test_cls,
ref_model: nn.Module,
ref_optim: torch.optim.Optimizer,
fsdp_model: nn.Module,
fsdp_optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
config: Float8LinearConfig,
precompute: bool = False,
compile: bool = False,
):
for iter_idx in range(10):
losses: List[torch.Tensor] = []
for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)):

optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(model(local_inp).sum())
losses[-1].backward()
if model is ref_model:
for name, param in model.named_parameters():
dist.all_reduce(param.grad)
param.grad.div_(dist.get_world_size())

if linear_requires_sync(config):
sync_float8_amax_and_scale_history(model)

optim.step()
if (
model is fsdp_model
and precompute
and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC
):
precompute_float8_dynamic_scale_for_fsdp(model)

if compile:
# When compile, the ref loss and fsdp loss are not exactly the same, only check the loss values are valid for now.
assert (torch.isfinite(losses[0]).any() and torch.isfinite(losses[1]).any()), f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}"
else:
test_cls.assertEqual(losses[0], losses[1], f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
Loading