Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,13 @@ def reduce_gradients(self, parameter_list, hcg):
with framework.no_grad():
for param in parameter_list:
g_var = self._get_param_grad(param)
if g_var is None:
if hasattr(param, "main_grad"):
g_var = paddle.zeros_like(param, dtype=paddle.float32)
param.main_grad = g_var
else:
g_var = paddle.zeros_like(param, dtype=param.dtype)
param.grad = g_var
if g_var is not None:
reduce_op = ReduceOp.AVG
if not self.use_reduce_avg:
Expand Down
15 changes: 6 additions & 9 deletions python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,22 +619,19 @@ def _copy_grad_to_buffer(self, param):
)

grad_var = param.main_grad if self.use_main_grad else param.grad
assert grad_var is not None, (
f"The current parameter[{param.name}] has no gradient, its stop_grdient is {param.stop_gradient}"
)
grad_var.stop_gradient = True
grad_var.flatten_()

tmp_var.add_(grad_var)
tmp_var.get_tensor()._set_dims(param.shape)
if grad_var is not None:
grad_var.stop_gradient = True
grad_var.flatten_()
tmp_var.add_(grad_var)
grad_var._clear()

tmp_var.get_tensor()._set_dims(param.shape)
if self.use_main_grad:
param.main_grad._clear()
if not self._free_grads_in_comm:
param.main_grad = tmp_var
param.main_grad.name = "main_grad@" + param.name
else:
param.grad._clear()
if not self._free_grads_in_comm:
param._copy_gradient_from(tmp_var)

Expand Down
Loading