Skip to content

Commit

Permalink
[float8] all-reduce amax on dp mesh instead of global pg (#933)
Browse files Browse the repository at this point in the history
* [float8] all-reduce amax on dp mesh instead of global pg

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* liner

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* improve comments

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* move hp tensor inside if

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* linter

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* linter

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* linter

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* linter

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* linter

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy authored Sep 26, 2024
1 parent 72cc27d commit da0bbe3
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 6 deletions.
32 changes: 31 additions & 1 deletion test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
import torch.nn as nn
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import DTensor
from torch.distributed._tensor import DTensor, init_device_mesh
from torchao.float8.float8_tensor import GemmInputRole
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
Expand Down Expand Up @@ -293,6 +295,34 @@ def _get_curr_active_memory_mb(self) -> int:
return round(mem_stats["active_bytes.all.current"] / 1e6)


class Test2DParallelMultiThread(FSDPTestMultiThread, TestFloat8Common):
@property
def world_size(self) -> int:
return 4

def test_amax_allreduce_device_mesh(self):
dp_size = 2
pp_size = self.world_size // dp_size
global_mesh = init_device_mesh("cuda", (pp_size, dp_size), mesh_dim_names=("pp", "dp"))
dp_mesh = global_mesh["dp"]
pp_mesh = global_mesh["pp"]

if self.rank in [0, 1]:
# rank 0 and 1 are the 1st stage in the pipeline
# rank 2 and 4 are doing nothing but waiting for the 1st stage
torch.manual_seed(42 + self.rank)
hp_tensor = torch.randn(768, 32, device="cuda")
float8_tensor = hp_tensor_to_float8_dynamic(
hp_tensor,
torch.float8_e4m3fn,
Float8LinearConfig(
cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
),
gemm_input_role=GemmInputRole.WEIGHT,
reduce_amax=True,
device_mesh=dp_mesh
)

class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common):
@property
def world_size(self) -> int:
Expand Down
3 changes: 2 additions & 1 deletion torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic(
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
device_mesh = None,
) -> Float8Tensor:
"""
Given a high precision tensor `hp_tensor`,
Expand All @@ -52,7 +53,7 @@ def hp_tensor_to_float8_dynamic(
"""
if tensor_already_casted_to_fp8(hp_tensor):
return hp_tensor
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax)
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax, device_mesh)
return hp_tensor_and_scale_to_float8(
hp_tensor,
scale,
Expand Down
14 changes: 10 additions & 4 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,29 @@ def amax_history_to_scale_stack(


@torch.no_grad()
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
def tensor_to_amax(
x: torch.Tensor, reduce_amax: bool = False, device_mesh=None
) -> torch.Tensor:
amax = torch.max(torch.abs(x))

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
# happen elsewhere.
if reduce_amax and dist.is_initialized():
dist.all_reduce(amax, op=dist.ReduceOp.MAX)
pg = device_mesh.get_group() if device_mesh is not None else None
dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=pg)

return amax


@torch.no_grad()
def tensor_to_scale(
x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
x: torch.Tensor,
float8_dtype: torch.dtype,
reduce_amax: bool = False,
device_mesh=None,
) -> torch.Tensor:
amax = tensor_to_amax(x, reduce_amax=reduce_amax)
amax = tensor_to_amax(x, reduce_amax=reduce_amax, device_mesh=device_mesh)
return amax_to_scale(amax, float8_dtype, x.dtype)


Expand Down
1 change: 1 addition & 0 deletions torchao/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def fsdp_pre_all_gather(self, mesh):
self._linear_mm_config,
reduce_amax=True,
gemm_input_role=GemmInputRole.WEIGHT,
device_mesh=mesh,
)
return (float8_tensor._data,), (float8_tensor._scale,)

Expand Down

0 comments on commit da0bbe3

Please sign in to comment.