Skip to content

Commit

Permalink
fix gradient clip
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Dec 26, 2023
1 parent d98f1c2 commit 0c6c724
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 326 deletions.
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4023,6 +4023,7 @@ void SequenceMaskScalarInferMeta(const MetaTensor& x,

void SquaredL2NormInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims({1});
out->set_dtype(x.dtype());
}

void SqueezeInferMeta(const MetaTensor& x,
Expand Down
50 changes: 0 additions & 50 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,7 @@ def __init__(self, optimizer, shard_fn=None):
optimizer.helper = paddle.base.layer_helper.LayerHelper(
optimizer.__class__.__name__
)
# solve global_clip for auto_parallel
self._shard_clip = False
self._generate_flag = False
if (
hasattr(optimizer, "_grad_clip")
and optimizer._grad_clip is not None
Expand Down Expand Up @@ -564,40 +562,15 @@ def _shard_accumulator(self, param):
placements=placements,
)

def generate_pp_mesh(self, all_process_ids=[]):
pp_mesh = None
if len(all_process_ids) <= 1:
return pp_mesh
else:
mesh = np.array(all_process_ids)
for i in range(mesh.shape[-1]):
ranks = mesh[:, i].tolist()
if dist.get_rank() in ranks:
pp_mesh = dist.ProcessMesh(ranks)
return pp_mesh

def step(self):
if not isinstance(self._inner_opt._parameter_list[0], dict):
params_grads = []
all_process_ids = []
for param in self._inner_opt._parameter_list:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads.append((param, grad_var))
if (
not self._generate_flag
and self._shard_clip
and param.is_dist()
):
if param.process_mesh.process_ids not in all_process_ids:
all_process_ids.append(param.process_mesh.process_ids)
if not self._generate_flag and self._shard_clip:
self._inner_opt._grad_clip._pp_mesh = self.generate_pp_mesh(
all_process_ids
)
self._generate_flag = True
for p, g in params_grads:
self._shard_accumulator(p)
self._inner_opt._apply_optimize(
Expand All @@ -606,36 +579,13 @@ def step(self):
else:
for param_group in self._inner_opt._param_groups:
params_grads = defaultdict(lambda: [])
all_process_ids = []
shard_clip_flag = False
for param in param_group['params']:
if param.stop_gradient:
continue
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads['params'].append((param, grad_var))
if (
not self._generate_flag
and "grad_clip" in param_group.keys()
and isinstance(
param_group["grad_clip"],
paddle.nn.ClipGradByGlobalNorm,
)
and param.is_dist()
):
if (
param.process_mesh.process_ids
not in all_process_ids
):
all_process_ids.append(
param.process_mesh.process_ids
)
shard_clip_flag = True

if shard_clip_flag:
param_group["grad_clip"]._pp_mesh = self.generate_pp_mesh(
all_process_ids
)
params_grads.update(
{k: v for k, v in param_group.items() if k != 'params'}
)
Expand Down
64 changes: 17 additions & 47 deletions python/paddle/nn/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import paddle
import paddle.autograd as imperative_base
import paddle.distributed as dist
from paddle import _C_ops
from paddle.base import core, framework, unique_name
from paddle.base.data_feeder import check_variable_and_dtype
Expand Down Expand Up @@ -661,8 +662,6 @@ def __init__(
# are so many hard code depends on `add_n` in the legacy static
# manual hybrid-parallel.
self._async_add_n = None
# just for auto parallel.
self._pp_mesh = None

def __str__(self):
return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)
Expand All @@ -673,6 +672,8 @@ def _dygraph_clip(self, params_grads):
sum_square_list = []
sum_square_list_fp16 = []
sum_square_list_fp32 = []
src_mesh = params_grads[0][0].process_mesh

for p, g in params_grads:
if g is None:
continue
Expand All @@ -689,6 +690,14 @@ def _dygraph_clip(self, params_grads):
merge_grad = get_tensor_from_selected_rows(merge_grad)

sum_square = _squared_l2_norm(merge_grad)

# if the gradient mesh is not equal to src mesh
# do reshard to get the result of squared_l2 from other pp stage mesh
if src_mesh is not None and g.process_mesh != src_mesh:
sum_square = dist.reshard(
sum_square, src_mesh, sum_square.placements
)

if (
sum_square.dtype == core.VarDesc.VarType.FP16
or sum_square.dtype == core.VarDesc.VarType.BF16
Expand All @@ -715,64 +724,21 @@ def async_add_n(var_list):
global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = async_add_n(sum_square_list_fp16)
if self._pp_mesh is not None:
# sync pp
global_norm_var_fp16 = (
paddle.distributed.auto_parallel.api.dtensor_from_local(
global_norm_var_fp16._local_value().reshape([-1]),
self._pp_mesh,
[paddle.distributed.Partial()],
)
)
global_norm_var_fp16 = paddle.distributed.reshard(
global_norm_var_fp16,
self._pp_mesh,
[paddle.distributed.Replicate()],
)
global_norm_var.append(global_norm_var_fp16.astype(sum_dtype))
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = async_add_n(sum_square_list_fp32)
if self._pp_mesh is not None:
# sync pp
global_norm_var_fp32 = (
paddle.distributed.auto_parallel.api.dtensor_from_local(
global_norm_var_fp32._local_value().reshape([-1]),
self._pp_mesh,
[paddle.distributed.Partial()],
)
)
global_norm_var_fp32 = paddle.distributed.reshard(
global_norm_var_fp32,
self._pp_mesh,
[paddle.distributed.Replicate()],
)
if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32)
else:
global_norm_var.append(global_norm_var_fp32.astype(sum_dtype))
if len(sum_square_list) > 0:
global_norm_var_fp64 = async_add_n(sum_square_list)
if self._pp_mesh is not None:
# sync pp
global_norm_var_fp64 = (
paddle.distributed.auto_parallel.api.dtensor_from_local(
global_norm_var_fp64._local_value().reshape([-1]),
self._pp_mesh,
[paddle.distributed.Partial()],
)
)
global_norm_var_fp64 = paddle.distributed.reshard(
global_norm_var_fp64,
self._pp_mesh,
[paddle.distributed.Replicate()],
)
global_norm_var.append(global_norm_var_fp64)
if self._pp_mesh is not None:
global_norm_var = [t._local_value() for t in global_norm_var]

global_norm_var = async_add_n(global_norm_var)
global_norm_var = paddle.sqrt(global_norm_var)
max_global_norm = paddle.full(
shape=[], dtype=global_norm_var.dtype, fill_value=self.clip_norm
shape=[], dtype=sum_dtype, fill_value=self.clip_norm
)

need_clip = False
Expand Down Expand Up @@ -800,6 +766,10 @@ def async_add_n(var_list):
if clip_var.dtype != g.dtype
else clip_var
)
if clip_input.process_mesh != g.process_mesh:
clip_input = paddle.distributed.reshard(
clip_input, g.process_mesh, clip_input.placements
)
new_grad = paddle.multiply(g, clip_input)
params_and_grads.append((p, new_grad))
else:
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ def _create_optimization_pass(
# need to filter again here.
if (
param_and_grad[1] is None
or not param_and_grad[0]._is_initialized()
or not param_and_grad[1]._is_initialized()
):
continue
if param_and_grad[0].stop_gradient is False:
Expand All @@ -1209,7 +1209,7 @@ def _create_optimization_pass(
for param_and_grad in parameters_and_grads['params']:
if (
param_and_grad[1] is None
or not param_and_grad[0]._is_initialized()
or not param_and_grad[1]._is_initialized()
):
continue
if param_and_grad[0].stop_gradient is False:
Expand Down
Loading

0 comments on commit 0c6c724

Please sign in to comment.