diff --git a/src/relay/transforms/first_order_gradient.cc b/src/relay/transforms/first_order_gradient.cc index 55714592ded7..3419cb670a28 100644 --- a/src/relay/transforms/first_order_gradient.cc +++ b/src/relay/transforms/first_order_gradient.cc @@ -174,11 +174,14 @@ struct FirstOrderReverseAD : ExprFunctor { } ADValue VisitExpr_(const TupleGetItemNode* op) final { - Expr e = GetRef(op); ADValue tup = VisitExpr(op->tuple); - auto tt = op->tuple->checked_type().as(); + TupleType tt = Downcast(op->tuple->checked_type()); size_t idx = op->index; - auto ret = std::make_shared(ll, e, diag_ctx); + // reconstruct projection using let-bound variable to avoid duplicating input tuple + TupleGetItem orig = TupleGetItem(tup->get().forward, idx); + orig->checked_type_ = op->checked_type(); + auto ret = std::make_shared(ll, orig, diag_ctx); + // for orig = pi(tup, i), pi_grad(tup, i, g) = G where pi(G, i) = g and pi(G, j) = 0 for j != i backprop_actions.push_back([tup, tt, idx, ret](LetList* ll) { auto& ad_tup = tup->get(); std::vector updated_grads; @@ -193,16 +196,26 @@ struct FirstOrderReverseAD : ExprFunctor { } ADValue VisitExpr_(const TupleNode* op) final { - Expr e = GetRef(op); - std::vector fields; + auto tt = Downcast(op->checked_type()); + std::vector ad_fields; + std::vector field_bindings; for (const auto& f : op->fields) { - fields.push_back(VisitExpr(f)); + ADValue f_ad = VisitExpr(f); + if (!dynamic_cast(f_ad.get())) { + diag_ctx.EmitFatal(Diagnostic::Error(f->span) + << "first-order AD only supports (nested) tuples of tensors"); + } + ad_fields.push_back(f_ad); + field_bindings.push_back(f_ad->get().forward); } - auto tt = op->checked_type().as(); - auto ret = std::make_shared(ll, e, diag_ctx); - backprop_actions.push_back([fields, tt, ret](LetList* ll) { - for (size_t i = 0; i < fields.size(); ++i) { - auto& ad_field = fields[i]->get(); + // reconstruct tuple using let-bound variables to avoid duplication + auto orig = Tuple(field_bindings); + orig->checked_type_ = tt; + auto ret = std::make_shared(ll, orig, diag_ctx); + // for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1), ..., pi(G, n)] + backprop_actions.push_back([ad_fields, tt, ret](LetList* ll) { + for (size_t i = 0; i < ad_fields.size(); ++i) { + auto& ad_field = ad_fields[i]->get(); ad_field.reverse = LiftedAdd(tt->fields[i], ad_field.reverse, GetField(ret->reverse, i), ll); } diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 6228c5fc157b..cd0edf95aba7 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -429,6 +429,23 @@ def test_no_duplication(): assert counts["nn.dense"] == 3, "We expect 3 dense (1 forward, two backward)" +def test_no_duplication_tuples(): + x = tvm.relay.Var("x", type_annotation=tvm.relay.TensorType([12, 12])) + y = tvm.relay.Var("y", type_annotation=tvm.relay.TensorType([12, 12])) + xy = tvm.relay.nn.dense(x, y) + + t = relay.Tuple([xy, xy]) + + m = tvm.relay.sum(xy, keepdims=True) + s = tvm.relay.sum(relay.TupleGetItem(t, 0) - m) + fn = tvm.relay.Function([x, y], s) + fn = run_infer_type(fn) + gr = tvm.relay.transform.gradient(fn, mode="first_order") + + counts = count_ops(gr) + assert counts["nn.dense"] == 3, "We expect 3 dense (1 forward, two backward)" + + def test_global_function(): m = tvm.IRModule() shape = (10, 10)