Skip to content

Commit

Permalink
fix first-order AD tuple/projection expr duplication (#8318)
Browse files Browse the repository at this point in the history
  • Loading branch information
altanh authored Jun 24, 2021
1 parent 3e28716 commit 4f9e614
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/relay/transforms/first_order_gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,14 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
}

ADValue VisitExpr_(const TupleGetItemNode* op) final {
Expr e = GetRef<Expr>(op);
ADValue tup = VisitExpr(op->tuple);
auto tt = op->tuple->checked_type().as<TupleTypeNode>();
TupleType tt = Downcast<TupleType>(op->tuple->checked_type());
size_t idx = op->index;
auto ret = std::make_shared<ADTensor>(ll, e, diag_ctx);
// reconstruct projection using let-bound variable to avoid duplicating input tuple
TupleGetItem orig = TupleGetItem(tup->get<ADTensor>().forward, idx);
orig->checked_type_ = op->checked_type();
auto ret = std::make_shared<ADTensor>(ll, orig, diag_ctx);
// for orig = pi(tup, i), pi_grad(tup, i, g) = G where pi(G, i) = g and pi(G, j) = 0 for j != i
backprop_actions.push_back([tup, tt, idx, ret](LetList* ll) {
auto& ad_tup = tup->get<ADTensor>();
std::vector<Expr> updated_grads;
Expand All @@ -193,16 +196,26 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
}

ADValue VisitExpr_(const TupleNode* op) final {
Expr e = GetRef<Expr>(op);
std::vector<ADValue> fields;
auto tt = Downcast<TupleType>(op->checked_type());
std::vector<ADValue> ad_fields;
std::vector<Expr> field_bindings;
for (const auto& f : op->fields) {
fields.push_back(VisitExpr(f));
ADValue f_ad = VisitExpr(f);
if (!dynamic_cast<ADTensor*>(f_ad.get())) {
diag_ctx.EmitFatal(Diagnostic::Error(f->span)
<< "first-order AD only supports (nested) tuples of tensors");
}
ad_fields.push_back(f_ad);
field_bindings.push_back(f_ad->get<ADTensor>().forward);
}
auto tt = op->checked_type().as<TupleTypeNode>();
auto ret = std::make_shared<ADTensor>(ll, e, diag_ctx);
backprop_actions.push_back([fields, tt, ret](LetList* ll) {
for (size_t i = 0; i < fields.size(); ++i) {
auto& ad_field = fields[i]->get<ADTensor>();
// reconstruct tuple using let-bound variables to avoid duplication
auto orig = Tuple(field_bindings);
orig->checked_type_ = tt;
auto ret = std::make_shared<ADTensor>(ll, orig, diag_ctx);
// for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1), ..., pi(G, n)]
backprop_actions.push_back([ad_fields, tt, ret](LetList* ll) {
for (size_t i = 0; i < ad_fields.size(); ++i) {
auto& ad_field = ad_fields[i]->get<ADTensor>();
ad_field.reverse =
LiftedAdd(tt->fields[i], ad_field.reverse, GetField(ret->reverse, i), ll);
}
Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_pass_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,23 @@ def test_no_duplication():
assert counts["nn.dense"] == 3, "We expect 3 dense (1 forward, two backward)"


def test_no_duplication_tuples():
x = tvm.relay.Var("x", type_annotation=tvm.relay.TensorType([12, 12]))
y = tvm.relay.Var("y", type_annotation=tvm.relay.TensorType([12, 12]))
xy = tvm.relay.nn.dense(x, y)

t = relay.Tuple([xy, xy])

m = tvm.relay.sum(xy, keepdims=True)
s = tvm.relay.sum(relay.TupleGetItem(t, 0) - m)
fn = tvm.relay.Function([x, y], s)
fn = run_infer_type(fn)
gr = tvm.relay.transform.gradient(fn, mode="first_order")

counts = count_ops(gr)
assert counts["nn.dense"] == 3, "We expect 3 dense (1 forward, two backward)"


def test_global_function():
m = tvm.IRModule()
shape = (10, 10)
Expand Down

0 comments on commit 4f9e614

Please sign in to comment.