From 6201c5c95c84b40907811ed7613ea3e91ba93e5e Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Sun, 23 Aug 2020 13:52:29 +0000 Subject: [PATCH 1/8] save lint lint lint fix lint lint update lint save save save lint format format save save fix use a form more suitable for numeric check save --- include/tvm/relay/feature.h | 36 ++++++ include/tvm/relay/transform.h | 9 ++ python/tvm/relay/prelude.py | 2 + src/relay/analysis/feature.cc | 51 ++++++++- src/relay/transforms/gradient.cc | 55 ++++++--- src/relay/transforms/lazy_gradient_init.cc | 107 +++++++----------- src/relay/transforms/partial_eval.cc | 5 +- src/relay/transforms/pass_util.h | 3 +- src/relay/transforms/to_a_normal_form.cc | 58 +++++----- src/relay/transforms/to_cps.cc | 3 + tests/python/relay/test_analysis_feature.py | 2 - tests/python/relay/test_op_grad_level10.py | 14 ++- tests/python/relay/test_pass_gradient.py | 24 +++- .../relay/test_pass_lazy_gradient_init.py | 18 +++ .../python/relay/test_pass_merge_composite.py | 15 +-- tests/python/relay/test_pass_partial_eval.py | 25 +--- 16 files changed, 266 insertions(+), 161 deletions(-) diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index 3783e320f57c..92743e7deb5c 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -29,6 +29,7 @@ #include #include +#include namespace tvm { namespace relay { @@ -124,6 +125,13 @@ class FeatureSet { */ bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); } + /*! + * \brief Pretty Print the FeatureSet. + * + * \return a string representation. + */ + std::string Print() const; + private: std::bitset bs_; FeatureSet() = default; @@ -160,6 +168,34 @@ inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) { return DetectFeature(expr) + DetectFeature(mod); } +/*! + * \brief Check the feature of the program. + * + * \param expr The expression. + * \param fs The feature set of the program. + */ +void CheckFeature(const RelayExpr& expr, const FeatureSet& fs); + +/*! + * \brief Check the feature of the program. + * + * \param mod The module. + * \param fs The feature set of the program. + */ +void CheckFeature(const IRModule& mod, const FeatureSet& fs); + +/*! + * \brief Check the feature of the program. + * + * \param expr The expression. + * \param mod The module. + * \param fs The feature set of the program. + */ +inline void CheckFeature(const RelayExpr& expr, const IRModule& mod, const FeatureSet& fs) { + CheckFeature(expr, fs); + CheckFeature(mod, fs); +} + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index d322710ec95a..de2bcc4f4318 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -147,6 +147,15 @@ TVM_DLL Pass ToBasicBlockNormalForm(); */ TVM_DLL Pass ToANormalForm(); +/*! + * \brief ToANormalForm but on incomplete graph. + * + * \param expr the graph. + * + * \return The transformed program. + */ +TVM_DLL Expr ToANormalForm(const Expr& expr); + /*! * \brief Turn an expression into continuation passing style(CPS). * diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 1b7ed77e9b57..2675f1da88b0 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -17,6 +17,7 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" from tvm.ir import IRModule, TypeCall +from tvm import relay from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .expr import Var, GlobalVar, If, const @@ -1237,6 +1238,7 @@ def __init__(self, mod=None): mod = IRModule() self.mod = mod self.load_prelude() + self.mod = relay.transform.ToANormalForm()(self.mod) def get_name(self, canonical, dtype): """Get name corresponding to the canonical name""" diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index a145b28d55e8..df743c6ed678 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -86,12 +86,43 @@ FeatureSet DetectFeature(const Expr& expr) { return fd.fs; } +std::string FeatureSet::Print() const { + std::string ret; + ret += "["; + size_t detected = 0; +#define DETECT_FEATURE(FEATURE_NAME) \ + ++detected; \ + if (bs_[FEATURE_NAME]) { \ + ret += #FEATURE_NAME; \ + ret += ", "; \ + } + DETECT_FEATURE(fVar); + DETECT_FEATURE(fGlobalVar); + DETECT_FEATURE(fConstant); + DETECT_FEATURE(fTuple); + DETECT_FEATURE(fTupleGetItem); + DETECT_FEATURE(fFunction); + DETECT_FEATURE(fOp); + DETECT_FEATURE(fCall); + DETECT_FEATURE(fLet); + DETECT_FEATURE(fIf); + DETECT_FEATURE(fRefCreate); + DETECT_FEATURE(fRefRead); + DETECT_FEATURE(fRefWrite); + DETECT_FEATURE(fConstructor); + DETECT_FEATURE(fMatch); + DETECT_FEATURE(fGraph); + DETECT_FEATURE(fLetRec); +#undef DETECT_FEATURE + CHECK(detected == feature_count) << "some feature not printed"; + ret += "]"; + return ret; +} + FeatureSet DetectFeature(const IRModule& mod) { FeatureSet fs = FeatureSet::No(); - if (mod.defined()) { - for (const auto& f : mod->functions) { - fs += DetectFeature(f.second); - } + for (const auto& f : mod->functions) { + fs += DetectFeature(f.second); } return fs; } @@ -106,5 +137,17 @@ Array PyDetectFeature(const Expr& expr, const Optional& mod) TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature); +void CheckFeature(const Expr& expr, const FeatureSet& fs) { + auto dfs = DetectFeature(expr); + CHECK(dfs.is_subset_of(fs)) << AsText(expr, false) + << "\nhas unsupported feature: " << (dfs - fs).Print(); +} + +void CheckFeature(const IRModule& mod, const FeatureSet& fs) { + for (const auto& f : mod->functions) { + CheckFeature(f.second, fs); + } +} + } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 0cebba72c375..b31a1c57a761 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -81,7 +82,7 @@ Type WithGradientType(const Type& t) { Expr DeGlobal(const Optional& mod, const Expr& e) { const auto* x = e.as(); - if (mod.defined() && (x)) { + if (mod.defined() && x) { BaseFunc base_func = mod.value()->Lookup(GetRef(x)); if (auto* n = base_func.as()) { return n->body; @@ -354,7 +355,7 @@ Expr LiftTensor(const std::function& f, LetList* ll) { CHECK(IsAtomic(e)) << e; if (forward_type.as()) { - auto ret = f(e); + auto ret = ll->Push(f(e)); ret->checked_type_ = tf(forward_type); return ret; } else if (auto* tt = forward_type.as()) { @@ -365,9 +366,9 @@ Expr LiftTensor(const std::function& f, fields.push_back(field); types.push_back(field->checked_type_); } - auto ret = Tuple(fields); + auto ret = ll->Push(Tuple(fields)); ret->checked_type_ = TupleType(types); - return std::move(ret); + return ret; } else { LOG(FATAL) << "unsupported input/output type: " << tt; throw; @@ -395,9 +396,10 @@ void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, L } } +// TODO(@M.K.): why take Expr? /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */ Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) { - auto rev = [&](const Expr& e) { return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); }; + auto rev = [&](const Expr& e) { return Pair(e, RefCreate(ZerosLike(e))); }; auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); }; return LiftTensor(rev, rev_type, forward_type, e, ll); } @@ -411,14 +413,14 @@ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) { /*! \brief ReverseType(t) -> t. Get the gradient. */ Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) { - auto grad = [&](const Expr& e) { return ll->Push(RefRead(GetField(e, 1))); }; + auto grad = [&](const Expr& e) { return RefRead(GetField(e, 1)); }; auto grad_type = [&](const Type& forward_type) { return forward_type; }; return LiftTensor(grad, grad_type, forward_type, e, ll); } void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { if (t.as()) { - ll->Push(RefWrite(GetField(arg, 1), Add(ll->Push(RefRead(GetField(arg, 1))), grad))); + ll->Push(RefWrite(GetField(arg, 1), Add(RefRead(GetField(arg, 1)), grad))); } else if (auto* tt = t.as()) { for (size_t i = 0; i < tt->fields.size(); ++i) { UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll); @@ -448,6 +450,24 @@ struct ReverseAD : ExprMutator { throw; } + Expr Remap(const Expr& e) { + struct Remapper : ExprMutator { + std::shared_ptr ad_vars; + LetList* ll; + Remapper(const std::shared_ptr& ad_vars, LetList* ll) : ad_vars(ad_vars), ll(ll) {} + Expr VisitExpr_(const VarNode* var) final { + // memoize Var -> ADVar so we don't end up with free Vars when checkpointing + auto var_ref = GetRef(var); + if (ad_vars->count(var_ref) == 0) { + return var_ref; + } else { + return GetValue(var_ref->checked_type(), ad_vars->at(var_ref), ll); + } + } + }; + return LetList::With([&](LetList* ll) { return Remapper(ad_vars, ll)(e); }); + } + Expr VisitCheckpoint(const CallNode* call) { const OpNode* op_node = call->op.as(); CHECK(op_node) << "expected op in call"; @@ -455,7 +475,7 @@ struct ReverseAD : ExprMutator { CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation"; auto x = call->args[0]; return LetList::With([&](LetList* ll) { - auto x_var = ll->Push(x); + auto x_var = ll->Push(Remap(x)); auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll)); auto bpv = ll->Push(RefRead(bp)); Expr nbp = Function({}, LetList::With([&](LetList* ll) { @@ -508,7 +528,8 @@ struct ReverseAD : ExprMutator { return Call(bpv, {}); }), TupleType::Empty(), {}); - ll->Push(RefWrite(bp, nbp)); + ll->Push(RefWrite(bp, transform::ToANormalForm(nbp))); + // TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that. return ret; }); } @@ -516,8 +537,10 @@ struct ReverseAD : ExprMutator { } Expr VisitExpr_(const ConstantNode* op) final { - Expr e = GetRef(op); - return Pair(e, RefCreate(ZerosLike(e))); + return LetList::With([&](LetList* ll) { + Expr e = ll->Push(GetRef(op)); + return Pair(e, RefCreate(ZerosLike(e))); + }); } Expr VisitExpr_(const IfNode* op) final { @@ -528,7 +551,7 @@ struct ReverseAD : ExprMutator { Expr VisitExpr_(const VarNode* var) final { // memoize Var -> ADVar so we don't end up with free Vars when checkpointing auto var_ref = GetRef(var); - if (!ad_vars->count(var_ref)) { + if (ad_vars->count(var_ref) == 0) { auto res = Downcast(ExprMutator::VisitExpr_(var)); (*ad_vars)[var_ref] = res; } @@ -568,6 +591,10 @@ bool MissingGrad(const Expr& e) { } Expr Gradient(const Expr& re, const Optional& mod) { + CheckFeature(re, FeatureSet::All() - fGraph); + if (mod.defined()) { + CheckFeature(mod.value(), FeatureSet::All() - fGraph); + } auto e = DeGlobal(mod, re); auto f = e.as(); CHECK(f) << "input need to be a function"; @@ -619,7 +646,9 @@ Expr Gradient(const Expr& re, const Optional& mod) { }; return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret)); }); - return Function(f->params, body, GradRetType(GetRef(f)), {}); + auto ret = Function(f->params, body, GradRetType(GetRef(f)), {}); + CheckFeature(ret, FeatureSet::All() - fGraph); + return ret; } TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient); diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index f06246667a8b..de9406ec309d 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -63,6 +63,7 @@ #include #include #include +#include #include #include "let_list.h" @@ -70,92 +71,54 @@ namespace tvm { namespace relay { -/*! - * \brief Visitor appropriately wraps tensors with Raw constructor - * - * Recursively looks at the type of the expression (TensorType or TupleType are only supported for - * now) and either call the GradCell constructor if TensorType or unfold and recursively visit if - * TupleType - */ -class InputVisitor : public ExprFunctor { +class LazyGradientInitializer : public ExprMutator, public TypeMutator { public: - explicit InputVisitor(IRModule module) : module_(module) {} - - Expr VisitExpr_(const VarNode* op, const Type& t) final { - std::cout << op->type_annotation << std::endl; - return WrapExpr(GetRef(op), op->type_annotation); - } - - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { - return WrapExpr(GetRef(op), t); + explicit LazyGradientInitializer(IRModule module) : module_(module) { + module_->ImportFromStd("gradient.rly"); } - private: - IRModule module_; - - Expr WrapExpr(const Expr expr, const Type& type) { + Expr WrapExpr(const Var& var, const Type& type, LetList* ll) { if (type.as()) { - return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); + return Call(module_->GetConstructor("GradCell", "Raw"), {var}, Attrs(), {type}); } else if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { const Type& t = type_anno->fields[i]; - fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); + fields.push_back(WrapExpr(ll->Push(TupleGetItem(var, i)), t, ll)); } Expr tuple = Tuple(fields); return tuple; } - return expr; - } -}; - -/*! - * \brief Visitor appropriately unwraps expressions with GradCell type into Tensors - * - * Recursively looks at the type of the expression - * and either use the FromGradCell function if TypeCall to GradCell - * or unfold and recursively visit if TupleType - */ -class OutputVisitor : public ExprFunctor { - public: - explicit OutputVisitor(IRModule module) : module_(module) {} - - Expr VisitExpr_(const CallNode* op, const Type& t) final { - return UnwrapExpr(GetRef(op), t); + return var; } - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { - return UnwrapExpr(GetRef(op), t); - } - - private: - IRModule module_; - - Expr UnwrapExpr(const Expr expr, const Type& type) { + Expr UnwrapExpr(const Var& var, const Type& type, LetList* ll) { if (auto* type_call = type.as()) { if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { - return Call(module_->GetGlobalVar("FromGradCell"), {expr}); + return Call(module_->GetGlobalVar("FromGradCell"), {var}); } - return expr; + return var; } else if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { const Type& t = type_anno->fields[i]; - fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); + fields.push_back(UnwrapExpr(ll->Push(TupleGetItem(var, i)), t, ll)); } Expr tuple = Tuple(fields); return tuple; } - return expr; + return var; } -}; -class LazyGradientInitializer : public ExprMutator, public TypeMutator { - public: - explicit LazyGradientInitializer(IRModule module) : module_(module) { - module_->ImportFromStd("gradient.rly"); + // Turn off memo for constant node. + Expr VisitExpr(const Expr& e) final { + if (e.as()) { + return ExprFunctor::VisitExpr(e); + } else { + return ExprMutator::VisitExpr(e); + } } /*! @@ -165,23 +128,26 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator { * input/output types should only be a combination of TupleTypes and TensorTypes */ Expr Transform(const Expr& e) { - auto* f = (e).as(); + auto* f = e.as(); auto* transformed = this->Mutate(e).as(); + CHECK(f); + CHECK(transformed); + if (e.same_as(GetRef(transformed))) { return GetRef(transformed); } - // wrap inputs of Tensor type using InputVisitor class - tvm::Array args; - for (Var var : f->params) { - Expr wrappedInput = InputVisitor(module_).VisitExpr(var, var->checked_type()); - args.push_back(wrappedInput); - } - Expr transformedExpr = Call(GetRef(transformed), args); - - // unwrap outputs of GradCell type into Tensor type using OutputVisitor class - Expr tensorOutput = OutputVisitor(module_).VisitExpr(transformedExpr, transformed->ret_type); + auto tensorOutput = LetList::With([&](LetList* ll) { + // wrap inputs of Tensor type using InputVisitor class + tvm::Array args; + for (const Var& var : f->params) { + args.push_back(WrapExpr(var, var->checked_type(), ll)); + } + Expr transformedExpr = Call(GetRef(transformed), args); + // unwrap outputs of GradCell type into Tensor type using OutputVisitor class + return UnwrapExpr(ll->Push(transformedExpr), transformed->ret_type, ll); + }); return Function(f->params, tensorOutput, f->ret_type, Array()); } @@ -293,7 +259,10 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator { }; Expr LazyGradientInit(const Expr& e, IRModule mod) { - return LazyGradientInitializer(mod).Transform(e); + CheckFeature(e, mod, FeatureSet::All() - fGraph); + auto ret = LazyGradientInitializer(mod).Transform(e); + CheckFeature(ret, mod, FeatureSet::All() - fGraph); + return ret; } namespace transform { diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 63bd04d526de..e07dbea59bd1 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -92,6 +92,7 @@ #include #include #include +#include #include #include #include @@ -1181,6 +1182,7 @@ Expr PostProcess(const Expr& e) { return StripWithFuncId(DeDup(Remap(e))); } } // namespace partial_eval IRModule PartialEval(const IRModule& m) { + CheckFeature(m, FeatureSet::All() - fGraph); relay::partial_eval::PartialEvaluator pe(m); std::vector gvs; for (const auto& p : m->functions) { @@ -1189,6 +1191,7 @@ IRModule PartialEval(const IRModule& m) { for (const auto& gv : gvs) { pe.VisitGlobalVar(gv); } + CheckFeature(m, FeatureSet::All() - fGraph); return m; } @@ -1197,7 +1200,7 @@ namespace transform { Pass PartialEval() { runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext pc) { return relay::PartialEval(m); }; - return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); + return CreateModulePass(pass_func, 1, "PartialEval", {}); } TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval); diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index 50d0fbb5f17b..63708c45bfe3 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -117,7 +117,8 @@ inline Expr TransformF(const std::function& func, const Expr& * if so, the compute cost of the expression is bounded so it can be copy without graph mode. */ inline bool IsAtomic(const Expr& e) { - return e.as() || e.as() || e.as() || e.as(); + return e.as() || e.as() || e.as() || e.as() || + e.as(); // Constant is always by reference. } /*! diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 06e0d56e1919..367b491382c3 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -252,32 +252,6 @@ Expr Fill::VisitExpr_(const MatchNode* m, const Var& v) { return Compound(e, Match(data, clauses, m->complete), v); } -Expr ToANormalFormAux(const Expr& e) { - /* When you lift a lambda, what is inside is also being lift. - * - * So we must determine the scope of the lambda before determining the scope of it's body. - * - * To make this more principled, - * we always determine the scope of parent before determining the scope of children. - * - * So we calculate all the dependency between nodes. - */ - support::Arena arena; - DependencyGraph dg = DependencyGraph::Create(&arena, e); - /* In order to model new subscopes created by lambda, if else and pattern matching, - * we also assign scope to edge as well. - * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope. - * - * So, the scope of the whole expr is global. - * The scope of any subexpr, is the lowest common ancestor of all incoming edge. - * - * Every scope additionally contain a LetList which collect all value of that scope. - * We do an additional pass to fill all the LetList and we are done. - */ - std::pair scopes = CalcScope(dg); - return Fill::ToANormalForm(e, dg, &scopes.first); -} - IRModule ToANormalForm(const IRModule& m) { DLOG(INFO) << "ToANF:" << std::endl << m; @@ -288,7 +262,7 @@ IRModule ToANormalForm(const IRModule& m) { if (const auto* n = it.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; } - Expr ret = TransformF([&](const Expr& e) { return ToANormalFormAux(e); }, it.second); + Expr ret = TransformF([&](const Expr& e) { return transform::ToANormalForm(e); }, it.second); CHECK_EQ(FreeVars(ret).size(), 0) << AsText(ret) << "should not has free vars: " << FreeVars(ret); updates.Set(it.first, Downcast(ret)); @@ -305,13 +279,41 @@ IRModule ToANormalForm(const IRModule& m) { namespace transform { +Expr ToANormalForm(const Expr& e) { + /* When you lift a lambda, what is inside is also being lift. + * + * So we must determine the scope of the lambda before determining the scope of it's body. + * + * To make this more principled, + * we always determine the scope of parent before determining the scope of children. + * + * So we calculate all the dependency between nodes. + */ + support::Arena arena; + DependencyGraph dg = DependencyGraph::Create(&arena, e); + /* In order to model new subscopes created by lambda, if else and pattern matching, + * we also assign scope to edge as well. + * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope. + * + * So, the scope of the whole expr is global. + * The scope of any subexpr, is the lowest common ancestor of all incoming edge. + * + * Every scope additionally contain a LetList which collect all value of that scope. + * We do an additional pass to fill all the LetList and we are done. + */ + std::pair scopes = CalcScope(dg); + return Fill::ToANormalForm(e, dg, &scopes.first); +} + Pass ToANormalForm() { runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext pc) { return relay::ToANormalForm(m); }; return CreateModulePass(pass_func, 1, "ToANormalForm", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed(ToANormalForm); +TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed([]() { + return ToANormalForm(); +}); } // namespace transform diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index 6972d5a76b77..7c11ce5d4cd9 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -52,6 +52,7 @@ */ #include #include +#include #include #include @@ -301,11 +302,13 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { } Function ToCPS(const Function& f, const IRModule& m) { + CheckFeature(f, m, FeatureSet::All() - fGraph); CPSMap cps; return ToCPS(f, m, &cps); } Function UnCPS(const Function& f) { + CheckFeature(f, FeatureSet::All() - fGraph); CHECK_GT(f->params.size(), 0); std::vector new_params; for (const auto& p : f->params) { diff --git a/tests/python/relay/test_analysis_feature.py b/tests/python/relay/test_analysis_feature.py index ec5deb3c4e60..2b32376a9515 100644 --- a/tests/python/relay/test_analysis_feature.py +++ b/tests/python/relay/test_analysis_feature.py @@ -39,7 +39,6 @@ def test_prelude(): Feature.fIf, Feature.fConstructor, Feature.fMatch, - Feature.fGraph ]) @@ -65,7 +64,6 @@ def test_ad(): Feature.fRefCreate, Feature.fRefRead, Feature.fRefWrite, - Feature.fGraph ]) diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index 2c749c934149..b8624b46eca8 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -32,17 +32,21 @@ def test_cross_entropy_with_logits_grad(): x = relay.var("x", shape=(2, 5), dtype=dtype) y = relay.var("y", shape=(2, 5), dtype=dtype) check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1) - + + def test_checkpoint(): inputs = [relay.var("x{}".format(i), shape=(1,)) for i in range(4)] output = relay.multiply(relay.add(inputs[0], inputs[1]), relay.add(inputs[2], inputs[3])) check_grad(relay.Function(inputs, relay.annotation.checkpoint(output))) - out_tuple = relay.Tuple([relay.add(inputs[0], inputs[1]), - relay.multiply(inputs[2], inputs[3])]) - out_single = relay.subtract(relay.TupleGetItem(relay.annotation.checkpoint(out_tuple), 0), - relay.TupleGetItem(out_tuple, 1)) + scope = relay.ScopeBuilder() + out_tuple = scope.let("out_tuple", + relay.Tuple([relay.add(inputs[0], inputs[1]), + relay.multiply(inputs[2], inputs[3])])) + scope.ret(relay.subtract(relay.annotation.checkpoint(relay.TupleGetItem(out_tuple, 0)), + relay.TupleGetItem(out_tuple, 1))) + out_single = scope.get() check_grad(relay.Function(inputs, out_single)) diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 4838c6a4e7fc..296d3e5e9354 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -45,6 +45,18 @@ def test_id(): tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) +def test_relu(): + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + func = relay.Function([x], op.nn.relu(x)) + func = run_infer_type(func) + back_func = run_infer_type(gradient(func)) + assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) + # gradient will implicitly check that no graph appear in result + + def test_add(): shape = (10, 10) dtype = 'float32' @@ -72,12 +84,14 @@ def test_check_grad(): def test_temp_add(): + scope = relay.ScopeBuilder() shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) - y = x + x - func = relay.Function([x], y + y) + y = scope.let("y", x + x) + scope.ret(y + y) + func = relay.Function([x], scope.get()) func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) @@ -280,12 +294,14 @@ def test_if(): def test_grad_tuple(): + scope = relay.ScopeBuilder() shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) - y = x + x - func = relay.Function([x], relay.Tuple([y + y, y])) + y = scope.let("y", x + x) + scope.ret(relay.Tuple([y + y, y])) + func = relay.Function([x], scope.get()) func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([relay.TupleType([t, t]), relay.TupleType([t])])) diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py index 414926802870..377164e08b73 100644 --- a/tests/python/relay/test_pass_lazy_gradient_init.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -229,6 +229,24 @@ def test_multivar_reverse_ad(): assert_allclose(grad_x.asnumpy(), y.asnumpy()) assert_allclose(grad_y.asnumpy(), x.asnumpy()) +def test_partial_eval(): + """Test transformation following reverse mode ad and PartialEval""" + mod = tvm.IRModule() + + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + + func = relay.Function([], relay.const(np.ones(shape, dtype))) + func = run_infer_type(func) + back_func = transform.gradient(func) + back_func = run_infer_type(back_func) + + mod["main"] = back_func + back_func = mod["main"] + + transform.PartialEvaluate()(mod) + def test_after_partial_eval(): """Test transformation following reverse mode ad and PartialEval""" mod = tvm.IRModule() diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index ddb5b5dab675..e6f311dadcdc 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Unit tests for merge composite.""" +import pytest import tvm from tvm import relay, tir from tvm.relay.dataflow_pattern import TupleGetItemPattern, is_op, wildcard @@ -213,7 +214,7 @@ def expected(): r = relay.Call(add_relu, [a, b]) return relay.Function([a, b], r) - check_result(pattern_table, before(), expected(), import_prelude=True) + check_result(pattern_table, before(), expected()) def test_branch_merge(): @@ -999,14 +1000,4 @@ def _check_type_false(extract): if __name__ == "__main__": test_simple_merge() - test_branch_merge() - test_multiple_patterns() - test_optional_pattern() - test_merge_order() - test_parallel_merge() - test_multiple_input_subgraphs() - test_reuse_call_merge() - test_tuple_get_item_merge() - test_pattern_with_check() - test_diamond_not_merge() - test_type_check() + #pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index 45593b43ecb1..95805d285b59 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import pytest import numpy as np import tvm from tvm import te @@ -173,10 +174,9 @@ def test_function_invalidate(): def test_head_cons(): mod = tvm.IRModule() p = Prelude(mod) - hd = p.hd t = TypeVar("t") x = Var("x", t) - body = hd(p.cons(x, p.nil())) + body = p.hd(p.cons(x, p.nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) assert tvm.ir.structural_equal(res, Function([x], x, t, [t])) @@ -340,23 +340,4 @@ def test_tuple_match(): if __name__ == '__main__': - test_nat_update() - test_ref() - test_tuple() - test_empty_ad() - test_const_inline() - test_ad() - test_if_ref() - test_function_invalidate() - test_head_cons() - test_map() - test_loop() - test_swap_loop() - test_abs_diff() - test_double() - test_nat_id() - test_global_match_nat_id() - test_match_nat_id() - test_concat() - test_triangle_number() - test_tuple_match() + pytest.main([__file__]) From fff05bdff257752223e816bacb3dd1845afd27e7 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 24 Aug 2020 12:29:02 +0000 Subject: [PATCH 2/8] save --- src/relay/transforms/gradient.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index b31a1c57a761..7894c34de55d 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -357,7 +357,7 @@ Expr LiftTensor(const std::function& f, if (forward_type.as()) { auto ret = ll->Push(f(e)); ret->checked_type_ = tf(forward_type); - return ret; + return std::move(ret); } else if (auto* tt = forward_type.as()) { tvm::Array fields; tvm::Array types; @@ -368,7 +368,7 @@ Expr LiftTensor(const std::function& f, } auto ret = ll->Push(Tuple(fields)); ret->checked_type_ = TupleType(types); - return ret; + return std::move(ret); } else { LOG(FATAL) << "unsupported input/output type: " << tt; throw; @@ -459,7 +459,7 @@ struct ReverseAD : ExprMutator { // memoize Var -> ADVar so we don't end up with free Vars when checkpointing auto var_ref = GetRef(var); if (ad_vars->count(var_ref) == 0) { - return var_ref; + return std::move(var_ref); } else { return GetValue(var_ref->checked_type(), ad_vars->at(var_ref), ll); } @@ -648,7 +648,7 @@ Expr Gradient(const Expr& re, const Optional& mod) { }); auto ret = Function(f->params, body, GradRetType(GetRef(f)), {}); CheckFeature(ret, FeatureSet::All() - fGraph); - return ret; + return std::move(ret); } TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient); From e66ec39adce6cfad9a104ed34c02203826d7993e Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 24 Aug 2020 19:34:29 +0000 Subject: [PATCH 3/8] save --- python/tvm/relay/prelude.py | 17 +++++++++-------- python/tvm/relay/transform/transform.py | 17 ++++++++++++++++- src/relay/transforms/to_a_normal_form.cc | 4 ++++ 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 2675f1da88b0..b0a64aa4a4af 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -27,6 +27,7 @@ from .adt import PatternConstructor, PatternVar, PatternWildcard from . import op, transform from .analysis import free_vars +from tvm.relay.transform import ToANormalFormExpr def get_tensor_array_shape(expr, dtype, prelude): """Get the static shape of a tensor array if it has fixed rank shape. @@ -205,7 +206,6 @@ def define_tensor_concatenate(self): self.prelude.mod[concat_var] = \ Function([x, y], Match(x, [case], False), tensor_type_var(), []) - def define_tensor_expand_dims(self): """Defines a function to grow a tensor_t's rank by adding one dimension in front of the original tensor_t. @@ -512,8 +512,9 @@ def define_tensor_array_stack(self): self.prelude.hd(tensor_array_expand_dims), self.prelude.tl(tensor_array_expand_dims)) output_tensor_type_var, _ = self._get_adt_by_shape(output_shape) - self.prelude.mod[stack_var] = Function([tensor_array], tensors, - output_tensor_type_var(), []) + self.prelude.mod[stack_var] = \ + Function([tensor_array], tensors, + output_tensor_type_var(), []) def define_tensor_array_gather(self): """Defines a function to return the selected values in a tensor array as tensor_t. @@ -810,7 +811,7 @@ def define_tensor_concat(self): tensor4_var(op.concatenate([t41, t42], axis=0)))], False)) # op.concatenate does not support tensor with rank higher than 4 - self.prelude.mod[concat_var] =\ + self.prelude.mod[concat_var] = \ Function([x, y], Match(x, [tensor1_case, tensor2_case, tensor3_case, @@ -1168,7 +1169,7 @@ def define_tensor_array_gather(self): current = Var("current", scalar_type('int32')) limit = Var("limit", scalar_type('int32')) indices_ = Var('indices_', TensorType([Any()], 'int32')) - helper_body =\ + helper_body = \ If(equal(current, const(0)), stack_var(accu), helper_var( @@ -1188,7 +1189,7 @@ def define_tensor_array_gather(self): indices_shape = op.shape_of(indices) limit = op.take(indices_shape, const(0)) body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices) - self.prelude.mod[gather_var] =\ + self.prelude.mod[gather_var] = \ Function([tensor_array, indices], body, tensor_type_var(), []) def define_tensor_array_stack(self): @@ -1206,7 +1207,8 @@ def define_tensor_array_stack(self): tensors = self.prelude.foldl(concat_var, self.prelude.hd(tensor_array_expand_dims), self.prelude.tl(tensor_array_expand_dims)) - self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), []) + self.prelude.mod[stack_var] = \ + ToANormalFormExpr(Function([tensor_array], tensors, tensor_type_var(), [])) def register(self): """Register all tensor array ops in Prelude""" @@ -1238,7 +1240,6 @@ def __init__(self, mod=None): mod = IRModule() self.mod = mod self.load_prelude() - self.mod = relay.transform.ToANormalForm()(self.mod) def get_name(self, canonical, dtype): """Get name corresponding to the canonical name""" diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index cc92141b73db..de3f9861c96e 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -510,11 +510,26 @@ def ToANormalForm(): Returns ------- - ret: Union[tvm.transform.Pass, tvm.relay.Expr] + ret : Union[tvm.transform.Pass, tvm.relay.Expr] The registered pass that transforms an expression into A Normal Form. """ return _ffi_api.ToANormalForm() +def ToANormalFormExpr(e): + """ToANormalForm, but on expression level. + + Parameters + ---------- + e : Expr + The graph expression. + + Returns + ------- + ret : Expr + The transformed expresion. + """ + return _ffi_api.ToANormalFormExpr(e) + def ToBasicBlockNormalForm(): """Turn an expression to Basic Block Normal Form. We define a block as a group of expressions implied by the scope structure. diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 367b491382c3..adb757b9de0c 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -315,6 +315,10 @@ TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed([]() { return ToANormalForm(); }); +TVM_REGISTER_GLOBAL("relay._transform.ToANormalFormExpr").set_body_typed([](const Expr& e) { + return ToANormalForm(e); +}); + } // namespace transform } // namespace relay From f822baf3a2bd79b41fd577656fa2f01d4e8e04bb Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 24 Aug 2020 19:37:35 +0000 Subject: [PATCH 4/8] lint --- python/tvm/relay/prelude.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index b0a64aa4a4af..073454743952 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -18,6 +18,7 @@ """A prelude containing useful global functions and ADT definitions.""" from tvm.ir import IRModule, TypeCall from tvm import relay +from tvm.relay.transform import ToANormalFormExpr from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .expr import Var, GlobalVar, If, const @@ -27,7 +28,6 @@ from .adt import PatternConstructor, PatternVar, PatternWildcard from . import op, transform from .analysis import free_vars -from tvm.relay.transform import ToANormalFormExpr def get_tensor_array_shape(expr, dtype, prelude): """Get the static shape of a tensor array if it has fixed rank shape. From 39f15310129b7db3a7f2acdaf411586d9e199a28 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 24 Aug 2020 19:39:21 +0000 Subject: [PATCH 5/8] save --- include/tvm/relay/feature.h | 6 ++---- src/relay/analysis/feature.cc | 2 +- tests/python/relay/test_pass_merge_composite.py | 3 +-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index 92743e7deb5c..7df881938f50 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -126,11 +126,9 @@ class FeatureSet { bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); } /*! - * \brief Pretty Print the FeatureSet. - * - * \return a string representation. + * \brief return a string representation. */ - std::string Print() const; + std::string ToString() const; private: std::bitset bs_; diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index df743c6ed678..6d9b888cc8f4 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -86,7 +86,7 @@ FeatureSet DetectFeature(const Expr& expr) { return fd.fs; } -std::string FeatureSet::Print() const { +std::string FeatureSet::ToString() const { std::string ret; ret += "["; size_t detected = 0; diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index e6f311dadcdc..aef6ab5afa96 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -999,5 +999,4 @@ def _check_type_false(extract): if __name__ == "__main__": - test_simple_merge() - #pytest.main([__file__]) + pytest.main([__file__]) From b16dfc66b4e04de8ef4bb85debb17f0970bced89 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 24 Aug 2020 19:43:13 +0000 Subject: [PATCH 6/8] lint --- python/tvm/relay/prelude.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 073454743952..0cf824130094 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -17,7 +17,6 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" from tvm.ir import IRModule, TypeCall -from tvm import relay from tvm.relay.transform import ToANormalFormExpr from .ty import GlobalTypeVar, TensorType, Any, scalar_type From a6efb72dfe806de08f2ae28f0a831f9a436a1d87 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 24 Aug 2020 20:07:21 +0000 Subject: [PATCH 7/8] fix --- src/relay/analysis/feature.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index 6d9b888cc8f4..63f5e711bfcd 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -140,7 +140,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeat void CheckFeature(const Expr& expr, const FeatureSet& fs) { auto dfs = DetectFeature(expr); CHECK(dfs.is_subset_of(fs)) << AsText(expr, false) - << "\nhas unsupported feature: " << (dfs - fs).Print(); + << "\nhas unsupported feature: " << (dfs - fs).ToString(); } void CheckFeature(const IRModule& mod, const FeatureSet& fs) { From 1ba13fc78d7f87dfa2da3f9799d7fea62e8d69f3 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Mon, 24 Aug 2020 21:49:49 +0000 Subject: [PATCH 8/8] fix --- python/tvm/relay/prelude.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 0cf824130094..893c855f9585 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -1207,7 +1207,7 @@ def define_tensor_array_stack(self): self.prelude.hd(tensor_array_expand_dims), self.prelude.tl(tensor_array_expand_dims)) self.prelude.mod[stack_var] = \ - ToANormalFormExpr(Function([tensor_array], tensors, tensor_type_var(), [])) + Function([tensor_array], ToANormalFormExpr(tensors), tensor_type_var(), []) def register(self): """Register all tensor array ops in Prelude"""