diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 69e173eb57468..7ffbf1933be37 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -182,8 +182,8 @@ class ConditionalBlockGradOp : public ConditionalOp { framework::Variable *inside_var = cur_scope.FindLocalVar(inside_grad_name); if (inside_var == nullptr) { - assign_zero_outside_grads.push_back(outside_grad_name); - assign_zero_inputs.push_back(inputs[i]); + assign_zero_outside_grads.emplace_back(outside_grad_name); + assign_zero_inputs.emplace_back(inputs[i]); continue; } platform::DeviceContext *dev_ctx = @@ -191,6 +191,8 @@ class ConditionalBlockGradOp : public ConditionalOp { framework::VisitVarType(*inside_var, AssignFunctor(outside_var, *dev_ctx)); } + // Assign zero to the grad_vars that are in outside_grads but not in + // inside_grads AssignZeroToParentScope(place, parent_scope, assign_zero_inputs, assign_zero_outside_grads); } diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 9a9e7ee243872..276aa68e895c6 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -424,6 +424,41 @@ def test_ast_to_func(self): ProgramTranslator().enable(False) +class IfElseNet(paddle.nn.Layer): + def __init__(self): + super(IfElseNet, self).__init__() + self.param = self.create_parameter( + shape=[3, 2], dtype='float32', is_bias=False) + + @paddle.jit.to_static + def forward(self, a, b, c): + a = paddle.matmul(a, self.param) + a = paddle.reshape(a, (2, 4)) + cond = paddle.to_tensor([10]) + if cond == 10: + a_argmax = a.argmax(axis=-1) + b = b + self.param + else: + print(c) + return b + + +class TestDy2StIfElseBackward(unittest.TestCase): + def test_run_backward(self): + a = paddle.randn((4, 3), dtype='float32') + a.stop_gradient = False + b = paddle.to_tensor([10]).astype('float32') + b.stop_gradient = False + c = paddle.to_tensor([2]) + c.stop_gradient = False + + net = IfElseNet() + net.train() + out = net(a, b, c) + out.backward() + self.assertTrue(np.allclose((b + net.param).numpy(), out.numpy())) + + if __name__ == '__main__': with paddle.fluid.framework._test_eager_guard(): unittest.main()