Skip to content

Commit fc14199

Browse files
committed
fix simplefsdp gradient_divide_factor
1 parent 5d8e2d5 commit fc14199

File tree

2 files changed

+92
-6
lines changed

2 files changed

+92
-6
lines changed

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,8 @@ def parallelize_deepseekv3(
132132
ac_mode=job_config.activation_checkpoint.mode,
133133
mp_policy=mp_policy,
134134
shard_dim=experts_shard_dim,
135+
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
135136
)
136-
# TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp
137-
# transformer_block.moe.experts.set_gradient_divide_factor(
138-
# parallel_dims.fsdp_gradient_divide_factor,
139-
# )
140137

141138
model = data_parallel(
142139
model,

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Shard,
2121
)
2222
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
23+
from torch.distributed.fsdp._fully_shard._fsdp_collectives import _div_if_needed
2324
from torch.distributed.tensor._dtensor_spec import DTensorSpec
2425
from torch.distributed.tensor._redistribute import redistribute_local_tensor
2526
from torch.distributed.tensor.placement_types import _StridedShard, Placement
@@ -49,6 +50,84 @@ class MixedPrecisionPolicy:
4950
reduce_dtype: Optional[torch.dtype] = None
5051

5152

53+
class SimpleFSDPPartial(Partial):
54+
# A subclass of Partial placement that allows user to perform gradient reduction with a custom
55+
# factor (gradient_divide_factor) other than the default world size.
56+
def __init__(
57+
self,
58+
reduce_op: str,
59+
gradient_divide_factor: float,
60+
reduce_dtype: Optional[torch.dtype] = None,
61+
):
62+
self.gradient_divide_factor = gradient_divide_factor
63+
self.reduce_dtype = reduce_dtype
64+
super().__init__(reduce_op)
65+
66+
def _get_gradient_divide_factors(
67+
self,
68+
) -> tuple[Optional[float], Optional[float]]:
69+
"""
70+
the logic follows
71+
https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688
72+
"""
73+
overflow_risk = self.reduce_dtype not in (torch.float32, torch.bfloat16)
74+
75+
pre_factor: Optional[float]
76+
if overflow_risk:
77+
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
78+
# overflow/underflow. For N data parallel workers, each worker computes
79+
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
80+
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
81+
pre_factor = 1
82+
while (
83+
self.gradient_divide_factor % pre_factor == 0
84+
and self.gradient_divide_factor / pre_factor > pre_factor
85+
):
86+
pre_factor *= 2
87+
post_factor = self.gradient_divide_factor / pre_factor
88+
else:
89+
# Prefer post-multiplying as it operates on less data and is thus faster
90+
pre_factor, post_factor = None, self.gradient_divide_factor
91+
92+
return pre_factor, post_factor
93+
94+
def _reduce_value(
95+
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
96+
) -> torch.Tensor:
97+
# for all_reduce in DDP
98+
(
99+
pre_factor,
100+
post_factor,
101+
) = self._get_gradient_divide_factors()
102+
if pre_factor is not None:
103+
_div_if_needed(tensor, pre_factor)
104+
reduced = super()._reduce_value(tensor, mesh, mesh_dim)
105+
if post_factor is not None:
106+
_div_if_needed(reduced, post_factor)
107+
return reduced
108+
109+
def _reduce_shard_value(
110+
self,
111+
tensor: torch.Tensor,
112+
mesh: DeviceMesh,
113+
mesh_dim: int,
114+
shard_spec: Placement,
115+
) -> torch.Tensor:
116+
# for reduce_scatter in FSDP
117+
(
118+
pre_factor,
119+
post_factor,
120+
) = self._get_gradient_divide_factors()
121+
122+
if pre_factor is not None:
123+
_div_if_needed(tensor, pre_factor)
124+
reduced = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec)
125+
126+
if post_factor is not None:
127+
_div_if_needed(reduced, post_factor)
128+
return reduced
129+
130+
52131
def _distribute_dtensor(
53132
tensor: DTensor,
54133
device_mesh: DeviceMesh,
@@ -192,18 +271,26 @@ def __init__(
192271
mode,
193272
regional_ac,
194273
mp_policy,
274+
gradient_divide_factor,
195275
):
196276
super().__init__()
197277
self.device_mesh = device_mesh
198278
self.param_sharding = param_sharding
199279
self.mode = mode
200280
self.compute_placements = [Replicate()] * self.device_mesh.ndim
201-
self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim
281+
self.grad_placements = [
282+
SimpleFSDPPartial(
283+
reduce_op="sum",
284+
gradient_divide_factor=gradient_divide_factor,
285+
reduce_dtype=mp_policy.reduce_dtype,
286+
)
287+
if gradient_divide_factor is not None
288+
else Partial(reduce_op="avg")
289+
] * self.device_mesh.ndim
202290
self.regional_ac = regional_ac
203291
mp_policy = mp_policy or MixedPrecisionPolicy()
204292
self.param_dtype = mp_policy.param_dtype
205293
self.reduce_dtype = mp_policy.reduce_dtype
206-
self.ep_mesh_name, self.tp_mesh_name = "ep", "tp"
207294

208295
def replicate_compute(self, x):
209296
# data parallel runtime replicate parameters and do local compute
@@ -286,6 +373,7 @@ def data_parallel(
286373
ac_mode: str = "none",
287374
mp_policy: Optional[MixedPrecisionPolicy] = None,
288375
shard_dim: int = 0,
376+
gradient_divide_factor: Optional[float] = None,
289377
):
290378
if mode == "replicate":
291379
param_sharding = (Replicate(),)
@@ -348,6 +436,7 @@ def data_parallel(
348436
mode,
349437
regional_ac,
350438
mp_policy=mp_policy,
439+
gradient_divide_factor=gradient_divide_factor,
351440
),
352441
)
353442
return model

0 commit comments

Comments
 (0)