Skip to content

Commit 4e14a0f

Browse files
committed
fix simplefsdp gradient_divide_factor
1 parent 5d8e2d5 commit 4e14a0f

File tree

4 files changed

+119
-8
lines changed

4 files changed

+119
-8
lines changed

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,8 @@ def parallelize_deepseekv3(
132132
ac_mode=job_config.activation_checkpoint.mode,
133133
mp_policy=mp_policy,
134134
shard_dim=experts_shard_dim,
135+
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
135136
)
136-
# TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp
137-
# transformer_block.moe.experts.set_gradient_divide_factor(
138-
# parallel_dims.fsdp_gradient_divide_factor,
139-
# )
140137

141138
model = data_parallel(
142139
model,

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from collections.abc import Sequence
88
from contextlib import contextmanager
99
from dataclasses import dataclass
10-
from typing import List, Optional
10+
from typing import cast, List, Optional
1111

1212
import torch
13+
import torch.distributed._functional_collectives as funcol
1314
import torch.nn as nn
1415

1516
from torch.distributed._tensor import (
@@ -20,6 +21,7 @@
2021
Shard,
2122
)
2223
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
24+
from torch.distributed.fsdp._fully_shard._fsdp_collectives import _div_if_needed
2325
from torch.distributed.tensor._dtensor_spec import DTensorSpec
2426
from torch.distributed.tensor._redistribute import redistribute_local_tensor
2527
from torch.distributed.tensor.placement_types import _StridedShard, Placement
@@ -49,6 +51,101 @@ class MixedPrecisionPolicy:
4951
reduce_dtype: Optional[torch.dtype] = None
5052

5153

54+
@dataclass(frozen=True)
55+
class SimpleFSDPPartial(Partial):
56+
gradient_divide_factor: Optional[float] = None
57+
reduce_dtype: Optional[torch.dtype] = None
58+
data_parallel_size: Optional[int] = None
59+
force_sum_reduction_for_comms: bool = False
60+
61+
def _get_gradient_divide_factors(
62+
self,
63+
) -> tuple[Optional[float], Optional[float], str, str,]:
64+
"""
65+
the logic follows
66+
https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688
67+
"""
68+
if self.gradient_divide_factor is None:
69+
return None, None, None, None
70+
71+
overflow_risk = self.reduce_dtype not in (torch.float32, torch.bfloat16)
72+
73+
if not overflow_risk and not self.force_sum_reduction_for_comms:
74+
if self.gradient_divide_factor == self.data_parallel_size:
75+
if self.data_parallel_size == 1:
76+
return None, None, "SUM", "SUM"
77+
return None, None, "AVG", "AVG"
78+
else:
79+
reduce_scatter_op = torch.distributed._make_nccl_premul_sum(
80+
1 / self.gradient_divide_factor
81+
)
82+
return None, None, reduce_scatter_op, "SUM"
83+
84+
pre_factor: Optional[float]
85+
if overflow_risk:
86+
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
87+
# overflow/underflow. For N data parallel workers, each worker computes
88+
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
89+
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
90+
pre_factor = 1
91+
while (
92+
self.gradient_divide_factor % pre_factor == 0
93+
and self.gradient_divide_factor / pre_factor > pre_factor
94+
):
95+
pre_factor *= 2
96+
post_factor = self.gradient_divide_factor / pre_factor
97+
else:
98+
# Prefer post-multiplying as it operates on less data and is thus faster
99+
pre_factor, post_factor = None, self.gradient_divide_factor
100+
101+
return pre_factor, post_factor, "SUM", "SUM"
102+
103+
def _reduce_value(
104+
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
105+
) -> torch.Tensor:
106+
# for all_reduce in DDP
107+
(
108+
pre_factor,
109+
post_factor,
110+
_,
111+
all_reduce_op,
112+
) = self._get_gradient_divide_factors()
113+
if pre_factor is not None:
114+
_div_if_needed(tensor, pre_factor)
115+
tensor = funcol.all_reduce(
116+
tensor, reduceOp=all_reduce_op, group=(mesh, mesh_dim)
117+
)
118+
if post_factor is not None:
119+
_div_if_needed(tensor, post_factor)
120+
return tensor
121+
122+
def _reduce_shard_value(
123+
self,
124+
tensor: torch.Tensor,
125+
mesh: DeviceMesh,
126+
mesh_dim: int,
127+
shard_spec: Placement,
128+
) -> torch.Tensor:
129+
# for reduce_scatter in FSDP
130+
(
131+
pre_factor,
132+
post_factor,
133+
reduce_scatter_op,
134+
_,
135+
) = self._get_gradient_divide_factors()
136+
137+
if pre_factor is not None:
138+
_div_if_needed(tensor, pre_factor)
139+
shard_spec = cast(Shard, shard_spec)
140+
tensor = shard_spec._reduce_shard_tensor(
141+
tensor, mesh, reduce_scatter_op, mesh_dim
142+
)
143+
144+
if post_factor is not None:
145+
_div_if_needed(tensor, post_factor)
146+
return tensor
147+
148+
52149
def _distribute_dtensor(
53150
tensor: DTensor,
54151
device_mesh: DeviceMesh,
@@ -192,18 +289,27 @@ def __init__(
192289
mode,
193290
regional_ac,
194291
mp_policy,
292+
gradient_divide_factor,
195293
):
196294
super().__init__()
197295
self.device_mesh = device_mesh
198296
self.param_sharding = param_sharding
199297
self.mode = mode
200298
self.compute_placements = [Replicate()] * self.device_mesh.ndim
201-
self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim
299+
self.grad_placements = [
300+
SimpleFSDPPartial(
301+
reduce_op="avg",
302+
gradient_divide_factor=gradient_divide_factor,
303+
reduce_dtype=mp_policy.reduce_dtype,
304+
data_parallel_size=self.device_mesh.size(),
305+
)
306+
if gradient_divide_factor is not None
307+
else Partial(reduce_op="avg")
308+
] * self.device_mesh.ndim
202309
self.regional_ac = regional_ac
203310
mp_policy = mp_policy or MixedPrecisionPolicy()
204311
self.param_dtype = mp_policy.param_dtype
205312
self.reduce_dtype = mp_policy.reduce_dtype
206-
self.ep_mesh_name, self.tp_mesh_name = "ep", "tp"
207313

208314
def replicate_compute(self, x):
209315
# data parallel runtime replicate parameters and do local compute
@@ -286,6 +392,7 @@ def data_parallel(
286392
ac_mode: str = "none",
287393
mp_policy: Optional[MixedPrecisionPolicy] = None,
288394
shard_dim: int = 0,
395+
gradient_divide_factor: Optional[float] = None,
289396
):
290397
if mode == "replicate":
291398
param_sharding = (Replicate(),)
@@ -348,6 +455,7 @@ def data_parallel(
348455
mode,
349456
regional_ac,
350457
mp_policy=mp_policy,
458+
gradient_divide_factor=gradient_divide_factor,
351459
),
352460
)
353461
return model

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
qk_rope_head_dim=64,
5252
v_head_dim=128,
5353
mscale=0.70,
54+
use_flex_attn=True,
55+
attn_mask_type="block_causal",
5456
),
5557
"debugmodel_flex_attn": DeepSeekV3ModelArgs(
5658
vocab_size=2048,

torchtitan/models/llama3/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929

3030
llama3_configs = {
3131
"debugmodel": TransformerModelArgs(
32-
dim=256, n_layers=6, n_heads=16, vocab_size=2048, rope_theta=500000
32+
dim=256,
33+
n_layers=6,
34+
n_heads=16,
35+
vocab_size=2048,
36+
rope_theta=500000,
3337
),
3438
"debugmodel_flex_attn": TransformerModelArgs(
3539
dim=256,

0 commit comments

Comments
 (0)