Skip to content

Commit

Permalink
[Dy2St]Fix cond_block_grad error when handle no need grad vras (Paddl…
Browse files Browse the repository at this point in the history
…ePaddle#43034)

* Fix cond_block_grad error when handle no need grad vras

* Add comment and UT
  • Loading branch information
0x45f committed May 30, 2022
1 parent aedd459 commit 7ac84bd
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
23 changes: 16 additions & 7 deletions paddle/fluid/operators/controlflow/conditional_block_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
/* keep_kid_scopes */ false);

AssignLocalGradientToParentScope(dev_place, cur_scope, scope,
inside_grads, outside_grads);
inside_grads, outside_grads, inputs);
return;
}

Expand All @@ -155,27 +155,36 @@ class ConditionalBlockGradOp : public ConditionalOp {
const platform::Place &place, const framework::Scope &cur_scope,
const framework::Scope &parent_scope,
const std::vector<std::string> &inside_grads,
const std::vector<std::string> &outside_grads) const {
const std::vector<std::string> &outside_grads,
const std::vector<std::string> &inputs) const {
std::vector<std::string> assign_zero_outside_grads;
std::vector<std::string> assign_zero_inputs;
for (size_t i = 0; i < outside_grads.size(); ++i) {
const std::string &outside_grad_name = outside_grads[i];
const std::string &inside_grad_name = inside_grads[i];
VLOG(4) << "inside_grad_name = " << inside_grad_name
<< ", outside_grad_name = " << outside_grad_name;
framework::Variable *inside_var =
cur_scope.FindLocalVar(inside_grad_name);
if (inside_var == nullptr) {
continue;
}
framework::Variable *outside_var =
parent_scope.FindVar(outside_grad_name);
if (outside_var == nullptr) {
continue;
}
framework::Variable *inside_var =
cur_scope.FindLocalVar(inside_grad_name);
if (inside_var == nullptr) {
assign_zero_outside_grads.emplace_back(outside_grad_name);
assign_zero_inputs.emplace_back(inputs[i]);
continue;
}
platform::DeviceContext *dev_ctx =
platform::DeviceContextPool::Instance().Get(place);
framework::VisitVarType(*inside_var,
AssignFunctor(outside_var, *dev_ctx));
}
// Assign zero to the grad_vars that are in outside_grads but not in
// inside_grads
AssignZeroToParentScope(place, parent_scope, assign_zero_inputs,
assign_zero_outside_grads);
}

void AssignZeroToParentScope(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,41 @@ def test_ast_to_func(self):
ProgramTranslator().enable(False)


class IfElseNet(paddle.nn.Layer):
def __init__(self):
super(IfElseNet, self).__init__()
self.param = self.create_parameter(
shape=[3, 2], dtype='float32', is_bias=False)

@paddle.jit.to_static
def forward(self, a, b, c):
a = paddle.matmul(a, self.param)
a = paddle.reshape(a, (2, 4))
cond = paddle.to_tensor([10])
if cond == 10:
a_argmax = a.argmax(axis=-1)
b = b + self.param
else:
print(c)
return b


class TestDy2StIfElseBackward(unittest.TestCase):
def test_run_backward(self):
a = paddle.randn((4, 3), dtype='float32')
a.stop_gradient = False
b = paddle.to_tensor([10]).astype('float32')
b.stop_gradient = False
c = paddle.to_tensor([2])
c.stop_gradient = False

net = IfElseNet()
net.train()
out = net(a, b, c)
out.backward()
self.assertTrue(np.allclose((b + net.param).numpy(), out.numpy()))


if __name__ == '__main__':
with paddle.fluid.framework._test_eager_guard():
unittest.main()

0 comments on commit 7ac84bd

Please sign in to comment.