Skip to content

Commit

Permalink
Recompute: fix bug with transformer attention mask
Browse files Browse the repository at this point in the history
  • Loading branch information
JZ-LIANG committed Aug 6, 2021
1 parent c91b1e0 commit 6dcdbc5
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions python/paddle/distributed/fleet/utils/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,23 +145,25 @@ def backward(ctx, *args):

# run backward() with only tensor that requires grad
forward_outputs_with_grad = []
backward_inputs = list(args)
# NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
# pylayer will force the stop_gradient of attention mask to be False, which will make the number of
# tensor that need grad does not match.
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = []
for i in range(len(outputs)):
if isinstance(outputs[i],
core.VarBase) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i])

if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this recompute() is not necessary"
)

assert len(backward_inputs) == len(
forward_outputs_with_grad
), "number of forward outputs is [{}], but the backward got [{}] inputs".format(
len(forward_outputs_with_grad), len(backward_inputs))

# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
paddle.autograd.backward(forward_outputs_with_grad,
backward_inputs_with_grad)

grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))
Expand Down

0 comments on commit 6dcdbc5

Please sign in to comment.