diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index 3783e320f57c4..7df881938f504 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -29,6 +29,7 @@ #include #include +#include namespace tvm { namespace relay { @@ -124,6 +125,11 @@ class FeatureSet { */ bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); } + /*! + * \brief return a string representation. + */ + std::string ToString() const; + private: std::bitset bs_; FeatureSet() = default; @@ -160,6 +166,34 @@ inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) { return DetectFeature(expr) + DetectFeature(mod); } +/*! + * \brief Check the feature of the program. + * + * \param expr The expression. + * \param fs The feature set of the program. + */ +void CheckFeature(const RelayExpr& expr, const FeatureSet& fs); + +/*! + * \brief Check the feature of the program. + * + * \param mod The module. + * \param fs The feature set of the program. + */ +void CheckFeature(const IRModule& mod, const FeatureSet& fs); + +/*! + * \brief Check the feature of the program. + * + * \param expr The expression. + * \param mod The module. + * \param fs The feature set of the program. + */ +inline void CheckFeature(const RelayExpr& expr, const IRModule& mod, const FeatureSet& fs) { + CheckFeature(expr, fs); + CheckFeature(mod, fs); +} + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index d322710ec95a3..de2bcc4f4318c 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -147,6 +147,15 @@ TVM_DLL Pass ToBasicBlockNormalForm(); */ TVM_DLL Pass ToANormalForm(); +/*! + * \brief ToANormalForm but on incomplete graph. + * + * \param expr the graph. + * + * \return The transformed program. + */ +TVM_DLL Expr ToANormalForm(const Expr& expr); + /*! * \brief Turn an expression into continuation passing style(CPS). * diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 1b7ed77e9b576..893c855f95856 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -17,6 +17,7 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" from tvm.ir import IRModule, TypeCall +from tvm.relay.transform import ToANormalFormExpr from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .expr import Var, GlobalVar, If, const @@ -204,7 +205,6 @@ def define_tensor_concatenate(self): self.prelude.mod[concat_var] = \ Function([x, y], Match(x, [case], False), tensor_type_var(), []) - def define_tensor_expand_dims(self): """Defines a function to grow a tensor_t's rank by adding one dimension in front of the original tensor_t. @@ -511,8 +511,9 @@ def define_tensor_array_stack(self): self.prelude.hd(tensor_array_expand_dims), self.prelude.tl(tensor_array_expand_dims)) output_tensor_type_var, _ = self._get_adt_by_shape(output_shape) - self.prelude.mod[stack_var] = Function([tensor_array], tensors, - output_tensor_type_var(), []) + self.prelude.mod[stack_var] = \ + Function([tensor_array], tensors, + output_tensor_type_var(), []) def define_tensor_array_gather(self): """Defines a function to return the selected values in a tensor array as tensor_t. @@ -809,7 +810,7 @@ def define_tensor_concat(self): tensor4_var(op.concatenate([t41, t42], axis=0)))], False)) # op.concatenate does not support tensor with rank higher than 4 - self.prelude.mod[concat_var] =\ + self.prelude.mod[concat_var] = \ Function([x, y], Match(x, [tensor1_case, tensor2_case, tensor3_case, @@ -1167,7 +1168,7 @@ def define_tensor_array_gather(self): current = Var("current", scalar_type('int32')) limit = Var("limit", scalar_type('int32')) indices_ = Var('indices_', TensorType([Any()], 'int32')) - helper_body =\ + helper_body = \ If(equal(current, const(0)), stack_var(accu), helper_var( @@ -1187,7 +1188,7 @@ def define_tensor_array_gather(self): indices_shape = op.shape_of(indices) limit = op.take(indices_shape, const(0)) body = helper_var(tensor_array, self.prelude.nil(), limit, limit, indices) - self.prelude.mod[gather_var] =\ + self.prelude.mod[gather_var] = \ Function([tensor_array, indices], body, tensor_type_var(), []) def define_tensor_array_stack(self): @@ -1205,7 +1206,8 @@ def define_tensor_array_stack(self): tensors = self.prelude.foldl(concat_var, self.prelude.hd(tensor_array_expand_dims), self.prelude.tl(tensor_array_expand_dims)) - self.prelude.mod[stack_var] = Function([tensor_array], tensors, tensor_type_var(), []) + self.prelude.mod[stack_var] = \ + Function([tensor_array], ToANormalFormExpr(tensors), tensor_type_var(), []) def register(self): """Register all tensor array ops in Prelude""" diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index cc92141b73db4..de3f9861c96ec 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -510,11 +510,26 @@ def ToANormalForm(): Returns ------- - ret: Union[tvm.transform.Pass, tvm.relay.Expr] + ret : Union[tvm.transform.Pass, tvm.relay.Expr] The registered pass that transforms an expression into A Normal Form. """ return _ffi_api.ToANormalForm() +def ToANormalFormExpr(e): + """ToANormalForm, but on expression level. + + Parameters + ---------- + e : Expr + The graph expression. + + Returns + ------- + ret : Expr + The transformed expresion. + """ + return _ffi_api.ToANormalFormExpr(e) + def ToBasicBlockNormalForm(): """Turn an expression to Basic Block Normal Form. We define a block as a group of expressions implied by the scope structure. diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index a145b28d55e8f..63f5e711bfcd2 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -86,12 +86,43 @@ FeatureSet DetectFeature(const Expr& expr) { return fd.fs; } +std::string FeatureSet::ToString() const { + std::string ret; + ret += "["; + size_t detected = 0; +#define DETECT_FEATURE(FEATURE_NAME) \ + ++detected; \ + if (bs_[FEATURE_NAME]) { \ + ret += #FEATURE_NAME; \ + ret += ", "; \ + } + DETECT_FEATURE(fVar); + DETECT_FEATURE(fGlobalVar); + DETECT_FEATURE(fConstant); + DETECT_FEATURE(fTuple); + DETECT_FEATURE(fTupleGetItem); + DETECT_FEATURE(fFunction); + DETECT_FEATURE(fOp); + DETECT_FEATURE(fCall); + DETECT_FEATURE(fLet); + DETECT_FEATURE(fIf); + DETECT_FEATURE(fRefCreate); + DETECT_FEATURE(fRefRead); + DETECT_FEATURE(fRefWrite); + DETECT_FEATURE(fConstructor); + DETECT_FEATURE(fMatch); + DETECT_FEATURE(fGraph); + DETECT_FEATURE(fLetRec); +#undef DETECT_FEATURE + CHECK(detected == feature_count) << "some feature not printed"; + ret += "]"; + return ret; +} + FeatureSet DetectFeature(const IRModule& mod) { FeatureSet fs = FeatureSet::No(); - if (mod.defined()) { - for (const auto& f : mod->functions) { - fs += DetectFeature(f.second); - } + for (const auto& f : mod->functions) { + fs += DetectFeature(f.second); } return fs; } @@ -106,5 +137,17 @@ Array PyDetectFeature(const Expr& expr, const Optional& mod) TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature); +void CheckFeature(const Expr& expr, const FeatureSet& fs) { + auto dfs = DetectFeature(expr); + CHECK(dfs.is_subset_of(fs)) << AsText(expr, false) + << "\nhas unsupported feature: " << (dfs - fs).ToString(); +} + +void CheckFeature(const IRModule& mod, const FeatureSet& fs) { + for (const auto& f : mod->functions) { + CheckFeature(f.second, fs); + } +} + } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 0cebba72c3759..7894c34de55db 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -81,7 +82,7 @@ Type WithGradientType(const Type& t) { Expr DeGlobal(const Optional& mod, const Expr& e) { const auto* x = e.as(); - if (mod.defined() && (x)) { + if (mod.defined() && x) { BaseFunc base_func = mod.value()->Lookup(GetRef(x)); if (auto* n = base_func.as()) { return n->body; @@ -354,9 +355,9 @@ Expr LiftTensor(const std::function& f, LetList* ll) { CHECK(IsAtomic(e)) << e; if (forward_type.as()) { - auto ret = f(e); + auto ret = ll->Push(f(e)); ret->checked_type_ = tf(forward_type); - return ret; + return std::move(ret); } else if (auto* tt = forward_type.as()) { tvm::Array fields; tvm::Array types; @@ -365,7 +366,7 @@ Expr LiftTensor(const std::function& f, fields.push_back(field); types.push_back(field->checked_type_); } - auto ret = Tuple(fields); + auto ret = ll->Push(Tuple(fields)); ret->checked_type_ = TupleType(types); return std::move(ret); } else { @@ -395,9 +396,10 @@ void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, L } } +// TODO(@M.K.): why take Expr? /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */ Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) { - auto rev = [&](const Expr& e) { return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); }; + auto rev = [&](const Expr& e) { return Pair(e, RefCreate(ZerosLike(e))); }; auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); }; return LiftTensor(rev, rev_type, forward_type, e, ll); } @@ -411,14 +413,14 @@ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) { /*! \brief ReverseType(t) -> t. Get the gradient. */ Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) { - auto grad = [&](const Expr& e) { return ll->Push(RefRead(GetField(e, 1))); }; + auto grad = [&](const Expr& e) { return RefRead(GetField(e, 1)); }; auto grad_type = [&](const Type& forward_type) { return forward_type; }; return LiftTensor(grad, grad_type, forward_type, e, ll); } void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { if (t.as()) { - ll->Push(RefWrite(GetField(arg, 1), Add(ll->Push(RefRead(GetField(arg, 1))), grad))); + ll->Push(RefWrite(GetField(arg, 1), Add(RefRead(GetField(arg, 1)), grad))); } else if (auto* tt = t.as()) { for (size_t i = 0; i < tt->fields.size(); ++i) { UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll); @@ -448,6 +450,24 @@ struct ReverseAD : ExprMutator { throw; } + Expr Remap(const Expr& e) { + struct Remapper : ExprMutator { + std::shared_ptr ad_vars; + LetList* ll; + Remapper(const std::shared_ptr& ad_vars, LetList* ll) : ad_vars(ad_vars), ll(ll) {} + Expr VisitExpr_(const VarNode* var) final { + // memoize Var -> ADVar so we don't end up with free Vars when checkpointing + auto var_ref = GetRef(var); + if (ad_vars->count(var_ref) == 0) { + return std::move(var_ref); + } else { + return GetValue(var_ref->checked_type(), ad_vars->at(var_ref), ll); + } + } + }; + return LetList::With([&](LetList* ll) { return Remapper(ad_vars, ll)(e); }); + } + Expr VisitCheckpoint(const CallNode* call) { const OpNode* op_node = call->op.as(); CHECK(op_node) << "expected op in call"; @@ -455,7 +475,7 @@ struct ReverseAD : ExprMutator { CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation"; auto x = call->args[0]; return LetList::With([&](LetList* ll) { - auto x_var = ll->Push(x); + auto x_var = ll->Push(Remap(x)); auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll)); auto bpv = ll->Push(RefRead(bp)); Expr nbp = Function({}, LetList::With([&](LetList* ll) { @@ -508,7 +528,8 @@ struct ReverseAD : ExprMutator { return Call(bpv, {}); }), TupleType::Empty(), {}); - ll->Push(RefWrite(bp, nbp)); + ll->Push(RefWrite(bp, transform::ToANormalForm(nbp))); + // TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that. return ret; }); } @@ -516,8 +537,10 @@ struct ReverseAD : ExprMutator { } Expr VisitExpr_(const ConstantNode* op) final { - Expr e = GetRef(op); - return Pair(e, RefCreate(ZerosLike(e))); + return LetList::With([&](LetList* ll) { + Expr e = ll->Push(GetRef(op)); + return Pair(e, RefCreate(ZerosLike(e))); + }); } Expr VisitExpr_(const IfNode* op) final { @@ -528,7 +551,7 @@ struct ReverseAD : ExprMutator { Expr VisitExpr_(const VarNode* var) final { // memoize Var -> ADVar so we don't end up with free Vars when checkpointing auto var_ref = GetRef(var); - if (!ad_vars->count(var_ref)) { + if (ad_vars->count(var_ref) == 0) { auto res = Downcast(ExprMutator::VisitExpr_(var)); (*ad_vars)[var_ref] = res; } @@ -568,6 +591,10 @@ bool MissingGrad(const Expr& e) { } Expr Gradient(const Expr& re, const Optional& mod) { + CheckFeature(re, FeatureSet::All() - fGraph); + if (mod.defined()) { + CheckFeature(mod.value(), FeatureSet::All() - fGraph); + } auto e = DeGlobal(mod, re); auto f = e.as(); CHECK(f) << "input need to be a function"; @@ -619,7 +646,9 @@ Expr Gradient(const Expr& re, const Optional& mod) { }; return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret)); }); - return Function(f->params, body, GradRetType(GetRef(f)), {}); + auto ret = Function(f->params, body, GradRetType(GetRef(f)), {}); + CheckFeature(ret, FeatureSet::All() - fGraph); + return std::move(ret); } TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient); diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index f06246667a8ba..de9406ec309dc 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -63,6 +63,7 @@ #include #include #include +#include #include #include "let_list.h" @@ -70,92 +71,54 @@ namespace tvm { namespace relay { -/*! - * \brief Visitor appropriately wraps tensors with Raw constructor - * - * Recursively looks at the type of the expression (TensorType or TupleType are only supported for - * now) and either call the GradCell constructor if TensorType or unfold and recursively visit if - * TupleType - */ -class InputVisitor : public ExprFunctor { +class LazyGradientInitializer : public ExprMutator, public TypeMutator { public: - explicit InputVisitor(IRModule module) : module_(module) {} - - Expr VisitExpr_(const VarNode* op, const Type& t) final { - std::cout << op->type_annotation << std::endl; - return WrapExpr(GetRef(op), op->type_annotation); - } - - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { - return WrapExpr(GetRef(op), t); + explicit LazyGradientInitializer(IRModule module) : module_(module) { + module_->ImportFromStd("gradient.rly"); } - private: - IRModule module_; - - Expr WrapExpr(const Expr expr, const Type& type) { + Expr WrapExpr(const Var& var, const Type& type, LetList* ll) { if (type.as()) { - return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); + return Call(module_->GetConstructor("GradCell", "Raw"), {var}, Attrs(), {type}); } else if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { const Type& t = type_anno->fields[i]; - fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); + fields.push_back(WrapExpr(ll->Push(TupleGetItem(var, i)), t, ll)); } Expr tuple = Tuple(fields); return tuple; } - return expr; - } -}; - -/*! - * \brief Visitor appropriately unwraps expressions with GradCell type into Tensors - * - * Recursively looks at the type of the expression - * and either use the FromGradCell function if TypeCall to GradCell - * or unfold and recursively visit if TupleType - */ -class OutputVisitor : public ExprFunctor { - public: - explicit OutputVisitor(IRModule module) : module_(module) {} - - Expr VisitExpr_(const CallNode* op, const Type& t) final { - return UnwrapExpr(GetRef(op), t); + return var; } - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { - return UnwrapExpr(GetRef(op), t); - } - - private: - IRModule module_; - - Expr UnwrapExpr(const Expr expr, const Type& type) { + Expr UnwrapExpr(const Var& var, const Type& type, LetList* ll) { if (auto* type_call = type.as()) { if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { - return Call(module_->GetGlobalVar("FromGradCell"), {expr}); + return Call(module_->GetGlobalVar("FromGradCell"), {var}); } - return expr; + return var; } else if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { const Type& t = type_anno->fields[i]; - fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); + fields.push_back(UnwrapExpr(ll->Push(TupleGetItem(var, i)), t, ll)); } Expr tuple = Tuple(fields); return tuple; } - return expr; + return var; } -}; -class LazyGradientInitializer : public ExprMutator, public TypeMutator { - public: - explicit LazyGradientInitializer(IRModule module) : module_(module) { - module_->ImportFromStd("gradient.rly"); + // Turn off memo for constant node. + Expr VisitExpr(const Expr& e) final { + if (e.as()) { + return ExprFunctor::VisitExpr(e); + } else { + return ExprMutator::VisitExpr(e); + } } /*! @@ -165,23 +128,26 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator { * input/output types should only be a combination of TupleTypes and TensorTypes */ Expr Transform(const Expr& e) { - auto* f = (e).as(); + auto* f = e.as(); auto* transformed = this->Mutate(e).as(); + CHECK(f); + CHECK(transformed); + if (e.same_as(GetRef(transformed))) { return GetRef(transformed); } - // wrap inputs of Tensor type using InputVisitor class - tvm::Array args; - for (Var var : f->params) { - Expr wrappedInput = InputVisitor(module_).VisitExpr(var, var->checked_type()); - args.push_back(wrappedInput); - } - Expr transformedExpr = Call(GetRef(transformed), args); - - // unwrap outputs of GradCell type into Tensor type using OutputVisitor class - Expr tensorOutput = OutputVisitor(module_).VisitExpr(transformedExpr, transformed->ret_type); + auto tensorOutput = LetList::With([&](LetList* ll) { + // wrap inputs of Tensor type using InputVisitor class + tvm::Array args; + for (const Var& var : f->params) { + args.push_back(WrapExpr(var, var->checked_type(), ll)); + } + Expr transformedExpr = Call(GetRef(transformed), args); + // unwrap outputs of GradCell type into Tensor type using OutputVisitor class + return UnwrapExpr(ll->Push(transformedExpr), transformed->ret_type, ll); + }); return Function(f->params, tensorOutput, f->ret_type, Array()); } @@ -293,7 +259,10 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator { }; Expr LazyGradientInit(const Expr& e, IRModule mod) { - return LazyGradientInitializer(mod).Transform(e); + CheckFeature(e, mod, FeatureSet::All() - fGraph); + auto ret = LazyGradientInitializer(mod).Transform(e); + CheckFeature(ret, mod, FeatureSet::All() - fGraph); + return ret; } namespace transform { diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 63bd04d526dee..e07dbea59bd1b 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -92,6 +92,7 @@ #include #include #include +#include #include #include #include @@ -1181,6 +1182,7 @@ Expr PostProcess(const Expr& e) { return StripWithFuncId(DeDup(Remap(e))); } } // namespace partial_eval IRModule PartialEval(const IRModule& m) { + CheckFeature(m, FeatureSet::All() - fGraph); relay::partial_eval::PartialEvaluator pe(m); std::vector gvs; for (const auto& p : m->functions) { @@ -1189,6 +1191,7 @@ IRModule PartialEval(const IRModule& m) { for (const auto& gv : gvs) { pe.VisitGlobalVar(gv); } + CheckFeature(m, FeatureSet::All() - fGraph); return m; } @@ -1197,7 +1200,7 @@ namespace transform { Pass PartialEval() { runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext pc) { return relay::PartialEval(m); }; - return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); + return CreateModulePass(pass_func, 1, "PartialEval", {}); } TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval); diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index 50d0fbb5f17b7..63708c45bfe38 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -117,7 +117,8 @@ inline Expr TransformF(const std::function& func, const Expr& * if so, the compute cost of the expression is bounded so it can be copy without graph mode. */ inline bool IsAtomic(const Expr& e) { - return e.as() || e.as() || e.as() || e.as(); + return e.as() || e.as() || e.as() || e.as() || + e.as(); // Constant is always by reference. } /*! diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 06e0d56e19194..adb757b9de0cd 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -252,32 +252,6 @@ Expr Fill::VisitExpr_(const MatchNode* m, const Var& v) { return Compound(e, Match(data, clauses, m->complete), v); } -Expr ToANormalFormAux(const Expr& e) { - /* When you lift a lambda, what is inside is also being lift. - * - * So we must determine the scope of the lambda before determining the scope of it's body. - * - * To make this more principled, - * we always determine the scope of parent before determining the scope of children. - * - * So we calculate all the dependency between nodes. - */ - support::Arena arena; - DependencyGraph dg = DependencyGraph::Create(&arena, e); - /* In order to model new subscopes created by lambda, if else and pattern matching, - * we also assign scope to edge as well. - * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope. - * - * So, the scope of the whole expr is global. - * The scope of any subexpr, is the lowest common ancestor of all incoming edge. - * - * Every scope additionally contain a LetList which collect all value of that scope. - * We do an additional pass to fill all the LetList and we are done. - */ - std::pair scopes = CalcScope(dg); - return Fill::ToANormalForm(e, dg, &scopes.first); -} - IRModule ToANormalForm(const IRModule& m) { DLOG(INFO) << "ToANF:" << std::endl << m; @@ -288,7 +262,7 @@ IRModule ToANormalForm(const IRModule& m) { if (const auto* n = it.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; } - Expr ret = TransformF([&](const Expr& e) { return ToANormalFormAux(e); }, it.second); + Expr ret = TransformF([&](const Expr& e) { return transform::ToANormalForm(e); }, it.second); CHECK_EQ(FreeVars(ret).size(), 0) << AsText(ret) << "should not has free vars: " << FreeVars(ret); updates.Set(it.first, Downcast(ret)); @@ -305,13 +279,45 @@ IRModule ToANormalForm(const IRModule& m) { namespace transform { +Expr ToANormalForm(const Expr& e) { + /* When you lift a lambda, what is inside is also being lift. + * + * So we must determine the scope of the lambda before determining the scope of it's body. + * + * To make this more principled, + * we always determine the scope of parent before determining the scope of children. + * + * So we calculate all the dependency between nodes. + */ + support::Arena arena; + DependencyGraph dg = DependencyGraph::Create(&arena, e); + /* In order to model new subscopes created by lambda, if else and pattern matching, + * we also assign scope to edge as well. + * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope. + * + * So, the scope of the whole expr is global. + * The scope of any subexpr, is the lowest common ancestor of all incoming edge. + * + * Every scope additionally contain a LetList which collect all value of that scope. + * We do an additional pass to fill all the LetList and we are done. + */ + std::pair scopes = CalcScope(dg); + return Fill::ToANormalForm(e, dg, &scopes.first); +} + Pass ToANormalForm() { runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext pc) { return relay::ToANormalForm(m); }; return CreateModulePass(pass_func, 1, "ToANormalForm", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed(ToANormalForm); +TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed([]() { + return ToANormalForm(); +}); + +TVM_REGISTER_GLOBAL("relay._transform.ToANormalFormExpr").set_body_typed([](const Expr& e) { + return ToANormalForm(e); +}); } // namespace transform diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index 6972d5a76b777..7c11ce5d4cd93 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -52,6 +52,7 @@ */ #include #include +#include #include #include @@ -301,11 +302,13 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { } Function ToCPS(const Function& f, const IRModule& m) { + CheckFeature(f, m, FeatureSet::All() - fGraph); CPSMap cps; return ToCPS(f, m, &cps); } Function UnCPS(const Function& f) { + CheckFeature(f, FeatureSet::All() - fGraph); CHECK_GT(f->params.size(), 0); std::vector new_params; for (const auto& p : f->params) { diff --git a/tests/python/relay/test_analysis_feature.py b/tests/python/relay/test_analysis_feature.py index ec5deb3c4e60d..2b32376a95152 100644 --- a/tests/python/relay/test_analysis_feature.py +++ b/tests/python/relay/test_analysis_feature.py @@ -39,7 +39,6 @@ def test_prelude(): Feature.fIf, Feature.fConstructor, Feature.fMatch, - Feature.fGraph ]) @@ -65,7 +64,6 @@ def test_ad(): Feature.fRefCreate, Feature.fRefRead, Feature.fRefWrite, - Feature.fGraph ]) diff --git a/tests/python/relay/test_op_grad_level10.py b/tests/python/relay/test_op_grad_level10.py index 2c749c9341497..b8624b46eca81 100644 --- a/tests/python/relay/test_op_grad_level10.py +++ b/tests/python/relay/test_op_grad_level10.py @@ -32,17 +32,21 @@ def test_cross_entropy_with_logits_grad(): x = relay.var("x", shape=(2, 5), dtype=dtype) y = relay.var("y", shape=(2, 5), dtype=dtype) check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1) - + + def test_checkpoint(): inputs = [relay.var("x{}".format(i), shape=(1,)) for i in range(4)] output = relay.multiply(relay.add(inputs[0], inputs[1]), relay.add(inputs[2], inputs[3])) check_grad(relay.Function(inputs, relay.annotation.checkpoint(output))) - out_tuple = relay.Tuple([relay.add(inputs[0], inputs[1]), - relay.multiply(inputs[2], inputs[3])]) - out_single = relay.subtract(relay.TupleGetItem(relay.annotation.checkpoint(out_tuple), 0), - relay.TupleGetItem(out_tuple, 1)) + scope = relay.ScopeBuilder() + out_tuple = scope.let("out_tuple", + relay.Tuple([relay.add(inputs[0], inputs[1]), + relay.multiply(inputs[2], inputs[3])])) + scope.ret(relay.subtract(relay.annotation.checkpoint(relay.TupleGetItem(out_tuple, 0)), + relay.TupleGetItem(out_tuple, 1))) + out_single = scope.get() check_grad(relay.Function(inputs, out_single)) diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 4838c6a4e7fce..296d3e5e9354f 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -45,6 +45,18 @@ def test_id(): tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) +def test_relu(): + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + func = relay.Function([x], op.nn.relu(x)) + func = run_infer_type(func) + back_func = run_infer_type(gradient(func)) + assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) + # gradient will implicitly check that no graph appear in result + + def test_add(): shape = (10, 10) dtype = 'float32' @@ -72,12 +84,14 @@ def test_check_grad(): def test_temp_add(): + scope = relay.ScopeBuilder() shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) - y = x + x - func = relay.Function([x], y + y) + y = scope.let("y", x + x) + scope.ret(y + y) + func = relay.Function([x], scope.get()) func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) @@ -280,12 +294,14 @@ def test_if(): def test_grad_tuple(): + scope = relay.ScopeBuilder() shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) - y = x + x - func = relay.Function([x], relay.Tuple([y + y, y])) + y = scope.let("y", x + x) + scope.ret(relay.Tuple([y + y, y])) + func = relay.Function([x], scope.get()) func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([relay.TupleType([t, t]), relay.TupleType([t])])) diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py index 414926802870a..377164e08b733 100644 --- a/tests/python/relay/test_pass_lazy_gradient_init.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -229,6 +229,24 @@ def test_multivar_reverse_ad(): assert_allclose(grad_x.asnumpy(), y.asnumpy()) assert_allclose(grad_y.asnumpy(), x.asnumpy()) +def test_partial_eval(): + """Test transformation following reverse mode ad and PartialEval""" + mod = tvm.IRModule() + + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + + func = relay.Function([], relay.const(np.ones(shape, dtype))) + func = run_infer_type(func) + back_func = transform.gradient(func) + back_func = run_infer_type(back_func) + + mod["main"] = back_func + back_func = mod["main"] + + transform.PartialEvaluate()(mod) + def test_after_partial_eval(): """Test transformation following reverse mode ad and PartialEval""" mod = tvm.IRModule() diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index ddb5b5dab6757..aef6ab5afa967 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Unit tests for merge composite.""" +import pytest import tvm from tvm import relay, tir from tvm.relay.dataflow_pattern import TupleGetItemPattern, is_op, wildcard @@ -213,7 +214,7 @@ def expected(): r = relay.Call(add_relu, [a, b]) return relay.Function([a, b], r) - check_result(pattern_table, before(), expected(), import_prelude=True) + check_result(pattern_table, before(), expected()) def test_branch_merge(): @@ -998,15 +999,4 @@ def _check_type_false(extract): if __name__ == "__main__": - test_simple_merge() - test_branch_merge() - test_multiple_patterns() - test_optional_pattern() - test_merge_order() - test_parallel_merge() - test_multiple_input_subgraphs() - test_reuse_call_merge() - test_tuple_get_item_merge() - test_pattern_with_check() - test_diamond_not_merge() - test_type_check() + pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index 45593b43ecb12..95805d285b59e 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import pytest import numpy as np import tvm from tvm import te @@ -173,10 +174,9 @@ def test_function_invalidate(): def test_head_cons(): mod = tvm.IRModule() p = Prelude(mod) - hd = p.hd t = TypeVar("t") x = Var("x", t) - body = hd(p.cons(x, p.nil())) + body = p.hd(p.cons(x, p.nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) assert tvm.ir.structural_equal(res, Function([x], x, t, [t])) @@ -340,23 +340,4 @@ def test_tuple_match(): if __name__ == '__main__': - test_nat_update() - test_ref() - test_tuple() - test_empty_ad() - test_const_inline() - test_ad() - test_if_ref() - test_function_invalidate() - test_head_cons() - test_map() - test_loop() - test_swap_loop() - test_abs_diff() - test_double() - test_nat_id() - test_global_match_nat_id() - test_match_nat_id() - test_concat() - test_triangle_number() - test_tuple_match() + pytest.main([__file__])