From 81452aafde2612ed7d0d76891151b675ef76f382 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Sat, 26 Oct 2019 17:04:42 -0700 Subject: [PATCH] [Relay][Training] Add checkpoint annotation for checkpointing memory optimization (#4146) * add checkpoint annotation for checkpointing memory optimization * add alpha-equivalence checkpoint test and fix gradient type issue * fix build issues * ignore checkpoint annotation when checking missing gradients * refactor, fix checkpoint compute for tuple and add tests --- python/tvm/relay/op/annotation/annotation.py | 19 ++- src/relay/op/annotation/annotation.cc | 27 ++++ src/relay/pass/de_duplicate.cc | 4 +- src/relay/pass/gradient.cc | 162 +++++++++++++++---- tests/python/relay/test_op_grad_level10.py | 12 ++ tests/python/relay/test_op_level10.py | 121 ++++++++++++++ 6 files changed, 309 insertions(+), 36 deletions(-) diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index 10c898538596..2b9d4bcd81bc 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -17,10 +17,10 @@ """Annotation operations.""" from __future__ import absolute_import as _abs from . import _make +from ..op import register_schedule, schedule_injective from .... import nd as _nd from .... import TVMContext as _TVMContext - def on_device(data, device): """Annotate an expression with a certain device type. @@ -61,3 +61,20 @@ def stop_fusion(data): The annotated expression. """ return _make.stop_fusion(data) + +def checkpoint(data): + """Annotate an expression to be a checkpoint for the checkpointing memory optimization. + + Parameters + ---------- + data : tvm.relay.Expr + The expression to be annotated. + + Returns + ------- + result : tvm.relay.Expr + The annotated expression. + """ + return _make.checkpoint(data) + +register_schedule("annotation.checkpoint", schedule_injective) diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index eeacc6cbf999..5a8ad33c63a7 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -144,5 +144,32 @@ Mark the end of bitpacking. return {topi::identity(inputs[0])}; }); +TVM_REGISTER_API("relay.op.annotation._make.checkpoint") +.set_body_typed([](Expr data) { + static const Op& op = Op::Get("annotation.checkpoint"); + return CallNode::make(op, {data}, Attrs{}, {}); +}); + +RELAY_REGISTER_OP("annotation.checkpoint") +.describe(R"code( +Mark a checkpoint for checkpointing memory optimization. +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_support_level(10) +.add_type_rel("Identity", IdentityRel) +.set_attr("TOpPattern", kOpaque) +.set_attr("TOpIsStateful", false) +.set_attr("FInferCorrectLayout", + ElemwiseArbitraryLayout) +.set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype, const Target& target) -> Array { + Array outputs; + for (size_t i = 0; i < inputs.size(); ++i) { + outputs.push_back(topi::identity(inputs[i])); + } + return outputs; + }); + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc index 332803cb71ba..38acdcde94b0 100644 --- a/src/relay/pass/de_duplicate.cc +++ b/src/relay/pass/de_duplicate.cc @@ -52,7 +52,9 @@ Expr DeDup(const Expr& e) { } Expr VisitExpr(const Expr& e) final { - return ExprMutator::VisitExpr(e); + auto ret = ExprMutator::VisitExpr(e); + ret->checked_type_ = e->checked_type_; + return ret; } Expr VisitExpr_(const VarNode* op) final { diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 8b06b8721994..b93c110a71c6 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -273,24 +273,29 @@ Type ReverseType(const Type& t) { * by doing a structure preserving map. */ Expr LiftTensor(const std::function& f, - const Type& t, + const std::function& tf, + const Type& forward_type, const Expr& e, LetList* ll) { CHECK(IsAtomic(e)) << e; - if (t.as()) { + if (forward_type.as()) { auto ret = f(e); - ret->checked_type_ = t; + ret->checked_type_ = tf(forward_type); return ret; - } else if (auto* tt = t.as()) { + } else if (auto* tt = forward_type.as()) { tvm::Array fields; + tvm::Array types; for (size_t i = 0; i < tt->fields.size(); ++i) { - fields.push_back(LiftTensor(f, - tt->fields[i], - ll->Push(GetField(e, i)), - ll)); + auto field = LiftTensor(f, + tf, + tt->fields[i], + ll->Push(GetField(e, i)), + ll); + fields.push_back(field); + types.push_back(field->checked_type_); } auto ret = TupleNode::make(fields); - ret->checked_type_ = t; + ret->checked_type_ = TupleTypeNode::make(types); return std::move(ret); } else { LOG(FATAL) << "unsupported input/output type: " << tt; @@ -298,25 +303,63 @@ Expr LiftTensor(const std::function& f, } } +/*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr, + * by stitching the references in the AD values. + */ +void TransferGrads(const Type& forward_type, + const Expr& from, + const Expr& to, + LetList* ll) { + CHECK(IsAtomic(from)) << from; + CHECK(IsAtomic(to)) << to; + if (forward_type.as()) { + auto from_ref = TupleGetItemNode::make(from, 1); + auto to_ref = TupleGetItemNode::make(to, 1); + ll->Push(RefWriteNode::make(to_ref, RefReadNode::make(from_ref))); + } else if (auto* tt = forward_type.as()) { + for (size_t i = 0; i < tt->fields.size(); ++i) { + TransferGrads(tt->fields[i], + ll->Push(TupleGetItemNode::make(from, i)), + ll->Push(TupleGetItemNode::make(to, i)), + ll); + } + } else { + LOG(FATAL) << "Unsupported input/output type: " << forward_type; + throw; + } +} + /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */ -Expr GetRev(const Type& t, const Expr& e, LetList* ll) { +Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) { auto rev = [&](const Expr& e) { return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e)))); }; - return LiftTensor(rev, t, e, ll); + auto rev_type = [&](const Type& forward_type) { + return ReverseType(forward_type); + }; + return LiftTensor(rev, rev_type, forward_type, e, ll); } /*! \brief ReverseType(t) -> t. Get the original value. */ -Expr GetValue(const Type& t, const Expr& e, LetList* ll) { - return LiftTensor([&](const Expr& e) { return GetField(e, 0); }, t, e, ll); +Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) { + auto val = [&](const Expr& e) { + return GetField(e, 0); + }; + auto val_type = [&](const Type& forward_type) { + return forward_type; + }; + return LiftTensor(val, val_type, forward_type, e, ll); } /*! \brief ReverseType(t) -> t. Get the gradient. */ -Expr GetGrad(const Type& t, const Expr& e, LetList* ll) { +Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) { auto grad = [&](const Expr& e) { return ll->Push(RefReadNode::make(GetField(e, 1))); }; - return LiftTensor(grad, t, e, ll); + auto grad_type = [&](const Type& forward_type) { + return forward_type; + }; + return LiftTensor(grad, grad_type, forward_type, e, ll); } void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { @@ -337,42 +380,87 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { } } +Expr BPEmpty() { + Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {}); + return RefCreateNode::make(unitF); +} + struct ReverseAD : ExprMutator { + using ADVarMap = std::unordered_map; + Var bp; + std::shared_ptr ad_vars; const OpMap rev_map = Op::GetAttr("FPrimalGradient"); - explicit ReverseAD(const Var& bp) : bp(bp) { } + explicit ReverseAD(const Var& bp, std::shared_ptr ad_vars) + : bp(bp), ad_vars(ad_vars) { } Expr VisitExpr_(const OpNode* op) final { LOG(FATAL) << "op should only be inside call"; throw; } - Expr VisitExpr_(const CallNode* op) final { - if (const OpNode* op_node = op->op.as()) { + Expr VisitCheckpoint(const CallNode *call) { + const OpNode* op_node = call->op.as(); + CHECK(op_node) << "expected op in call"; + Op op_ref = GetRef(op_node); + CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation"; + auto x = call->args[0]; + return LetList::With([&](LetList* ll) { + auto x_var = ll->Push(x); + auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll)); + auto bpv = ll->Push(RefReadNode::make(bp)); + Expr nbp = FunctionNode::make( + {}, + LetList::With([&](LetList* ll) { + // we need a new ReverseAD visitor to avoid clobbering the bp local var + auto dup_bp = ll->Push(BPEmpty()); + ReverseAD dup_diff(dup_bp, ad_vars); + auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x))); + + TransferGrads(call->checked_type(), ret, dup_ad, ll); + ll->Push(CallNode::make(RefReadNode::make(dup_bp), {})); + return CallNode::make(bpv, {}); + }), + TupleTypeNode::make({}), + {}); + ll->Push(RefWriteNode::make(bp, nbp)); + return ret; + }); + } + + Expr VisitExpr_(const CallNode* call) final { + if (const OpNode* op_node = call->op.as()) { Op op_ref = GetRef(op_node); + + if (op_ref->name == "annotation.checkpoint") { + return VisitCheckpoint(call); + } + + CHECK(rev_map.count(op_ref)) + << op_node->name << " does not have reverse mode defined"; return LetList::With([&](LetList* ll) { std::vector args; - for (const auto& arg : op->args) { + for (const auto& arg : call->args) { args.push_back(ll->Push(VisitExpr(arg))); } std::vector orig_args; for (size_t i = 0; i < args.size(); i++) { - orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll)); + orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll)); } - Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args); - orig->checked_type_ = op->checked_type(); + Expr orig = CallNode::make(call->op, orig_args, call->attrs, call->type_args); + orig->checked_type_ = call->checked_type(); Var orig_var = ll->Push(orig); - orig_var->checked_type_ = op->checked_type(); - auto ret = ll->Push(GetRev(op->checked_type(), orig_var, ll)); + orig_var->checked_type_ = call->checked_type(); + auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll)); auto bpv = ll->Push(RefReadNode::make(bp)); Expr nbp = FunctionNode::make( {}, LetList::With([&](LetList* ll) { - tvm::Array rev = rev_map[op_ref](orig, GetGrad(op->checked_type(), ret, ll)); + tvm::Array rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); CHECK(args.size() == rev.size()); for (size_t i = 0; i < args.size(); ++i) { - UpdateGrad(op->args[i]->checked_type(), args[i], rev[i], ll); + UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); } return CallNode::make(bpv, {}); }), @@ -382,7 +470,7 @@ struct ReverseAD : ExprMutator { return ret; }); } - return ExprMutator::VisitExpr_(op); + return ExprMutator::VisitExpr_(call); } Expr VisitExpr_(const ConstantNode* op) final { @@ -396,16 +484,22 @@ struct ReverseAD : ExprMutator { VisitExpr(op->false_branch)); } + Expr VisitExpr_(const VarNode* var) final { + // memoize Var -> ADVar so we don't end up with free Vars when checkpointing + auto var_ref = GetRef(var); + if (!ad_vars->count(var_ref)) { + auto res = Downcast(ExprMutator::VisitExpr_(var)); + (*ad_vars)[var_ref] = res; + } + + return ad_vars->at(var_ref); + } + Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } }; -Expr BPEmpty() { - Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {}); - return RefCreateNode::make(unitF); -} - bool MissingGrad(const Expr& e) { struct MGVisitor : ExprVisitor { const OpMap rev_map = Op::GetAttr("FPrimalGradient"); @@ -413,7 +507,7 @@ bool MissingGrad(const Expr& e) { void VisitExpr_(const OpNode* op) final { Op op_ref = GetRef(op); - if (!rev_map.count(op_ref)) { + if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) { op_names.insert(op_ref->name); } ExprVisitor::VisitExpr_(op); @@ -445,7 +539,7 @@ Expr Gradient(const Expr& re, const Module& mod) { CHECK(!MissingGrad(e)) << "input has operators with missing gradients"; Expr body = LetList::With([&](LetList* ll) { Var bp = ll->Push(BPEmpty()); - Expr rev = ReverseAD(bp)(e); + Expr rev = ReverseAD(bp, std::make_shared())(e); std::vector args; for (const auto& p : f->params) { args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p))))); diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index 7aa9e0bc135f..acf3b75e0cb5 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -30,6 +30,18 @@ def test_cross_entropy_with_logits_grad(): x = relay.var("x", shape=(2, 5)) y = relay.var("y", shape=(2, 5)) check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1) + +def test_checkpoint(): + inputs = [relay.var("x{}".format(i), shape=(1,)) for i in range(4)] + output = relay.multiply(relay.add(inputs[0], inputs[1]), + relay.add(inputs[2], inputs[3])) + check_grad(relay.Function(inputs, relay.annotation.checkpoint(output))) + + out_tuple = relay.Tuple([relay.add(inputs[0], inputs[1]), + relay.multiply(inputs[2], inputs[3])]) + out_single = relay.subtract(relay.TupleGetItem(relay.annotation.checkpoint(out_tuple), 0), + relay.TupleGetItem(out_tuple, 1)) + check_grad(relay.Function(inputs, out_single)) if __name__ == "__main__": diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index e828fa30de56..d9e29d8bbd9f 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -31,6 +31,127 @@ def run_infer_type(expr): entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body +def test_checkpoint(): + dtype = "float32" + xs = [relay.var("x{}".format(i), dtype) for i in range(4)] + f = relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])) + f_checkpoint = relay.annotation.checkpoint(f) + + func, func_checkpoint = relay.Function(xs, f), relay.Function(xs, f_checkpoint) + f, f_checkpoint = run_infer_type(func), run_infer_type(func_checkpoint) + assert f.checked_type == f_checkpoint.checked_type + + inputs = [np.random.uniform() for _ in range(len(xs))] + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + f_res = intrp.evaluate(f)(*inputs) + f_checkpoint_res = intrp.evaluate(f_checkpoint)(*inputs) + tvm.testing.assert_allclose(f_res.asnumpy(), f_checkpoint_res.asnumpy(), 0, 0) + +def test_checkpoint_alpha_equal(): + xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)] + f = relay.Function(xs, relay.annotation.checkpoint( + relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])) + )) + df = transform.gradient(run_infer_type(f)) + + # run PE and DCE + with transform.PassContext(opt_level=3): + passes = [transform.PartialEvaluate(), + transform.DeadCodeElimination(inline_once=True)] + mod = transform.Sequential(passes)(relay.Module.from_expr(df)) + df = mod["main"] + + df_parsed = relay.parser.fromtext( + """ + v0.0.4 + fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], + %z: Tensor[(1), float32], %w: Tensor[(1), float32]) + -> (Tensor[(1), float32], + (Tensor[(1), float32], Tensor[(1), float32], + Tensor[(1), float32], Tensor[(1), float32])) { + %0 = add(%x, %y); + %1 = add(%z, %w); + let %x1: Tensor[(1), float32] = multiply(%0, %1); + let %x2: Tensor[(1), float32] = ones_like(%x1); + let %x3: Tensor[(1), float32] = add(%x, %y); + let %x4: Tensor[(1), float32] = add(%z, %w); + %2 = zeros_like(%x3); + %3 = multiply(%x2, %x4); + %4 = collapse_sum_like(%3, %x3); + let %x5: Tensor[(1), float32] = add(%2, %4); + %5 = zeros_like(%x4); + %6 = multiply(%x2, %x3); + %7 = collapse_sum_like(%6, %x4); + let %x6: Tensor[(1), float32] = add(%5, %7); + %8 = zeros_like(%x); + %9 = collapse_sum_like(%x5, %x); + %10 = add(%8, %9); + %11 = zeros_like(%y); + %12 = collapse_sum_like(%x5, %y); + %13 = add(%11, %12); + %14 = zeros_like(%z); + %15 = collapse_sum_like(%x6, %z); + %16 = add(%14, %15); + %17 = zeros_like(%w); + %18 = collapse_sum_like(%x6, %w); + %19 = add(%17, %18); + %20 = (%10, %13, %16, %19); + (%x1, %20) + } + """ + ) + + relay.analysis.assert_alpha_equal(df, df_parsed) + +def test_checkpoint_alpha_equal_tuple(): + xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)] + f = relay.Function(xs, relay.annotation.checkpoint( + relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])]) + )) + df = transform.gradient(run_infer_type(f)) + + # run PE and DCE + with transform.PassContext(opt_level=3): + passes = [transform.PartialEvaluate(), + transform.DeadCodeElimination(inline_once=True)] + mod = transform.Sequential(passes)(relay.Module.from_expr(df)) + df = mod["main"] + + df_parsed = relay.parser.fromtext( + """ + v0.0.4 + fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32], + %z: Tensor[(1), float32], %w: Tensor[(1), float32]) + -> ((Tensor[(1), float32], Tensor[(1), float32]), + (Tensor[(1), float32], Tensor[(1), float32], + Tensor[(1), float32], Tensor[(1), float32])) { + let %x1: Tensor[(1), float32] = add(%x, %y) /* ty=Tensor[(1), float32] */; + let %x2: Tensor[(1), float32] = add(%z, %w) /* ty=Tensor[(1), float32] */; + let %x3: Tensor[(1), float32] = zeros_like(%x2) /* ty=Tensor[(1), float32] */; + let %x4: Tensor[(1), float32] = ones_like(%x1) /* ty=Tensor[(1), float32] */; + %0 = (%x1, %x2); + %1 = zeros_like(%x) /* ty=Tensor[(1), float32] */; + %2 = collapse_sum_like(%x4, %x) /* ty=Tensor[(1), float32] */; + %3 = add(%1, %2) /* ty=Tensor[(1), float32] */; + %4 = zeros_like(%y) /* ty=Tensor[(1), float32] */; + %5 = collapse_sum_like(%x4, %y) /* ty=Tensor[(1), float32] */; + %6 = add(%4, %5) /* ty=Tensor[(1), float32] */; + %7 = zeros_like(%z) /* ty=Tensor[(1), float32] */; + %8 = collapse_sum_like(%x3, %z) /* ty=Tensor[(1), float32] */; + %9 = add(%7, %8) /* ty=Tensor[(1), float32] */; + %10 = zeros_like(%w) /* ty=Tensor[(1), float32] */; + %11 = collapse_sum_like(%x3, %w) /* ty=Tensor[(1), float32] */; + %12 = add(%10, %11) /* ty=Tensor[(1), float32] */; + %13 = (%3, %6, %9, %12); + (%0, %13) + } + """ + ) + + relay.analysis.assert_alpha_equal(df, df_parsed) + def test_collapse_sum_like(): shape = (3, 4, 5, 6) shape_like = (4, 5, 6)