diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index 3783e320f57c4..c72470086590f 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -160,6 +160,34 @@ inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) { return DetectFeature(expr) + DetectFeature(mod); } +/*! + * \brief Check the feature of the program. + * + * \param expr The expression. + * \param fs The feature set of the program. + */ +void CheckFeature(const RelayExpr& expr, const FeatureSet& fs); + +/*! + * \brief Check the feature of the program. + * + * \param mod The module. + * \param fs The feature set of the program. + */ +void CheckFeature(const IRModule& mod, const FeatureSet& fs); + +/*! + * \brief Check the feature of the program. + * + * \param expr The expression. + * \param mod The module. + * \param fs The feature set of the program. + */ +inline void CheckFeature(const RelayExpr& expr, const IRModule& mod, const FeatureSet& fs) { + CheckFeature(expr, fs); + CheckFeature(mod, fs); +} + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index d322710ec95a3..9da3d7806f29c 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -147,6 +147,15 @@ TVM_DLL Pass ToBasicBlockNormalForm(); */ TVM_DLL Pass ToANormalForm(); +/*! + * \brief ToANormalForm but on incomplete graph. + * + * \param RelayExpr the graph. + * + * \return The transformed program. + */ +TVM_DLL Expr ToANormalForm(const Expr& expr); + /*! * \brief Turn an expression into continuation passing style(CPS). * diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 1b7ed77e9b576..2675f1da88b0c 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -17,6 +17,7 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """A prelude containing useful global functions and ADT definitions.""" from tvm.ir import IRModule, TypeCall +from tvm import relay from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .expr import Var, GlobalVar, If, const @@ -1237,6 +1238,7 @@ def __init__(self, mod=None): mod = IRModule() self.mod = mod self.load_prelude() + self.mod = relay.transform.ToANormalForm()(self.mod) def get_name(self, canonical, dtype): """Get name corresponding to the canonical name""" diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index a145b28d55e8f..4ed84d0ccf98b 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -46,6 +46,7 @@ FeatureSet DetectFeature(const Expr& expr) { ExprVisitor::VisitExpr(expr); } else { if (!IsAtomic(expr)) { + std::cout << AsText(expr) << std::endl; fs += fGraph; } } @@ -88,10 +89,8 @@ FeatureSet DetectFeature(const Expr& expr) { FeatureSet DetectFeature(const IRModule& mod) { FeatureSet fs = FeatureSet::No(); - if (mod.defined()) { - for (const auto& f : mod->functions) { - fs += DetectFeature(f.second); - } + for (const auto& f : mod->functions) { + fs += DetectFeature(f.second); } return fs; } @@ -106,5 +105,16 @@ Array PyDetectFeature(const Expr& expr, const Optional& mod) TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature); +void CheckFeature(const Expr& expr, const FeatureSet& fs) { + CHECK(DetectFeature(expr).is_subset_of(fs)) << AsText(expr, false) << "\nhas more feature then\n" + << fs << "supported"; +} + +void CheckFeature(const IRModule& mod, const FeatureSet& fs) { + for (const auto& f : mod->functions) { + CheckFeature(f.second, fs); + } +} + } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 0cebba72c3759..1bf416b36fe39 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -81,7 +82,7 @@ Type WithGradientType(const Type& t) { Expr DeGlobal(const Optional& mod, const Expr& e) { const auto* x = e.as(); - if (mod.defined() && (x)) { + if (mod.defined() && x) { BaseFunc base_func = mod.value()->Lookup(GetRef(x)); if (auto* n = base_func.as()) { return n->body; @@ -404,7 +405,7 @@ Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) { /*! \brief ReverseType(t) -> t. Get the original value. */ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) { - auto val = [&](const Expr& e) { return GetField(e, 0); }; + auto val = [&](const Expr& e) { return ll->Push(GetField(e, 0)); }; auto val_type = [&](const Type& forward_type) { return forward_type; }; return LiftTensor(val, val_type, forward_type, e, ll); } @@ -508,7 +509,8 @@ struct ReverseAD : ExprMutator { return Call(bpv, {}); }), TupleType::Empty(), {}); - ll->Push(RefWrite(bp, nbp)); + ll->Push(RefWrite(bp, transform::ToANormalForm(nbp))); + // TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that. return ret; }); } @@ -516,8 +518,10 @@ struct ReverseAD : ExprMutator { } Expr VisitExpr_(const ConstantNode* op) final { - Expr e = GetRef(op); - return Pair(e, RefCreate(ZerosLike(e))); + return LetList::With([&](LetList* ll) { + Expr e = ll->Push(GetRef(op)); + return Pair(e, RefCreate(ZerosLike(e))); + }); } Expr VisitExpr_(const IfNode* op) final { @@ -568,6 +572,10 @@ bool MissingGrad(const Expr& e) { } Expr Gradient(const Expr& re, const Optional& mod) { + CheckFeature(re, FeatureSet::All() - fGraph); + if (mod.defined()) { + CheckFeature(mod.value(), FeatureSet::All() - fGraph); + } auto e = DeGlobal(mod, re); auto f = e.as(); CHECK(f) << "input need to be a function"; @@ -619,7 +627,9 @@ Expr Gradient(const Expr& re, const Optional& mod) { }; return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret)); }); - return Function(f->params, body, GradRetType(GetRef(f)), {}); + auto ret = Function(f->params, body, GradRetType(GetRef(f)), {}); + CheckFeature(ret, FeatureSet::All() - fGraph); + return ret; } TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient); diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index f06246667a8ba..4c3b1d00b6b68 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -63,6 +63,7 @@ #include #include #include +#include #include #include "let_list.h" @@ -82,7 +83,6 @@ class InputVisitor : public ExprFunctor { explicit InputVisitor(IRModule module) : module_(module) {} Expr VisitExpr_(const VarNode* op, const Type& t) final { - std::cout << op->type_annotation << std::endl; return WrapExpr(GetRef(op), op->type_annotation); } @@ -93,7 +93,7 @@ class InputVisitor : public ExprFunctor { private: IRModule module_; - Expr WrapExpr(const Expr expr, const Type& type) { + Expr WrapExpr(const Expr expr, const Type& type, LetList* ll) { if (type.as()) { return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); } else if (auto* type_anno = type.as()) { @@ -293,7 +293,9 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator { }; Expr LazyGradientInit(const Expr& e, IRModule mod) { - return LazyGradientInitializer(mod).Transform(e); + auto ret = LazyGradientInitializer(mod).Transform(e); + CheckFeature(e, mod, FeatureSet::All() - fGraph); + return ret; } namespace transform { diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 63bd04d526dee..46983c5aff091 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -92,6 +92,7 @@ #include #include #include +#include #include #include #include @@ -1195,9 +1196,12 @@ IRModule PartialEval(const IRModule& m) { namespace transform { Pass PartialEval() { - runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { return relay::PartialEval(m); }; - return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { + CheckFeature(m, FeatureSet::All() - fGraph); + return relay::PartialEval(m); + }; + return CreateModulePass(pass_func, 1, "PartialEval", {}); } TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval); diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 06e0d56e19194..367b491382c35 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -252,32 +252,6 @@ Expr Fill::VisitExpr_(const MatchNode* m, const Var& v) { return Compound(e, Match(data, clauses, m->complete), v); } -Expr ToANormalFormAux(const Expr& e) { - /* When you lift a lambda, what is inside is also being lift. - * - * So we must determine the scope of the lambda before determining the scope of it's body. - * - * To make this more principled, - * we always determine the scope of parent before determining the scope of children. - * - * So we calculate all the dependency between nodes. - */ - support::Arena arena; - DependencyGraph dg = DependencyGraph::Create(&arena, e); - /* In order to model new subscopes created by lambda, if else and pattern matching, - * we also assign scope to edge as well. - * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope. - * - * So, the scope of the whole expr is global. - * The scope of any subexpr, is the lowest common ancestor of all incoming edge. - * - * Every scope additionally contain a LetList which collect all value of that scope. - * We do an additional pass to fill all the LetList and we are done. - */ - std::pair scopes = CalcScope(dg); - return Fill::ToANormalForm(e, dg, &scopes.first); -} - IRModule ToANormalForm(const IRModule& m) { DLOG(INFO) << "ToANF:" << std::endl << m; @@ -288,7 +262,7 @@ IRModule ToANormalForm(const IRModule& m) { if (const auto* n = it.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; } - Expr ret = TransformF([&](const Expr& e) { return ToANormalFormAux(e); }, it.second); + Expr ret = TransformF([&](const Expr& e) { return transform::ToANormalForm(e); }, it.second); CHECK_EQ(FreeVars(ret).size(), 0) << AsText(ret) << "should not has free vars: " << FreeVars(ret); updates.Set(it.first, Downcast(ret)); @@ -305,13 +279,41 @@ IRModule ToANormalForm(const IRModule& m) { namespace transform { +Expr ToANormalForm(const Expr& e) { + /* When you lift a lambda, what is inside is also being lift. + * + * So we must determine the scope of the lambda before determining the scope of it's body. + * + * To make this more principled, + * we always determine the scope of parent before determining the scope of children. + * + * So we calculate all the dependency between nodes. + */ + support::Arena arena; + DependencyGraph dg = DependencyGraph::Create(&arena, e); + /* In order to model new subscopes created by lambda, if else and pattern matching, + * we also assign scope to edge as well. + * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope. + * + * So, the scope of the whole expr is global. + * The scope of any subexpr, is the lowest common ancestor of all incoming edge. + * + * Every scope additionally contain a LetList which collect all value of that scope. + * We do an additional pass to fill all the LetList and we are done. + */ + std::pair scopes = CalcScope(dg); + return Fill::ToANormalForm(e, dg, &scopes.first); +} + Pass ToANormalForm() { runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext pc) { return relay::ToANormalForm(m); }; return CreateModulePass(pass_func, 1, "ToANormalForm", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed(ToANormalForm); +TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed([]() { + return ToANormalForm(); +}); } // namespace transform diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index 6972d5a76b777..7c11ce5d4cd93 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -52,6 +52,7 @@ */ #include #include +#include #include #include @@ -301,11 +302,13 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { } Function ToCPS(const Function& f, const IRModule& m) { + CheckFeature(f, m, FeatureSet::All() - fGraph); CPSMap cps; return ToCPS(f, m, &cps); } Function UnCPS(const Function& f) { + CheckFeature(f, FeatureSet::All() - fGraph); CHECK_GT(f->params.size(), 0); std::vector new_params; for (const auto& p : f->params) { diff --git a/tests/python/relay/test_analysis_feature.py b/tests/python/relay/test_analysis_feature.py index ec5deb3c4e60d..2b32376a95152 100644 --- a/tests/python/relay/test_analysis_feature.py +++ b/tests/python/relay/test_analysis_feature.py @@ -39,7 +39,6 @@ def test_prelude(): Feature.fIf, Feature.fConstructor, Feature.fMatch, - Feature.fGraph ]) @@ -65,7 +64,6 @@ def test_ad(): Feature.fRefCreate, Feature.fRefRead, Feature.fRefWrite, - Feature.fGraph ]) diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 4838c6a4e7fce..296d3e5e9354f 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -45,6 +45,18 @@ def test_id(): tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy())) +def test_relu(): + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + func = relay.Function([x], op.nn.relu(x)) + func = run_infer_type(func) + back_func = run_infer_type(gradient(func)) + assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) + # gradient will implicitly check that no graph appear in result + + def test_add(): shape = (10, 10) dtype = 'float32' @@ -72,12 +84,14 @@ def test_check_grad(): def test_temp_add(): + scope = relay.ScopeBuilder() shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) - y = x + x - func = relay.Function([x], y + y) + y = scope.let("y", x + x) + scope.ret(y + y) + func = relay.Function([x], scope.get()) func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) @@ -280,12 +294,14 @@ def test_if(): def test_grad_tuple(): + scope = relay.ScopeBuilder() shape = (10, 10) dtype = 'float32' t = relay.TensorType(shape, dtype) x = relay.var("x", t) - y = x + x - func = relay.Function([x], relay.Tuple([y + y, y])) + y = scope.let("y", x + x) + scope.ret(relay.Tuple([y + y, y])) + func = relay.Function([x], scope.get()) func = run_infer_type(func) back_func = run_infer_type(gradient(func)) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([relay.TupleType([t, t]), relay.TupleType([t])])) diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index ddb5b5dab6757..743b766839024 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Unit tests for merge composite.""" +import pytest import tvm from tvm import relay, tir from tvm.relay.dataflow_pattern import TupleGetItemPattern, is_op, wildcard @@ -999,14 +1000,4 @@ def _check_type_false(extract): if __name__ == "__main__": test_simple_merge() - test_branch_merge() - test_multiple_patterns() - test_optional_pattern() - test_merge_order() - test_parallel_merge() - test_multiple_input_subgraphs() - test_reuse_call_merge() - test_tuple_get_item_merge() - test_pattern_with_check() - test_diamond_not_merge() - test_type_check() + #pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index 45593b43ecb12..95805d285b59e 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import pytest import numpy as np import tvm from tvm import te @@ -173,10 +174,9 @@ def test_function_invalidate(): def test_head_cons(): mod = tvm.IRModule() p = Prelude(mod) - hd = p.hd t = TypeVar("t") x = Var("x", t) - body = hd(p.cons(x, p.nil())) + body = p.hd(p.cons(x, p.nil())) f = Function([x], body, None, [t]) res = dcpe(f, mod) assert tvm.ir.structural_equal(res, Function([x], x, t, [t])) @@ -340,23 +340,4 @@ def test_tuple_match(): if __name__ == '__main__': - test_nat_update() - test_ref() - test_tuple() - test_empty_ad() - test_const_inline() - test_ad() - test_if_ref() - test_function_invalidate() - test_head_cons() - test_map() - test_loop() - test_swap_loop() - test_abs_diff() - test_double() - test_nat_id() - test_global_match_nat_id() - test_match_nat_id() - test_concat() - test_triangle_number() - test_tuple_match() + pytest.main([__file__])