diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index 6313850b48871..b8624b46eca81 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -43,8 +43,8 @@ def test_checkpoint(): scope = relay.ScopeBuilder() out_tuple = scope.let("out_tuple", relay.Tuple([relay.add(inputs[0], inputs[1]), - relay.add(inputs[2], inputs[3])])) - scope.ret(relay.subtract(relay.TupleGetItem(relay.annotation.checkpoint(out_tuple), 0), + relay.multiply(inputs[2], inputs[3])])) + scope.ret(relay.subtract(relay.annotation.checkpoint(relay.TupleGetItem(out_tuple, 0)), relay.TupleGetItem(out_tuple, 1))) out_single = scope.get() check_grad(relay.Function(inputs, out_single))