diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 3e454777fc53..12cf4a1bffda 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -344,13 +344,40 @@ Expr Gradient(const Expr& re, const Module& mod) { args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p))))); } auto c = ll->Push(CallNode::make(rev, args)); - ll->Push(RefWriteNode::make(GetField(c, 1), OnesLike(GetField(c, 0)))); + std::function init_grad; + init_grad = [&](const Expr& e, const Type& t) { + if (t.as()) { + ll->Push(RefWriteNode::make(GetField(e, 1), OnesLike(GetField(e, 0)))); + } else if (auto tt = t.as()) { + CHECK_GT(tt->fields.size(), 0); + init_grad(ll->Push(GetField(e, 0)), tt->fields[0]); + } else { + LOG(FATAL) << "unhandled type " << t; + throw; + } + }; + init_grad(c, f->body->checked_type()); ll->Push(CallNode::make(RefReadNode::make(bp), {})); std::vector ret; for (const auto& a : args) { ret.push_back(RefReadNode::make(GetField(a, 1))); } - return Pair(GetField(c, 0), TupleNode::make(ret)); + std::function get_final_result; + get_final_result = [&](const Expr& e, const Type& t) -> Expr { + if (t.as()) { + return GetField(e, 0); + } else if (auto tt = t.as()) { + tvm::Array fields; + for (size_t i = 0; i < tt->fields.size(); ++i) { + fields.push_back(get_final_result(ll->Push(GetField(e, i)), tt->fields[i])); + } + return TupleNode::make(fields); + } else { + LOG(FATAL) << "unhandled type " << t; + throw; + } + }; + return Pair(get_final_result(c, f->body->checked_type()), TupleNode::make(ret)); }); return FunctionNode::make(f->params, body, GradRetType(GetRef(f)), {}); } diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index c3b1971506d5..0b85262a795e 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -252,11 +252,28 @@ def test_if(): net = relay.log(net) func = relay.Function(free_vars(net), net) func = run_infer_type(func) - net = run_infer_type(func) - net = gradient(net, mode='higher_order') + net = gradient(func, mode='higher_order') net = run_infer_type(net) +def test_grad_tuple(): + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + y = x + x + func = relay.Function([x], relay.Tuple([y + y, y])) + func = run_infer_type(func) + back_func = run_infer_type(gradient(func)) + assert back_func.checked_type == relay.FuncType([t], relay.TupleType([relay.TupleType([t, t]), relay.TupleType([t])])) + ex = create_executor() + x = rand(dtype, *shape) + (forward_four, forward_two), (grad,) = ex.evaluate(back_func)(x) + tvm.testing.assert_allclose(forward_four.asnumpy(), 4 * x.asnumpy()) + tvm.testing.assert_allclose(forward_two.asnumpy(), 2 * x.asnumpy()) + tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy())) + + if __name__ == "__main__": test_id() test_add() @@ -269,3 +286,4 @@ def test_if(): test_ref() test_square_second_order() test_if() + test_grad_tuple()