Skip to content

Commit

Permalink
Recompute: fix bug with transformer attention mask
Browse files Browse the repository at this point in the history
  • Loading branch information
JZ-LIANG committed Aug 6, 2021
1 parent c91b1e0 commit 6dcdbc5
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions python/paddle/distributed/fleet/utils/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,23 +145,25 @@ def backward(ctx, *args):

# run backward() with only tensor that requires grad
forward_outputs_with_grad = []
backward_inputs = list(args)
# NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
# pylayer will force the stop_gradient of attention mask to be False, which will make the number of
# tensor that need grad does not match.
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad = []
for i in range(len(outputs)):
if isinstance(outputs[i],
core.VarBase) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs_with_grad.append(args[i])

if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this recompute() is not necessary"
)

assert len(backward_inputs) == len(
forward_outputs_with_grad
), "number of forward outputs is [{}], but the backward got [{}] inputs".format(
len(forward_outputs_with_grad), len(backward_inputs))

# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
paddle.autograd.backward(forward_outputs_with_grad,
backward_inputs_with_grad)

grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))
Expand Down

1 comment on commit 6dcdbc5

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.