Skip to content

Commit

Permalink
Add comment and UT
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed May 27, 2022
1 parent 9c1efa2 commit a63b709
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
6 changes: 4 additions & 2 deletions paddle/fluid/operators/controlflow/conditional_block_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,17 @@ class ConditionalBlockGradOp : public ConditionalOp {
framework::Variable *inside_var =
cur_scope.FindLocalVar(inside_grad_name);
if (inside_var == nullptr) {
assign_zero_outside_grads.push_back(outside_grad_name);
assign_zero_inputs.push_back(inputs[i]);
assign_zero_outside_grads.emplace_back(outside_grad_name);
assign_zero_inputs.emplace_back(inputs[i]);
continue;
}
platform::DeviceContext *dev_ctx =
platform::DeviceContextPool::Instance().Get(place);
framework::VisitVarType(*inside_var,
AssignFunctor(outside_var, *dev_ctx));
}
// Assign zero to the grad_vars that are in outside_grads but not in
// inside_grads
AssignZeroToParentScope(place, parent_scope, assign_zero_inputs,
assign_zero_outside_grads);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,41 @@ def test_ast_to_func(self):
ProgramTranslator().enable(False)


class IfElseNet(paddle.nn.Layer):
def __init__(self):
super(IfElseNet, self).__init__()
self.param = self.create_parameter(
shape=[3, 2], dtype='float32', is_bias=False)

@paddle.jit.to_static
def forward(self, a, b, c):
a = paddle.matmul(a, self.param)
a = paddle.reshape(a, (2, 4))
cond = paddle.to_tensor([10])
if cond == 10:
a_argmax = a.argmax(axis=-1)
b = b + self.param
else:
print(c)
return b


class TestDy2StIfElseBackward(unittest.TestCase):
def test_run_backward(self):
a = paddle.randn((4, 3), dtype='float32')
a.stop_gradient = False
b = paddle.to_tensor([10]).astype('float32')
b.stop_gradient = False
c = paddle.to_tensor([2])
c.stop_gradient = False

net = IfElseNet()
net.train()
out = net(a, b, c)
out.backward()
self.assertTrue(np.allclose((b + net.param).numpy(), out.numpy()))


if __name__ == '__main__':
with paddle.fluid.framework._test_eager_guard():
unittest.main()

1 comment on commit a63b709

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.