Skip to content

Commit

Permalink
[Relay] [Training] Allow gradient to return a tuple (apache#3600)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame authored and wweic committed Sep 6, 2019
1 parent 549d873 commit 94df86f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
31 changes: 29 additions & 2 deletions src/relay/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,40 @@ Expr Gradient(const Expr& re, const Module& mod) {
args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
}
auto c = ll->Push(CallNode::make(rev, args));
ll->Push(RefWriteNode::make(GetField(c, 1), OnesLike(GetField(c, 0))));
std::function<void(const Expr&, const Type&)> init_grad;
init_grad = [&](const Expr& e, const Type& t) {
if (t.as<TensorTypeNode>()) {
ll->Push(RefWriteNode::make(GetField(e, 1), OnesLike(GetField(e, 0))));
} else if (auto tt = t.as<TupleTypeNode>()) {
CHECK_GT(tt->fields.size(), 0);
init_grad(ll->Push(GetField(e, 0)), tt->fields[0]);
} else {
LOG(FATAL) << "unhandled type " << t;
throw;
}
};
init_grad(c, f->body->checked_type());
ll->Push(CallNode::make(RefReadNode::make(bp), {}));
std::vector<Expr> ret;
for (const auto& a : args) {
ret.push_back(RefReadNode::make(GetField(a, 1)));
}
return Pair(GetField(c, 0), TupleNode::make(ret));
std::function<Expr(const Expr&, const Type&)> get_final_result;
get_final_result = [&](const Expr& e, const Type& t) -> Expr {
if (t.as<TensorTypeNode>()) {
return GetField(e, 0);
} else if (auto tt = t.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
for (size_t i = 0; i < tt->fields.size(); ++i) {
fields.push_back(get_final_result(ll->Push(GetField(e, i)), tt->fields[i]));
}
return TupleNode::make(fields);
} else {
LOG(FATAL) << "unhandled type " << t;
throw;
}
};
return Pair(get_final_result(c, f->body->checked_type()), TupleNode::make(ret));
});
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
}
Expand Down
22 changes: 20 additions & 2 deletions tests/python/relay/test_pass_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,28 @@ def test_if():
net = relay.log(net)
func = relay.Function(free_vars(net), net)
func = run_infer_type(func)
net = run_infer_type(func)
net = gradient(net, mode='higher_order')
net = gradient(func, mode='higher_order')
net = run_infer_type(net)


def test_grad_tuple():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
y = x + x
func = relay.Function([x], relay.Tuple([y + y, y]))
func = run_infer_type(func)
back_func = run_infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([relay.TupleType([t, t]), relay.TupleType([t])]))
ex = create_executor()
x = rand(dtype, *shape)
(forward_four, forward_two), (grad,) = ex.evaluate(back_func)(x)
tvm.testing.assert_allclose(forward_four.asnumpy(), 4 * x.asnumpy())
tvm.testing.assert_allclose(forward_two.asnumpy(), 2 * x.asnumpy())
tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy()))


if __name__ == "__main__":
test_id()
test_add()
Expand All @@ -269,3 +286,4 @@ def test_if():
test_ref()
test_square_second_order()
test_if()
test_grad_tuple()

0 comments on commit 94df86f

Please sign in to comment.