-
Notifications
You must be signed in to change notification settings - Fork 193
/
Copy pathclip-grad-v2.2.patch
53 lines (50 loc) · 2.25 KB
/
clip-grad-v2.2.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index e8d0d02..91c663e 100644
--- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
@@ -52,6 +52,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# - should not be a replica due to tensor model parallelism
grads = []
grads_for_norm = []
+ grads_in_moe = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = not hasattr(param, 'shared') or not param.shared
@@ -63,7 +64,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
- grads_for_norm.append(grad)
+ if hasattr(param, 'dp_comm') and param.dp_comm in ('none'):
+ grads_in_moe.append(grad)
+ else:
+ grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm)
@@ -72,6 +76,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Calculate norm.
if norm_type == inf:
+ # TODO: moe
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
@@ -96,7 +101,20 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
+ if grads_in_moe:
+ grad_norm, _ = multi_tensor_applier(
+ amp_C.multi_tensor_l2norm,
+ dummy_overflow_buf,
+ [grads_in_moe],
+ False # no per-parameter norm
+ )
+ grad_norm = grad_norm ** norm_type
+ torch.distributed.all_reduce(grad_norm,
+ group=mpu.get_model_parallel_group())
+ total_norm += grad_norm
+
else:
+ # TODO: moe
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type