Skip to content

Commit 5f0be26

Browse files
committed
fix simplefsdp gradient_divide_factor
1 parent 5d8e2d5 commit 5f0be26

File tree

2 files changed

+78
-6
lines changed

2 files changed

+78
-6
lines changed

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,19 @@ def parallelize_deepseekv3(
125125
):
126126
experts_shard_dim = 1
127127

128+
# when EP is enable, the shared experts' gradient reduce is done over
129+
# dp_mod_ep_mesh instead of whole dp_mesh.
130+
# we add a `fsdp_gradient_divide_factor` to scale gradient over dp_mesh
131+
# to be consistent with data.
128132
transformer_block.moe.experts = data_parallel(
129133
transformer_block.moe.experts,
130134
dp_mod_ep_mesh,
131135
dp_mode,
132136
ac_mode=job_config.activation_checkpoint.mode,
133137
mp_policy=mp_policy,
134138
shard_dim=experts_shard_dim,
139+
reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
135140
)
136-
# TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp
137-
# transformer_block.moe.experts.set_gradient_divide_factor(
138-
# parallel_dims.fsdp_gradient_divide_factor,
139-
# )
140141

141142
model = data_parallel(
142143
model,

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Shard,
2121
)
2222
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
23+
from torch.distributed.fsdp._fully_shard._fsdp_collectives import _div_if_needed
2324
from torch.distributed.tensor._dtensor_spec import DTensorSpec
2425
from torch.distributed.tensor._redistribute import redistribute_local_tensor
2526
from torch.distributed.tensor.placement_types import _StridedShard, Placement
@@ -49,6 +50,67 @@ class MixedPrecisionPolicy:
4950
reduce_dtype: Optional[torch.dtype] = None
5051

5152

53+
class _ScaledPartial(Partial):
54+
# A subclass of Partial placement that allows user to perform gradient reduction with a custom
55+
# factor (reduction_divide_factor) other than the default world size.
56+
def __init__(
57+
self,
58+
reduction_divide_factor: float,
59+
reduce_dtype: Optional[torch.dtype] = None,
60+
):
61+
self.reduction_divide_factor = reduction_divide_factor
62+
self.reduce_dtype = reduce_dtype
63+
super().__init__(reduce_op="sum")
64+
65+
def _get_reduction_divide_factors(
66+
self,
67+
) -> tuple[Optional[float], Optional[float]]:
68+
"""
69+
the logic follows
70+
https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688
71+
"""
72+
assert self.reduce_dtype in (
73+
torch.float32,
74+
torch.bfloat16,
75+
), "only support reduce_dtype to be fp32/bf16"
76+
pre_factor, post_factor = self.reduction_divide_factor, None
77+
return pre_factor, post_factor
78+
79+
def _reduce_value(
80+
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
81+
) -> torch.Tensor:
82+
# for all_reduce in DDP
83+
(
84+
pre_factor,
85+
post_factor,
86+
) = self._get_reduction_divide_factors()
87+
if pre_factor is not None:
88+
_div_if_needed(tensor, pre_factor)
89+
reduced = super()._reduce_value(tensor, mesh, mesh_dim)
90+
if post_factor is not None:
91+
_div_if_needed(reduced, post_factor)
92+
return reduced
93+
94+
def _reduce_shard_value(
95+
self,
96+
tensor: torch.Tensor,
97+
mesh: DeviceMesh,
98+
mesh_dim: int,
99+
shard_spec: Placement,
100+
) -> torch.Tensor:
101+
# for reduce_scatter in FSDP
102+
(
103+
pre_factor,
104+
post_factor,
105+
) = self._get_reduction_divide_factors()
106+
if pre_factor is not None:
107+
_div_if_needed(tensor, pre_factor)
108+
reduced = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec)
109+
if post_factor is not None:
110+
_div_if_needed(reduced, post_factor)
111+
return reduced
112+
113+
52114
def _distribute_dtensor(
53115
tensor: DTensor,
54116
device_mesh: DeviceMesh,
@@ -192,18 +254,25 @@ def __init__(
192254
mode,
193255
regional_ac,
194256
mp_policy,
257+
reduction_divide_factor,
195258
):
196259
super().__init__()
197260
self.device_mesh = device_mesh
198261
self.param_sharding = param_sharding
199262
self.mode = mode
200263
self.compute_placements = [Replicate()] * self.device_mesh.ndim
201-
self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim
264+
self.grad_placements = [
265+
_ScaledPartial(
266+
reduction_divide_factor=reduction_divide_factor,
267+
reduce_dtype=mp_policy.reduce_dtype,
268+
)
269+
if reduction_divide_factor is not None
270+
else Partial(reduce_op="avg")
271+
] * self.device_mesh.ndim
202272
self.regional_ac = regional_ac
203273
mp_policy = mp_policy or MixedPrecisionPolicy()
204274
self.param_dtype = mp_policy.param_dtype
205275
self.reduce_dtype = mp_policy.reduce_dtype
206-
self.ep_mesh_name, self.tp_mesh_name = "ep", "tp"
207276

208277
def replicate_compute(self, x):
209278
# data parallel runtime replicate parameters and do local compute
@@ -286,6 +355,7 @@ def data_parallel(
286355
ac_mode: str = "none",
287356
mp_policy: Optional[MixedPrecisionPolicy] = None,
288357
shard_dim: int = 0,
358+
reduction_divide_factor: Optional[float] = None,
289359
):
290360
if mode == "replicate":
291361
param_sharding = (Replicate(),)
@@ -348,6 +418,7 @@ def data_parallel(
348418
mode,
349419
regional_ac,
350420
mp_policy=mp_policy,
421+
reduction_divide_factor=reduction_divide_factor,
351422
),
352423
)
353424
return model

0 commit comments

Comments
 (0)