|
7 | 7 | from collections.abc import Sequence |
8 | 8 | from contextlib import contextmanager |
9 | 9 | from dataclasses import dataclass |
10 | | -from typing import List, Optional |
| 10 | +from typing import cast, List, Optional |
11 | 11 |
|
12 | 12 | import torch |
| 13 | +import torch.distributed._functional_collectives as funcol |
13 | 14 | import torch.nn as nn |
14 | 15 |
|
15 | 16 | from torch.distributed._tensor import ( |
|
20 | 21 | Shard, |
21 | 22 | ) |
22 | 23 | from torch.distributed.device_mesh import _mesh_resources, DeviceMesh |
| 24 | +from torch.distributed.fsdp._fully_shard._fsdp_collectives import _div_if_needed |
23 | 25 | from torch.distributed.tensor._dtensor_spec import DTensorSpec |
24 | 26 | from torch.distributed.tensor._redistribute import redistribute_local_tensor |
25 | 27 | from torch.distributed.tensor.placement_types import _StridedShard, Placement |
@@ -49,6 +51,101 @@ class MixedPrecisionPolicy: |
49 | 51 | reduce_dtype: Optional[torch.dtype] = None |
50 | 52 |
|
51 | 53 |
|
| 54 | +@dataclass(frozen=True) |
| 55 | +class SimpleFSDPPartial(Partial): |
| 56 | + gradient_divide_factor: Optional[float] = None |
| 57 | + reduce_dtype: Optional[torch.dtype] = None |
| 58 | + data_parallel_size: Optional[int] = None |
| 59 | + force_sum_reduction_for_comms: bool = False |
| 60 | + |
| 61 | + def _get_gradient_divide_factors( |
| 62 | + self, |
| 63 | + ) -> tuple[Optional[float], Optional[float], str, str,]: |
| 64 | + """ |
| 65 | + the logic follows |
| 66 | + https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688 |
| 67 | + """ |
| 68 | + if self.gradient_divide_factor is None: |
| 69 | + return None, None, None, None |
| 70 | + |
| 71 | + overflow_risk = self.reduce_dtype not in (torch.float32, torch.bfloat16) |
| 72 | + |
| 73 | + if not overflow_risk and not self.force_sum_reduction_for_comms: |
| 74 | + if self.gradient_divide_factor == self.data_parallel_size: |
| 75 | + if self.data_parallel_size == 1: |
| 76 | + return None, None, "SUM", "SUM" |
| 77 | + return None, None, "AVG", "AVG" |
| 78 | + else: |
| 79 | + reduce_scatter_op = torch.distributed._make_nccl_premul_sum( |
| 80 | + 1 / self.gradient_divide_factor |
| 81 | + ) |
| 82 | + return None, None, reduce_scatter_op, "SUM" |
| 83 | + |
| 84 | + pre_factor: Optional[float] |
| 85 | + if overflow_risk: |
| 86 | + # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid |
| 87 | + # overflow/underflow. For N data parallel workers, each worker computes |
| 88 | + # g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid |
| 89 | + # overflow/underflow, we divide by ~sqrt(N) before/after the reduction. |
| 90 | + pre_factor = 1 |
| 91 | + while ( |
| 92 | + self.gradient_divide_factor % pre_factor == 0 |
| 93 | + and self.gradient_divide_factor / pre_factor > pre_factor |
| 94 | + ): |
| 95 | + pre_factor *= 2 |
| 96 | + post_factor = self.gradient_divide_factor / pre_factor |
| 97 | + else: |
| 98 | + # Prefer post-multiplying as it operates on less data and is thus faster |
| 99 | + pre_factor, post_factor = None, self.gradient_divide_factor |
| 100 | + |
| 101 | + return pre_factor, post_factor, "SUM", "SUM" |
| 102 | + |
| 103 | + def _reduce_value( |
| 104 | + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int |
| 105 | + ) -> torch.Tensor: |
| 106 | + # for all_reduce in DDP |
| 107 | + ( |
| 108 | + pre_factor, |
| 109 | + post_factor, |
| 110 | + _, |
| 111 | + all_reduce_op, |
| 112 | + ) = self._get_gradient_divide_factors() |
| 113 | + if pre_factor is not None: |
| 114 | + _div_if_needed(tensor, pre_factor) |
| 115 | + tensor = funcol.all_reduce( |
| 116 | + tensor, reduceOp=all_reduce_op, group=(mesh, mesh_dim) |
| 117 | + ) |
| 118 | + if post_factor is not None: |
| 119 | + _div_if_needed(tensor, post_factor) |
| 120 | + return tensor |
| 121 | + |
| 122 | + def _reduce_shard_value( |
| 123 | + self, |
| 124 | + tensor: torch.Tensor, |
| 125 | + mesh: DeviceMesh, |
| 126 | + mesh_dim: int, |
| 127 | + shard_spec: Placement, |
| 128 | + ) -> torch.Tensor: |
| 129 | + # for reduce_scatter in FSDP |
| 130 | + ( |
| 131 | + pre_factor, |
| 132 | + post_factor, |
| 133 | + reduce_scatter_op, |
| 134 | + _, |
| 135 | + ) = self._get_gradient_divide_factors() |
| 136 | + |
| 137 | + if pre_factor is not None: |
| 138 | + _div_if_needed(tensor, pre_factor) |
| 139 | + shard_spec = cast(Shard, shard_spec) |
| 140 | + tensor = shard_spec._reduce_shard_tensor( |
| 141 | + tensor, mesh, reduce_scatter_op, mesh_dim |
| 142 | + ) |
| 143 | + |
| 144 | + if post_factor is not None: |
| 145 | + _div_if_needed(tensor, post_factor) |
| 146 | + return tensor |
| 147 | + |
| 148 | + |
52 | 149 | def _distribute_dtensor( |
53 | 150 | tensor: DTensor, |
54 | 151 | device_mesh: DeviceMesh, |
@@ -192,18 +289,27 @@ def __init__( |
192 | 289 | mode, |
193 | 290 | regional_ac, |
194 | 291 | mp_policy, |
| 292 | + gradient_divide_factor, |
195 | 293 | ): |
196 | 294 | super().__init__() |
197 | 295 | self.device_mesh = device_mesh |
198 | 296 | self.param_sharding = param_sharding |
199 | 297 | self.mode = mode |
200 | 298 | self.compute_placements = [Replicate()] * self.device_mesh.ndim |
201 | | - self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim |
| 299 | + self.grad_placements = [ |
| 300 | + SimpleFSDPPartial( |
| 301 | + reduce_op="avg", |
| 302 | + gradient_divide_factor=gradient_divide_factor, |
| 303 | + reduce_dtype=mp_policy.reduce_dtype, |
| 304 | + data_parallel_size=self.device_mesh.size(), |
| 305 | + ) |
| 306 | + if gradient_divide_factor is not None |
| 307 | + else Partial(reduce_op="avg") |
| 308 | + ] * self.device_mesh.ndim |
202 | 309 | self.regional_ac = regional_ac |
203 | 310 | mp_policy = mp_policy or MixedPrecisionPolicy() |
204 | 311 | self.param_dtype = mp_policy.param_dtype |
205 | 312 | self.reduce_dtype = mp_policy.reduce_dtype |
206 | | - self.ep_mesh_name, self.tp_mesh_name = "ep", "tp" |
207 | 313 |
|
208 | 314 | def replicate_compute(self, x): |
209 | 315 | # data parallel runtime replicate parameters and do local compute |
@@ -286,6 +392,7 @@ def data_parallel( |
286 | 392 | ac_mode: str = "none", |
287 | 393 | mp_policy: Optional[MixedPrecisionPolicy] = None, |
288 | 394 | shard_dim: int = 0, |
| 395 | + gradient_divide_factor: Optional[float] = None, |
289 | 396 | ): |
290 | 397 | if mode == "replicate": |
291 | 398 | param_sharding = (Replicate(),) |
@@ -348,6 +455,7 @@ def data_parallel( |
348 | 455 | mode, |
349 | 456 | regional_ac, |
350 | 457 | mp_policy=mp_policy, |
| 458 | + gradient_divide_factor=gradient_divide_factor, |
351 | 459 | ), |
352 | 460 | ) |
353 | 461 | return model |
0 commit comments